careamics 0.1.0rc3__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (66) hide show
  1. careamics/__init__.py +8 -6
  2. careamics/careamist.py +30 -29
  3. careamics/config/__init__.py +12 -9
  4. careamics/config/algorithm_model.py +5 -5
  5. careamics/config/architectures/unet_model.py +1 -0
  6. careamics/config/callback_model.py +1 -0
  7. careamics/config/configuration_example.py +87 -0
  8. careamics/config/configuration_factory.py +285 -78
  9. careamics/config/configuration_model.py +22 -23
  10. careamics/config/data_model.py +62 -160
  11. careamics/config/inference_model.py +20 -21
  12. careamics/config/references/algorithm_descriptions.py +1 -0
  13. careamics/config/references/references.py +1 -0
  14. careamics/config/support/supported_extraction_strategies.py +1 -0
  15. careamics/config/support/supported_optimizers.py +3 -3
  16. careamics/config/training_model.py +2 -1
  17. careamics/config/transformations/n2v_manipulate_model.py +2 -1
  18. careamics/config/transformations/nd_flip_model.py +7 -12
  19. careamics/config/transformations/normalize_model.py +2 -1
  20. careamics/config/transformations/transform_model.py +1 -0
  21. careamics/config/transformations/xy_random_rotate90_model.py +7 -9
  22. careamics/config/validators/validator_utils.py +1 -0
  23. careamics/conftest.py +1 -0
  24. careamics/dataset/dataset_utils/__init__.py +0 -1
  25. careamics/dataset/dataset_utils/dataset_utils.py +1 -0
  26. careamics/dataset/in_memory_dataset.py +17 -48
  27. careamics/dataset/iterable_dataset.py +16 -71
  28. careamics/dataset/patching/__init__.py +0 -7
  29. careamics/dataset/patching/patching.py +1 -0
  30. careamics/dataset/patching/sequential_patching.py +6 -6
  31. careamics/dataset/patching/tiled_patching.py +10 -6
  32. careamics/lightning_datamodule.py +123 -49
  33. careamics/lightning_module.py +7 -7
  34. careamics/lightning_prediction_datamodule.py +59 -48
  35. careamics/losses/__init__.py +0 -1
  36. careamics/losses/loss_factory.py +1 -0
  37. careamics/model_io/__init__.py +0 -1
  38. careamics/model_io/bioimage/_readme_factory.py +2 -1
  39. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  40. careamics/model_io/bioimage/model_description.py +4 -3
  41. careamics/model_io/bmz_io.py +8 -7
  42. careamics/model_io/model_io_utils.py +4 -4
  43. careamics/models/layers.py +1 -0
  44. careamics/models/model_factory.py +1 -0
  45. careamics/models/unet.py +91 -17
  46. careamics/prediction/stitch_prediction.py +1 -0
  47. careamics/transforms/__init__.py +2 -23
  48. careamics/transforms/compose.py +98 -0
  49. careamics/transforms/n2v_manipulate.py +18 -23
  50. careamics/transforms/nd_flip.py +38 -64
  51. careamics/transforms/normalize.py +45 -34
  52. careamics/transforms/pixel_manipulation.py +2 -2
  53. careamics/transforms/transform.py +33 -0
  54. careamics/transforms/tta.py +2 -2
  55. careamics/transforms/xy_random_rotate90.py +41 -68
  56. careamics/utils/__init__.py +0 -1
  57. careamics/utils/context.py +1 -0
  58. careamics/utils/logging.py +1 -0
  59. careamics/utils/metrics.py +1 -0
  60. careamics/utils/torch_utils.py +1 -0
  61. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/METADATA +16 -61
  62. careamics-0.1.0rc5.dist-info/RECORD +111 -0
  63. careamics/dataset/patching/patch_transform.py +0 -44
  64. careamics-0.1.0rc3.dist-info/RECORD +0 -109
  65. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/WHEEL +0 -0
  66. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,5 @@
1
1
  """Model I/O utilities."""
2
2
 
3
-
4
3
  __all__ = ["load_pretrained", "export_to_bmz"]
5
4
 
6
5
 
@@ -1,4 +1,5 @@
1
1
  """Functions used to create a README.md file for BMZ export."""
2
+
2
3
  from pathlib import Path
3
4
  from typing import Optional
4
5
 
@@ -117,4 +118,4 @@ def readme_factory(
117
118
 
118
119
  readme.write_text("".join(description))
119
120
 
120
- return readme
121
+ return readme.absolute()
@@ -1,4 +1,5 @@
1
1
  """Bioimage.io utils."""
2
+
2
3
  from pathlib import Path
3
4
  from typing import Union
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Module use to build BMZ model description."""
2
+
2
3
  from pathlib import Path
3
4
  from typing import List, Optional, Tuple, Union
4
5
 
@@ -26,14 +27,14 @@ from bioimageio.spec.model.v0_5 import (
26
27
  WeightsDescr,
27
28
  )
28
29
 
29
- from careamics.config import Configuration, DataModel
30
+ from careamics.config import Configuration, DataConfig
30
31
 
31
32
  from ._readme_factory import readme_factory
32
33
 
33
34
 
34
35
  def _create_axes(
35
36
  array: np.ndarray,
36
- data_config: DataModel,
37
+ data_config: DataConfig,
37
38
  channel_names: Optional[List[str]] = None,
38
39
  is_input: bool = True,
39
40
  ) -> List[AxisBase]:
@@ -100,7 +101,7 @@ def _create_axes(
100
101
  def _create_inputs_ouputs(
101
102
  input_array: np.ndarray,
102
103
  output_array: np.ndarray,
103
- data_config: DataModel,
104
+ data_config: DataConfig,
104
105
  input_path: Union[Path, str],
105
106
  output_path: Union[Path, str],
106
107
  channel_names: Optional[List[str]] = None,
@@ -1,4 +1,5 @@
1
1
  """Function to export to the BioImage Model Zoo format."""
2
+
2
3
  import tempfile
3
4
  from pathlib import Path
4
5
  from typing import List, Optional, Tuple, Union
@@ -11,7 +12,7 @@ from torch import __version__, load, save
11
12
 
12
13
  from careamics.config import Configuration, load_configuration, save_configuration
13
14
  from careamics.config.support import SupportedArchitecture
14
- from careamics.lightning_module import CAREamicsKiln
15
+ from careamics.lightning_module import CAREamicsModule
15
16
 
16
17
  from .bioimage import (
17
18
  create_env_text,
@@ -21,7 +22,7 @@ from .bioimage import (
21
22
  )
22
23
 
23
24
 
24
- def _export_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> Path:
25
+ def _export_state_dict(model: CAREamicsModule, path: Union[Path, str]) -> Path:
25
26
  """
26
27
  Export the model state dictionary to a file.
27
28
 
@@ -51,7 +52,7 @@ def _export_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> Path:
51
52
  return path
52
53
 
53
54
 
54
- def _load_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> None:
55
+ def _load_state_dict(model: CAREamicsModule, path: Union[Path, str]) -> None:
55
56
  """
56
57
  Load a model from a state dictionary.
57
58
 
@@ -73,7 +74,7 @@ def _load_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> None:
73
74
 
74
75
  # TODO break down in subfunctions
75
76
  def export_to_bmz(
76
- model: CAREamicsKiln,
77
+ model: CAREamicsModule,
77
78
  config: Configuration,
78
79
  path: Union[Path, str],
79
80
  name: str,
@@ -177,7 +178,7 @@ def export_to_bmz(
177
178
  )
178
179
 
179
180
  # test model description
180
- summary: ValidationSummary = test_model(model_description)
181
+ summary: ValidationSummary = test_model(model_description, decimal=0)
181
182
  if summary.status == "failed":
182
183
  raise ValueError(f"Model description test failed: {summary}")
183
184
 
@@ -185,7 +186,7 @@ def export_to_bmz(
185
186
  save_bioimageio_package(model_description, output_path=path)
186
187
 
187
188
 
188
- def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]:
189
+ def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
189
190
  """Load a model from a BioImage Model Zoo archive.
190
191
 
191
192
  Parameters
@@ -223,7 +224,7 @@ def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]
223
224
  config = load_configuration(config_path)
224
225
 
225
226
  # create careamics lightning module
226
- model = CAREamicsKiln(algorithm_config=config.algorithm_config)
227
+ model = CAREamicsModule(algorithm_config=config.algorithm_config)
227
228
 
228
229
  # load model state dictionary
229
230
  _load_state_dict(model, weights_path)
@@ -6,12 +6,12 @@ from typing import Tuple, Union
6
6
  from torch import load
7
7
 
8
8
  from careamics.config import Configuration
9
- from careamics.lightning_module import CAREamicsKiln
9
+ from careamics.lightning_module import CAREamicsModule
10
10
  from careamics.model_io.bmz_io import load_from_bmz
11
11
  from careamics.utils import check_path_exists
12
12
 
13
13
 
14
- def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]:
14
+ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
15
15
  """
16
16
  Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
17
17
 
@@ -44,7 +44,7 @@ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuratio
44
44
  )
45
45
 
46
46
 
47
- def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]:
47
+ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
48
48
  """
49
49
  Load a model from a checkpoint and return both model and configuration.
50
50
 
@@ -75,6 +75,6 @@ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configurati
75
75
  f"checkpoint: {checkpoint.keys()}"
76
76
  ) from e
77
77
 
78
- model = CAREamicsKiln.load_from_checkpoint(path)
78
+ model = CAREamicsModule.load_from_checkpoint(path)
79
79
 
80
80
  return model, Configuration(**cfg_dict)
@@ -3,6 +3,7 @@ Layer module.
3
3
 
4
4
  This submodule contains layers used in the CAREamics models.
5
5
  """
6
+
6
7
  from typing import List, Optional, Tuple, Union
7
8
 
8
9
  import torch
@@ -3,6 +3,7 @@ Model factory.
3
3
 
4
4
  Model creation factory functions.
5
5
  """
6
+
6
7
  from typing import Union
7
8
 
8
9
  import torch
careamics/models/unet.py CHANGED
@@ -3,7 +3,8 @@ UNet model.
3
3
 
4
4
  A UNet encoder, decoder and complete model.
5
5
  """
6
- from typing import Any, List, Union
6
+
7
+ from typing import Any, List, Tuple, Union
7
8
 
8
9
  import torch
9
10
  import torch.nn as nn
@@ -33,6 +34,9 @@ class UnetEncoder(nn.Module):
33
34
  Dropout probability, by default 0.0.
34
35
  pool_kernel : int, optional
35
36
  Kernel size for the max pooling layers, by default 2.
37
+ groups: int, optional
38
+ Number of blocked connections from input channels to output
39
+ channels, by default 1.
36
40
  """
37
41
 
38
42
  def __init__(
@@ -45,6 +49,7 @@ class UnetEncoder(nn.Module):
45
49
  dropout: float = 0.0,
46
50
  pool_kernel: int = 2,
47
51
  n2v2: bool = False,
52
+ groups: int = 1,
48
53
  ) -> None:
49
54
  """
50
55
  Constructor.
@@ -65,6 +70,9 @@ class UnetEncoder(nn.Module):
65
70
  Dropout probability, by default 0.0.
66
71
  pool_kernel : int, optional
67
72
  Kernel size for the max pooling layers, by default 2.
73
+ groups: int, optional
74
+ Number of blocked connections from input channels to output
75
+ channels, by default 1.
68
76
  """
69
77
  super().__init__()
70
78
 
@@ -77,7 +85,7 @@ class UnetEncoder(nn.Module):
77
85
  encoder_blocks = []
78
86
 
79
87
  for n in range(depth):
80
- out_channels = num_channels_init * (2**n)
88
+ out_channels = num_channels_init * (2**n) * groups
81
89
  in_channels = in_channels if n == 0 else out_channels // 2
82
90
  encoder_blocks.append(
83
91
  Conv_Block(
@@ -86,6 +94,7 @@ class UnetEncoder(nn.Module):
86
94
  out_channels=out_channels,
87
95
  dropout_perc=dropout,
88
96
  use_batch_norm=use_batch_norm,
97
+ groups=groups,
89
98
  )
90
99
  )
91
100
  encoder_blocks.append(self.pooling)
@@ -131,6 +140,9 @@ class UnetDecoder(nn.Module):
131
140
  Whether to use batch normalization, by default True.
132
141
  dropout : float, optional
133
142
  Dropout probability, by default 0.0.
143
+ groups: int, optional
144
+ Number of blocked connections from input channels to output
145
+ channels, by default 1.
134
146
  """
135
147
 
136
148
  def __init__(
@@ -141,6 +153,7 @@ class UnetDecoder(nn.Module):
141
153
  use_batch_norm: bool = True,
142
154
  dropout: float = 0.0,
143
155
  n2v2: bool = False,
156
+ groups: int = 1,
144
157
  ) -> None:
145
158
  """
146
159
  Constructor.
@@ -157,15 +170,19 @@ class UnetDecoder(nn.Module):
157
170
  Whether to use batch normalization, by default True.
158
171
  dropout : float, optional
159
172
  Dropout probability, by default 0.0.
173
+ groups: int, optional
174
+ Number of blocked connections from input channels to output
175
+ channels, by default 1.
160
176
  """
161
177
  super().__init__()
162
178
 
163
179
  upsampling = nn.Upsample(
164
180
  scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
165
181
  )
166
- in_channels = out_channels = num_channels_init * 2 ** (depth - 1)
182
+ in_channels = out_channels = num_channels_init * groups * (2 ** (depth - 1))
167
183
 
168
184
  self.n2v2 = n2v2
185
+ self.groups = groups
169
186
 
170
187
  self.bottleneck = Conv_Block(
171
188
  conv_dim,
@@ -174,34 +191,32 @@ class UnetDecoder(nn.Module):
174
191
  intermediate_channel_multiplier=2,
175
192
  use_batch_norm=use_batch_norm,
176
193
  dropout_perc=dropout,
194
+ groups=self.groups,
177
195
  )
178
196
 
179
- decoder_blocks = []
197
+ decoder_blocks: List[nn.Module] = []
180
198
  for n in range(depth):
181
199
  decoder_blocks.append(upsampling)
182
- in_channels = (
183
- num_channels_init ** (depth - n)
184
- if (self.n2v2 and n == depth - 1)
185
- else num_channels_init * 2 ** (depth - n)
186
- )
200
+ in_channels = (num_channels_init * 2 ** (depth - n)) * groups
187
201
  out_channels = in_channels // 2
188
202
  decoder_blocks.append(
189
203
  Conv_Block(
190
204
  conv_dim,
191
- in_channels=in_channels + in_channels // 2
192
- if n > 0
193
- else in_channels,
205
+ in_channels=(
206
+ in_channels + in_channels // 2 if n > 0 else in_channels
207
+ ),
194
208
  out_channels=out_channels,
195
209
  intermediate_channel_multiplier=2,
196
210
  dropout_perc=dropout,
197
211
  activation="ReLU",
198
212
  use_batch_norm=use_batch_norm,
213
+ groups=groups,
199
214
  )
200
215
  )
201
216
 
202
217
  self.decoder_blocks = nn.ModuleList(decoder_blocks)
203
218
 
204
- def forward(self, *features: List[torch.Tensor]) -> torch.Tensor:
219
+ def forward(self, *features: torch.Tensor) -> torch.Tensor:
205
220
  """
206
221
  Forward pass.
207
222
 
@@ -217,20 +232,70 @@ class UnetDecoder(nn.Module):
217
232
  Output of the decoder.
218
233
  """
219
234
  x: torch.Tensor = features[0]
220
- skip_connections: torch.Tensor = features[1:][::-1]
235
+ skip_connections: Tuple[torch.Tensor, ...] = features[-1:0:-1]
221
236
 
222
237
  x = self.bottleneck(x)
223
238
 
224
239
  for i, module in enumerate(self.decoder_blocks):
225
240
  x = module(x)
226
241
  if isinstance(module, nn.Upsample):
242
+ # divide index by 2 because of upsampling layers
243
+ skip_connection: torch.Tensor = skip_connections[i // 2]
227
244
  if self.n2v2:
228
245
  if x.shape != skip_connections[-1].shape:
229
- x = torch.cat([x, skip_connections[i // 2]], axis=1)
246
+ x = self._interleave(x, skip_connection, self.groups)
230
247
  else:
231
- x = torch.cat([x, skip_connections[i // 2]], axis=1)
248
+ x = self._interleave(x, skip_connection, self.groups)
232
249
  return x
233
250
 
251
+ @staticmethod
252
+ def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor:
253
+ """
254
+ Splits the tensors `A` and `B` into equally sized groups along the
255
+ channel axis (axis=1); then concatenates the groups in alternating
256
+ order along the channel axis, starting with the first group from tensor
257
+ A.
258
+
259
+ Parameters
260
+ ----------
261
+ A: torch.Tensor
262
+ B: torch.Tensor
263
+ groups: int
264
+ The number of groups.
265
+
266
+ Returns
267
+ -------
268
+ torch.Tensor
269
+
270
+ Raises
271
+ ------
272
+ ValueError:
273
+ If either of `A` or `B`'s channel axis is not divisible by `groups`.
274
+ """
275
+ if (A.shape[1] % groups != 0) or (B.shape[1] % groups != 0):
276
+ raise ValueError(f"Number of channels not divisible by {groups} groups.")
277
+
278
+ m = A.shape[1] // groups
279
+ n = B.shape[1] // groups
280
+
281
+ A_groups: List[torch.Tensor] = [
282
+ A[:, i * m : (i + 1) * m] for i in range(groups)
283
+ ]
284
+ B_groups: List[torch.Tensor] = [
285
+ B[:, i * n : (i + 1) * n] for i in range(groups)
286
+ ]
287
+
288
+ interleaved = torch.cat(
289
+ [
290
+ tensor_list[i]
291
+ for i in range(groups)
292
+ for tensor_list in [A_groups, B_groups]
293
+ ],
294
+ dim=1,
295
+ )
296
+
297
+ return interleaved
298
+
234
299
 
235
300
  class UNet(nn.Module):
236
301
  """
@@ -273,6 +338,7 @@ class UNet(nn.Module):
273
338
  pool_kernel: int = 2,
274
339
  final_activation: Union[SupportedActivation, str] = SupportedActivation.NONE,
275
340
  n2v2: bool = False,
341
+ independent_channels: bool = True,
276
342
  **kwargs: Any,
277
343
  ) -> None:
278
344
  """
@@ -298,9 +364,14 @@ class UNet(nn.Module):
298
364
  Kernel size of the pooling layers, by default 2.
299
365
  last_activation : Optional[Callable], optional
300
366
  Activation function to use for the last layer, by default None.
367
+ independent_channels : bool
368
+ Whether to train parallel independent networks for each channel, by
369
+ default True.
301
370
  """
302
371
  super().__init__()
303
372
 
373
+ groups = in_channels if independent_channels else 1
374
+
304
375
  self.encoder = UnetEncoder(
305
376
  conv_dims,
306
377
  in_channels=in_channels,
@@ -310,6 +381,7 @@ class UNet(nn.Module):
310
381
  dropout=dropout,
311
382
  pool_kernel=pool_kernel,
312
383
  n2v2=n2v2,
384
+ groups=groups,
313
385
  )
314
386
 
315
387
  self.decoder = UnetDecoder(
@@ -319,11 +391,13 @@ class UNet(nn.Module):
319
391
  use_batch_norm=use_batch_norm,
320
392
  dropout=dropout,
321
393
  n2v2=n2v2,
394
+ groups=groups,
322
395
  )
323
396
  self.final_conv = getattr(nn, f"Conv{conv_dims}d")(
324
- in_channels=num_channels_init,
397
+ in_channels=num_channels_init * groups,
325
398
  out_channels=num_classes,
326
399
  kernel_size=1,
400
+ groups=groups,
327
401
  )
328
402
  self.final_activation = get_activation(final_activation)
329
403
 
@@ -3,6 +3,7 @@ Prediction convenience functions.
3
3
 
4
4
  These functions are used during prediction.
5
5
  """
6
+
6
7
  from typing import List
7
8
 
8
9
  import numpy as np
@@ -8,34 +8,13 @@ __all__ = [
8
8
  "ImageRestorationTTA",
9
9
  "Denormalize",
10
10
  "Normalize",
11
+ "Compose",
11
12
  ]
12
13
 
13
14
 
15
+ from .compose import Compose, get_all_transforms
14
16
  from .n2v_manipulate import N2VManipulate
15
17
  from .nd_flip import NDFlip
16
18
  from .normalize import Denormalize, Normalize
17
19
  from .tta import ImageRestorationTTA
18
20
  from .xy_random_rotate90 import XYRandomRotate90
19
-
20
- ALL_TRANSFORMS = {
21
- "Normalize": Normalize,
22
- "N2VManipulate": N2VManipulate,
23
- "NDFlip": NDFlip,
24
- "XYRandomRotate90": XYRandomRotate90,
25
- }
26
-
27
-
28
- def get_all_transforms() -> dict:
29
- """Return all the transforms accepted by CAREamics.
30
-
31
- Note that while CAREamics accepts any `Compose` transforms from Albumentations (see
32
- https://albumentations.ai/), only a few transformations are explicitely supported
33
- (see `SupportedTransform`).
34
-
35
- Returns
36
- -------
37
- dict
38
- A dictionary with all the transforms accepted by CAREamics, where the keys are
39
- the transform names and the values are the transform classes.
40
- """
41
- return ALL_TRANSFORMS
@@ -0,0 +1,98 @@
1
+ """A class chaining transforms together."""
2
+
3
+ from typing import Callable, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from careamics.config.data_model import TRANSFORMS_UNION
8
+
9
+ from .n2v_manipulate import N2VManipulate
10
+ from .nd_flip import NDFlip
11
+ from .normalize import Normalize
12
+ from .transform import Transform
13
+ from .xy_random_rotate90 import XYRandomRotate90
14
+
15
+ ALL_TRANSFORMS = {
16
+ "Normalize": Normalize,
17
+ "N2VManipulate": N2VManipulate,
18
+ "NDFlip": NDFlip,
19
+ "XYRandomRotate90": XYRandomRotate90,
20
+ }
21
+
22
+
23
+ def get_all_transforms() -> dict:
24
+ """Return all the transforms accepted by CAREamics.
25
+
26
+ Returns
27
+ -------
28
+ dict
29
+ A dictionary with all the transforms accepted by CAREamics, where the keys are
30
+ the transform names and the values are the transform classes.
31
+ """
32
+ return ALL_TRANSFORMS
33
+
34
+
35
+ class Compose:
36
+ """A class chaining transforms together."""
37
+
38
+ def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None:
39
+ """Instantiate a Compose object.
40
+
41
+ Parameters
42
+ ----------
43
+ transform_list : List[TRANSFORMS_UNION]
44
+ A list of dictionaries where each dictionary contains the name of a
45
+ transform and its parameters.
46
+ """
47
+ # retrieve all available transforms
48
+ all_transforms = get_all_transforms()
49
+
50
+ # instantiate all transforms
51
+ transforms = [all_transforms[t.name](**t.model_dump()) for t in transform_list]
52
+
53
+ self._callable_transforms = self._chain_transforms(transforms)
54
+
55
+ def _chain_transforms(self, transforms: List[Transform]) -> Callable:
56
+ """Chain the transforms together.
57
+
58
+ Parameters
59
+ ----------
60
+ transforms : List[Transform]
61
+ A list of transforms to chain together.
62
+
63
+ Returns
64
+ -------
65
+ Callable
66
+ A callable that applies the transforms in order to the input data.
67
+ """
68
+
69
+ def _chain(
70
+ patch: np.ndarray, target: Optional[np.ndarray]
71
+ ) -> Tuple[np.ndarray, ...]:
72
+ params = (patch, target)
73
+
74
+ for t in transforms:
75
+ params = t(*params)
76
+
77
+ return params
78
+
79
+ return _chain
80
+
81
+ def __call__(
82
+ self, patch: np.ndarray, target: Optional[np.ndarray] = None
83
+ ) -> Tuple[np.ndarray, ...]:
84
+ """Apply the transforms to the input data.
85
+
86
+ Parameters
87
+ ----------
88
+ patch : np.ndarray
89
+ The input data.
90
+ target : Optional[np.ndarray], optional
91
+ Target data, by default None
92
+
93
+ Returns
94
+ -------
95
+ Tuple[np.ndarray, ...]
96
+ The output of the transformations.
97
+ """
98
+ return self._callable_transforms(patch, target)
@@ -1,19 +1,19 @@
1
1
  from typing import Any, Literal, Optional, Tuple
2
2
 
3
3
  import numpy as np
4
- from albumentations import ImageOnlyTransform
5
4
 
6
5
  from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
6
+ from careamics.transforms.transform import Transform
7
7
 
8
8
  from .pixel_manipulation import median_manipulate, uniform_manipulate
9
9
  from .struct_mask_parameters import StructMaskParameters
10
10
 
11
11
 
12
- class N2VManipulate(ImageOnlyTransform):
12
+ class N2VManipulate(Transform):
13
13
  """
14
14
  Default augmentation for the N2V model.
15
15
 
16
- This transform expects (Z)YXC dimensions.
16
+ This transform expects C(Z)YX dimensions.
17
17
 
18
18
  Parameters
19
19
  ----------
@@ -33,6 +33,7 @@ class N2VManipulate(ImageOnlyTransform):
33
33
  remove_center: bool = True,
34
34
  struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
35
35
  struct_mask_span: int = 5,
36
+ seed: Optional[int] = None, # TODO use in pixel manipulation
36
37
  ):
37
38
  """Constructor.
38
39
 
@@ -50,8 +51,9 @@ class N2VManipulate(ImageOnlyTransform):
50
51
  StructN2V mask axis, by default "none"
51
52
  struct_mask_span : int, optional
52
53
  StructN2V mask span, by default 5
54
+ seed : Optional[int], optional
55
+ Random seed, by default None
53
56
  """
54
- super().__init__(p=1)
55
57
  self.masked_pixel_percentage = masked_pixel_percentage
56
58
  self.roi_size = roi_size
57
59
  self.strategy = strategy
@@ -65,23 +67,26 @@ class N2VManipulate(ImageOnlyTransform):
65
67
  span=struct_mask_span,
66
68
  )
67
69
 
68
- def apply(
69
- self, patch: np.ndarray, **kwargs: Any
70
+ # numpy random generator
71
+ self.rng = np.random.default_rng(seed=seed)
72
+
73
+ def __call__(
74
+ self, patch: np.ndarray, *args: Any, **kwargs: Any
70
75
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
71
76
  """Apply the transform to the image.
72
77
 
73
78
  Parameters
74
79
  ----------
75
80
  image : np.ndarray
76
- Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
81
+ Image or image patch, 2D or 3D, shape C(Z)YX.
77
82
  """
78
83
  masked = np.zeros_like(patch)
79
84
  mask = np.zeros_like(patch)
80
85
  if self.strategy == SupportedPixelManipulation.UNIFORM:
81
86
  # Iterate over the channels to apply manipulation separately
82
- for c in range(patch.shape[-1]):
83
- masked[..., c], mask[..., c] = uniform_manipulate(
84
- patch=patch[..., c],
87
+ for c in range(patch.shape[0]):
88
+ masked[c, ...], mask[c, ...] = uniform_manipulate(
89
+ patch=patch[c, ...],
85
90
  mask_pixel_percentage=self.masked_pixel_percentage,
86
91
  subpatch_size=self.roi_size,
87
92
  remove_center=self.remove_center,
@@ -89,9 +94,9 @@ class N2VManipulate(ImageOnlyTransform):
89
94
  )
90
95
  elif self.strategy == SupportedPixelManipulation.MEDIAN:
91
96
  # Iterate over the channels to apply manipulation separately
92
- for c in range(patch.shape[-1]):
93
- masked[..., c], mask[..., c] = median_manipulate(
94
- patch=patch[..., c],
97
+ for c in range(patch.shape[0]):
98
+ masked[c, ...], mask[c, ...] = median_manipulate(
99
+ patch=patch[c, ...],
95
100
  mask_pixel_percentage=self.masked_pixel_percentage,
96
101
  subpatch_size=self.roi_size,
97
102
  struct_params=self.struct_mask,
@@ -101,13 +106,3 @@ class N2VManipulate(ImageOnlyTransform):
101
106
 
102
107
  # TODO why return patch?
103
108
  return masked, patch, mask
104
-
105
- def get_transform_init_args_names(self) -> Tuple[str, ...]:
106
- """Get the transform parameters.
107
-
108
- Returns
109
- -------
110
- Tuple[str, ...]
111
- Transform parameters.
112
- """
113
- return ("roi_size", "masked_pixel_percentage", "strategy", "struct_mask")