careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (103) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +92 -55
  4. careamics/config/__init__.py +0 -1
  5. careamics/config/algorithm_model.py +5 -3
  6. careamics/config/architectures/architecture_model.py +7 -0
  7. careamics/config/architectures/custom_model.py +8 -1
  8. careamics/config/architectures/register_model.py +3 -1
  9. careamics/config/architectures/unet_model.py +3 -0
  10. careamics/config/architectures/vae_model.py +2 -0
  11. careamics/config/callback_model.py +4 -15
  12. careamics/config/configuration_example.py +4 -4
  13. careamics/config/configuration_factory.py +113 -55
  14. careamics/config/configuration_model.py +14 -16
  15. careamics/config/data_model.py +63 -165
  16. careamics/config/inference_model.py +9 -75
  17. careamics/config/optimizer_models.py +4 -4
  18. careamics/config/references/algorithm_descriptions.py +1 -0
  19. careamics/config/references/references.py +1 -0
  20. careamics/config/support/__init__.py +0 -2
  21. careamics/config/support/supported_activations.py +2 -0
  22. careamics/config/support/supported_algorithms.py +3 -1
  23. careamics/config/support/supported_architectures.py +2 -0
  24. careamics/config/support/supported_data.py +2 -0
  25. careamics/config/support/supported_loggers.py +2 -0
  26. careamics/config/support/supported_losses.py +2 -0
  27. careamics/config/support/supported_optimizers.py +2 -0
  28. careamics/config/support/supported_pixel_manipulations.py +3 -3
  29. careamics/config/support/supported_struct_axis.py +2 -0
  30. careamics/config/support/supported_transforms.py +4 -15
  31. careamics/config/tile_information.py +2 -0
  32. careamics/config/training_model.py +1 -0
  33. careamics/config/transformations/__init__.py +3 -2
  34. careamics/config/transformations/n2v_manipulate_model.py +1 -0
  35. careamics/config/transformations/normalize_model.py +1 -0
  36. careamics/config/transformations/transform_model.py +1 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +13 -7
  39. careamics/config/validators/validator_utils.py +1 -0
  40. careamics/conftest.py +13 -0
  41. careamics/dataset/dataset_utils/__init__.py +0 -1
  42. careamics/dataset/dataset_utils/dataset_utils.py +5 -4
  43. careamics/dataset/dataset_utils/file_utils.py +4 -3
  44. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  45. careamics/dataset/dataset_utils/read_utils.py +2 -0
  46. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  47. careamics/dataset/in_memory_dataset.py +84 -76
  48. careamics/dataset/iterable_dataset.py +166 -134
  49. careamics/dataset/patching/__init__.py +0 -7
  50. careamics/dataset/patching/patching.py +56 -14
  51. careamics/dataset/patching/random_patching.py +8 -2
  52. careamics/dataset/patching/sequential_patching.py +20 -14
  53. careamics/dataset/patching/tiled_patching.py +13 -7
  54. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  55. careamics/dataset/zarr_dataset.py +2 -0
  56. careamics/lightning_datamodule.py +63 -41
  57. careamics/lightning_module.py +9 -3
  58. careamics/lightning_prediction_datamodule.py +15 -20
  59. careamics/lightning_prediction_loop.py +8 -6
  60. careamics/losses/__init__.py +1 -3
  61. careamics/losses/loss_factory.py +2 -1
  62. careamics/losses/losses.py +11 -7
  63. careamics/model_io/__init__.py +0 -1
  64. careamics/model_io/bioimage/_readme_factory.py +2 -1
  65. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  66. careamics/model_io/bioimage/model_description.py +1 -0
  67. careamics/model_io/bmz_io.py +4 -3
  68. careamics/models/activation.py +2 -0
  69. careamics/models/layers.py +122 -25
  70. careamics/models/model_factory.py +2 -1
  71. careamics/models/unet.py +114 -19
  72. careamics/prediction/stitch_prediction.py +2 -5
  73. careamics/transforms/__init__.py +4 -25
  74. careamics/transforms/compose.py +124 -0
  75. careamics/transforms/n2v_manipulate.py +65 -34
  76. careamics/transforms/normalize.py +91 -28
  77. careamics/transforms/pixel_manipulation.py +7 -7
  78. careamics/transforms/struct_mask_parameters.py +3 -1
  79. careamics/transforms/transform.py +24 -0
  80. careamics/transforms/tta.py +2 -2
  81. careamics/transforms/xy_flip.py +123 -0
  82. careamics/transforms/xy_random_rotate90.py +66 -60
  83. careamics/utils/__init__.py +0 -1
  84. careamics/utils/base_enum.py +28 -0
  85. careamics/utils/context.py +1 -0
  86. careamics/utils/logging.py +1 -0
  87. careamics/utils/metrics.py +1 -0
  88. careamics/utils/path_utils.py +2 -0
  89. careamics/utils/ram.py +2 -0
  90. careamics/utils/receptive_field.py +93 -87
  91. careamics/utils/torch_utils.py +1 -0
  92. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
  93. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  94. careamics/config/noise_models.py +0 -162
  95. careamics/config/support/supported_extraction_strategies.py +0 -24
  96. careamics/config/transformations/nd_flip_model.py +0 -32
  97. careamics/dataset/patching/patch_transform.py +0 -44
  98. careamics/losses/noise_model_factory.py +0 -40
  99. careamics/losses/noise_models.py +0 -524
  100. careamics/transforms/nd_flip.py +0 -93
  101. careamics-0.1.0rc4.dist-info/RECORD +0 -110
  102. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  103. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
@@ -1,23 +1,12 @@
1
+ """Transforms supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
4
6
  class SupportedTransform(str, BaseEnum):
5
- """Transforms officially supported by CAREamics.
6
-
7
- - Flip: from Albumentations, randomly flip the input horizontally, vertically or
8
- both, parameter `p` can be used to set the probability to apply the transform.
9
- - XYRandomRotate90: #TODO
10
- - Normalize # TODO add details, in particular about the parameters
11
- - ManipulateN2V # TODO add details, in particular about the parameters
12
- - NDFlip
13
-
14
- Note that while any Albumentations (see https://albumentations.ai/) transform can be
15
- used in CAREamics, no check are implemented to verify the compatibility of any other
16
- transforms than the ones officially supported.
17
- """
7
+ """Transforms officially supported by CAREamics."""
18
8
 
19
- NDFLIP = "NDFlip"
9
+ XY_FLIP = "XYFlip"
20
10
  XY_RANDOM_ROTATE90 = "XYRandomRotate90"
21
11
  NORMALIZE = "Normalize"
22
12
  N2V_MANIPULATE = "N2VManipulate"
23
- # CUSTOM = "Custom"
@@ -1,3 +1,5 @@
1
+ """Pydantic model representing the metadata of a prediction tile."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from typing import Optional, Tuple
@@ -1,4 +1,5 @@
1
1
  """Training configuration."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  from pprint import pformat
@@ -2,13 +2,14 @@
2
2
 
3
3
  __all__ = [
4
4
  "N2VManipulateModel",
5
- "NDFlipModel",
5
+ "XYFlipModel",
6
6
  "NormalizeModel",
7
7
  "XYRandomRotate90Model",
8
+ "XorYFlipModel",
8
9
  ]
9
10
 
10
11
 
11
12
  from .n2v_manipulate_model import N2VManipulateModel
12
- from .nd_flip_model import NDFlipModel
13
13
  from .normalize_model import NormalizeModel
14
+ from .xy_flip_model import XYFlipModel
14
15
  from .xy_random_rotate90_model import XYRandomRotate90Model
@@ -1,4 +1,5 @@
1
1
  """Pydantic model for the N2VManipulate transform."""
2
+
2
3
  from typing import Literal
3
4
 
4
5
  from pydantic import ConfigDict, Field, field_validator
@@ -1,4 +1,5 @@
1
1
  """Pydantic model for the Normalize transform."""
2
+
2
3
  from typing import Literal
3
4
 
4
5
  from pydantic import ConfigDict, Field
@@ -1,4 +1,5 @@
1
1
  """Parent model for the transforms."""
2
+
2
3
  from typing import Any, Dict
3
4
 
4
5
  from pydantic import BaseModel, ConfigDict
@@ -0,0 +1,43 @@
1
+ """Pydantic model for the XYFlip transform."""
2
+
3
+ from typing import Literal, Optional
4
+
5
+ from pydantic import ConfigDict, Field
6
+
7
+ from .transform_model import TransformModel
8
+
9
+
10
+ class XYFlipModel(TransformModel):
11
+ """
12
+ Pydantic model used to represent XYFlip transformation.
13
+
14
+ Attributes
15
+ ----------
16
+ name : Literal["XYFlip"]
17
+ Name of the transformation.
18
+ p : float
19
+ Probability of applying the transform, by default 0.5.
20
+ seed : Optional[int]
21
+ Seed for the random number generator, by default None.
22
+ """
23
+
24
+ model_config = ConfigDict(
25
+ validate_assignment=True,
26
+ )
27
+
28
+ name: Literal["XYFlip"] = "XYFlip"
29
+ flip_x: bool = Field(
30
+ True,
31
+ description="Whether to flip along the X axis.",
32
+ )
33
+ flip_y: bool = Field(
34
+ True,
35
+ description="Whether to flip along the Y axis.",
36
+ )
37
+ p: float = Field(
38
+ 0.5,
39
+ description="Probability of applying the transform.",
40
+ ge=0,
41
+ le=1,
42
+ )
43
+ seed: Optional[int] = None
@@ -1,5 +1,6 @@
1
1
  """Pydantic model for the XYRandomRotate90 transform."""
2
- from typing import Literal
2
+
3
+ from typing import Literal, Optional
3
4
 
4
5
  from pydantic import ConfigDict, Field
5
6
 
@@ -8,16 +9,16 @@ from .transform_model import TransformModel
8
9
 
9
10
  class XYRandomRotate90Model(TransformModel):
10
11
  """
11
- Pydantic model used to represent NDFlip transformation.
12
+ Pydantic model used to represent the XY random 90 degree rotation transformation.
12
13
 
13
14
  Attributes
14
15
  ----------
15
16
  name : Literal["XYRandomRotate90"]
16
17
  Name of the transformation.
17
18
  p : float
18
- Probability of applying the transformation, by default 0.5.
19
- is_3D : bool
20
- Whether the transformation should be applied in 3D, by default False.
19
+ Probability of applying the transform, by default 0.5.
20
+ seed : Optional[int]
21
+ Seed for the random number generator, by default None.
21
22
  """
22
23
 
23
24
  model_config = ConfigDict(
@@ -25,5 +26,10 @@ class XYRandomRotate90Model(TransformModel):
25
26
  )
26
27
 
27
28
  name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
28
- p: float = Field(default=0.5, ge=0.0, le=1.0)
29
- is_3D: bool = Field(default=False)
29
+ p: float = Field(
30
+ 0.5,
31
+ description="Probability of applying the transform.",
32
+ ge=0,
33
+ le=1,
34
+ )
35
+ seed: Optional[int] = None
@@ -3,6 +3,7 @@ Validator functions.
3
3
 
4
4
  These functions are used to validate dimensions and axes of inputs.
5
5
  """
6
+
6
7
  from typing import List, Optional, Tuple, Union
7
8
 
8
9
  _AXES = "STCZYX"
careamics/conftest.py CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  See https://sybil.readthedocs.io/en/latest/use.html#pytest
4
4
  """
5
+
5
6
  from pathlib import Path
6
7
 
7
8
  import pytest
@@ -13,6 +14,18 @@ from sybil.parsers.doctest import DocTestParser
13
14
 
14
15
  @pytest.fixture(scope="module")
15
16
  def my_path(tmpdir_factory: TempPathFactory) -> Path:
17
+ """Fixture used in doctest to create a temporary directory.
18
+
19
+ Parameters
20
+ ----------
21
+ tmpdir_factory : TempPathFactory
22
+ Temporary path factory from pytest.
23
+
24
+ Returns
25
+ -------
26
+ Path
27
+ Temporary directory path.
28
+ """
16
29
  return tmpdir_factory.mktemp("my_path")
17
30
 
18
31
 
@@ -1,6 +1,5 @@
1
1
  """Files and arrays utils used in the datasets."""
2
2
 
3
-
4
3
  __all__ = [
5
4
  "reshape_array",
6
5
  "get_files_size",
@@ -1,4 +1,5 @@
1
- """Convenience methods for datasets."""
1
+ """Dataset utilities."""
2
+
2
3
  from typing import List, Tuple
3
4
 
4
5
  import numpy as np
@@ -16,12 +17,12 @@ def _get_shape_order(
16
17
 
17
18
  Parameters
18
19
  ----------
19
- shape_in : Tuple
20
+ shape_in : Tuple[int, ...]
20
21
  Input shape.
21
- ref_axes : str
22
- Reference axes.
23
22
  axes_in : str
24
23
  Input axes.
24
+ ref_axes : str
25
+ Reference axes.
25
26
 
26
27
  Returns
27
28
  -------
@@ -1,3 +1,5 @@
1
+ """File utilities."""
2
+
1
3
  from fnmatch import fnmatch
2
4
  from pathlib import Path
3
5
  from typing import List, Union
@@ -11,8 +13,7 @@ logger = get_logger(__name__)
11
13
 
12
14
 
13
15
  def get_files_size(files: List[Path]) -> float:
14
- """
15
- Get files size in MB.
16
+ """Get files size in MB.
16
17
 
17
18
  Parameters
18
19
  ----------
@@ -32,7 +33,7 @@ def list_files(
32
33
  data_type: Union[str, SupportedData],
33
34
  extension_filter: str = "",
34
35
  ) -> List[Path]:
35
- """Creates a recursive list of files in `data_path`.
36
+ """Create a recursive list of files in `data_path`.
36
37
 
37
38
  If `data_path` is a file, its name is validated against the `data_type` using
38
39
  `fnmatch`, and the method returns `data_path` itself.
@@ -1,3 +1,5 @@
1
+ """Funtions to read tiff images."""
2
+
1
3
  import logging
2
4
  from fnmatch import fnmatch
3
5
  from pathlib import Path
@@ -19,8 +21,10 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
19
21
  ----------
20
22
  file_path : Path
21
23
  Path to a file.
22
- axes : str
23
- Description of axes in format STCZYX.
24
+ *args : list
25
+ Additional arguments.
26
+ **kwargs : dict
27
+ Additional keyword arguments.
24
28
 
25
29
  Returns
26
30
  -------
@@ -1,3 +1,5 @@
1
+ """Read function utilities."""
2
+
1
3
  from typing import Callable, Union
2
4
 
3
5
  from careamics.config.support import SupportedData
@@ -1,3 +1,5 @@
1
+ """Function to read zarr images."""
2
+
1
3
  from typing import Union
2
4
 
3
5
  from zarr import Group, core, hierarchy, storage
@@ -6,26 +8,28 @@ from zarr import Group, core, hierarchy, storage
6
8
  def read_zarr(
7
9
  zarr_source: Group, axes: str
8
10
  ) -> Union[core.Array, storage.DirectoryStore, hierarchy.Group]:
9
- """Reads a file and returns a pointer.
11
+ """Read a file and returns a pointer.
10
12
 
11
13
  Parameters
12
14
  ----------
13
- file_path : Path
14
- pathlib.Path object containing a path to a file
15
+ zarr_source : Group
16
+ Zarr storage.
17
+ axes : str
18
+ Axes of the data.
15
19
 
16
20
  Returns
17
21
  -------
18
22
  np.ndarray
19
- Pointer to zarr storage
23
+ Pointer to zarr storage.
20
24
 
21
25
  Raises
22
26
  ------
23
27
  ValueError, OSError
24
- if a file is not a valid tiff or damaged
28
+ if a file is not a valid tiff or damaged.
25
29
  ValueError
26
- if data dimensions are not 2, 3 or 4
30
+ if data dimensions are not 2, 3 or 4.
27
31
  ValueError
28
- if axes parameter from config is not consistent with data dimensions
32
+ if axes parameter from config is not consistent with data dimensions.
29
33
  """
30
34
  if isinstance(zarr_source, hierarchy.Group):
31
35
  array = zarr_source[0]
@@ -1,4 +1,5 @@
1
1
  """In-memory dataset module."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  import copy
@@ -8,11 +9,13 @@ from typing import Any, Callable, List, Optional, Tuple, Union
8
9
  import numpy as np
9
10
  from torch.utils.data import Dataset
10
11
 
12
+ from careamics.transforms import Compose
13
+
11
14
  from ..config import DataConfig, InferenceConfig
12
15
  from ..config.tile_information import TileInformation
16
+ from ..config.transformations import NormalizeModel
13
17
  from ..utils.logging import get_logger
14
18
  from .dataset_utils import read_tiff, reshape_array
15
- from .patching.patch_transform import get_patch_transform
16
19
  from .patching.patching import (
17
20
  prepare_patches_supervised,
18
21
  prepare_patches_supervised_array,
@@ -25,24 +28,49 @@ logger = get_logger(__name__)
25
28
 
26
29
 
27
30
  class InMemoryDataset(Dataset):
28
- """Dataset storing data in memory and allowing generating patches from it."""
31
+ """Dataset storing data in memory and allowing generating patches from it.
32
+
33
+ Parameters
34
+ ----------
35
+ data_config : DataConfig
36
+ Data configuration.
37
+ inputs : Union[np.ndarray, List[Path]]
38
+ Input data.
39
+ input_target : Optional[Union[np.ndarray, List[Path]]], optional
40
+ Target data, by default None.
41
+ read_source_func : Callable, optional
42
+ Read source function for custom types, by default read_tiff.
43
+ **kwargs : Any
44
+ Additional keyword arguments, unused.
45
+ """
29
46
 
30
47
  def __init__(
31
48
  self,
32
49
  data_config: DataConfig,
33
50
  inputs: Union[np.ndarray, List[Path]],
34
- data_target: Optional[Union[np.ndarray, List[Path]]] = None,
51
+ input_target: Optional[Union[np.ndarray, List[Path]]] = None,
35
52
  read_source_func: Callable = read_tiff,
36
53
  **kwargs: Any,
37
54
  ) -> None:
38
55
  """
39
56
  Constructor.
40
57
 
41
- # TODO
58
+ Parameters
59
+ ----------
60
+ data_config : DataConfig
61
+ Data configuration.
62
+ inputs : Union[np.ndarray, List[Path]]
63
+ Input data.
64
+ input_target : Optional[Union[np.ndarray, List[Path]]], optional
65
+ Target data, by default None.
66
+ read_source_func : Callable, optional
67
+ Read source function for custom types, by default read_tiff.
68
+ **kwargs : Any
69
+ Additional keyword arguments, unused.
42
70
  """
43
71
  self.data_config = data_config
44
72
  self.inputs = inputs
45
- self.data_target = data_target
73
+ self.input_targets = input_target
46
74
  self.axes = self.data_config.axes
47
75
  self.patch_size = self.data_config.patch_size
48
76
 
@@ -50,28 +78,25 @@ class InMemoryDataset(Dataset):
50
78
  self.read_source_func = read_source_func
51
79
 
52
80
  # Generate patches
53
- supervised = self.data_target is not None
54
- patches = self._prepare_patches(supervised)
81
+ supervised = self.input_targets is not None
82
+ patch_data = self._prepare_patches(supervised)
55
83
 
56
84
  # Add results to members
57
- self.data, self.data_targets, computed_mean, computed_std = patches
85
+ self.patches, self.patch_targets, computed_mean, computed_std = patch_data
58
86
 
59
87
  if not self.data_config.mean or not self.data_config.std:
60
88
  self.mean, self.std = computed_mean, computed_std
61
89
  logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
62
90
 
63
- # if the transforms are not an instance of Compose
64
- if self.data_config.has_transform_list():
65
- # update mean and std in configuration
66
- # the object is mutable and should then be recorded in the CAREamist obj
67
- self.data_config.set_mean_and_std(self.mean, self.std)
91
+ # update mean and std in configuration
92
+ # the object is mutable and should then be recorded in the CAREamist obj
93
+ self.data_config.set_mean_and_std(self.mean, self.std)
68
94
  else:
69
95
  self.mean, self.std = self.data_config.mean, self.data_config.std
70
96
 
71
97
  # get transforms
72
- self.patch_transform = get_patch_transform(
73
- patch_transforms=self.data_config.transforms,
74
- with_target=self.data_target is not None,
98
+ self.patch_transform = Compose(
99
+ transform_list=self.data_config.transforms,
75
100
  )
76
101
 
77
102
  def _prepare_patches(
@@ -92,18 +117,18 @@ class InMemoryDataset(Dataset):
92
117
  """
93
118
  if supervised:
94
119
  if isinstance(self.inputs, np.ndarray) and isinstance(
95
- self.data_target, np.ndarray
120
+ self.input_targets, np.ndarray
96
121
  ):
97
122
  return prepare_patches_supervised_array(
98
123
  self.inputs,
99
124
  self.axes,
100
- self.data_target,
125
+ self.input_targets,
101
126
  self.patch_size,
102
127
  )
103
- elif isinstance(self.inputs, list) and isinstance(self.data_target, list):
128
+ elif isinstance(self.inputs, list) and isinstance(self.input_targets, list):
104
129
  return prepare_patches_supervised(
105
130
  self.inputs,
106
- self.data_target,
131
+ self.input_targets,
107
132
  self.axes,
108
133
  self.patch_size,
109
134
  self.read_source_func,
@@ -112,7 +137,7 @@ class InMemoryDataset(Dataset):
112
137
  raise ValueError(
113
138
  f"Data and target must be of the same type, either both numpy "
114
139
  f"arrays or both lists of paths, got {type(self.inputs)} (data) "
115
- f"and {type(self.data_target)} (target)."
140
+ f"and {type(self.input_targets)} (target)."
116
141
  )
117
142
  else:
118
143
  if isinstance(self.inputs, np.ndarray):
@@ -138,9 +163,9 @@ class InMemoryDataset(Dataset):
138
163
  int
139
164
  Length of the dataset.
140
165
  """
141
- return len(self.data)
166
+ return len(self.patches)
142
167
 
143
- def __getitem__(self, index: int) -> Tuple[np.ndarray]:
168
+ def __getitem__(self, index: int) -> Tuple[np.ndarray, ...]:
144
169
  """
145
170
  Return the patch corresponding to the provided index.
146
171
 
@@ -159,40 +184,17 @@ class InMemoryDataset(Dataset):
159
184
  ValueError
160
185
  If dataset mean and std are not set.
161
186
  """
162
- patch = self.data[index]
187
+ patch = self.patches[index]
163
188
 
164
189
  # if there is a target
165
- if self.data_target is not None:
190
+ if self.patch_targets is not None:
166
191
  # get target
167
- target = self.data_targets[index]
168
-
169
- # Albumentations requires Channel last
170
- c_patch = np.moveaxis(patch, 0, -1)
171
- c_target = np.moveaxis(target, 0, -1)
192
+ target = self.patch_targets[index]
172
193
 
173
- # Apply transforms
174
- transformed = self.patch_transform(image=c_patch, target=c_target)
175
-
176
- # move axes back
177
- patch = np.moveaxis(transformed["image"], -1, 0)
178
- target = np.moveaxis(transformed["target"], -1, 0)
179
-
180
- return patch, target
194
+ return self.patch_transform(patch=patch, target=target)
181
195
 
182
196
  elif self.data_config.has_n2v_manipulate():
183
- # Albumentations requires Channel last
184
- patch = np.moveaxis(patch, 0, -1)
185
-
186
- # Apply transforms
187
- transformed_patch = self.patch_transform(image=patch)["image"]
188
- manip_patch, patch, mask = transformed_patch
189
-
190
- # move C axes back
191
- manip_patch = np.moveaxis(manip_patch, -1, 0)
192
- patch = np.moveaxis(patch, -1, 0)
193
- mask = np.moveaxis(mask, -1, 0)
194
-
195
- return (manip_patch, patch, mask)
197
+ return self.patch_transform(patch=patch)
196
198
  else:
197
199
  raise ValueError(
198
200
  "Something went wrong! No target provided (not supervised training) "
@@ -247,25 +249,25 @@ class InMemoryDataset(Dataset):
247
249
  indices = np.random.choice(total_patches, n_patches, replace=False)
248
250
 
249
251
  # extract patches
250
- val_patches = self.data[indices]
252
+ val_patches = self.patches[indices]
251
253
 
252
254
  # remove patches from self.patch
253
- self.data = np.delete(self.data, indices, axis=0)
255
+ self.patches = np.delete(self.patches, indices, axis=0)
254
256
 
255
257
  # same for targets
256
- if self.data_targets is not None:
257
- val_targets = self.data_targets[indices]
258
- self.data_targets = np.delete(self.data_targets, indices, axis=0)
258
+ if self.patch_targets is not None:
259
+ val_targets = self.patch_targets[indices]
260
+ self.patch_targets = np.delete(self.patch_targets, indices, axis=0)
259
261
 
260
262
  # clone the dataset
261
263
  dataset = copy.deepcopy(self)
262
264
 
263
265
  # reassign patches
264
- dataset.data = val_patches
266
+ dataset.patches = val_patches
265
267
 
266
268
  # reassign targets
267
- if self.data_targets is not None:
268
- dataset.data_targets = val_targets
269
+ if self.patch_targets is not None:
270
+ dataset.patch_targets = val_targets
269
271
 
270
272
  return dataset
271
273
 
@@ -274,7 +276,16 @@ class InMemoryPredictionDataset(Dataset):
274
276
  """
275
277
  Dataset storing data in memory and allowing generating patches from it.
276
278
 
277
- # TODO
279
+ Parameters
280
+ ----------
281
+ prediction_config : InferenceConfig
282
+ Prediction configuration.
283
+ inputs : np.ndarray
284
+ Input data.
285
+ data_target : Optional[np.ndarray], optional
286
+ Target data, by default None.
287
+ read_source_func : Optional[Callable], optional
288
+ Read source function for custom types, by default read_tiff.
278
289
  """
279
290
 
280
291
  def __init__(
@@ -288,10 +299,14 @@ class InMemoryPredictionDataset(Dataset):
288
299
 
289
300
  Parameters
290
301
  ----------
291
- array : np.ndarray
292
- Array containing the data.
293
- axes : str
294
- Description of axes in format STCZYX.
302
+ prediction_config : InferenceConfig
303
+ Prediction configuration.
304
+ inputs : np.ndarray
305
+ Input data.
306
+ data_target : Optional[np.ndarray], optional
307
+ Target data, by default None.
308
+ read_source_func : Optional[Callable], optional
309
+ Read source function for custom types, by default read_tiff.
295
310
 
296
311
  Raises
297
312
  ------
@@ -318,9 +333,8 @@ class InMemoryPredictionDataset(Dataset):
318
333
  self.mean, self.std = self.pred_config.mean, self.pred_config.std
319
334
 
320
335
  # get transforms
321
- self.patch_transform = get_patch_transform(
322
- patch_transforms=self.pred_config.transforms,
323
- with_target=self.data_target is not None,
336
+ self.patch_transform = Compose(
337
+ transform_list=[NormalizeModel(mean=self.mean, std=self.std)],
324
338
  )
325
339
 
326
340
  def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
@@ -335,7 +349,7 @@ class InMemoryPredictionDataset(Dataset):
335
349
  # reshape array
336
350
  reshaped_sample = reshape_array(self.input_array, self.axes)
337
351
 
338
- if self.tiling:
352
+ if self.tiling and self.tile_size is not None and self.tile_overlap is not None:
339
353
  # generate patches, which returns a generator
340
354
  patch_generator = extract_tiles(
341
355
  arr=reshaped_sample,
@@ -379,13 +393,7 @@ class InMemoryPredictionDataset(Dataset):
379
393
  """
380
394
  tile_array, tile_info = self.data[index]
381
395
 
382
- # Albumentations requires channel last, use the XArrayTile array
383
- patch = np.moveaxis(tile_array, 0, -1)
384
-
385
396
  # Apply transforms
386
- transformed_patch = self.patch_transform(image=patch)["image"]
387
-
388
- # move C axes back
389
- transformed_patch = np.moveaxis(transformed_patch, -1, 0)
397
+ transformed_tile, _ = self.patch_transform(patch=tile_array)
390
398
 
391
- return transformed_patch, tile_info
399
+ return transformed_tile, tile_info