careamics 0.0.8__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (63) hide show
  1. careamics/__init__.py +0 -4
  2. careamics/careamist.py +0 -1
  3. careamics/config/__init__.py +1 -13
  4. careamics/config/algorithms/care_algorithm_model.py +84 -0
  5. careamics/config/algorithms/n2n_algorithm_model.py +85 -0
  6. careamics/config/algorithms/n2v_algorithm_model.py +269 -1
  7. careamics/config/configuration.py +21 -13
  8. careamics/config/configuration_factories.py +179 -187
  9. careamics/config/configuration_io.py +2 -2
  10. careamics/config/data/__init__.py +1 -4
  11. careamics/config/data/data_model.py +46 -62
  12. careamics/config/support/supported_transforms.py +1 -1
  13. careamics/config/transformations/__init__.py +0 -2
  14. careamics/config/transformations/n2v_manipulate_model.py +15 -0
  15. careamics/config/transformations/transform_unions.py +0 -13
  16. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  17. careamics/dataset/in_memory_dataset.py +3 -10
  18. careamics/dataset/in_memory_pred_dataset.py +3 -5
  19. careamics/dataset/in_memory_tiled_pred_dataset.py +2 -2
  20. careamics/dataset/iterable_dataset.py +2 -2
  21. careamics/dataset/iterable_pred_dataset.py +3 -5
  22. careamics/dataset/iterable_tiled_pred_dataset.py +3 -3
  23. careamics/dataset_ng/dataset/__init__.py +3 -0
  24. careamics/dataset_ng/dataset/dataset.py +184 -0
  25. careamics/dataset_ng/demo_dataset.ipynb +271 -0
  26. careamics/dataset_ng/demo_patch_extractor.py +53 -0
  27. careamics/dataset_ng/demo_patch_extractor_factory.py +37 -0
  28. careamics/dataset_ng/patch_extractor/__init__.py +10 -0
  29. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +111 -0
  30. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +9 -0
  31. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +53 -0
  32. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +55 -0
  33. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +163 -0
  34. careamics/dataset_ng/patch_extractor/image_stack_loader.py +140 -0
  35. careamics/dataset_ng/patch_extractor/patch_extractor.py +29 -0
  36. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +208 -0
  37. careamics/dataset_ng/patching_strategies/__init__.py +11 -0
  38. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +82 -0
  39. careamics/dataset_ng/patching_strategies/random_patching.py +338 -0
  40. careamics/dataset_ng/patching_strategies/sequential_patching.py +75 -0
  41. careamics/lightning/lightning_module.py +78 -27
  42. careamics/lightning/train_data_module.py +8 -39
  43. careamics/losses/fcn/losses.py +17 -10
  44. careamics/lvae_training/eval_utils.py +21 -8
  45. careamics/model_io/bioimage/bioimage_utils.py +5 -3
  46. careamics/model_io/bioimage/model_description.py +3 -3
  47. careamics/model_io/bmz_io.py +2 -2
  48. careamics/model_io/model_io_utils.py +2 -2
  49. careamics/transforms/__init__.py +2 -1
  50. careamics/transforms/compose.py +5 -15
  51. careamics/transforms/n2v_manipulate_torch.py +143 -0
  52. careamics/transforms/pixel_manipulation.py +1 -0
  53. careamics/transforms/pixel_manipulation_torch.py +418 -0
  54. careamics/utils/version.py +38 -0
  55. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/METADATA +7 -8
  56. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/RECORD +59 -42
  57. careamics/config/care_configuration.py +0 -100
  58. careamics/config/data/n2v_data_model.py +0 -193
  59. careamics/config/n2n_configuration.py +0 -101
  60. careamics/config/n2v_configuration.py +0 -266
  61. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/WHEEL +0 -0
  62. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/entry_points.txt +0 -0
  63. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,6 @@
1
1
  """CAREamics transformation Pydantic models."""
2
2
 
3
3
  __all__ = [
4
- "N2V_TRANSFORMS_UNION",
5
4
  "NORM_AND_SPATIAL_UNION",
6
5
  "SPATIAL_TRANSFORMS_UNION",
7
6
  "N2VManipulateModel",
@@ -16,7 +15,6 @@ from .n2v_manipulate_model import N2VManipulateModel
16
15
  from .normalize_model import NormalizeModel
17
16
  from .transform_model import TransformModel
18
17
  from .transform_unions import (
19
- N2V_TRANSFORMS_UNION,
20
18
  NORM_AND_SPATIAL_UNION,
21
19
  SPATIAL_TRANSFORMS_UNION,
22
20
  )
@@ -7,6 +7,8 @@ from pydantic import ConfigDict, Field, field_validator
7
7
  from .transform_model import TransformModel
8
8
 
9
9
 
10
+ # TODO should probably not be a TransformModel anymore, no reason for it
11
+ # `name` is used as a discriminator field in the transforms
10
12
  class N2VManipulateModel(TransformModel):
11
13
  """
12
14
  Pydantic model used to represent N2V manipulation.
@@ -32,11 +34,24 @@ class N2VManipulateModel(TransformModel):
32
34
  )
33
35
 
34
36
  name: Literal["N2VManipulate"] = "N2VManipulate"
37
+
35
38
  roi_size: int = Field(default=11, ge=3, le=21)
39
+ """Size of the region where the pixel manipulation is applied."""
40
+
36
41
  masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=10.0)
42
+ """Percentage of masked pixels per image."""
43
+
44
+ remove_center: bool = Field(default=True) # TODO remove it
45
+ """Exclude center pixel from average calculation.""" # TODO rephrase this
46
+
37
47
  strategy: Literal["uniform", "median"] = Field(default="uniform")
48
+ """Strategy for pixel value replacement."""
49
+
38
50
  struct_mask_axis: Literal["horizontal", "vertical", "none"] = Field(default="none")
51
+ """Orientation of the structN2V mask. Set to `\"non\"` to not apply StructN2V."""
52
+
39
53
  struct_mask_span: int = Field(default=5, ge=3, le=15)
54
+ """Size of the structN2V mask."""
40
55
 
41
56
  @field_validator("roi_size", "struct_mask_span")
42
57
  @classmethod
@@ -4,7 +4,6 @@ from typing import Annotated, Union
4
4
 
5
5
  from pydantic import Discriminator
6
6
 
7
- from .n2v_manipulate_model import N2VManipulateModel
8
7
  from .normalize_model import NormalizeModel
9
8
  from .xy_flip_model import XYFlipModel
10
9
  from .xy_random_rotate90_model import XYRandomRotate90Model
@@ -14,7 +13,6 @@ NORM_AND_SPATIAL_UNION = Annotated[
14
13
  NormalizeModel,
15
14
  XYFlipModel,
16
15
  XYRandomRotate90Model,
17
- N2VManipulateModel,
18
16
  ],
19
17
  Discriminator("name"), # used to tell the different transform models apart
20
18
  ]
@@ -29,14 +27,3 @@ SPATIAL_TRANSFORMS_UNION = Annotated[
29
27
  Discriminator("name"), # used to tell the different transform models apart
30
28
  ]
31
29
  """Available spatial transforms in CAREamics."""
32
-
33
-
34
- N2V_TRANSFORMS_UNION = Annotated[
35
- Union[
36
- XYFlipModel,
37
- XYRandomRotate90Model,
38
- N2VManipulateModel,
39
- ],
40
- Discriminator("name"), # used to tell the different transform models apart
41
- ]
42
- """Available N2V-compatible transforms in CAREamics."""
@@ -9,7 +9,7 @@ from typing import Callable, Optional, Union
9
9
  from numpy.typing import NDArray
10
10
  from torch.utils.data import get_worker_info
11
11
 
12
- from careamics.config import GeneralDataConfig, InferenceConfig
12
+ from careamics.config import DataConfig, InferenceConfig
13
13
  from careamics.file_io.read import read_tiff
14
14
  from careamics.utils.logging import get_logger
15
15
 
@@ -19,7 +19,7 @@ logger = get_logger(__name__)
19
19
 
20
20
 
21
21
  def iterate_over_files(
22
- data_config: Union[GeneralDataConfig, InferenceConfig],
22
+ data_config: Union[DataConfig, InferenceConfig],
23
23
  data_files: list[Path],
24
24
  target_files: Optional[list[Path]] = None,
25
25
  read_source_func: Callable = read_tiff,
@@ -9,7 +9,7 @@ from typing import Any, Callable, Optional, Union
9
9
  import numpy as np
10
10
  from torch.utils.data import Dataset
11
11
 
12
- from careamics.config import GeneralDataConfig, N2VDataConfig
12
+ from careamics.config import DataConfig
13
13
  from careamics.config.transformations import NormalizeModel
14
14
  from careamics.dataset.patching.patching import (
15
15
  PatchedOutput,
@@ -46,7 +46,7 @@ class InMemoryDataset(Dataset):
46
46
 
47
47
  def __init__(
48
48
  self,
49
- data_config: GeneralDataConfig,
49
+ data_config: DataConfig,
50
50
  inputs: Union[np.ndarray, list[Path]],
51
51
  input_target: Optional[Union[np.ndarray, list[Path]]] = None,
52
52
  read_source_func: Callable = read_tiff,
@@ -215,16 +215,9 @@ class InMemoryDataset(Dataset):
215
215
  if self.data_targets is not None:
216
216
  # get target
217
217
  target = self.data_targets[index]
218
-
219
218
  return self.patch_transform(patch=patch, target=target)
220
219
 
221
- elif isinstance(self.data_config, N2VDataConfig):
222
- return self.patch_transform(patch=patch)
223
- else:
224
- raise ValueError(
225
- "Something went wrong! No target provided (not supervised training) "
226
- "while the algorithm is not Noise2Void."
227
- )
220
+ return self.patch_transform(patch=patch)
228
221
 
229
222
  def get_data_statistics(self) -> tuple[list[float], list[float]]:
230
223
  """Return training data statistics.
@@ -69,7 +69,7 @@ class InMemoryPredDataset(Dataset):
69
69
  """
70
70
  return len(self.data)
71
71
 
72
- def __getitem__(self, index: int) -> NDArray:
72
+ def __getitem__(self, index: int) -> tuple[NDArray, ...]:
73
73
  """
74
74
  Return the patch corresponding to the provided index.
75
75
 
@@ -80,9 +80,7 @@ class InMemoryPredDataset(Dataset):
80
80
 
81
81
  Returns
82
82
  -------
83
- NDArray
83
+ tuple(numpy.ndarray, ...)
84
84
  Transformed patch.
85
85
  """
86
- transformed_patch, _ = self.patch_transform(patch=self.data[index])
87
-
88
- return transformed_patch
86
+ return self.patch_transform(patch=self.data[index])
@@ -107,7 +107,7 @@ class InMemoryTiledPredDataset(Dataset):
107
107
  """
108
108
  return len(self.data)
109
109
 
110
- def __getitem__(self, index: int) -> tuple[NDArray, TileInformation]:
110
+ def __getitem__(self, index: int) -> tuple[tuple[NDArray, ...], TileInformation]:
111
111
  """
112
112
  Return the patch corresponding to the provided index.
113
113
 
@@ -124,6 +124,6 @@ class InMemoryTiledPredDataset(Dataset):
124
124
  tile_array, tile_info = self.data[index]
125
125
 
126
126
  # Apply transforms
127
- transformed_tile, _ = self.patch_transform(patch=tile_array)
127
+ transformed_tile = self.patch_transform(patch=tile_array)
128
128
 
129
129
  return transformed_tile, tile_info
@@ -10,7 +10,7 @@ from typing import Callable, Optional
10
10
  import numpy as np
11
11
  from torch.utils.data import IterableDataset
12
12
 
13
- from careamics.config import GeneralDataConfig
13
+ from careamics.config import DataConfig
14
14
  from careamics.config.transformations import NormalizeModel
15
15
  from careamics.file_io.read import read_tiff
16
16
  from careamics.transforms import Compose
@@ -49,7 +49,7 @@ class PathIterableDataset(IterableDataset):
49
49
 
50
50
  def __init__(
51
51
  self,
52
- data_config: GeneralDataConfig,
52
+ data_config: DataConfig,
53
53
  src_files: list[Path],
54
54
  target_files: Optional[list[Path]] = None,
55
55
  read_source_func: Callable = read_tiff,
@@ -97,13 +97,13 @@ class IterablePredDataset(IterableDataset):
97
97
 
98
98
  def __iter__(
99
99
  self,
100
- ) -> Generator[NDArray, None, None]:
100
+ ) -> Generator[tuple[NDArray, ...], None, None]:
101
101
  """
102
102
  Iterate over data source and yield single patch.
103
103
 
104
104
  Yields
105
105
  ------
106
- NDArray
106
+ (numpy.ndarray, numpy.ndarray or None)
107
107
  Single patch.
108
108
  """
109
109
  assert (
@@ -118,6 +118,4 @@ class IterablePredDataset(IterableDataset):
118
118
  # sample has S dimension
119
119
  for i in range(sample.shape[0]):
120
120
 
121
- transformed_sample, _ = self.patch_transform(patch=sample[i])
122
-
123
- yield transformed_sample
121
+ yield self.patch_transform(patch=sample[i])
@@ -109,13 +109,13 @@ class IterableTiledPredDataset(IterableDataset):
109
109
 
110
110
  def __iter__(
111
111
  self,
112
- ) -> Generator[tuple[NDArray, TileInformation], None, None]:
112
+ ) -> Generator[tuple[tuple[NDArray, ...], TileInformation], None, None]:
113
113
  """
114
114
  Iterate over data source and yield single patch.
115
115
 
116
116
  Yields
117
117
  ------
118
- Generator of NDArray and TileInformation tuple
118
+ Generator of (np.ndarray, np.ndarray or None) and TileInformation tuple
119
119
  Generator of single tiles.
120
120
  """
121
121
  assert (
@@ -136,6 +136,6 @@ class IterableTiledPredDataset(IterableDataset):
136
136
 
137
137
  # apply transform to patches
138
138
  for patch_array, tile_info in patch_gen:
139
- transformed_patch, _ = self.patch_transform(patch=patch_array)
139
+ transformed_patch = self.patch_transform(patch=patch_array)
140
140
 
141
141
  yield transformed_patch, tile_info
@@ -0,0 +1,3 @@
1
+ __all__ = ["CareamicsDataset"]
2
+
3
+ from .dataset import CareamicsDataset
@@ -0,0 +1,184 @@
1
+ from collections.abc import Sequence
2
+ from enum import Enum
3
+ from pathlib import Path
4
+ from typing import Literal, NamedTuple, Optional, Union
5
+
6
+ import numpy as np
7
+ from numpy.typing import NDArray
8
+ from torch.utils.data import Dataset
9
+ from typing_extensions import ParamSpec
10
+
11
+ from careamics.config import DataConfig, InferenceConfig
12
+ from careamics.config.support import SupportedData
13
+ from careamics.dataset.patching.patching import Stats
14
+ from careamics.dataset_ng.patch_extractor import (
15
+ ImageStackLoader,
16
+ PatchExtractor,
17
+ create_patch_extractor,
18
+ )
19
+ from careamics.dataset_ng.patching_strategies import (
20
+ FixedRandomPatchingStrategy,
21
+ PatchingStrategy,
22
+ PatchSpecs,
23
+ RandomPatchingStrategy,
24
+ )
25
+ from careamics.transforms import Compose
26
+
27
+ P = ParamSpec("P")
28
+
29
+
30
+ class Mode(str, Enum):
31
+ TRAINING = "training"
32
+ VALIDATING = "validating"
33
+ PREDICTING = "predicting"
34
+
35
+
36
+ class ImageRegionData(NamedTuple):
37
+ data: NDArray
38
+ source: Union[Path, Literal["array"]]
39
+ data_shape: Sequence[int]
40
+ dtype: str # dtype should be str for collate
41
+ axes: str
42
+ region_spec: PatchSpecs
43
+
44
+
45
+ InputType = Union[Sequence[np.ndarray], Sequence[Path]]
46
+
47
+
48
+ class CareamicsDataset(Dataset):
49
+ def __init__(
50
+ self,
51
+ data_config: Union[DataConfig, InferenceConfig],
52
+ mode: Mode,
53
+ inputs: InputType,
54
+ targets: Optional[InputType] = None,
55
+ image_stack_loader: Optional[ImageStackLoader[P]] = None,
56
+ *args: P.args,
57
+ **kwargs: P.kwargs,
58
+ ):
59
+ self.config = data_config
60
+ self.mode = mode
61
+
62
+ data_type_enum = SupportedData(self.config.data_type)
63
+ self.input_extractor = create_patch_extractor(
64
+ inputs,
65
+ self.config.axes,
66
+ data_type_enum,
67
+ image_stack_loader,
68
+ *args,
69
+ **kwargs,
70
+ )
71
+ if targets is not None:
72
+ self.target_extractor: Optional[PatchExtractor] = create_patch_extractor(
73
+ targets,
74
+ self.config.axes,
75
+ data_type_enum,
76
+ image_stack_loader,
77
+ *args,
78
+ **kwargs,
79
+ )
80
+ else:
81
+ self.target_extractor = None
82
+
83
+ self.patching_strategy = self._initialize_patching_strategy()
84
+
85
+ self.input_stats, self.target_stats = self._initialize_statistics()
86
+
87
+ self.transforms = self._initialize_transforms()
88
+
89
+ def _initialize_patching_strategy(self) -> PatchingStrategy:
90
+ patching_strategy: PatchingStrategy
91
+ if self.mode == Mode.TRAINING:
92
+ if isinstance(self.config, InferenceConfig):
93
+ raise ValueError("Inference config cannot be used for training.")
94
+ patching_strategy = RandomPatchingStrategy(
95
+ data_shapes=self.input_extractor.shape,
96
+ patch_size=self.config.patch_size,
97
+ # TODO: Add random seed to dataconfig
98
+ seed=getattr(self.config, "random_seed", None),
99
+ )
100
+ elif self.mode == Mode.VALIDATING:
101
+ if isinstance(self.config, InferenceConfig):
102
+ raise ValueError("Inference config cannot be used for validating.")
103
+ patching_strategy = FixedRandomPatchingStrategy(
104
+ data_shapes=self.input_extractor.shape,
105
+ patch_size=self.config.patch_size,
106
+ # TODO: Add random seed to dataconfig
107
+ seed=getattr(self.config, "random_seed", None),
108
+ )
109
+ elif self.mode == Mode.PREDICTING:
110
+ # TODO: patching strategy will be tilingStrategy in upcoming PR
111
+ raise NotImplementedError(
112
+ "Prediction mode for the CAREamicsDataset has not been implemented yet."
113
+ )
114
+ else:
115
+ raise ValueError(f"Unrecognised dataset mode {self.mode}.")
116
+
117
+ return patching_strategy
118
+
119
+ def _initialize_transforms(self) -> Optional[Compose]:
120
+ if isinstance(self.config, DataConfig):
121
+ return Compose(
122
+ transform_list=list(self.config.transforms),
123
+ )
124
+ # TODO: add TTA
125
+ return None
126
+
127
+ def _initialize_statistics(self) -> tuple[Stats, Optional[Stats]]:
128
+ # TODO: add running stats
129
+ # Currently assume that stats are provided in the configuration
130
+ input_stats = Stats(self.config.image_means, self.config.image_stds)
131
+ target_stats = None
132
+ if isinstance(self.config, DataConfig):
133
+ target_means = self.config.target_means
134
+ target_stds = self.config.target_stds
135
+ if target_means is not None and target_stds is not None:
136
+ target_stats = Stats(target_means, target_stds)
137
+ return input_stats, target_stats
138
+
139
+ def __len__(self):
140
+ return self.patching_strategy.n_patches
141
+
142
+ def _create_image_region(
143
+ self, patch: np.ndarray, patch_spec: PatchSpecs, extractor: PatchExtractor
144
+ ) -> ImageRegionData:
145
+ data_idx = patch_spec["data_idx"]
146
+ return ImageRegionData(
147
+ data=patch,
148
+ source=extractor.image_stacks[data_idx].source,
149
+ dtype=str(extractor.image_stacks[data_idx].data_dtype),
150
+ data_shape=extractor.image_stacks[data_idx].data_shape,
151
+ # TODO: should it be axes of the original image instead?
152
+ axes=self.config.axes,
153
+ region_spec=patch_spec,
154
+ )
155
+
156
+ def __getitem__(
157
+ self, index: int
158
+ ) -> tuple[ImageRegionData, Optional[ImageRegionData]]:
159
+ patch_spec = self.patching_strategy.get_patch_spec(index)
160
+ input_patch = self.input_extractor.extract_patch(**patch_spec)
161
+
162
+ target_patch = (
163
+ self.target_extractor.extract_patch(**patch_spec)
164
+ if self.target_extractor is not None
165
+ else None
166
+ )
167
+
168
+ if self.transforms is not None:
169
+ input_patch, target_patch = self.transforms(input_patch, target_patch)
170
+
171
+ input_data = self._create_image_region(
172
+ patch=input_patch, patch_spec=patch_spec, extractor=self.input_extractor
173
+ )
174
+
175
+ if target_patch is not None and self.target_extractor is not None:
176
+ target_data = self._create_image_region(
177
+ patch=target_patch,
178
+ patch_spec=patch_spec,
179
+ extractor=self.target_extractor,
180
+ )
181
+ else:
182
+ target_data = None
183
+
184
+ return input_data, target_data