careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +12 -11
- careamics/config/__init__.py +0 -1
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/callback_model.py +1 -0
- careamics/config/configuration_example.py +0 -2
- careamics/config/configuration_factory.py +112 -42
- careamics/config/configuration_model.py +14 -16
- careamics/config/data_model.py +59 -157
- careamics/config/inference_model.py +19 -20
- careamics/config/references/algorithm_descriptions.py +1 -0
- careamics/config/references/references.py +1 -0
- careamics/config/support/supported_extraction_strategies.py +1 -0
- careamics/config/training_model.py +1 -0
- careamics/config/transformations/n2v_manipulate_model.py +1 -0
- careamics/config/transformations/nd_flip_model.py +6 -11
- careamics/config/transformations/normalize_model.py +1 -0
- careamics/config/transformations/transform_model.py +1 -0
- careamics/config/transformations/xy_random_rotate90_model.py +6 -8
- careamics/config/validators/validator_utils.py +1 -0
- careamics/conftest.py +1 -0
- careamics/dataset/dataset_utils/__init__.py +0 -1
- careamics/dataset/dataset_utils/dataset_utils.py +1 -0
- careamics/dataset/in_memory_dataset.py +14 -45
- careamics/dataset/iterable_dataset.py +13 -68
- careamics/dataset/patching/__init__.py +0 -7
- careamics/dataset/patching/patching.py +1 -0
- careamics/dataset/patching/sequential_patching.py +6 -6
- careamics/dataset/patching/tiled_patching.py +10 -6
- careamics/lightning_datamodule.py +20 -24
- careamics/lightning_module.py +1 -1
- careamics/lightning_prediction_datamodule.py +15 -10
- careamics/losses/__init__.py +0 -1
- careamics/losses/loss_factory.py +1 -0
- careamics/model_io/__init__.py +0 -1
- careamics/model_io/bioimage/_readme_factory.py +2 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -0
- careamics/model_io/bioimage/model_description.py +1 -0
- careamics/model_io/bmz_io.py +2 -1
- careamics/models/layers.py +1 -0
- careamics/models/model_factory.py +1 -0
- careamics/models/unet.py +91 -17
- careamics/prediction/stitch_prediction.py +1 -0
- careamics/transforms/__init__.py +2 -23
- careamics/transforms/compose.py +98 -0
- careamics/transforms/n2v_manipulate.py +18 -23
- careamics/transforms/nd_flip.py +38 -64
- careamics/transforms/normalize.py +45 -34
- careamics/transforms/pixel_manipulation.py +2 -2
- careamics/transforms/transform.py +33 -0
- careamics/transforms/tta.py +2 -2
- careamics/transforms/xy_random_rotate90.py +41 -68
- careamics/utils/__init__.py +0 -1
- careamics/utils/context.py +1 -0
- careamics/utils/logging.py +1 -0
- careamics/utils/metrics.py +1 -0
- careamics/utils/torch_utils.py +1 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/METADATA +16 -61
- careamics-0.1.0rc5.dist-info/RECORD +111 -0
- careamics/dataset/patching/patch_transform.py +0 -44
- careamics-0.1.0rc4.dist-info/RECORD +0 -110
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/licenses/LICENSE +0 -0
careamics/models/unet.py
CHANGED
|
@@ -3,7 +3,8 @@ UNet model.
|
|
|
3
3
|
|
|
4
4
|
A UNet encoder, decoder and complete model.
|
|
5
5
|
"""
|
|
6
|
-
|
|
6
|
+
|
|
7
|
+
from typing import Any, List, Tuple, Union
|
|
7
8
|
|
|
8
9
|
import torch
|
|
9
10
|
import torch.nn as nn
|
|
@@ -33,6 +34,9 @@ class UnetEncoder(nn.Module):
|
|
|
33
34
|
Dropout probability, by default 0.0.
|
|
34
35
|
pool_kernel : int, optional
|
|
35
36
|
Kernel size for the max pooling layers, by default 2.
|
|
37
|
+
groups: int, optional
|
|
38
|
+
Number of blocked connections from input channels to output
|
|
39
|
+
channels, by default 1.
|
|
36
40
|
"""
|
|
37
41
|
|
|
38
42
|
def __init__(
|
|
@@ -45,6 +49,7 @@ class UnetEncoder(nn.Module):
|
|
|
45
49
|
dropout: float = 0.0,
|
|
46
50
|
pool_kernel: int = 2,
|
|
47
51
|
n2v2: bool = False,
|
|
52
|
+
groups: int = 1,
|
|
48
53
|
) -> None:
|
|
49
54
|
"""
|
|
50
55
|
Constructor.
|
|
@@ -65,6 +70,9 @@ class UnetEncoder(nn.Module):
|
|
|
65
70
|
Dropout probability, by default 0.0.
|
|
66
71
|
pool_kernel : int, optional
|
|
67
72
|
Kernel size for the max pooling layers, by default 2.
|
|
73
|
+
groups: int, optional
|
|
74
|
+
Number of blocked connections from input channels to output
|
|
75
|
+
channels, by default 1.
|
|
68
76
|
"""
|
|
69
77
|
super().__init__()
|
|
70
78
|
|
|
@@ -77,7 +85,7 @@ class UnetEncoder(nn.Module):
|
|
|
77
85
|
encoder_blocks = []
|
|
78
86
|
|
|
79
87
|
for n in range(depth):
|
|
80
|
-
out_channels = num_channels_init * (2**n)
|
|
88
|
+
out_channels = num_channels_init * (2**n) * groups
|
|
81
89
|
in_channels = in_channels if n == 0 else out_channels // 2
|
|
82
90
|
encoder_blocks.append(
|
|
83
91
|
Conv_Block(
|
|
@@ -86,6 +94,7 @@ class UnetEncoder(nn.Module):
|
|
|
86
94
|
out_channels=out_channels,
|
|
87
95
|
dropout_perc=dropout,
|
|
88
96
|
use_batch_norm=use_batch_norm,
|
|
97
|
+
groups=groups,
|
|
89
98
|
)
|
|
90
99
|
)
|
|
91
100
|
encoder_blocks.append(self.pooling)
|
|
@@ -131,6 +140,9 @@ class UnetDecoder(nn.Module):
|
|
|
131
140
|
Whether to use batch normalization, by default True.
|
|
132
141
|
dropout : float, optional
|
|
133
142
|
Dropout probability, by default 0.0.
|
|
143
|
+
groups: int, optional
|
|
144
|
+
Number of blocked connections from input channels to output
|
|
145
|
+
channels, by default 1.
|
|
134
146
|
"""
|
|
135
147
|
|
|
136
148
|
def __init__(
|
|
@@ -141,6 +153,7 @@ class UnetDecoder(nn.Module):
|
|
|
141
153
|
use_batch_norm: bool = True,
|
|
142
154
|
dropout: float = 0.0,
|
|
143
155
|
n2v2: bool = False,
|
|
156
|
+
groups: int = 1,
|
|
144
157
|
) -> None:
|
|
145
158
|
"""
|
|
146
159
|
Constructor.
|
|
@@ -157,15 +170,19 @@ class UnetDecoder(nn.Module):
|
|
|
157
170
|
Whether to use batch normalization, by default True.
|
|
158
171
|
dropout : float, optional
|
|
159
172
|
Dropout probability, by default 0.0.
|
|
173
|
+
groups: int, optional
|
|
174
|
+
Number of blocked connections from input channels to output
|
|
175
|
+
channels, by default 1.
|
|
160
176
|
"""
|
|
161
177
|
super().__init__()
|
|
162
178
|
|
|
163
179
|
upsampling = nn.Upsample(
|
|
164
180
|
scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
|
|
165
181
|
)
|
|
166
|
-
in_channels = out_channels = num_channels_init * 2 ** (depth - 1)
|
|
182
|
+
in_channels = out_channels = num_channels_init * groups * (2 ** (depth - 1))
|
|
167
183
|
|
|
168
184
|
self.n2v2 = n2v2
|
|
185
|
+
self.groups = groups
|
|
169
186
|
|
|
170
187
|
self.bottleneck = Conv_Block(
|
|
171
188
|
conv_dim,
|
|
@@ -174,34 +191,32 @@ class UnetDecoder(nn.Module):
|
|
|
174
191
|
intermediate_channel_multiplier=2,
|
|
175
192
|
use_batch_norm=use_batch_norm,
|
|
176
193
|
dropout_perc=dropout,
|
|
194
|
+
groups=self.groups,
|
|
177
195
|
)
|
|
178
196
|
|
|
179
|
-
decoder_blocks = []
|
|
197
|
+
decoder_blocks: List[nn.Module] = []
|
|
180
198
|
for n in range(depth):
|
|
181
199
|
decoder_blocks.append(upsampling)
|
|
182
|
-
in_channels = (
|
|
183
|
-
num_channels_init ** (depth - n)
|
|
184
|
-
if (self.n2v2 and n == depth - 1)
|
|
185
|
-
else num_channels_init * 2 ** (depth - n)
|
|
186
|
-
)
|
|
200
|
+
in_channels = (num_channels_init * 2 ** (depth - n)) * groups
|
|
187
201
|
out_channels = in_channels // 2
|
|
188
202
|
decoder_blocks.append(
|
|
189
203
|
Conv_Block(
|
|
190
204
|
conv_dim,
|
|
191
|
-
in_channels=
|
|
192
|
-
|
|
193
|
-
|
|
205
|
+
in_channels=(
|
|
206
|
+
in_channels + in_channels // 2 if n > 0 else in_channels
|
|
207
|
+
),
|
|
194
208
|
out_channels=out_channels,
|
|
195
209
|
intermediate_channel_multiplier=2,
|
|
196
210
|
dropout_perc=dropout,
|
|
197
211
|
activation="ReLU",
|
|
198
212
|
use_batch_norm=use_batch_norm,
|
|
213
|
+
groups=groups,
|
|
199
214
|
)
|
|
200
215
|
)
|
|
201
216
|
|
|
202
217
|
self.decoder_blocks = nn.ModuleList(decoder_blocks)
|
|
203
218
|
|
|
204
|
-
def forward(self, *features:
|
|
219
|
+
def forward(self, *features: torch.Tensor) -> torch.Tensor:
|
|
205
220
|
"""
|
|
206
221
|
Forward pass.
|
|
207
222
|
|
|
@@ -217,20 +232,70 @@ class UnetDecoder(nn.Module):
|
|
|
217
232
|
Output of the decoder.
|
|
218
233
|
"""
|
|
219
234
|
x: torch.Tensor = features[0]
|
|
220
|
-
skip_connections: torch.Tensor = features[1:
|
|
235
|
+
skip_connections: Tuple[torch.Tensor, ...] = features[-1:0:-1]
|
|
221
236
|
|
|
222
237
|
x = self.bottleneck(x)
|
|
223
238
|
|
|
224
239
|
for i, module in enumerate(self.decoder_blocks):
|
|
225
240
|
x = module(x)
|
|
226
241
|
if isinstance(module, nn.Upsample):
|
|
242
|
+
# divide index by 2 because of upsampling layers
|
|
243
|
+
skip_connection: torch.Tensor = skip_connections[i // 2]
|
|
227
244
|
if self.n2v2:
|
|
228
245
|
if x.shape != skip_connections[-1].shape:
|
|
229
|
-
x =
|
|
246
|
+
x = self._interleave(x, skip_connection, self.groups)
|
|
230
247
|
else:
|
|
231
|
-
x =
|
|
248
|
+
x = self._interleave(x, skip_connection, self.groups)
|
|
232
249
|
return x
|
|
233
250
|
|
|
251
|
+
@staticmethod
|
|
252
|
+
def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor:
|
|
253
|
+
"""
|
|
254
|
+
Splits the tensors `A` and `B` into equally sized groups along the
|
|
255
|
+
channel axis (axis=1); then concatenates the groups in alternating
|
|
256
|
+
order along the channel axis, starting with the first group from tensor
|
|
257
|
+
A.
|
|
258
|
+
|
|
259
|
+
Parameters
|
|
260
|
+
----------
|
|
261
|
+
A: torch.Tensor
|
|
262
|
+
B: torch.Tensor
|
|
263
|
+
groups: int
|
|
264
|
+
The number of groups.
|
|
265
|
+
|
|
266
|
+
Returns
|
|
267
|
+
-------
|
|
268
|
+
torch.Tensor
|
|
269
|
+
|
|
270
|
+
Raises
|
|
271
|
+
------
|
|
272
|
+
ValueError:
|
|
273
|
+
If either of `A` or `B`'s channel axis is not divisible by `groups`.
|
|
274
|
+
"""
|
|
275
|
+
if (A.shape[1] % groups != 0) or (B.shape[1] % groups != 0):
|
|
276
|
+
raise ValueError(f"Number of channels not divisible by {groups} groups.")
|
|
277
|
+
|
|
278
|
+
m = A.shape[1] // groups
|
|
279
|
+
n = B.shape[1] // groups
|
|
280
|
+
|
|
281
|
+
A_groups: List[torch.Tensor] = [
|
|
282
|
+
A[:, i * m : (i + 1) * m] for i in range(groups)
|
|
283
|
+
]
|
|
284
|
+
B_groups: List[torch.Tensor] = [
|
|
285
|
+
B[:, i * n : (i + 1) * n] for i in range(groups)
|
|
286
|
+
]
|
|
287
|
+
|
|
288
|
+
interleaved = torch.cat(
|
|
289
|
+
[
|
|
290
|
+
tensor_list[i]
|
|
291
|
+
for i in range(groups)
|
|
292
|
+
for tensor_list in [A_groups, B_groups]
|
|
293
|
+
],
|
|
294
|
+
dim=1,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
return interleaved
|
|
298
|
+
|
|
234
299
|
|
|
235
300
|
class UNet(nn.Module):
|
|
236
301
|
"""
|
|
@@ -273,6 +338,7 @@ class UNet(nn.Module):
|
|
|
273
338
|
pool_kernel: int = 2,
|
|
274
339
|
final_activation: Union[SupportedActivation, str] = SupportedActivation.NONE,
|
|
275
340
|
n2v2: bool = False,
|
|
341
|
+
independent_channels: bool = True,
|
|
276
342
|
**kwargs: Any,
|
|
277
343
|
) -> None:
|
|
278
344
|
"""
|
|
@@ -298,9 +364,14 @@ class UNet(nn.Module):
|
|
|
298
364
|
Kernel size of the pooling layers, by default 2.
|
|
299
365
|
last_activation : Optional[Callable], optional
|
|
300
366
|
Activation function to use for the last layer, by default None.
|
|
367
|
+
independent_channels : bool
|
|
368
|
+
Whether to train parallel independent networks for each channel, by
|
|
369
|
+
default True.
|
|
301
370
|
"""
|
|
302
371
|
super().__init__()
|
|
303
372
|
|
|
373
|
+
groups = in_channels if independent_channels else 1
|
|
374
|
+
|
|
304
375
|
self.encoder = UnetEncoder(
|
|
305
376
|
conv_dims,
|
|
306
377
|
in_channels=in_channels,
|
|
@@ -310,6 +381,7 @@ class UNet(nn.Module):
|
|
|
310
381
|
dropout=dropout,
|
|
311
382
|
pool_kernel=pool_kernel,
|
|
312
383
|
n2v2=n2v2,
|
|
384
|
+
groups=groups,
|
|
313
385
|
)
|
|
314
386
|
|
|
315
387
|
self.decoder = UnetDecoder(
|
|
@@ -319,11 +391,13 @@ class UNet(nn.Module):
|
|
|
319
391
|
use_batch_norm=use_batch_norm,
|
|
320
392
|
dropout=dropout,
|
|
321
393
|
n2v2=n2v2,
|
|
394
|
+
groups=groups,
|
|
322
395
|
)
|
|
323
396
|
self.final_conv = getattr(nn, f"Conv{conv_dims}d")(
|
|
324
|
-
in_channels=num_channels_init,
|
|
397
|
+
in_channels=num_channels_init * groups,
|
|
325
398
|
out_channels=num_classes,
|
|
326
399
|
kernel_size=1,
|
|
400
|
+
groups=groups,
|
|
327
401
|
)
|
|
328
402
|
self.final_activation = get_activation(final_activation)
|
|
329
403
|
|
careamics/transforms/__init__.py
CHANGED
|
@@ -8,34 +8,13 @@ __all__ = [
|
|
|
8
8
|
"ImageRestorationTTA",
|
|
9
9
|
"Denormalize",
|
|
10
10
|
"Normalize",
|
|
11
|
+
"Compose",
|
|
11
12
|
]
|
|
12
13
|
|
|
13
14
|
|
|
15
|
+
from .compose import Compose, get_all_transforms
|
|
14
16
|
from .n2v_manipulate import N2VManipulate
|
|
15
17
|
from .nd_flip import NDFlip
|
|
16
18
|
from .normalize import Denormalize, Normalize
|
|
17
19
|
from .tta import ImageRestorationTTA
|
|
18
20
|
from .xy_random_rotate90 import XYRandomRotate90
|
|
19
|
-
|
|
20
|
-
ALL_TRANSFORMS = {
|
|
21
|
-
"Normalize": Normalize,
|
|
22
|
-
"N2VManipulate": N2VManipulate,
|
|
23
|
-
"NDFlip": NDFlip,
|
|
24
|
-
"XYRandomRotate90": XYRandomRotate90,
|
|
25
|
-
}
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def get_all_transforms() -> dict:
|
|
29
|
-
"""Return all the transforms accepted by CAREamics.
|
|
30
|
-
|
|
31
|
-
Note that while CAREamics accepts any `Compose` transforms from Albumentations (see
|
|
32
|
-
https://albumentations.ai/), only a few transformations are explicitely supported
|
|
33
|
-
(see `SupportedTransform`).
|
|
34
|
-
|
|
35
|
-
Returns
|
|
36
|
-
-------
|
|
37
|
-
dict
|
|
38
|
-
A dictionary with all the transforms accepted by CAREamics, where the keys are
|
|
39
|
-
the transform names and the values are the transform classes.
|
|
40
|
-
"""
|
|
41
|
-
return ALL_TRANSFORMS
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""A class chaining transforms together."""
|
|
2
|
+
|
|
3
|
+
from typing import Callable, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from careamics.config.data_model import TRANSFORMS_UNION
|
|
8
|
+
|
|
9
|
+
from .n2v_manipulate import N2VManipulate
|
|
10
|
+
from .nd_flip import NDFlip
|
|
11
|
+
from .normalize import Normalize
|
|
12
|
+
from .transform import Transform
|
|
13
|
+
from .xy_random_rotate90 import XYRandomRotate90
|
|
14
|
+
|
|
15
|
+
ALL_TRANSFORMS = {
|
|
16
|
+
"Normalize": Normalize,
|
|
17
|
+
"N2VManipulate": N2VManipulate,
|
|
18
|
+
"NDFlip": NDFlip,
|
|
19
|
+
"XYRandomRotate90": XYRandomRotate90,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_all_transforms() -> dict:
|
|
24
|
+
"""Return all the transforms accepted by CAREamics.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
dict
|
|
29
|
+
A dictionary with all the transforms accepted by CAREamics, where the keys are
|
|
30
|
+
the transform names and the values are the transform classes.
|
|
31
|
+
"""
|
|
32
|
+
return ALL_TRANSFORMS
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Compose:
|
|
36
|
+
"""A class chaining transforms together."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None:
|
|
39
|
+
"""Instantiate a Compose object.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
transform_list : List[TRANSFORMS_UNION]
|
|
44
|
+
A list of dictionaries where each dictionary contains the name of a
|
|
45
|
+
transform and its parameters.
|
|
46
|
+
"""
|
|
47
|
+
# retrieve all available transforms
|
|
48
|
+
all_transforms = get_all_transforms()
|
|
49
|
+
|
|
50
|
+
# instantiate all transforms
|
|
51
|
+
transforms = [all_transforms[t.name](**t.model_dump()) for t in transform_list]
|
|
52
|
+
|
|
53
|
+
self._callable_transforms = self._chain_transforms(transforms)
|
|
54
|
+
|
|
55
|
+
def _chain_transforms(self, transforms: List[Transform]) -> Callable:
|
|
56
|
+
"""Chain the transforms together.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
transforms : List[Transform]
|
|
61
|
+
A list of transforms to chain together.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
Callable
|
|
66
|
+
A callable that applies the transforms in order to the input data.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def _chain(
|
|
70
|
+
patch: np.ndarray, target: Optional[np.ndarray]
|
|
71
|
+
) -> Tuple[np.ndarray, ...]:
|
|
72
|
+
params = (patch, target)
|
|
73
|
+
|
|
74
|
+
for t in transforms:
|
|
75
|
+
params = t(*params)
|
|
76
|
+
|
|
77
|
+
return params
|
|
78
|
+
|
|
79
|
+
return _chain
|
|
80
|
+
|
|
81
|
+
def __call__(
|
|
82
|
+
self, patch: np.ndarray, target: Optional[np.ndarray] = None
|
|
83
|
+
) -> Tuple[np.ndarray, ...]:
|
|
84
|
+
"""Apply the transforms to the input data.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
patch : np.ndarray
|
|
89
|
+
The input data.
|
|
90
|
+
target : Optional[np.ndarray], optional
|
|
91
|
+
Target data, by default None
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
Tuple[np.ndarray, ...]
|
|
96
|
+
The output of the transformations.
|
|
97
|
+
"""
|
|
98
|
+
return self._callable_transforms(patch, target)
|
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
from typing import Any, Literal, Optional, Tuple
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
|
-
from albumentations import ImageOnlyTransform
|
|
5
4
|
|
|
6
5
|
from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
|
|
6
|
+
from careamics.transforms.transform import Transform
|
|
7
7
|
|
|
8
8
|
from .pixel_manipulation import median_manipulate, uniform_manipulate
|
|
9
9
|
from .struct_mask_parameters import StructMaskParameters
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class N2VManipulate(
|
|
12
|
+
class N2VManipulate(Transform):
|
|
13
13
|
"""
|
|
14
14
|
Default augmentation for the N2V model.
|
|
15
15
|
|
|
16
|
-
This transform expects (Z)
|
|
16
|
+
This transform expects C(Z)YX dimensions.
|
|
17
17
|
|
|
18
18
|
Parameters
|
|
19
19
|
----------
|
|
@@ -33,6 +33,7 @@ class N2VManipulate(ImageOnlyTransform):
|
|
|
33
33
|
remove_center: bool = True,
|
|
34
34
|
struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
35
35
|
struct_mask_span: int = 5,
|
|
36
|
+
seed: Optional[int] = None, # TODO use in pixel manipulation
|
|
36
37
|
):
|
|
37
38
|
"""Constructor.
|
|
38
39
|
|
|
@@ -50,8 +51,9 @@ class N2VManipulate(ImageOnlyTransform):
|
|
|
50
51
|
StructN2V mask axis, by default "none"
|
|
51
52
|
struct_mask_span : int, optional
|
|
52
53
|
StructN2V mask span, by default 5
|
|
54
|
+
seed : Optional[int], optional
|
|
55
|
+
Random seed, by default None
|
|
53
56
|
"""
|
|
54
|
-
super().__init__(p=1)
|
|
55
57
|
self.masked_pixel_percentage = masked_pixel_percentage
|
|
56
58
|
self.roi_size = roi_size
|
|
57
59
|
self.strategy = strategy
|
|
@@ -65,23 +67,26 @@ class N2VManipulate(ImageOnlyTransform):
|
|
|
65
67
|
span=struct_mask_span,
|
|
66
68
|
)
|
|
67
69
|
|
|
68
|
-
|
|
69
|
-
self
|
|
70
|
+
# numpy random generator
|
|
71
|
+
self.rng = np.random.default_rng(seed=seed)
|
|
72
|
+
|
|
73
|
+
def __call__(
|
|
74
|
+
self, patch: np.ndarray, *args: Any, **kwargs: Any
|
|
70
75
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
71
76
|
"""Apply the transform to the image.
|
|
72
77
|
|
|
73
78
|
Parameters
|
|
74
79
|
----------
|
|
75
80
|
image : np.ndarray
|
|
76
|
-
Image or image patch, 2D or 3D, shape (
|
|
81
|
+
Image or image patch, 2D or 3D, shape C(Z)YX.
|
|
77
82
|
"""
|
|
78
83
|
masked = np.zeros_like(patch)
|
|
79
84
|
mask = np.zeros_like(patch)
|
|
80
85
|
if self.strategy == SupportedPixelManipulation.UNIFORM:
|
|
81
86
|
# Iterate over the channels to apply manipulation separately
|
|
82
|
-
for c in range(patch.shape[
|
|
83
|
-
masked[
|
|
84
|
-
patch=patch[
|
|
87
|
+
for c in range(patch.shape[0]):
|
|
88
|
+
masked[c, ...], mask[c, ...] = uniform_manipulate(
|
|
89
|
+
patch=patch[c, ...],
|
|
85
90
|
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
86
91
|
subpatch_size=self.roi_size,
|
|
87
92
|
remove_center=self.remove_center,
|
|
@@ -89,9 +94,9 @@ class N2VManipulate(ImageOnlyTransform):
|
|
|
89
94
|
)
|
|
90
95
|
elif self.strategy == SupportedPixelManipulation.MEDIAN:
|
|
91
96
|
# Iterate over the channels to apply manipulation separately
|
|
92
|
-
for c in range(patch.shape[
|
|
93
|
-
masked[
|
|
94
|
-
patch=patch[
|
|
97
|
+
for c in range(patch.shape[0]):
|
|
98
|
+
masked[c, ...], mask[c, ...] = median_manipulate(
|
|
99
|
+
patch=patch[c, ...],
|
|
95
100
|
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
96
101
|
subpatch_size=self.roi_size,
|
|
97
102
|
struct_params=self.struct_mask,
|
|
@@ -101,13 +106,3 @@ class N2VManipulate(ImageOnlyTransform):
|
|
|
101
106
|
|
|
102
107
|
# TODO why return patch?
|
|
103
108
|
return masked, patch, mask
|
|
104
|
-
|
|
105
|
-
def get_transform_init_args_names(self) -> Tuple[str, ...]:
|
|
106
|
-
"""Get the transform parameters.
|
|
107
|
-
|
|
108
|
-
Returns
|
|
109
|
-
-------
|
|
110
|
-
Tuple[str, ...]
|
|
111
|
-
Transform parameters.
|
|
112
|
-
"""
|
|
113
|
-
return ("roi_size", "masked_pixel_percentage", "strategy", "struct_mask")
|
careamics/transforms/nd_flip.py
CHANGED
|
@@ -1,93 +1,67 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
|
-
from albumentations import DualTransform
|
|
5
4
|
|
|
5
|
+
from careamics.transforms.transform import Transform
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
|
|
8
|
+
class NDFlip(Transform):
|
|
8
9
|
"""Flip ND arrays on a single axis.
|
|
9
10
|
|
|
10
11
|
This transform ignores singleton axes and randomly flips one of the other
|
|
11
|
-
|
|
12
|
+
last two axes.
|
|
12
13
|
|
|
13
|
-
This transform expects (Z)
|
|
14
|
+
This transform expects C(Z)YX dimensions.
|
|
14
15
|
"""
|
|
15
16
|
|
|
16
|
-
def __init__(self,
|
|
17
|
+
def __init__(self, seed: Optional[int] = None):
|
|
17
18
|
"""Constructor.
|
|
18
19
|
|
|
19
20
|
Parameters
|
|
20
21
|
----------
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
is_3D : bool, optional
|
|
24
|
-
Whether the data is 3D, by default False
|
|
25
|
-
flip_z : bool, optional
|
|
26
|
-
Whether to flip Z dimension, by default True
|
|
22
|
+
seed : Optional[int], optional
|
|
23
|
+
Random seed, by default None
|
|
27
24
|
"""
|
|
28
|
-
super().__init__(p=p)
|
|
29
|
-
|
|
30
|
-
self.is_3D = is_3D
|
|
31
|
-
self.flip_z = flip_z
|
|
32
|
-
|
|
33
25
|
# "flippable" axes
|
|
34
|
-
|
|
35
|
-
self.axis_indices = [0, 1, 2] if flip_z else [1, 2]
|
|
36
|
-
else:
|
|
37
|
-
self.axis_indices = [0, 1]
|
|
26
|
+
self.axis_indices = [-2, -1]
|
|
38
27
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
Returns
|
|
43
|
-
-------
|
|
44
|
-
Dict[str, int]
|
|
45
|
-
Transform parameters.
|
|
46
|
-
"""
|
|
47
|
-
return {"flip_axis": np.random.choice(self.axis_indices)}
|
|
28
|
+
# numpy random generator
|
|
29
|
+
self.rng = np.random.default_rng(seed=seed)
|
|
48
30
|
|
|
49
|
-
def
|
|
50
|
-
|
|
31
|
+
def __call__(
|
|
32
|
+
self, patch: np.ndarray, target: Optional[np.ndarray] = None
|
|
33
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
34
|
+
"""Apply the transform to the source patch and the target (optional).
|
|
51
35
|
|
|
52
36
|
Parameters
|
|
53
37
|
----------
|
|
54
38
|
patch : np.ndarray
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
39
|
+
Patch, 2D or 3D, shape C(Z)YX.
|
|
40
|
+
target : Optional[np.ndarray], optional
|
|
41
|
+
Target for the patch, by default None
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
46
|
+
Transformed patch and target.
|
|
58
47
|
"""
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
48
|
+
# choose an axis to flip
|
|
49
|
+
axis = self.rng.choice(self.axis_indices)
|
|
50
|
+
|
|
51
|
+
patch_transformed = self._apply(patch, axis)
|
|
52
|
+
target_transformed = self._apply(target, axis) if target is not None else None
|
|
64
53
|
|
|
65
|
-
return
|
|
54
|
+
return patch_transformed, target_transformed
|
|
66
55
|
|
|
67
|
-
def
|
|
68
|
-
|
|
69
|
-
) -> np.ndarray:
|
|
70
|
-
"""Apply the transform to the mask.
|
|
56
|
+
def _apply(self, patch: np.ndarray, axis: int) -> np.ndarray:
|
|
57
|
+
"""Apply the transform to the image.
|
|
71
58
|
|
|
72
59
|
Parameters
|
|
73
60
|
----------
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
raise ValueError(
|
|
79
|
-
"Incompatible mask shape and dimensionality. ZYXC patch shape "
|
|
80
|
-
"expected, but got YXC shape."
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
return np.ascontiguousarray(np.flip(mask, axis=flip_axis))
|
|
84
|
-
|
|
85
|
-
def get_transform_init_args_names(self, **kwargs: Any) -> Tuple[str, ...]:
|
|
86
|
-
"""Get the transform arguments names.
|
|
87
|
-
|
|
88
|
-
Returns
|
|
89
|
-
-------
|
|
90
|
-
Tuple[str, ...]
|
|
91
|
-
Transform arguments names.
|
|
61
|
+
patch : np.ndarray
|
|
62
|
+
Image or image patch, 2D or 3D, shape C(Z)YX.
|
|
63
|
+
axis : int
|
|
64
|
+
Axis to flip.
|
|
92
65
|
"""
|
|
93
|
-
|
|
66
|
+
# TODO why ascontiguousarray?
|
|
67
|
+
return np.ascontiguousarray(np.flip(patch, axis=axis))
|