careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (79) hide show
  1. careamics/careamist.py +11 -14
  2. careamics/cli/conf.py +18 -3
  3. careamics/config/__init__.py +8 -0
  4. careamics/config/algorithms/__init__.py +4 -0
  5. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  6. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  7. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  8. careamics/config/algorithms/vae_algorithm_model.py +51 -16
  9. careamics/config/architectures/lvae_model.py +12 -8
  10. careamics/config/callback_model.py +7 -3
  11. careamics/config/configuration.py +15 -63
  12. careamics/config/configuration_factories.py +853 -29
  13. careamics/config/data/data_model.py +50 -11
  14. careamics/config/data/ng_data_model.py +168 -4
  15. careamics/config/data/patch_filter/__init__.py +15 -0
  16. careamics/config/data/patch_filter/filter_model.py +16 -0
  17. careamics/config/data/patch_filter/mask_filter_model.py +17 -0
  18. careamics/config/data/patch_filter/max_filter_model.py +15 -0
  19. careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
  20. careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
  21. careamics/config/inference_model.py +1 -2
  22. careamics/config/likelihood_model.py +2 -2
  23. careamics/config/loss_model.py +6 -2
  24. careamics/config/nm_model.py +26 -1
  25. careamics/config/optimizer_models.py +1 -2
  26. careamics/config/support/supported_algorithms.py +5 -3
  27. careamics/config/support/supported_filters.py +17 -0
  28. careamics/config/support/supported_losses.py +5 -2
  29. careamics/config/training_model.py +6 -36
  30. careamics/config/transformations/normalize_model.py +1 -2
  31. careamics/dataset_ng/dataset.py +57 -5
  32. careamics/dataset_ng/factory.py +101 -18
  33. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  34. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  35. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  36. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  37. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  38. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  39. careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
  40. careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
  41. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  42. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  43. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  44. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  45. careamics/file_io/read/__init__.py +0 -1
  46. careamics/lightning/__init__.py +16 -2
  47. careamics/lightning/callbacks/__init__.py +2 -0
  48. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  49. careamics/lightning/dataset_ng/data_module.py +79 -2
  50. careamics/lightning/lightning_module.py +162 -61
  51. careamics/lightning/microsplit_data_module.py +636 -0
  52. careamics/lightning/predict_data_module.py +8 -1
  53. careamics/lightning/train_data_module.py +19 -8
  54. careamics/losses/__init__.py +7 -1
  55. careamics/losses/loss_factory.py +9 -1
  56. careamics/losses/lvae/losses.py +85 -0
  57. careamics/lvae_training/dataset/__init__.py +8 -8
  58. careamics/lvae_training/dataset/config.py +56 -44
  59. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  60. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  61. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  62. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  63. careamics/lvae_training/eval_utils.py +46 -24
  64. careamics/model_io/bmz_io.py +9 -5
  65. careamics/models/lvae/likelihoods.py +31 -14
  66. careamics/models/lvae/lvae.py +2 -2
  67. careamics/models/lvae/noise_models.py +20 -14
  68. careamics/prediction_utils/__init__.py +8 -2
  69. careamics/prediction_utils/prediction_outputs.py +49 -3
  70. careamics/prediction_utils/stitch_prediction.py +83 -1
  71. careamics/transforms/xy_random_rotate90.py +1 -1
  72. careamics/utils/version.py +4 -4
  73. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
  74. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
  75. careamics/dataset/zarr_dataset.py +0 -151
  76. careamics/file_io/read/zarr.py +0 -60
  77. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
  78. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
  79. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,188 @@
1
+ """Filter patches based on Shannon entropy threshold."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ import numpy as np
6
+ from skimage.measure import shannon_entropy
7
+ from tqdm import tqdm
8
+
9
+ from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
10
+ create_array_extractor,
11
+ )
12
+ from careamics.dataset_ng.patch_filter.patch_filter_protocol import PatchFilterProtocol
13
+ from careamics.dataset_ng.patching_strategies import TilingStrategy
14
+
15
+
16
+ class ShannonPatchFilter(PatchFilterProtocol):
17
+ """
18
+ Filter patches based on Shannon entropy threshold.
19
+
20
+ Attributes
21
+ ----------
22
+ threshold : float
23
+ Threshold for the Shannon entropy of the patch.
24
+ p : float
25
+ Probability of applying the filter to a patch.
26
+ rng : np.random.Generator
27
+ Random number generator for stochastic filtering.
28
+ """
29
+
30
+ def __init__(
31
+ self, threshold: float, p: float = 1.0, seed: int | None = None
32
+ ) -> None:
33
+ """
34
+ Create a ShannonEntropyFilter.
35
+
36
+ This filter removes patches whose Shannon entropy is below a specified
37
+ threshold.
38
+
39
+ Parameters
40
+ ----------
41
+ threshold : float
42
+ Threshold for the Shannon entropy of the patch.
43
+ p : float, default=1
44
+ Probability of applying the filter to a patch. Must be between 0 and 1.
45
+ seed : int | None, default=None
46
+ Seed for the random number generator for reproducibility.
47
+
48
+ Raises
49
+ ------
50
+ ValueError
51
+ If threshold is negative.
52
+ ValueError
53
+ If p is not between 0 and 1.
54
+ """
55
+ if threshold < 0:
56
+ raise ValueError("Threshold must be non-negative.")
57
+ if not (0 <= p <= 1):
58
+ raise ValueError("Probability p must be between 0 and 1.")
59
+
60
+ self.threshold = threshold
61
+
62
+ self.p = p
63
+ self.rng = np.random.default_rng(seed)
64
+
65
+ def filter_out(self, patch: np.ndarray) -> bool:
66
+ """
67
+ Determine whether to filter out a patch based on its Shannon entropy.
68
+
69
+ Parameters
70
+ ----------
71
+ patch : numpy.NDArray
72
+ The patch to evaluate.
73
+
74
+ Returns
75
+ -------
76
+ bool
77
+ True if the patch should be filtered out, False otherwise.
78
+ """
79
+ if self.rng.uniform(0, 1) < self.p:
80
+ return shannon_entropy(patch) < self.threshold
81
+ return False
82
+
83
+ @staticmethod
84
+ def filter_map(
85
+ image: np.ndarray,
86
+ patch_size: Sequence[int],
87
+ ) -> np.ndarray:
88
+ """
89
+ Compute the Shannon entropy map of an image.
90
+
91
+ The entropy is computed over non-overlapping patches. This method can be used
92
+ to assess a useful threshold for the Shannon entropy filter.
93
+
94
+ Parameters
95
+ ----------
96
+ image : numpy.NDArray
97
+ The image for which to compute the entropy map, must be 2D or 3D.
98
+ patch_size : Sequence[int]
99
+ The size of the patches to compute the entropy over. Must be a sequence
100
+ of two integers.
101
+
102
+ Returns
103
+ -------
104
+ numpy.NDArray
105
+ The Shannon entropy map of the patch.
106
+
107
+ Raises
108
+ ------
109
+ ValueError
110
+ If the image is not 2D or 3D.
111
+
112
+ Example
113
+ -------
114
+ The `filter_map` method can be used to assess a useful threshold for the
115
+ Shannon entropy filter. Below is an example of how to compute and visualize
116
+ the Shannon entropy map of a random image and visualize thresholded versions
117
+ of the map.
118
+ >>> import numpy as np
119
+ >>> from matplotlib import pyplot as plt
120
+ >>> from careamics.dataset_ng.patch_filter import ShannonPatchFilter
121
+ >>> rng = np.random.default_rng(42)
122
+ >>> image = rng.binomial(20, 0.1, (256, 256)).astype(np.float32)
123
+ >>> image[64:192, 64:192] += rng.normal(50, 5, (128, 128))
124
+ >>> image[96:160, 96:160] = rng.poisson(image[96:160, 96:160])
125
+ >>> patch_size = (16, 16)
126
+ >>> entropy_map = ShannonPatchFilter.filter_map(image, patch_size)
127
+ >>> fig, ax = plt.subplots(1, 5, figsize=(20, 5)) # doctest: +SKIP
128
+ >>> for i, thresh in enumerate([2 + 1.5 * i for i in range(5)]):
129
+ ... ax[i].imshow(entropy_map >= thresh, cmap="gray") #doctest: +SKIP
130
+ ... ax[i].set_title(f"Threshold: {thresh}") #doctest: +SKIP
131
+ >>> plt.show() # doctest: +SKIP
132
+ """
133
+ if len(image.shape) < 2 or len(image.shape) > 3:
134
+ raise ValueError("Image must be 2D or 3D.")
135
+
136
+ axes = "YX" if len(patch_size) == 2 else "ZYX"
137
+
138
+ shannon_img = np.zeros_like(image, dtype=float)
139
+
140
+ extractor = create_array_extractor(source=[image], axes=axes)
141
+ tiling = TilingStrategy(
142
+ data_shapes=[(1, 1, *image.shape)],
143
+ tile_size=patch_size,
144
+ overlaps=(0,) * len(patch_size), # no overlap
145
+ )
146
+
147
+ for idx in tqdm(range(tiling.n_patches), desc="Computing Shannon Entropy map"):
148
+ patch_spec = tiling.get_patch_spec(idx)
149
+ patch = extractor.extract_patch(
150
+ data_idx=0,
151
+ sample_idx=0,
152
+ coords=patch_spec["coords"],
153
+ patch_size=patch_size,
154
+ )
155
+
156
+ coordinates = tuple(
157
+ slice(patch_spec["coords"][i], patch_spec["coords"][i] + p)
158
+ for i, p in enumerate(patch_size)
159
+ )
160
+ shannon_img[coordinates] = shannon_entropy(patch)
161
+
162
+ return shannon_img
163
+
164
+ @staticmethod
165
+ def apply_filter(
166
+ filter_map: np.ndarray,
167
+ threshold: float,
168
+ ) -> np.ndarray:
169
+ """
170
+ Apply the Shannon entropy filter to a precomputed filter map.
171
+
172
+ The filter map is the output of the `filter_map` method.
173
+
174
+ Parameters
175
+ ----------
176
+ filter_map : numpy.NDArray
177
+ The precomputed Shannon entropy map of the image.
178
+ threshold : float
179
+ The Shannon entropy threshold for filtering.
180
+
181
+ Returns
182
+ -------
183
+ numpy.NDArray
184
+ A boolean array where True indicates that the patch should be kept
185
+ (not filtered out) and False indicates that the patch should be filtered
186
+ out.
187
+ """
188
+ return filter_map >= threshold
@@ -9,4 +9,3 @@ __all__ = [
9
9
 
10
10
  from .get_func import ReadFunc, get_read_func
11
11
  from .tiff import read_tiff
12
- from .zarr import read_zarr
@@ -1,18 +1,32 @@
1
1
  """CAREamics PyTorch Lightning modules."""
2
2
 
3
3
  __all__ = [
4
+ "DataStatsCallback",
4
5
  "FCNModule",
5
6
  "HyperParametersCallback",
7
+ "MicroSplitDataModule",
6
8
  "PredictDataModule",
7
9
  "ProgressBarCallback",
8
10
  "TrainDataModule",
9
11
  "VAEModule",
10
12
  "create_careamics_module",
13
+ "create_microsplit_predict_datamodule",
14
+ "create_microsplit_train_datamodule",
11
15
  "create_predict_datamodule",
12
16
  "create_train_datamodule",
17
+ "create_unet_based_module",
18
+ "create_vae_based_module",
13
19
  ]
14
20
 
15
- from .callbacks import HyperParametersCallback, ProgressBarCallback
21
+ from .callbacks import DataStatsCallback, HyperParametersCallback, ProgressBarCallback
16
22
  from .lightning_module import FCNModule, VAEModule, create_careamics_module
23
+ from .microsplit_data_module import (
24
+ MicroSplitDataModule,
25
+ create_microsplit_predict_datamodule,
26
+ create_microsplit_train_datamodule,
27
+ )
17
28
  from .predict_data_module import PredictDataModule, create_predict_datamodule
18
- from .train_data_module import TrainDataModule, create_train_datamodule
29
+ from .train_data_module import (
30
+ TrainDataModule,
31
+ create_train_datamodule,
32
+ )
@@ -1,11 +1,13 @@
1
1
  """Callbacks module."""
2
2
 
3
3
  __all__ = [
4
+ "DataStatsCallback",
4
5
  "HyperParametersCallback",
5
6
  "PredictionWriterCallback",
6
7
  "ProgressBarCallback",
7
8
  ]
8
9
 
10
+ from .data_stats_callback import DataStatsCallback
9
11
  from .hyperparameters_callback import HyperParametersCallback
10
12
  from .prediction_writer_callback import PredictionWriterCallback
11
13
  from .progress_bar_callback import ProgressBarCallback
@@ -0,0 +1,33 @@
1
+ """Data statistics callback."""
2
+
3
+ import pytorch_lightning as L
4
+ from pytorch_lightning.callbacks import Callback
5
+
6
+
7
+ class DataStatsCallback(Callback):
8
+ """Callback to update model's data statistics from datamodule.
9
+
10
+ This callback ensures that the model has access to the data statistics (mean and
11
+ std) calculated by the datamodule before training starts.
12
+ """
13
+
14
+ def setup(self, trainer: L.Trainer, module: L.LightningModule, stage: str) -> None:
15
+ """Called when trainer is setting up.
16
+
17
+ Parameters
18
+ ----------
19
+ trainer : Lightning.Trainer
20
+ The trainer instance.
21
+ module : Lightning.LightningModule
22
+ The model being trained.
23
+ stage : str
24
+ The current stage of training (e.g., 'fit', 'validate', 'test', 'predict').
25
+ """
26
+ if stage == "fit":
27
+ # Get data statistics from datamodule
28
+ (data_mean, data_std), _ = trainer.datamodule.get_data_stats()
29
+
30
+ # Set data statistics in the model's likelihood module
31
+ module.noise_model_likelihood.set_data_stats(
32
+ data_mean=data_mean["target"], data_std=data_std["target"]
33
+ )
@@ -39,6 +39,10 @@ class CareamicsDataModule(L.LightningDataModule):
39
39
  train_data_target : Optional[InputType]
40
40
  Training data target, can be a path to a folder,
41
41
  a list of paths, or a numpy array.
42
+ train_data_mask : InputType (when filtering is needed)
43
+ Training data mask, can be a path to a folder,
44
+ a list of paths, or a numpy array. Used for coordinate filtering.
45
+ Only required when using coordinate-based patch filtering.
42
46
  val_data : Optional[InputType]
43
47
  Validation data, can be a path to a folder,
44
48
  a list of paths, or a numpy array.
@@ -99,6 +103,9 @@ class CareamicsDataModule(L.LightningDataModule):
99
103
  train_data_target : Optional[Any]
100
104
  Training data target, can be a path to a folder, a list of paths, or a numpy
101
105
  array.
106
+ train_data_mask : Optional[Any]
107
+ Training data mask, can be a path to a folder, a list of paths, or a numpy
108
+ array.
102
109
  val_data : Optional[Any]
103
110
  Validation data, can be a path to a folder, a list of paths, or a numpy array.
104
111
  val_data_target : Optional[Any]
@@ -118,7 +125,7 @@ class CareamicsDataModule(L.LightningDataModule):
118
125
  If input and target data types are not consistent.
119
126
  """
120
127
 
121
- # standard use
128
+ # standard use (no mask)
122
129
  @overload
123
130
  def __init__(
124
131
  self,
@@ -136,7 +143,26 @@ class CareamicsDataModule(L.LightningDataModule):
136
143
  use_in_memory: bool = True,
137
144
  ) -> None: ...
138
145
 
139
- # custom read function
146
+ # with training mask for filtering
147
+ @overload
148
+ def __init__(
149
+ self,
150
+ data_config: NGDataConfig,
151
+ *,
152
+ train_data: InputType | None = None,
153
+ train_data_target: InputType | None = None,
154
+ train_data_mask: InputType,
155
+ val_data: InputType | None = None,
156
+ val_data_target: InputType | None = None,
157
+ pred_data: InputType | None = None,
158
+ pred_data_target: InputType | None = None,
159
+ extension_filter: str = "",
160
+ val_percentage: float | None = None,
161
+ val_minimum_split: int = 5,
162
+ use_in_memory: bool = True,
163
+ ) -> None: ...
164
+
165
+ # custom read function (no mask)
140
166
  @overload
141
167
  def __init__(
142
168
  self,
@@ -156,6 +182,48 @@ class CareamicsDataModule(L.LightningDataModule):
156
182
  use_in_memory: bool = True,
157
183
  ) -> None: ...
158
184
 
185
+ # custom read function with training mask
186
+ @overload
187
+ def __init__(
188
+ self,
189
+ data_config: NGDataConfig,
190
+ *,
191
+ train_data: InputType | None = None,
192
+ train_data_target: InputType | None = None,
193
+ train_data_mask: InputType,
194
+ val_data: InputType | None = None,
195
+ val_data_target: InputType | None = None,
196
+ pred_data: InputType | None = None,
197
+ pred_data_target: InputType | None = None,
198
+ read_source_func: Callable,
199
+ read_kwargs: dict[str, Any] | None = None,
200
+ extension_filter: str = "",
201
+ val_percentage: float | None = None,
202
+ val_minimum_split: int = 5,
203
+ use_in_memory: bool = True,
204
+ ) -> None: ...
205
+
206
+ # image stack loader (no mask)
207
+ @overload
208
+ def __init__(
209
+ self,
210
+ data_config: NGDataConfig,
211
+ *,
212
+ train_data: Any | None = None,
213
+ train_data_target: Any | None = None,
214
+ val_data: Any | None = None,
215
+ val_data_target: Any | None = None,
216
+ pred_data: Any | None = None,
217
+ pred_data_target: Any | None = None,
218
+ image_stack_loader: ImageStackLoader,
219
+ image_stack_loader_kwargs: dict[str, Any] | None = None,
220
+ extension_filter: str = "",
221
+ val_percentage: float | None = None,
222
+ val_minimum_split: int = 5,
223
+ use_in_memory: bool = True,
224
+ ) -> None: ...
225
+
226
+ # image stack loader with training mask
159
227
  @overload
160
228
  def __init__(
161
229
  self,
@@ -163,6 +231,7 @@ class CareamicsDataModule(L.LightningDataModule):
163
231
  *,
164
232
  train_data: Any | None = None,
165
233
  train_data_target: Any | None = None,
234
+ train_data_mask: Any,
166
235
  val_data: Any | None = None,
167
236
  val_data_target: Any | None = None,
168
237
  pred_data: Any | None = None,
@@ -181,6 +250,7 @@ class CareamicsDataModule(L.LightningDataModule):
181
250
  *,
182
251
  train_data: Any | None = None,
183
252
  train_data_target: Any | None = None,
253
+ train_data_mask: Any | None = None,
184
254
  val_data: Any | None = None,
185
255
  val_data_target: Any | None = None,
186
256
  pred_data: Any | None = None,
@@ -209,6 +279,10 @@ class CareamicsDataModule(L.LightningDataModule):
209
279
  train_data_target : Optional[InputType]
210
280
  Training data target, can be a path to a folder,
211
281
  a list of paths, or a numpy array.
282
+ train_data_mask : InputType (when filtering is needed)
283
+ Training data mask, can be a path to a folder,
284
+ a list of paths, or a numpy array. Used for coordinate filtering.
285
+ Only required when using coordinate-based patch filtering.
212
286
  val_data : Optional[InputType]
213
287
  Validation data, can be a path to a folder,
214
288
  a list of paths, or a numpy array.
@@ -268,6 +342,8 @@ class CareamicsDataModule(L.LightningDataModule):
268
342
  self.train_data, self.train_data_target = self._initialize_data_pair(
269
343
  train_data, train_data_target
270
344
  )
345
+ self.train_data_mask, _ = self._initialize_data_pair(train_data_mask, None)
346
+
271
347
  self.val_data, self.val_data_target = self._initialize_data_pair(
272
348
  val_data, val_data_target
273
349
  )
@@ -574,6 +650,7 @@ class CareamicsDataModule(L.LightningDataModule):
574
650
  mode=Mode.TRAINING,
575
651
  inputs=self.train_data,
576
652
  targets=self.train_data_target,
653
+ masks=self.train_data_mask,
577
654
  config=self.config,
578
655
  in_memory=self.use_in_memory,
579
656
  read_func=self.read_source_func,