careamics 0.0.8__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +0 -4
- careamics/careamist.py +0 -1
- careamics/config/__init__.py +1 -13
- careamics/config/algorithms/care_algorithm_model.py +84 -0
- careamics/config/algorithms/n2n_algorithm_model.py +85 -0
- careamics/config/algorithms/n2v_algorithm_model.py +269 -1
- careamics/config/configuration.py +21 -13
- careamics/config/configuration_factories.py +179 -187
- careamics/config/configuration_io.py +2 -2
- careamics/config/data/__init__.py +1 -4
- careamics/config/data/data_model.py +46 -62
- careamics/config/support/supported_transforms.py +1 -1
- careamics/config/transformations/__init__.py +0 -2
- careamics/config/transformations/n2v_manipulate_model.py +15 -0
- careamics/config/transformations/transform_unions.py +0 -13
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +3 -10
- careamics/dataset/in_memory_pred_dataset.py +3 -5
- careamics/dataset/in_memory_tiled_pred_dataset.py +2 -2
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +3 -5
- careamics/dataset/iterable_tiled_pred_dataset.py +3 -3
- careamics/dataset_ng/dataset/__init__.py +3 -0
- careamics/dataset_ng/dataset/dataset.py +184 -0
- careamics/dataset_ng/demo_dataset.ipynb +271 -0
- careamics/dataset_ng/demo_patch_extractor.py +53 -0
- careamics/dataset_ng/demo_patch_extractor_factory.py +37 -0
- careamics/dataset_ng/patch_extractor/__init__.py +10 -0
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +111 -0
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +9 -0
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +53 -0
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +55 -0
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +163 -0
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +140 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +29 -0
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +208 -0
- careamics/dataset_ng/patching_strategies/__init__.py +11 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +82 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +338 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +75 -0
- careamics/lightning/lightning_module.py +78 -27
- careamics/lightning/train_data_module.py +8 -39
- careamics/losses/fcn/losses.py +17 -10
- careamics/lvae_training/eval_utils.py +21 -8
- careamics/model_io/bioimage/bioimage_utils.py +5 -3
- careamics/model_io/bioimage/model_description.py +3 -3
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +2 -2
- careamics/transforms/__init__.py +2 -1
- careamics/transforms/compose.py +5 -15
- careamics/transforms/n2v_manipulate_torch.py +143 -0
- careamics/transforms/pixel_manipulation.py +1 -0
- careamics/transforms/pixel_manipulation_torch.py +418 -0
- careamics/utils/version.py +38 -0
- {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/METADATA +7 -8
- {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/RECORD +59 -42
- careamics/config/care_configuration.py +0 -100
- careamics/config/data/n2v_data_model.py +0 -193
- careamics/config/n2n_configuration.py +0 -101
- careamics/config/n2v_configuration.py +0 -266
- {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/WHEEL +0 -0
- {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
"""A module for random patching strategies."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from .patching_strategy_protocol import PatchSpecs
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RandomPatchingStrategy:
|
|
12
|
+
"""
|
|
13
|
+
A patching strategy for sampling random patches, it implements the
|
|
14
|
+
`PatchingStrategy` `Protocol`.
|
|
15
|
+
|
|
16
|
+
The output of `get_patch_spec` will be random, i.e. if the same index is given
|
|
17
|
+
twice the two outputs can be different.
|
|
18
|
+
|
|
19
|
+
However the strategy still ensures that there will be a known number of patches for
|
|
20
|
+
each sample in each image stack. This is achieved through defining a set of bins
|
|
21
|
+
that map to each sample in each image stack. Whichever bin an `index` passed to
|
|
22
|
+
`get_patch_spec` falls into, determines the `"data_idx"` and `"sample_idx"` in
|
|
23
|
+
the returned `PatchSpecs`, but the `"coords"` will be random.
|
|
24
|
+
|
|
25
|
+
The number of patches in each sample is based on the number of patches that would
|
|
26
|
+
fit if they were sampled sequentially, non-overlapping, and covering the entire
|
|
27
|
+
array.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
data_shapes: Sequence[Sequence[int]],
|
|
33
|
+
patch_size: Sequence[int],
|
|
34
|
+
seed: Optional[int] = None,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
A patching strategy for sampling random patches.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
data_shapes : sequence of (sequence of int)
|
|
42
|
+
The shapes of the underlying data. Each element is the dimension of the
|
|
43
|
+
axes SC(Z)YX.
|
|
44
|
+
patch_size : sequence of int
|
|
45
|
+
The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
|
|
46
|
+
data respectively.
|
|
47
|
+
seed : int, optional
|
|
48
|
+
An optional seed to ensure the reproducibility of the random patches.
|
|
49
|
+
"""
|
|
50
|
+
self.rng = np.random.default_rng(seed=seed)
|
|
51
|
+
self.patch_size = patch_size
|
|
52
|
+
self.data_shapes = data_shapes
|
|
53
|
+
|
|
54
|
+
# these bins will determine which image stack and sample a patch comes from
|
|
55
|
+
# the image_stack_cumulative_patches map a patch index to each image stack
|
|
56
|
+
# the sample_cumulative_patches map a patch index to each sample
|
|
57
|
+
# the image_stack_cumulative_samples map a sample index to each image stack
|
|
58
|
+
(
|
|
59
|
+
self.image_stack_cumulative_patches,
|
|
60
|
+
self.sample_cumulative_patches,
|
|
61
|
+
self.image_stack_cumulative_samples,
|
|
62
|
+
) = self._calc_bins(self.data_shapes, self.patch_size)
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def n_patches(self) -> int:
|
|
66
|
+
"""
|
|
67
|
+
The number of patches that this patching strategy will return.
|
|
68
|
+
|
|
69
|
+
It also determines the maximum index that can be given to `get_patch_spec`.
|
|
70
|
+
"""
|
|
71
|
+
# last bin boundary will be total patches
|
|
72
|
+
return self.image_stack_cumulative_patches[-1]
|
|
73
|
+
|
|
74
|
+
def get_patch_spec(self, index: int) -> PatchSpecs:
|
|
75
|
+
"""Return the patch specs for a given index.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
index : int
|
|
80
|
+
A patch index.
|
|
81
|
+
|
|
82
|
+
Returns
|
|
83
|
+
-------
|
|
84
|
+
PatchSpecs
|
|
85
|
+
A dictionary that specifies a single patch in a series of `ImageStacks`.
|
|
86
|
+
"""
|
|
87
|
+
# TODO: break into smaller testable functions?
|
|
88
|
+
if index >= self.n_patches:
|
|
89
|
+
raise IndexError(
|
|
90
|
+
f"Index {index} out of bounds for RandomPatchingStrategy with number "
|
|
91
|
+
f"of patches {self.n_patches}"
|
|
92
|
+
)
|
|
93
|
+
# digitize returns the bin that `index` belongs to
|
|
94
|
+
data_index = np.digitize(index, bins=self.image_stack_cumulative_patches).item()
|
|
95
|
+
# maps to a particular sample within the whole series of image stacks
|
|
96
|
+
# (not just a single image stack)
|
|
97
|
+
total_samples_index = np.digitize(
|
|
98
|
+
index, bins=self.sample_cumulative_patches
|
|
99
|
+
).item()
|
|
100
|
+
|
|
101
|
+
data_shape = self.data_shapes[data_index]
|
|
102
|
+
spatial_shape = data_shape[2:]
|
|
103
|
+
|
|
104
|
+
# calculate sample index relative to image stack:
|
|
105
|
+
# subtract the total number of samples in the previous image stacks
|
|
106
|
+
if data_index == 0:
|
|
107
|
+
n_previous_samples = 0
|
|
108
|
+
else:
|
|
109
|
+
n_previous_samples = self.image_stack_cumulative_samples[data_index - 1]
|
|
110
|
+
sample_index = total_samples_index - n_previous_samples
|
|
111
|
+
coords = _generate_random_coords(spatial_shape, self.patch_size, self.rng)
|
|
112
|
+
return {
|
|
113
|
+
"data_idx": data_index,
|
|
114
|
+
"sample_idx": sample_index,
|
|
115
|
+
"coords": coords,
|
|
116
|
+
"patch_size": self.patch_size,
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
@staticmethod
|
|
120
|
+
def _calc_bins(
|
|
121
|
+
data_shapes: Sequence[Sequence[int]], patch_size: Sequence[int]
|
|
122
|
+
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
|
123
|
+
"""Calculate bins used to map an index to an image_stack and a sample.
|
|
124
|
+
|
|
125
|
+
The number of patches in each sample is based on the number of patches that
|
|
126
|
+
would fit if they were sampled sequentially.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
data_shapes : sequence of (sequence of int)
|
|
131
|
+
The shapes of the underlying data. Each element is the dimension of the
|
|
132
|
+
axes SC(Z)YX.
|
|
133
|
+
patch_size : sequence of int
|
|
134
|
+
The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
|
|
135
|
+
data respectively.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
image_stack_cumulative_patches: tuple of int
|
|
140
|
+
The bins that map a patch index to an image stack. E.g. if a patch index
|
|
141
|
+
falls below the first bin boundary it belongs to the first image stack, if
|
|
142
|
+
a patch index falls between the first bin boundary and the second bin
|
|
143
|
+
boundary it belongs to the second image stack, and so on.
|
|
144
|
+
sample_cumulative_patches: tuple of int
|
|
145
|
+
The bins that map a patch index to a sample. E.g. if a patch index
|
|
146
|
+
falls below the first bin boundary it belongs to the first sample, if
|
|
147
|
+
a patch index falls between the first bin boundary and the second bin
|
|
148
|
+
boundary it belongs to the second sample, and so on.
|
|
149
|
+
image_stack_cumulative_samples: tuple of int
|
|
150
|
+
The bins that map a sample index to an image stack. E.g. if a sample index
|
|
151
|
+
falls below the first bin boundary it belongs to the first image stack, if
|
|
152
|
+
a patch index falls between the first bin boundary and the second bin
|
|
153
|
+
boundary it belongs to the second image stack, and so on.
|
|
154
|
+
"""
|
|
155
|
+
patches_per_image_stack: list[int] = []
|
|
156
|
+
patches_per_sample: list[int] = []
|
|
157
|
+
samples_per_image_stack: list[int] = []
|
|
158
|
+
for data_shape in data_shapes:
|
|
159
|
+
spatial_shape = data_shape[2:]
|
|
160
|
+
n_single_sample_patches = _calc_n_patches(spatial_shape, patch_size)
|
|
161
|
+
# multiply by number of samples in image_stack
|
|
162
|
+
patches_per_image_stack.append(n_single_sample_patches * data_shape[0])
|
|
163
|
+
# list of length `sample` filled with `n_single_sample_patches`
|
|
164
|
+
patches_per_sample.extend([n_single_sample_patches] * data_shape[0])
|
|
165
|
+
# number of samples in each image stack
|
|
166
|
+
samples_per_image_stack.append(data_shape[0])
|
|
167
|
+
|
|
168
|
+
# cumulative sum creates the bins
|
|
169
|
+
image_stack_cumulative_patches = np.cumsum(patches_per_image_stack)
|
|
170
|
+
sample_cumulative_patches = np.cumsum(patches_per_sample)
|
|
171
|
+
image_stack_cumulative_samples = np.cumsum(samples_per_image_stack)
|
|
172
|
+
return (
|
|
173
|
+
tuple(image_stack_cumulative_patches),
|
|
174
|
+
tuple(sample_cumulative_patches),
|
|
175
|
+
tuple(image_stack_cumulative_samples),
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class FixedRandomPatchingStrategy:
|
|
180
|
+
"""
|
|
181
|
+
A patching strategy for sampling random patches it implements the `PatchingStrategy`
|
|
182
|
+
`Protocol`.
|
|
183
|
+
|
|
184
|
+
The output of `get_patch_spec` will be deterministic, i.e. if the same index is
|
|
185
|
+
given twice the two outputs will be the same.
|
|
186
|
+
|
|
187
|
+
The number of patches in each sample is based on the number of patches that would
|
|
188
|
+
fit if they were sampled sequentially, non-overlapping, and covering the entire
|
|
189
|
+
array.
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
def __init__(
|
|
193
|
+
self,
|
|
194
|
+
data_shapes: Sequence[Sequence[int]],
|
|
195
|
+
patch_size: Sequence[int],
|
|
196
|
+
seed: Optional[int] = None,
|
|
197
|
+
):
|
|
198
|
+
"""A patching strategy for sampling random patches.
|
|
199
|
+
|
|
200
|
+
Parameters
|
|
201
|
+
----------
|
|
202
|
+
data_shapes : sequence of (sequence of int)
|
|
203
|
+
The shapes of the underlying data. Each element is the dimension of the
|
|
204
|
+
axes SC(Z)YX.
|
|
205
|
+
patch_size : sequence of int
|
|
206
|
+
The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
|
|
207
|
+
data respectively.
|
|
208
|
+
seed : int, optional
|
|
209
|
+
An optional seed to ensure the reproducibility of the random patches.
|
|
210
|
+
"""
|
|
211
|
+
self.rng = np.random.default_rng(seed=seed)
|
|
212
|
+
self.patch_size = patch_size
|
|
213
|
+
self.data_shapes = data_shapes
|
|
214
|
+
|
|
215
|
+
# simply generate all the patches at initialisation, so they will be fixed
|
|
216
|
+
self.fixed_patch_specs: list[PatchSpecs] = []
|
|
217
|
+
for data_idx, data_shape in enumerate(self.data_shapes):
|
|
218
|
+
spatial_shape = data_shape[2:]
|
|
219
|
+
n_patches = _calc_n_patches(spatial_shape, self.patch_size)
|
|
220
|
+
for sample_idx in range(data_shape[0]):
|
|
221
|
+
for _ in range(n_patches):
|
|
222
|
+
random_coords = _generate_random_coords(
|
|
223
|
+
spatial_shape, self.patch_size, self.rng
|
|
224
|
+
)
|
|
225
|
+
patch_specs: PatchSpecs = {
|
|
226
|
+
"data_idx": data_idx,
|
|
227
|
+
"sample_idx": sample_idx,
|
|
228
|
+
"coords": random_coords,
|
|
229
|
+
"patch_size": self.patch_size,
|
|
230
|
+
}
|
|
231
|
+
self.fixed_patch_specs.append(patch_specs)
|
|
232
|
+
|
|
233
|
+
@property
|
|
234
|
+
def n_patches(self):
|
|
235
|
+
"""
|
|
236
|
+
The number of patches that this patching strategy will return.
|
|
237
|
+
|
|
238
|
+
It also determines the maximum index that can be given to `get_patch_spec`.
|
|
239
|
+
"""
|
|
240
|
+
return len(self.fixed_patch_specs)
|
|
241
|
+
|
|
242
|
+
def get_patch_spec(self, index: int) -> PatchSpecs:
|
|
243
|
+
"""Return the patch specs for a given index.
|
|
244
|
+
|
|
245
|
+
Parameters
|
|
246
|
+
----------
|
|
247
|
+
index : int
|
|
248
|
+
A patch index.
|
|
249
|
+
|
|
250
|
+
Returns
|
|
251
|
+
-------
|
|
252
|
+
PatchSpecs
|
|
253
|
+
A dictionary that specifies a single patch in a series of `ImageStacks`.
|
|
254
|
+
"""
|
|
255
|
+
if index >= self.n_patches:
|
|
256
|
+
raise IndexError(
|
|
257
|
+
f"Index {index} out of bounds for FixedRandomPatchingStrategy with "
|
|
258
|
+
f"number of patches, {self.n_patches}"
|
|
259
|
+
)
|
|
260
|
+
# simply index the pre-generated patches to get the correct patch
|
|
261
|
+
return self.fixed_patch_specs[index]
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _generate_random_coords(
|
|
265
|
+
spatial_shape: Sequence[int], patch_size: Sequence[int], rng: np.random.Generator
|
|
266
|
+
) -> tuple[int, ...]:
|
|
267
|
+
"""Generate random patch coordinates for a given `spatial_shape` and `patch_size`.
|
|
268
|
+
|
|
269
|
+
The coords are the top-left (and first z-slice for 3D data) of a patch. The
|
|
270
|
+
sequence will have length 2 or 3, for 2D and 3D data respectively.
|
|
271
|
+
|
|
272
|
+
Parameters
|
|
273
|
+
----------
|
|
274
|
+
spatial_shape : sequence of int
|
|
275
|
+
The dimension of the axes (Z)YX, a sequence of length 2 or 3, for 2D and 3D
|
|
276
|
+
data respectively.
|
|
277
|
+
patch_size : sequence of int
|
|
278
|
+
The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
|
|
279
|
+
data respectively.
|
|
280
|
+
rng : numpy.random.Generator
|
|
281
|
+
A numpy generator to ensure the reproducibility of the random patches.
|
|
282
|
+
|
|
283
|
+
Returns
|
|
284
|
+
-------
|
|
285
|
+
coords: tuple of int
|
|
286
|
+
The top-left (and first z-slice for 3D data) coords of a patch. The tuple will
|
|
287
|
+
have length 2 or 3, for 2D and 3D data respectively.
|
|
288
|
+
|
|
289
|
+
Raises
|
|
290
|
+
------
|
|
291
|
+
ValueError
|
|
292
|
+
Raises if the number of spatial dimensions do not match the number of patch
|
|
293
|
+
dimensions.
|
|
294
|
+
"""
|
|
295
|
+
if len(patch_size) != len(spatial_shape):
|
|
296
|
+
raise ValueError(
|
|
297
|
+
f"Number of patch dimension {len(patch_size)}, do not match the number of "
|
|
298
|
+
f"spatial dimensions {len(spatial_shape)}, for `patch_size={patch_size}` "
|
|
299
|
+
f"and `spatial_shape={spatial_shape}`."
|
|
300
|
+
)
|
|
301
|
+
return tuple(
|
|
302
|
+
rng.integers(
|
|
303
|
+
np.zeros(len(patch_size), dtype=int),
|
|
304
|
+
np.array(spatial_shape) - np.array(patch_size),
|
|
305
|
+
endpoint=False,
|
|
306
|
+
dtype=int,
|
|
307
|
+
).tolist()
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def _calc_n_patches(spatial_shape: Sequence[int], patch_size: Sequence[int]) -> int:
|
|
312
|
+
"""
|
|
313
|
+
Calculates the number of patches for a given `spatial_shape` and `patch_size`.
|
|
314
|
+
|
|
315
|
+
This is based on the number of patches that would fit if they were sampled
|
|
316
|
+
sequentially.
|
|
317
|
+
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
spatial_shape : sequence of int
|
|
321
|
+
The dimension of the axes (Z)YX, a sequence of length 2 or 3, for 2D and 3D
|
|
322
|
+
data respectively.
|
|
323
|
+
patch_size : sequence of int
|
|
324
|
+
The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
|
|
325
|
+
data respectively.
|
|
326
|
+
|
|
327
|
+
Returns
|
|
328
|
+
-------
|
|
329
|
+
int
|
|
330
|
+
The number of patches.
|
|
331
|
+
"""
|
|
332
|
+
if len(patch_size) != len(spatial_shape):
|
|
333
|
+
raise ValueError(
|
|
334
|
+
f"Number of patch dimension {len(patch_size)}, do not match the number of "
|
|
335
|
+
f"spatial dimensions {len(spatial_shape)}, for `patch_size={patch_size}` "
|
|
336
|
+
f"and `spatial_shape={spatial_shape}`."
|
|
337
|
+
)
|
|
338
|
+
return int(np.ceil(np.prod(spatial_shape) / np.prod(patch_size)))
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from typing_extensions import ParamSpec
|
|
7
|
+
|
|
8
|
+
from .patching_strategy_protocol import PatchSpecs
|
|
9
|
+
|
|
10
|
+
P = ParamSpec("P")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# TODO: this is an unfinished prototype based on current tiling implementation
|
|
14
|
+
# not guaranteed to work!
|
|
15
|
+
class SequentialPatchingStrategy:
|
|
16
|
+
# TODO: docs
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
data_shapes: Sequence[Sequence[int]],
|
|
20
|
+
patch_size: Sequence[int],
|
|
21
|
+
overlap: Optional[Sequence[int]] = None,
|
|
22
|
+
):
|
|
23
|
+
self.data_shapes = data_shapes
|
|
24
|
+
self.patch_size = patch_size
|
|
25
|
+
if overlap is None:
|
|
26
|
+
overlap = [0] * len(patch_size)
|
|
27
|
+
self.overlap = np.asarray(overlap)
|
|
28
|
+
|
|
29
|
+
self.patch_specs: list[PatchSpecs] = self._initialize_patch_specs()
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def n_patches(self) -> int:
|
|
33
|
+
return len(self.patch_specs)
|
|
34
|
+
|
|
35
|
+
def get_patch_spec(self, index: int) -> PatchSpecs:
|
|
36
|
+
return self.patch_specs[index]
|
|
37
|
+
|
|
38
|
+
def _compute_coords_1d(
|
|
39
|
+
self, patch_size: int, spatial_shape: int, overlap: int
|
|
40
|
+
) -> list[tuple[int, int]]:
|
|
41
|
+
step = patch_size - overlap
|
|
42
|
+
crop_coords = []
|
|
43
|
+
|
|
44
|
+
current_pos = 0
|
|
45
|
+
while current_pos <= spatial_shape - patch_size:
|
|
46
|
+
crop_coords.append((current_pos, current_pos + patch_size))
|
|
47
|
+
current_pos += step
|
|
48
|
+
|
|
49
|
+
if crop_coords[-1][1] < spatial_shape:
|
|
50
|
+
crop_coords.append((spatial_shape - patch_size, spatial_shape))
|
|
51
|
+
|
|
52
|
+
return crop_coords
|
|
53
|
+
|
|
54
|
+
def _initialize_patch_specs(self) -> list[PatchSpecs]:
|
|
55
|
+
patch_specs: list[PatchSpecs] = []
|
|
56
|
+
for data_idx, data_shape in enumerate(self.data_shapes):
|
|
57
|
+
|
|
58
|
+
data_spatial_shape = data_shape[-len(self.patch_size) :]
|
|
59
|
+
coords_list = [
|
|
60
|
+
self._compute_coords_1d(
|
|
61
|
+
self.patch_size[i], data_spatial_shape[i], self.overlap[i]
|
|
62
|
+
)
|
|
63
|
+
for i in range(len(self.patch_size))
|
|
64
|
+
]
|
|
65
|
+
for sample_idx in range(data_shape[0]):
|
|
66
|
+
for crop_coord in itertools.product(*coords_list):
|
|
67
|
+
patch_specs.append(
|
|
68
|
+
PatchSpecs(
|
|
69
|
+
data_idx=data_idx,
|
|
70
|
+
sample_idx=sample_idx,
|
|
71
|
+
coords=tuple(coord[0] for coord in crop_coord),
|
|
72
|
+
patch_size=self.patch_size,
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
return patch_specs
|
|
@@ -1,12 +1,17 @@
|
|
|
1
1
|
"""CAREamics Lightning module."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Callable, Optional, Union
|
|
3
|
+
from typing import Any, Callable, Literal, Optional, Union
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pytorch_lightning as L
|
|
7
7
|
from torch import Tensor, nn
|
|
8
8
|
|
|
9
|
-
from careamics.config import
|
|
9
|
+
from careamics.config import (
|
|
10
|
+
N2VAlgorithm,
|
|
11
|
+
UNetBasedAlgorithm,
|
|
12
|
+
VAEBasedAlgorithm,
|
|
13
|
+
algorithm_factory,
|
|
14
|
+
)
|
|
10
15
|
from careamics.config.support import (
|
|
11
16
|
SupportedAlgorithm,
|
|
12
17
|
SupportedArchitecture,
|
|
@@ -27,7 +32,11 @@ from careamics.models.lvae.noise_models import (
|
|
|
27
32
|
noise_model_factory,
|
|
28
33
|
)
|
|
29
34
|
from careamics.models.model_factory import model_factory
|
|
30
|
-
from careamics.transforms import
|
|
35
|
+
from careamics.transforms import (
|
|
36
|
+
Denormalize,
|
|
37
|
+
ImageRestorationTTA,
|
|
38
|
+
N2VManipulateTorch,
|
|
39
|
+
)
|
|
31
40
|
from careamics.utils.metrics import RunningPSNR, scale_invariant_psnr
|
|
32
41
|
from careamics.utils.torch_utils import get_optimizer, get_scheduler
|
|
33
42
|
|
|
@@ -73,13 +82,21 @@ class FCNModule(L.LightningModule):
|
|
|
73
82
|
Algorithm configuration.
|
|
74
83
|
"""
|
|
75
84
|
super().__init__()
|
|
76
|
-
|
|
85
|
+
|
|
77
86
|
if isinstance(algorithm_config, dict):
|
|
78
|
-
algorithm_config =
|
|
79
|
-
|
|
80
|
-
|
|
87
|
+
algorithm_config = algorithm_factory(algorithm_config)
|
|
88
|
+
|
|
89
|
+
# create preprocessing, model and loss function
|
|
90
|
+
if isinstance(algorithm_config, N2VAlgorithm):
|
|
91
|
+
self.use_n2v = True
|
|
92
|
+
self.n2v_preprocess: Optional[N2VManipulateTorch] = N2VManipulateTorch(
|
|
93
|
+
n2v_manipulate_config=algorithm_config.n2v_config
|
|
94
|
+
)
|
|
95
|
+
else:
|
|
96
|
+
self.use_n2v = False
|
|
97
|
+
self.n2v_preprocess = None
|
|
81
98
|
|
|
82
|
-
|
|
99
|
+
self.algorithm = algorithm_config.algorithm
|
|
83
100
|
self.model: nn.Module = model_factory(algorithm_config.model)
|
|
84
101
|
self.loss_func = loss_factory(algorithm_config.loss)
|
|
85
102
|
|
|
@@ -119,10 +136,15 @@ class FCNModule(L.LightningModule):
|
|
|
119
136
|
Any
|
|
120
137
|
Loss value.
|
|
121
138
|
"""
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
139
|
+
x, *targets = batch
|
|
140
|
+
if self.use_n2v and self.n2v_preprocess is not None:
|
|
141
|
+
x_preprocessed, *aux = self.n2v_preprocess(x)
|
|
142
|
+
else:
|
|
143
|
+
x_preprocessed = x
|
|
144
|
+
aux = []
|
|
145
|
+
|
|
146
|
+
out = self.model(x_preprocessed)
|
|
147
|
+
loss = self.loss_func(out, *aux, *targets)
|
|
126
148
|
self.log(
|
|
127
149
|
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
|
|
128
150
|
)
|
|
@@ -138,9 +160,15 @@ class FCNModule(L.LightningModule):
|
|
|
138
160
|
batch_idx : Any
|
|
139
161
|
Batch index.
|
|
140
162
|
"""
|
|
141
|
-
x, *
|
|
142
|
-
|
|
143
|
-
|
|
163
|
+
x, *targets = batch
|
|
164
|
+
if self.use_n2v and self.n2v_preprocess is not None:
|
|
165
|
+
x_preprocessed, *aux = self.n2v_preprocess(x)
|
|
166
|
+
else:
|
|
167
|
+
x_preprocessed = x
|
|
168
|
+
aux = []
|
|
169
|
+
|
|
170
|
+
out = self.model(x_preprocessed)
|
|
171
|
+
val_loss = self.loss_func(out, *aux, *targets)
|
|
144
172
|
|
|
145
173
|
# log validation loss
|
|
146
174
|
self.log(
|
|
@@ -177,10 +205,16 @@ class FCNModule(L.LightningModule):
|
|
|
177
205
|
and isinstance(batch[1][0], TileInformation)
|
|
178
206
|
)
|
|
179
207
|
|
|
208
|
+
# TODO add explanations for what is happening here
|
|
180
209
|
if is_tiled:
|
|
181
210
|
x, *aux = batch
|
|
211
|
+
if type(x) in [list, tuple]:
|
|
212
|
+
x = x[0]
|
|
182
213
|
else:
|
|
183
|
-
|
|
214
|
+
if type(batch) in [list, tuple]:
|
|
215
|
+
x = batch[0] # TODO change, ugly way to deal with n2v refac
|
|
216
|
+
else:
|
|
217
|
+
x = batch
|
|
184
218
|
aux = []
|
|
185
219
|
|
|
186
220
|
# apply test-time augmentation if available
|
|
@@ -593,6 +627,9 @@ def create_careamics_module(
|
|
|
593
627
|
algorithm: Union[SupportedAlgorithm, str],
|
|
594
628
|
loss: Union[SupportedLoss, str],
|
|
595
629
|
architecture: Union[SupportedArchitecture, str],
|
|
630
|
+
use_n2v2: bool = False,
|
|
631
|
+
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
632
|
+
struct_n2v_span: int = 5,
|
|
596
633
|
model_parameters: Optional[dict] = None,
|
|
597
634
|
optimizer: Union[SupportedOptimizer, str] = "Adam",
|
|
598
635
|
optimizer_parameters: Optional[dict] = None,
|
|
@@ -612,6 +649,12 @@ def create_careamics_module(
|
|
|
612
649
|
Loss function to use for training (see SupportedLoss).
|
|
613
650
|
architecture : SupportedArchitecture or str
|
|
614
651
|
Model architecture to use for training (see SupportedArchitecture).
|
|
652
|
+
use_n2v2 : bool, default=False
|
|
653
|
+
Whether to use N2V2 or Noise2Void.
|
|
654
|
+
struct_n2v_axis : "horizontal", "vertical", or "none", default="none"
|
|
655
|
+
Axis of the StructN2V mask.
|
|
656
|
+
struct_n2v_span : int, default=5
|
|
657
|
+
Span of the StructN2V mask.
|
|
615
658
|
model_parameters : dict, optional
|
|
616
659
|
Model parameters to use for training, by default {}. Model parameters are
|
|
617
660
|
defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
|
|
@@ -633,14 +676,15 @@ def create_careamics_module(
|
|
|
633
676
|
CAREamicsModule
|
|
634
677
|
CAREamics Lightning module.
|
|
635
678
|
"""
|
|
636
|
-
#
|
|
679
|
+
# TODO should use the same functions are in configuration_factory.py
|
|
680
|
+
# create an AlgorithmModel compatible dictionary
|
|
637
681
|
if lr_scheduler_parameters is None:
|
|
638
682
|
lr_scheduler_parameters = {}
|
|
639
683
|
if optimizer_parameters is None:
|
|
640
684
|
optimizer_parameters = {}
|
|
641
685
|
if model_parameters is None:
|
|
642
686
|
model_parameters = {}
|
|
643
|
-
|
|
687
|
+
algorithm_dict: dict[str, Any] = {
|
|
644
688
|
"algorithm": algorithm,
|
|
645
689
|
"loss": loss,
|
|
646
690
|
"optimizer": {
|
|
@@ -652,18 +696,25 @@ def create_careamics_module(
|
|
|
652
696
|
"parameters": lr_scheduler_parameters,
|
|
653
697
|
},
|
|
654
698
|
}
|
|
655
|
-
|
|
656
|
-
|
|
699
|
+
|
|
700
|
+
model_dict = {"architecture": architecture}
|
|
701
|
+
model_dict.update(model_parameters)
|
|
657
702
|
|
|
658
703
|
# add model parameters to algorithm configuration
|
|
659
|
-
|
|
704
|
+
algorithm_dict["model"] = model_dict
|
|
705
|
+
|
|
706
|
+
which_algo = algorithm_dict["algorithm"]
|
|
707
|
+
if which_algo in UNetBasedAlgorithm.get_compatible_algorithms():
|
|
708
|
+
algorithm_cfg = algorithm_factory(algorithm_dict)
|
|
709
|
+
|
|
710
|
+
# if use N2V
|
|
711
|
+
if isinstance(algorithm_cfg, N2VAlgorithm):
|
|
712
|
+
algorithm_cfg.n2v_config.struct_mask_axis = struct_n2v_axis
|
|
713
|
+
algorithm_cfg.n2v_config.struct_mask_span = struct_n2v_span
|
|
714
|
+
algorithm_cfg.set_n2v2(use_n2v2)
|
|
660
715
|
|
|
661
|
-
|
|
662
|
-
# TODO broken by new configutations!
|
|
663
|
-
algorithm_str = algorithm_configuration["algorithm"]
|
|
664
|
-
if algorithm_str in UNetBasedAlgorithm.get_compatible_algorithms():
|
|
665
|
-
return FCNModule(UNetBasedAlgorithm(**algorithm_configuration))
|
|
716
|
+
return FCNModule(algorithm_cfg)
|
|
666
717
|
else:
|
|
667
718
|
raise NotImplementedError(
|
|
668
|
-
f"
|
|
719
|
+
f"Algorithm {which_algo} is not implemented or unknown."
|
|
669
720
|
)
|