careamics 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (63) hide show
  1. careamics/__init__.py +0 -4
  2. careamics/careamist.py +0 -1
  3. careamics/config/__init__.py +1 -13
  4. careamics/config/algorithms/care_algorithm_model.py +84 -0
  5. careamics/config/algorithms/n2n_algorithm_model.py +85 -0
  6. careamics/config/algorithms/n2v_algorithm_model.py +269 -1
  7. careamics/config/configuration.py +21 -13
  8. careamics/config/configuration_factories.py +179 -187
  9. careamics/config/configuration_io.py +2 -2
  10. careamics/config/data/__init__.py +1 -4
  11. careamics/config/data/data_model.py +46 -62
  12. careamics/config/support/supported_transforms.py +1 -1
  13. careamics/config/transformations/__init__.py +0 -2
  14. careamics/config/transformations/n2v_manipulate_model.py +15 -0
  15. careamics/config/transformations/transform_unions.py +0 -13
  16. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  17. careamics/dataset/dataset_utils/running_stats.py +7 -3
  18. careamics/dataset/in_memory_dataset.py +3 -10
  19. careamics/dataset/in_memory_pred_dataset.py +3 -5
  20. careamics/dataset/in_memory_tiled_pred_dataset.py +2 -2
  21. careamics/dataset/iterable_dataset.py +2 -2
  22. careamics/dataset/iterable_pred_dataset.py +3 -5
  23. careamics/dataset/iterable_tiled_pred_dataset.py +3 -3
  24. careamics/dataset_ng/dataset/__init__.py +3 -0
  25. careamics/dataset_ng/dataset/dataset.py +184 -0
  26. careamics/dataset_ng/demo_dataset.ipynb +271 -0
  27. careamics/dataset_ng/demo_patch_extractor.py +53 -0
  28. careamics/dataset_ng/demo_patch_extractor_factory.py +37 -0
  29. careamics/dataset_ng/patch_extractor/__init__.py +10 -0
  30. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +111 -0
  31. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +9 -0
  32. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +53 -0
  33. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +55 -0
  34. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +163 -0
  35. careamics/dataset_ng/patch_extractor/image_stack_loader.py +140 -0
  36. careamics/dataset_ng/patch_extractor/patch_extractor.py +29 -0
  37. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +208 -0
  38. careamics/dataset_ng/patching_strategies/__init__.py +11 -0
  39. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +82 -0
  40. careamics/dataset_ng/patching_strategies/random_patching.py +338 -0
  41. careamics/dataset_ng/patching_strategies/sequential_patching.py +75 -0
  42. careamics/lightning/lightning_module.py +78 -27
  43. careamics/lightning/train_data_module.py +8 -39
  44. careamics/losses/fcn/losses.py +17 -10
  45. careamics/model_io/bioimage/bioimage_utils.py +5 -3
  46. careamics/model_io/bioimage/model_description.py +3 -3
  47. careamics/model_io/bmz_io.py +2 -2
  48. careamics/model_io/model_io_utils.py +2 -2
  49. careamics/transforms/__init__.py +2 -1
  50. careamics/transforms/compose.py +5 -15
  51. careamics/transforms/n2v_manipulate_torch.py +143 -0
  52. careamics/transforms/pixel_manipulation.py +1 -0
  53. careamics/transforms/pixel_manipulation_torch.py +418 -0
  54. careamics/utils/version.py +38 -0
  55. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/METADATA +7 -8
  56. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/RECORD +59 -42
  57. careamics/config/care_configuration.py +0 -100
  58. careamics/config/data/n2v_data_model.py +0 -193
  59. careamics/config/n2n_configuration.py +0 -101
  60. careamics/config/n2v_configuration.py +0 -266
  61. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/WHEEL +0 -0
  62. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/entry_points.txt +0 -0
  63. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/licenses/LICENSE +0 -0
@@ -8,7 +8,7 @@ import pytorch_lightning as L
8
8
  from numpy.typing import NDArray
9
9
  from torch.utils.data import DataLoader, IterableDataset
10
10
 
11
- from careamics.config.data import DataConfig, GeneralDataConfig, N2VDataConfig
11
+ from careamics.config.data import DataConfig
12
12
  from careamics.config.support import SupportedData
13
13
  from careamics.config.transformations import TransformModel
14
14
  from careamics.dataset.dataset_utils import (
@@ -118,7 +118,7 @@ class TrainDataModule(L.LightningDataModule):
118
118
 
119
119
  def __init__(
120
120
  self,
121
- data_config: GeneralDataConfig,
121
+ data_config: DataConfig,
122
122
  train_data: Union[Path, str, NDArray],
123
123
  val_data: Optional[Union[Path, str, NDArray]] = None,
124
124
  train_data_target: Optional[Union[Path, str, NDArray]] = None,
@@ -218,7 +218,7 @@ class TrainDataModule(L.LightningDataModule):
218
218
  )
219
219
 
220
220
  # configuration
221
- self.data_config: GeneralDataConfig = data_config
221
+ self.data_config: DataConfig = data_config
222
222
  self.data_type: str = data_config.data_type
223
223
  self.batch_size: int = data_config.batch_size
224
224
  self.use_in_memory: bool = use_in_memory
@@ -486,9 +486,6 @@ def create_train_datamodule(
486
486
  val_minimum_patches: int = 5,
487
487
  dataloader_params: Optional[dict] = None,
488
488
  use_in_memory: bool = True,
489
- use_n2v2: bool = False,
490
- struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
491
- struct_n2v_span: int = 5,
492
489
  ) -> TrainDataModule:
493
490
  """Create a TrainDataModule.
494
491
 
@@ -499,16 +496,8 @@ def create_train_datamodule(
499
496
  parameters passed to the datamodule are consistent with the model's requirements and
500
497
  are coherent.
501
498
 
502
- By default, the train DataModule will be set for Noise2Void if no target data is
503
- provided. That means that it will add a `N2VManipulateModel` transformation to the
504
- list of augmentations. The default augmentations are XY flip, XY rotation, and N2V
505
- pixel manipulation. If you pass a training target data, the default behaviour is to
506
- train a supervised model. It will use the default XY flip and rotation
507
- augmentations.
508
-
509
- To use a different set of transformations, you can pass a list of transforms to
510
- `transforms`. Note that if you intend to use Noise2Void, you should add
511
- `N2VManipulateModel` as the last transform in the list of transformations.
499
+ The default augmentations are XY flip and XY rotation. To use a different set of
500
+ transformations, you can pass a list of transforms to `transforms`.
512
501
 
513
502
  The data module can be used with Path, str or numpy arrays. In the case of
514
503
  numpy arrays, it loads and computes all the patches in memory. For Path and str
@@ -534,11 +523,6 @@ def create_train_datamodule(
534
523
  In `dataloader_params`, you can pass any parameter accepted by PyTorch dataloaders,
535
524
  except for `batch_size`, which is set by the `batch_size` parameter.
536
525
 
537
- Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2` to
538
- use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to define
539
- the axis and span of the structN2V mask. These parameters are without effect if
540
- a `train_target_data` or if `transforms` are provided.
541
-
542
526
  Parameters
543
527
  ----------
544
528
  train_data : pathlib.Path or str or numpy.ndarray
@@ -575,13 +559,6 @@ def create_train_datamodule(
575
559
  Pytorch dataloader parameters, by default {}.
576
560
  use_in_memory : bool, optional
577
561
  Use in memory dataset if possible, by default True.
578
- use_n2v2 : bool, optional
579
- Use N2V2 transformation during training, by default False.
580
- struct_n2v_axis : {"horizontal", "vertical", "none"}, optional
581
- Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
582
- default "none".
583
- struct_n2v_span : int, optional
584
- Span for the structN2V mask, by default 5.
585
562
 
586
563
  Returns
587
564
  -------
@@ -624,12 +601,11 @@ def create_train_datamodule(
624
601
  transforms:
625
602
  >>> import numpy as np
626
603
  >>> from careamics.lightning import create_train_datamodule
627
- >>> from careamics.config.transformations import XYFlipModel, N2VManipulateModel
604
+ >>> from careamics.config.transformations import XYFlipModel
628
605
  >>> from careamics.config.support import SupportedTransform
629
606
  >>> my_array = np.arange(256).reshape(16, 16)
630
607
  >>> my_transforms = [
631
608
  ... XYFlipModel(flip_y=False),
632
- ... N2VManipulateModel()
633
609
  ... ]
634
610
  >>> data_module = create_train_datamodule(
635
611
  ... train_data=my_array,
@@ -656,15 +632,8 @@ def create_train_datamodule(
656
632
  if transforms is not None:
657
633
  data_dict["transforms"] = transforms
658
634
 
659
- # TODO not compatible with HDN, consider adding an argument for n2v/hdn
660
- if train_target_data is None:
661
- data_config: GeneralDataConfig = N2VDataConfig(**data_dict)
662
- assert isinstance(data_config, N2VDataConfig)
663
-
664
- data_config.set_n2v2(use_n2v2)
665
- data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
666
- else:
667
- data_config = DataConfig(**data_dict)
635
+ # instantiate data configuration
636
+ data_config = DataConfig(**data_dict)
668
637
 
669
638
  # sanity check on the dataloader parameters
670
639
  if "batch_size" in dataloader_params:
@@ -8,7 +8,7 @@ import torch
8
8
  from torch.nn import L1Loss, MSELoss
9
9
 
10
10
 
11
- def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
11
+ def mse_loss(source: torch.Tensor, target: torch.Tensor, *args) -> torch.Tensor:
12
12
  """
13
13
  Mean squared error loss.
14
14
 
@@ -18,6 +18,8 @@ def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
18
18
  Source patches.
19
19
  target : torch.Tensor
20
20
  Target patches.
21
+ *args : Any
22
+ Additional arguments.
21
23
 
22
24
  Returns
23
25
  -------
@@ -29,34 +31,37 @@ def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
29
31
 
30
32
 
31
33
  def n2v_loss(
32
- manipulated_patches: torch.Tensor,
33
- original_patches: torch.Tensor,
34
+ manipulated_batch: torch.Tensor,
35
+ original_batch: torch.Tensor,
34
36
  masks: torch.Tensor,
37
+ *args,
35
38
  ) -> torch.Tensor:
36
39
  """
37
40
  N2V Loss function described in A Krull et al 2018.
38
41
 
39
42
  Parameters
40
43
  ----------
41
- manipulated_patches : torch.Tensor
42
- Patches with manipulated pixels.
43
- original_patches : torch.Tensor
44
- Noisy patches.
44
+ manipulated_batch : torch.Tensor
45
+ Batch after manipulation function applied.
46
+ original_batch : torch.Tensor
47
+ Original images.
45
48
  masks : torch.Tensor
46
- Array containing masked pixel locations.
49
+ Coordinates of changed pixels.
50
+ *args : Any
51
+ Additional arguments.
47
52
 
48
53
  Returns
49
54
  -------
50
55
  torch.Tensor
51
56
  Loss value.
52
57
  """
53
- errors = (original_patches - manipulated_patches) ** 2
58
+ errors = (original_batch - manipulated_batch) ** 2
54
59
  # Average over pixels and batch
55
60
  loss = torch.sum(errors * masks) / torch.sum(masks)
56
61
  return loss # TODO change output to dict ?
57
62
 
58
63
 
59
- def mae_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
64
+ def mae_loss(samples: torch.Tensor, labels: torch.Tensor, *args) -> torch.Tensor:
60
65
  """
61
66
  N2N Loss function described in to J Lehtinen et al 2018.
62
67
 
@@ -66,6 +71,8 @@ def mae_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
66
71
  Raw patches.
67
72
  labels : torch.Tensor
68
73
  Different subset of noisy patches.
74
+ *args : Any
75
+ Additional arguments.
69
76
 
70
77
  Returns
71
78
  -------
@@ -3,6 +3,8 @@
3
3
  from pathlib import Path
4
4
  from typing import Union
5
5
 
6
+ from careamics.utils.version import get_careamics_version
7
+
6
8
 
7
9
  def get_unzip_path(zip_path: Union[Path, str]) -> Path:
8
10
  """Generate unzipped folder path from the bioimage.io model path.
@@ -44,11 +46,11 @@ def create_env_text(pytorch_version: str, torchvision_version: str) -> str:
44
46
  f"name: careamics\n"
45
47
  f"dependencies:\n"
46
48
  f" - python=3.10\n"
47
- f" - pytorch={pytorch_version}\n"
48
- f" - torchvision={torchvision_version}\n"
49
49
  f" - pip\n"
50
50
  f" - pip:\n"
51
- f" - git+https://github.com/CAREamics/careamics.git\n"
51
+ f" - torch=={pytorch_version}\n"
52
+ f" - torchvision=={torchvision_version}\n"
53
+ f" - careamics=={get_careamics_version()}"
52
54
  )
53
55
 
54
56
  return env
@@ -28,14 +28,14 @@ from bioimageio.spec.model.v0_5 import (
28
28
  WeightsDescr,
29
29
  )
30
30
 
31
- from careamics.config import Configuration, GeneralDataConfig
31
+ from careamics.config import Configuration, DataConfig
32
32
 
33
33
  from ._readme_factory import readme_factory
34
34
 
35
35
 
36
36
  def _create_axes(
37
37
  array: np.ndarray,
38
- data_config: GeneralDataConfig,
38
+ data_config: DataConfig,
39
39
  channel_names: Optional[list[str]] = None,
40
40
  is_input: bool = True,
41
41
  ) -> list[AxisBase]:
@@ -102,7 +102,7 @@ def _create_axes(
102
102
  def _create_inputs_ouputs(
103
103
  input_array: np.ndarray,
104
104
  output_array: np.ndarray,
105
- data_config: GeneralDataConfig,
105
+ data_config: DataConfig,
106
106
  input_path: Union[Path, str],
107
107
  output_path: Union[Path, str],
108
108
  channel_names: Optional[list[str]] = None,
@@ -5,7 +5,6 @@ from pathlib import Path
5
5
  from typing import Optional, Union
6
6
 
7
7
  import numpy as np
8
- import pkg_resources
9
8
  from bioimageio.core import load_model_description, test_model
10
9
  from bioimageio.spec import ValidationSummary, save_bioimageio_package
11
10
  from pydantic import HttpUrl
@@ -16,6 +15,7 @@ from torchvision import __version__ as TORCHVISION_VERSION
16
15
  from careamics.config import Configuration, load_configuration, save_configuration
17
16
  from careamics.config.support import SupportedArchitecture
18
17
  from careamics.lightning.lightning_module import FCNModule, VAEModule
18
+ from careamics.utils.version import get_careamics_version
19
19
 
20
20
  from .bioimage import (
21
21
  create_env_text,
@@ -139,7 +139,7 @@ def export_to_bmz(
139
139
  path_to_archive.parent.mkdir(parents=True, exist_ok=True)
140
140
 
141
141
  # versions
142
- careamics_version = pkg_resources.get_distribution("careamics").version
142
+ careamics_version = get_careamics_version()
143
143
 
144
144
  # save files in temporary folder
145
145
  with tempfile.TemporaryDirectory() as tmpdirname:
@@ -5,7 +5,7 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from careamics.config import Configuration, configuration_factory
8
+ from careamics.config import Configuration
9
9
  from careamics.lightning.lightning_module import FCNModule, VAEModule
10
10
  from careamics.model_io.bmz_io import load_from_bmz
11
11
  from careamics.utils import check_path_exists
@@ -92,4 +92,4 @@ def _load_checkpoint(
92
92
  f"{cfg_dict['algorithm_config']['model']['architecture']}"
93
93
  )
94
94
 
95
- return model, configuration_factory(cfg_dict)
95
+ return model, Configuration(**cfg_dict)
@@ -5,15 +5,16 @@ __all__ = [
5
5
  "Denormalize",
6
6
  "ImageRestorationTTA",
7
7
  "N2VManipulate",
8
+ "N2VManipulateTorch",
8
9
  "Normalize",
9
10
  "XYFlip",
10
11
  "XYRandomRotate90",
11
12
  "get_all_transforms",
12
13
  ]
13
14
 
14
-
15
15
  from .compose import Compose, get_all_transforms
16
16
  from .n2v_manipulate import N2VManipulate
17
+ from .n2v_manipulate_torch import N2VManipulateTorch
17
18
  from .normalize import Denormalize, Normalize
18
19
  from .tta import ImageRestorationTTA
19
20
  from .xy_flip import XYFlip
@@ -6,7 +6,6 @@ from numpy.typing import NDArray
6
6
 
7
7
  from careamics.config.transformations import NORM_AND_SPATIAL_UNION
8
8
 
9
- from .n2v_manipulate import N2VManipulate
10
9
  from .normalize import Normalize
11
10
  from .transform import Transform
12
11
  from .xy_flip import XYFlip
@@ -14,7 +13,6 @@ from .xy_random_rotate90 import XYRandomRotate90
14
13
 
15
14
  ALL_TRANSFORMS = {
16
15
  "Normalize": Normalize,
17
- "N2VManipulate": N2VManipulate,
18
16
  "XYFlip": XYFlip,
19
17
  "XYRandomRotate90": XYRandomRotate90,
20
18
  }
@@ -82,21 +80,13 @@ class Compose:
82
80
  tuple[np.ndarray, Optional[np.ndarray]]
83
81
  The output of the transformations.
84
82
  """
85
- params: Union[
86
- tuple[NDArray, Optional[NDArray]],
87
- tuple[NDArray, NDArray, NDArray], # N2VManiupulate output
88
- ] = (patch, target)
83
+ params: Union[tuple[NDArray, Optional[NDArray]],] = (patch, target)
89
84
 
90
85
  for t in self.transforms:
91
- # N2VManipulate returns tuple of 3 arrays
92
- # - Other transoforms return tuple of (patch, target, additional_arrays)
93
- if isinstance(t, N2VManipulate):
94
- patch, *_ = params
95
- params = t(patch=patch)
96
- else:
97
- *params, _ = t(*params) # ignore additional_arrays dict
98
-
99
- return params
86
+ *params, _ = t(*params) # ignore additional_arrays dict
87
+
88
+ # avoid None values that create problems for collating
89
+ return tuple(p for p in params if p is not None)
100
90
 
101
91
  def _chain_transforms_additional_arrays(
102
92
  self,
@@ -0,0 +1,143 @@
1
+ """N2V manipulation transform for PyTorch."""
2
+
3
+ import platform
4
+ from typing import Any, Optional
5
+
6
+ import torch
7
+
8
+ from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
9
+ from careamics.config.transformations import N2VManipulateModel
10
+
11
+ from .pixel_manipulation_torch import (
12
+ median_manipulate_torch,
13
+ uniform_manipulate_torch,
14
+ )
15
+ from .struct_mask_parameters import StructMaskParameters
16
+
17
+
18
+ class N2VManipulateTorch:
19
+ """
20
+ Default augmentation for the N2V model.
21
+
22
+ This transform expects C(Z)YX dimensions.
23
+
24
+ Parameters
25
+ ----------
26
+ n2v_manipulate_config : N2VManipulateConfig
27
+ N2V manipulation configuration.
28
+ seed : Optional[int], optional
29
+ Random seed, by default None.
30
+
31
+ Attributes
32
+ ----------
33
+ masked_pixel_percentage : float
34
+ Percentage of pixels to mask.
35
+ roi_size : int
36
+ Size of the replacement area.
37
+ strategy : Literal[ "uniform", "median" ]
38
+ Replacement strategy, uniform or median.
39
+ remove_center : bool
40
+ Whether to remove central pixel from patch.
41
+ struct_mask : Optional[StructMaskParameters]
42
+ StructN2V mask parameters.
43
+ rng : Generator
44
+ Random number generator.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ n2v_manipulate_config: N2VManipulateModel,
50
+ seed: Optional[int] = None,
51
+ ):
52
+ """Constructor.
53
+
54
+ Parameters
55
+ ----------
56
+ n2v_manipulate_config : N2VManipulateModel
57
+ N2V manipulation configuration.
58
+ seed : Optional[int], optional
59
+ Random seed, by default None.
60
+ """
61
+ self.masked_pixel_percentage = n2v_manipulate_config.masked_pixel_percentage
62
+ self.roi_size = n2v_manipulate_config.roi_size
63
+ self.strategy = n2v_manipulate_config.strategy
64
+ self.remove_center = n2v_manipulate_config.remove_center
65
+
66
+ if n2v_manipulate_config.struct_mask_axis == SupportedStructAxis.NONE:
67
+ self.struct_mask: Optional[StructMaskParameters] = None
68
+ else:
69
+ self.struct_mask = StructMaskParameters(
70
+ axis=(
71
+ 0
72
+ if n2v_manipulate_config.struct_mask_axis
73
+ == SupportedStructAxis.HORIZONTAL
74
+ else 1
75
+ ),
76
+ span=n2v_manipulate_config.struct_mask_span,
77
+ )
78
+
79
+ # PyTorch random generator
80
+ # TODO refactor into careamics.utils.torch_utils.get_device
81
+ if torch.cuda.is_available():
82
+ device = "cuda"
83
+ elif torch.backends.mps.is_available() and platform.processor() in (
84
+ "arm",
85
+ "arm64",
86
+ ):
87
+ device = "mps"
88
+ else:
89
+ device = "cpu"
90
+
91
+ self.rng = (
92
+ torch.Generator(device=device).manual_seed(seed)
93
+ if seed is not None
94
+ else torch.Generator(device=device)
95
+ )
96
+
97
+ def __call__(
98
+ self, batch: torch.Tensor, *args: Any, **kwargs: Any
99
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
100
+ """Apply the transform to the image.
101
+
102
+ Parameters
103
+ ----------
104
+ batch : torch.Tensor
105
+ Batch if image patches, 2D or 3D, shape BC(Z)YX.
106
+ *args : Any
107
+ Additional arguments, unused.
108
+ **kwargs : Any
109
+ Additional keyword arguments, unused.
110
+
111
+ Returns
112
+ -------
113
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
114
+ Masked patch, original patch, and mask.
115
+ """
116
+ masked = torch.zeros_like(batch)
117
+ mask = torch.zeros_like(batch, dtype=torch.uint8)
118
+
119
+ if self.strategy == SupportedPixelManipulation.UNIFORM:
120
+ # Iterate over the channels to apply manipulation separately
121
+ for c in range(batch.shape[1]):
122
+ masked[:, c, ...], mask[:, c, ...] = uniform_manipulate_torch(
123
+ patch=batch[:, c, ...],
124
+ mask_pixel_percentage=self.masked_pixel_percentage,
125
+ subpatch_size=self.roi_size,
126
+ remove_center=self.remove_center,
127
+ struct_params=self.struct_mask,
128
+ rng=self.rng,
129
+ )
130
+ elif self.strategy == SupportedPixelManipulation.MEDIAN:
131
+ # Iterate over the channels to apply manipulation separately
132
+ for c in range(batch.shape[1]):
133
+ masked[:, c, ...], mask[:, c, ...] = median_manipulate_torch(
134
+ batch=batch[:, c, ...],
135
+ mask_pixel_percentage=self.masked_pixel_percentage,
136
+ subpatch_size=self.roi_size,
137
+ struct_params=self.struct_mask,
138
+ rng=self.rng,
139
+ )
140
+ else:
141
+ raise ValueError(f"Unknown masking strategy ({self.strategy}).")
142
+
143
+ return masked, batch, mask
@@ -213,6 +213,7 @@ def _create_subpatch_struct_mask(
213
213
  np.ndarray
214
214
  StructN2V mask for the subpatch.
215
215
  """
216
+ # TODO no test for this function!
216
217
  # Create a mask with the center of the subpatch masked
217
218
  mask_placeholder = np.ones(subpatch.shape)
218
219