careamics 0.0.3__py3-none-any.whl → 0.0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (56) hide show
  1. careamics/careamist.py +25 -17
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/architectures/lvae_model.py +0 -4
  6. careamics/config/configuration_factory.py +480 -177
  7. careamics/config/configuration_model.py +1 -2
  8. careamics/config/data_model.py +1 -15
  9. careamics/config/fcn_algorithm_model.py +14 -9
  10. careamics/config/likelihood_model.py +21 -4
  11. careamics/config/nm_model.py +31 -5
  12. careamics/config/optimizer_models.py +3 -1
  13. careamics/config/support/supported_optimizers.py +1 -1
  14. careamics/config/support/supported_transforms.py +1 -0
  15. careamics/config/training_model.py +35 -6
  16. careamics/config/transformations/__init__.py +4 -1
  17. careamics/config/transformations/transform_union.py +20 -0
  18. careamics/config/vae_algorithm_model.py +2 -36
  19. careamics/dataset/tiling/lvae_tiled_patching.py +90 -8
  20. careamics/lightning/lightning_module.py +10 -8
  21. careamics/lightning/train_data_module.py +2 -2
  22. careamics/losses/loss_factory.py +3 -3
  23. careamics/losses/lvae/losses.py +2 -2
  24. careamics/lvae_training/dataset/__init__.py +15 -0
  25. careamics/lvae_training/dataset/{vae_data_config.py → config.py} +25 -81
  26. careamics/lvae_training/dataset/lc_dataset.py +28 -20
  27. careamics/lvae_training/dataset/{vae_dataset.py → multich_dataset.py} +91 -51
  28. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  29. careamics/lvae_training/dataset/types.py +43 -0
  30. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  31. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  32. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  33. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  34. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  35. careamics/lvae_training/eval_utils.py +109 -64
  36. careamics/lvae_training/get_config.py +1 -1
  37. careamics/lvae_training/train_lvae.py +1 -1
  38. careamics/model_io/bioimage/bioimage_utils.py +4 -2
  39. careamics/model_io/bmz_io.py +6 -5
  40. careamics/models/lvae/likelihoods.py +18 -9
  41. careamics/models/lvae/lvae.py +12 -16
  42. careamics/models/lvae/noise_models.py +1 -1
  43. careamics/transforms/compose.py +90 -15
  44. careamics/transforms/n2v_manipulate.py +6 -2
  45. careamics/transforms/normalize.py +14 -3
  46. careamics/transforms/xy_flip.py +16 -6
  47. careamics/transforms/xy_random_rotate90.py +16 -7
  48. careamics/utils/metrics.py +204 -24
  49. careamics/utils/serializers.py +60 -0
  50. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/METADATA +4 -3
  51. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/RECORD +54 -43
  52. careamics-0.0.4.1.dist-info/entry_points.txt +2 -0
  53. careamics/lvae_training/dataset/data_utils.py +0 -701
  54. careamics/lvae_training/dataset/lc_dataset_config.py +0 -13
  55. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/WHEEL +0 -0
  56. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -22,7 +22,7 @@ def get_unzip_path(zip_path: Union[Path, str]) -> Path:
22
22
  return zip_path.parent / (str(zip_path.name) + ".unzip")
23
23
 
24
24
 
25
- def create_env_text(pytorch_version: str) -> str:
25
+ def create_env_text(pytorch_version: str, torchvision_version: str) -> str:
26
26
  """Create environment yaml content for the bioimage model.
27
27
 
28
28
  This installs an environment with the specified pytorch version and the latest
@@ -32,6 +32,8 @@ def create_env_text(pytorch_version: str) -> str:
32
32
  ----------
33
33
  pytorch_version : str
34
34
  Pytorch version.
35
+ torchvision_version : str
36
+ Torchvision version.
35
37
 
36
38
  Returns
37
39
  -------
@@ -43,7 +45,7 @@ def create_env_text(pytorch_version: str) -> str:
43
45
  f"dependencies:\n"
44
46
  f" - python=3.10\n"
45
47
  f" - pytorch={pytorch_version}\n"
46
- f" - torchvision={pytorch_version}\n"
48
+ f" - torchvision={torchvision_version}\n"
47
49
  f" - pip\n"
48
50
  f" - pip:\n"
49
51
  f" - git+https://github.com/CAREamics/careamics.git\n"
@@ -8,7 +8,9 @@ import numpy as np
8
8
  import pkg_resources
9
9
  from bioimageio.core import load_description, test_model
10
10
  from bioimageio.spec import ValidationSummary, save_bioimageio_package
11
- from torch import __version__, load, save
11
+ from torch import __version__ as PYTORCH_VERSION
12
+ from torch import load, save
13
+ from torchvision import __version__ as TORCHVISION_VERSION
12
14
 
13
15
  from careamics.config import Configuration, load_configuration, save_configuration
14
16
  from careamics.config.support import SupportedArchitecture
@@ -141,7 +143,6 @@ def export_to_bmz(
141
143
  path_to_archive.parent.mkdir(parents=True, exist_ok=True)
142
144
 
143
145
  # versions
144
- pytorch_version = __version__
145
146
  careamics_version = pkg_resources.get_distribution("careamics").version
146
147
 
147
148
  # save files in temporary folder
@@ -151,7 +152,7 @@ def export_to_bmz(
151
152
  # create environment file
152
153
  # TODO move in bioimage module
153
154
  env_path = temp_path / "environment.yml"
154
- env_path.write_text(create_env_text(pytorch_version))
155
+ env_path.write_text(create_env_text(PYTORCH_VERSION, TORCHVISION_VERSION))
155
156
 
156
157
  # export input and ouputs
157
158
  inputs = temp_path / "inputs.npy"
@@ -174,7 +175,7 @@ def export_to_bmz(
174
175
  inputs=inputs,
175
176
  outputs=outputs,
176
177
  weights_path=weight_path,
177
- torch_version=pytorch_version,
178
+ torch_version=PYTORCH_VERSION,
178
179
  careamics_version=careamics_version,
179
180
  config_path=config_path,
180
181
  env_path=env_path,
@@ -183,7 +184,7 @@ def export_to_bmz(
183
184
  )
184
185
 
185
186
  # test model description
186
- summary: ValidationSummary = test_model(model_description, decimal=1)
187
+ summary: ValidationSummary = test_model(model_description)
187
188
  if summary.status == "failed":
188
189
  raise ValueError(f"Model description test failed: {summary}")
189
190
 
@@ -7,6 +7,7 @@ from __future__ import annotations
7
7
  import math
8
8
  from typing import Literal, Union, TYPE_CHECKING, Any, Optional
9
9
 
10
+ import numpy as np
10
11
  import torch
11
12
  from torch import nn
12
13
 
@@ -287,30 +288,37 @@ class NoiseModelLikelihood(LikelihoodModule):
287
288
 
288
289
  def __init__(
289
290
  self,
290
- data_mean: torch.Tensor,
291
- data_std: torch.Tensor,
292
- noiseModel: NoiseModel, # TODO: check the type -> couldn't manage due to circular imports...
291
+ data_mean: Union[np.ndarray, torch.Tensor],
292
+ data_std: Union[np.ndarray, torch.Tensor],
293
+ noiseModel: NoiseModel,
293
294
  ):
294
295
  """Constructor.
295
296
 
296
297
  Parameters
297
298
  ----------
298
- data_mean: torch.Tensor
299
+ data_mean: Union[np.ndarray, torch.Tensor]
299
300
  The mean of the data, used to unnormalize data for noise model evaluation.
300
- data_std: torch.Tensor
301
+ data_std: Union[np.ndarray, torch.Tensor]
301
302
  The standard deviation of the data, used to unnormalize data for noise
302
303
  model evaluation.
303
304
  noiseModel: NoiseModel
304
305
  The noise model instance used to compute the likelihood.
305
306
  """
306
307
  super().__init__()
307
- self.data_mean = data_mean
308
- self.data_std = data_std
308
+ self.data_mean = torch.Tensor(data_mean)
309
+ self.data_std = torch.Tensor(data_std)
309
310
  self.noiseModel = noiseModel
310
311
 
311
- def set_params_to_same_device_as(
312
+ def _set_params_to_same_device_as(
312
313
  self, correct_device_tensor: torch.Tensor
313
- ) -> None: # TODO: needed?
314
+ ) -> None:
315
+ """Set the parameters to the same device as the input tensor.
316
+
317
+ Parameters
318
+ ----------
319
+ correct_device_tensor: torch.Tensor
320
+ The tensor whose device is used to set the parameters.
321
+ """
314
322
  if self.data_mean.device != correct_device_tensor.device:
315
323
  self.data_mean = self.data_mean.to(correct_device_tensor.device)
316
324
  self.data_std = self.data_std.to(correct_device_tensor.device)
@@ -355,6 +363,7 @@ class NoiseModelLikelihood(LikelihoodModule):
355
363
  torch.Tensor
356
364
  The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
357
365
  """
366
+ self._set_params_to_same_device_as(x)
358
367
  predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
359
368
  x_denormalized = x * self.data_std + self.data_mean
360
369
  likelihoods = self.noiseModel.likelihood(
@@ -38,7 +38,6 @@ class LadderVAE(nn.Module):
38
38
  decoder_dropout: float,
39
39
  nonlinearity: str,
40
40
  predict_logvar: bool,
41
- enable_noise_model: bool,
42
41
  analytical_kl: bool,
43
42
  ):
44
43
  """
@@ -62,15 +61,12 @@ class LadderVAE(nn.Module):
62
61
  self.decoder_dropout = decoder_dropout
63
62
  self.nonlin = nonlinearity
64
63
  self.predict_logvar = predict_logvar
65
- self.enable_noise_model = enable_noise_model
66
-
67
64
  self.analytical_kl = analytical_kl
68
65
  # -------------------------------------------------------
69
66
 
70
67
  # -------------------------------------------------------
71
68
  # Model attributes -> Hardcoded
72
69
  self.model_type = ModelType.LadderVae # TODO remove !
73
- self.model_type = ModelType.LadderVae # TODO remove !
74
70
  self.encoder_blocks_per_layer = 1
75
71
  self.decoder_blocks_per_layer = 1
76
72
  self.bottomup_batchnorm = True
@@ -94,13 +90,6 @@ class LadderVAE(nn.Module):
94
90
  self._stochastic_use_naive_exponential = False
95
91
  self._enable_topdown_normalize_factor = True
96
92
 
97
- # Noise model attributes -> Hardcoded
98
- self.noise_model_type = "gmm"
99
- self.denoise_channel = (
100
- "input" # 4 values for denoise_channel {'Ch1', 'Ch2', 'input','all'}
101
- )
102
- self.noise_model_learnable = False
103
-
104
93
  # Attributes that handle LC -> Hardcoded
105
94
  self.enable_multiscale = (
106
95
  self._multiscale_count is not None and self._multiscale_count > 1
@@ -806,11 +795,18 @@ class LadderVAE(nn.Module):
806
795
 
807
796
  # return samples
808
797
 
809
- # def reset_for_different_output_size(self, output_size):
810
- # for i in range(self.n_layers):
811
- # sz = output_size // 2**(1 + i)
812
- # self.bottom_up_layers[i].output_expected_shape = (sz, sz)
813
- # self.top_down_layers[i].latent_shape = (output_size, output_size)
798
+ def reset_for_different_output_size(self, output_size: int) -> None:
799
+ """Reset shape of output and latent tensors for different output size.
800
+
801
+ Used during evaluation to reset expected shapes of tensors when
802
+ input/output shape changes.
803
+ For instance, it is needed when the model was trained on, say, 64x64 sized
804
+ patches, but prediction is done on 128x128 patches.
805
+ """
806
+ for i in range(self.n_layers):
807
+ sz = output_size // 2 ** (1 + i)
808
+ self.bottom_up_layers[i].output_expected_shape = (sz, sz)
809
+ self.top_down_layers[i].latent_shape = (output_size, output_size)
814
810
 
815
811
  def pad_input(self, x):
816
812
  """
@@ -76,7 +76,7 @@ def train_gm_noise_model(
76
76
  # TODO any training params ? Different channels ?
77
77
  noise_model = GaussianMixtureNoiseModel(model_config)
78
78
  # TODO revisit config unpacking
79
- noise_model.train_noise_model(noise_model.signal, noise_model.observation)
79
+ noise_model.train_noise_model(model_config.signal, model_config.observation)
80
80
  return noise_model
81
81
 
82
82
 
@@ -1,13 +1,14 @@
1
1
  """A class chaining transforms together."""
2
2
 
3
- from typing import Dict, List, Optional, Tuple, cast
3
+ from typing import Dict, List, Optional, Tuple, Union, cast
4
4
 
5
- import numpy as np
5
+ from numpy.typing import NDArray
6
6
 
7
- from careamics.config.data_model import TRANSFORMS_UNION
7
+ from careamics.config.transformations import TransformModel
8
8
 
9
9
  from .n2v_manipulate import N2VManipulate
10
10
  from .normalize import Normalize
11
+ from .transform import Transform
11
12
  from .xy_flip import XYFlip
12
13
  from .xy_random_rotate90 import XYRandomRotate90
13
14
 
@@ -36,7 +37,7 @@ class Compose:
36
37
 
37
38
  Parameters
38
39
  ----------
39
- transform_list : List[TRANSFORMS_UNION]
40
+ transform_list : List[TransformModel]
40
41
  A list of dictionaries where each dictionary contains the name of a
41
42
  transform and its parameters.
42
43
 
@@ -46,26 +47,27 @@ class Compose:
46
47
  A callable that applies the transforms to the input data.
47
48
  """
48
49
 
49
- def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None:
50
+ def __init__(self, transform_list: List[TransformModel]) -> None:
50
51
  """Instantiate a Compose object.
51
52
 
52
53
  Parameters
53
54
  ----------
54
- transform_list : List[TRANSFORMS_UNION]
55
+ transform_list : List[TransformModel]
55
56
  A list of dictionaries where each dictionary contains the name of a
56
57
  transform and its parameters.
57
58
  """
58
59
  # retrieve all available transforms
59
- all_transforms = get_all_transforms()
60
+ # TODO: correctly type hint get_all_transforms function output
61
+ all_transforms: dict[str, type[Transform]] = get_all_transforms()
60
62
 
61
63
  # instantiate all transforms
62
- self.transforms = [
64
+ self.transforms: list[Transform] = [
63
65
  all_transforms[t.name](**t.model_dump()) for t in transform_list
64
66
  ]
65
67
 
66
68
  def _chain_transforms(
67
- self, patch: np.ndarray, target: Optional[np.ndarray]
68
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
69
+ self, patch: NDArray, target: Optional[NDArray]
70
+ ) -> Tuple[Optional[NDArray], ...]:
69
71
  """Chain transforms on the input data.
70
72
 
71
73
  Parameters
@@ -80,16 +82,56 @@ class Compose:
80
82
  Tuple[np.ndarray, Optional[np.ndarray]]
81
83
  The output of the transformations.
82
84
  """
83
- params = (patch, target)
85
+ params: Union[
86
+ tuple[NDArray, Optional[NDArray]],
87
+ tuple[NDArray, NDArray, NDArray], # N2VManiupulate output
88
+ ] = (patch, target)
84
89
 
85
90
  for t in self.transforms:
86
- params = t(*params)
91
+ # N2VManipulate returns tuple of 3 arrays
92
+ # - Other transoforms return tuple of (patch, target, additional_arrays)
93
+ if isinstance(t, N2VManipulate):
94
+ patch, *_ = params
95
+ params = t(patch=patch)
96
+ else:
97
+ *params, _ = t(*params) # ignore additional_arrays dict
87
98
 
88
99
  return params
89
100
 
101
+ def _chain_transforms_additional_arrays(
102
+ self,
103
+ patch: NDArray,
104
+ target: Optional[NDArray],
105
+ **additional_arrays: NDArray,
106
+ ) -> Tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
107
+ """Chain transforms on the input data, with additional arrays.
108
+
109
+ Parameters
110
+ ----------
111
+ patch : np.ndarray
112
+ Input data.
113
+ target : Optional[np.ndarray]
114
+ Target data, by default None.
115
+ **additional_arrays : NDArray
116
+ Additional arrays that will be transformed identically to `patch` and
117
+ `target`.
118
+
119
+ Returns
120
+ -------
121
+ Tuple[np.ndarray, Optional[np.ndarray]]
122
+ The output of the transformations.
123
+ """
124
+ params = {"patch": patch, "target": target, **additional_arrays}
125
+
126
+ for t in self.transforms:
127
+ patch, target, additional_arrays = t(**params)
128
+ params = {"patch": patch, "target": target, **additional_arrays}
129
+
130
+ return patch, target, additional_arrays
131
+
90
132
  def __call__(
91
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
92
- ) -> Tuple[np.ndarray, ...]:
133
+ self, patch: NDArray, target: Optional[NDArray] = None
134
+ ) -> Tuple[NDArray, ...]:
93
135
  """Apply the transforms to the input data.
94
136
 
95
137
  Parameters
@@ -104,4 +146,37 @@ class Compose:
104
146
  Tuple[np.ndarray, ...]
105
147
  The output of the transformations.
106
148
  """
107
- return cast(Tuple[np.ndarray, ...], self._chain_transforms(patch, target))
149
+ # TODO: solve casting Compose.__call__ ouput
150
+ return cast(Tuple[NDArray, ...], self._chain_transforms(patch, target))
151
+
152
+ def transform_with_additional_arrays(
153
+ self,
154
+ patch: NDArray,
155
+ target: Optional[NDArray] = None,
156
+ **additional_arrays: NDArray,
157
+ ) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
158
+ """Apply the transforms to the input data, including additional arrays.
159
+
160
+ Parameters
161
+ ----------
162
+ patch : np.ndarray
163
+ The input data.
164
+ target : Optional[np.ndarray], optional
165
+ Target data, by default None.
166
+ **additional_arrays : NDArray
167
+ Additional arrays that will be transformed identically to `patch` and
168
+ `target`.
169
+
170
+ Returns
171
+ -------
172
+ NDArray
173
+ The transformed patch.
174
+ NDArray | None
175
+ The transformed target.
176
+ dict of {str, NDArray}
177
+ Transformed additional arrays. Keys correspond to the keyword argument
178
+ names.
179
+ """
180
+ return self._chain_transforms_additional_arrays(
181
+ patch, target, **additional_arrays
182
+ )
@@ -3,6 +3,7 @@
3
3
  from typing import Any, Literal, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
+ from numpy.typing import NDArray
6
7
 
7
8
  from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
8
9
  from careamics.transforms.transform import Transform
@@ -98,8 +99,8 @@ class N2VManipulate(Transform):
98
99
  self.rng = np.random.default_rng(seed=seed)
99
100
 
100
101
  def __call__(
101
- self, patch: np.ndarray, *args: Any, **kwargs: Any
102
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
102
+ self, patch: NDArray, *args: Any, **kwargs: Any
103
+ ) -> Tuple[NDArray, NDArray, NDArray]:
103
104
  """Apply the transform to the image.
104
105
 
105
106
  Parameters
@@ -142,5 +143,8 @@ class N2VManipulate(Transform):
142
143
  else:
143
144
  raise ValueError(f"Unknown masking strategy ({self.strategy}).")
144
145
 
146
+ # TODO: Output does not match other transforms, how to resolve?
147
+ # - Don't include in Compose and apply after if algorithm is N2V?
148
+ # - or just don't return patch? but then mask is in the target position
145
149
  # TODO why return patch?
146
150
  return masked, patch, mask
@@ -90,8 +90,11 @@ class Normalize(Transform):
90
90
  self.eps = 1e-6
91
91
 
92
92
  def __call__(
93
- self, patch: np.ndarray, target: Optional[NDArray] = None
94
- ) -> tuple[NDArray, Optional[NDArray]]:
93
+ self,
94
+ patch: np.ndarray,
95
+ target: Optional[NDArray] = None,
96
+ **additional_arrays: NDArray,
97
+ ) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
95
98
  """Apply the transform to the source patch and the target (optional).
96
99
 
97
100
  Parameters
@@ -100,6 +103,9 @@ class Normalize(Transform):
100
103
  Patch, 2D or 3D, shape C(Z)YX.
101
104
  target : NDArray, optional
102
105
  Target for the patch, by default None.
106
+ **additional_arrays : NDArray
107
+ Additional arrays that will be transformed identically to `patch` and
108
+ `target`.
103
109
 
104
110
  Returns
105
111
  -------
@@ -111,6 +117,11 @@ class Normalize(Transform):
111
117
  f"Number of means (got a list of size {len(self.image_means)}) and "
112
118
  f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
113
119
  )
120
+ if len(additional_arrays) != 0:
121
+ raise NotImplementedError(
122
+ "Transforming additional arrays is currently not supported for "
123
+ "`Normalize`."
124
+ )
114
125
 
115
126
  # reshape mean and std and apply the normalization to the patch
116
127
  means = _reshape_stats(self.image_means, patch.ndim)
@@ -129,7 +140,7 @@ class Normalize(Transform):
129
140
  else:
130
141
  norm_target = None
131
142
 
132
- return norm_patch, norm_target
143
+ return norm_patch, norm_target, additional_arrays
133
144
 
134
145
  def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
135
146
  """
@@ -1,8 +1,9 @@
1
1
  """XY flip transform."""
2
2
 
3
- from typing import Optional, Tuple
3
+ from typing import Optional
4
4
 
5
5
  import numpy as np
6
+ from numpy.typing import NDArray
6
7
 
7
8
  from careamics.transforms.transform import Transform
8
9
 
@@ -78,8 +79,11 @@ class XYFlip(Transform):
78
79
  self.rng = np.random.default_rng(seed=seed)
79
80
 
80
81
  def __call__(
81
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
82
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
82
+ self,
83
+ patch: NDArray,
84
+ target: Optional[NDArray] = None,
85
+ **additional_arrays: NDArray,
86
+ ) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
83
87
  """Apply the transform to the source patch and the target (optional).
84
88
 
85
89
  Parameters
@@ -88,6 +92,9 @@ class XYFlip(Transform):
88
92
  Patch, 2D or 3D, shape C(Z)YX.
89
93
  target : Optional[np.ndarray], optional
90
94
  Target for the patch, by default None.
95
+ **additional_arrays : NDArray
96
+ Additional arrays that will be transformed identically to `patch` and
97
+ `target`.
91
98
 
92
99
  Returns
93
100
  -------
@@ -95,17 +102,20 @@ class XYFlip(Transform):
95
102
  Transformed patch and target.
96
103
  """
97
104
  if self.rng.random() > self.p:
98
- return patch, target
105
+ return patch, target, additional_arrays
99
106
 
100
107
  # choose an axis to flip
101
108
  axis = self.rng.choice(self.axis_indices)
102
109
 
103
110
  patch_transformed = self._apply(patch, axis)
104
111
  target_transformed = self._apply(target, axis) if target is not None else None
112
+ additional_transformed = {
113
+ key: self._apply(array, axis) for key, array in additional_arrays.items()
114
+ }
105
115
 
106
- return patch_transformed, target_transformed
116
+ return patch_transformed, target_transformed, additional_transformed
107
117
 
108
- def _apply(self, patch: np.ndarray, axis: int) -> np.ndarray:
118
+ def _apply(self, patch: NDArray, axis: int) -> NDArray:
109
119
  """Apply the transform to the image.
110
120
 
111
121
  Parameters
@@ -3,6 +3,7 @@
3
3
  from typing import Optional, Tuple
4
4
 
5
5
  import numpy as np
6
+ from numpy.typing import NDArray
6
7
 
7
8
  from careamics.transforms.transform import Transform
8
9
 
@@ -49,8 +50,11 @@ class XYRandomRotate90(Transform):
49
50
  self.rng = np.random.default_rng(seed=seed)
50
51
 
51
52
  def __call__(
52
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
53
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
53
+ self,
54
+ patch: NDArray,
55
+ target: Optional[NDArray] = None,
56
+ **additional_arrays: NDArray,
57
+ ) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
54
58
  """Apply the transform to the source patch and the target (optional).
55
59
 
56
60
  Parameters
@@ -59,6 +63,9 @@ class XYRandomRotate90(Transform):
59
63
  Patch, 2D or 3D, shape C(Z)YX.
60
64
  target : Optional[np.ndarray], optional
61
65
  Target for the patch, by default None.
66
+ **additional_arrays : NDArray
67
+ Additional arrays that will be transformed identically to `patch` and
68
+ `target`.
62
69
 
63
70
  Returns
64
71
  -------
@@ -66,7 +73,7 @@ class XYRandomRotate90(Transform):
66
73
  Transformed patch and target.
67
74
  """
68
75
  if self.rng.random() > self.p:
69
- return patch, target
76
+ return patch, target, additional_arrays
70
77
 
71
78
  # number of rotations
72
79
  n_rot = self.rng.integers(1, 4)
@@ -76,12 +83,14 @@ class XYRandomRotate90(Transform):
76
83
  target_transformed = (
77
84
  self._apply(target, n_rot, axes) if target is not None else None
78
85
  )
86
+ additional_transformed = {
87
+ key: self._apply(array, n_rot, axes)
88
+ for key, array in additional_arrays.items()
89
+ }
79
90
 
80
- return patch_transformed, target_transformed
91
+ return patch_transformed, target_transformed, additional_transformed
81
92
 
82
- def _apply(
83
- self, patch: np.ndarray, n_rot: int, axes: Tuple[int, int]
84
- ) -> np.ndarray:
93
+ def _apply(self, patch: NDArray, n_rot: int, axes: Tuple[int, int]) -> NDArray:
85
94
  """Apply the transform to the image.
86
95
 
87
96
  Parameters