careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +11 -14
- careamics/cli/conf.py +18 -3
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +51 -16
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +7 -3
- careamics/config/configuration.py +15 -63
- careamics/config/configuration_factories.py +853 -29
- careamics/config/data/data_model.py +50 -11
- careamics/config/data/ng_data_model.py +168 -4
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_model.py +16 -0
- careamics/config/data/patch_filter/mask_filter_model.py +17 -0
- careamics/config/data/patch_filter/max_filter_model.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
- careamics/config/inference_model.py +1 -2
- careamics/config/likelihood_model.py +2 -2
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +26 -1
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +6 -36
- careamics/config/transformations/normalize_model.py +1 -2
- careamics/dataset_ng/dataset.py +57 -5
- careamics/dataset_ng/factory.py +101 -18
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/dataset_ng/data_module.py +79 -2
- careamics/lightning/lightning_module.py +162 -61
- careamics/lightning/microsplit_data_module.py +636 -0
- careamics/lightning/predict_data_module.py +8 -1
- careamics/lightning/train_data_module.py +19 -8
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +85 -0
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/lvae_training/eval_utils.py +46 -24
- careamics/model_io/bmz_io.py +9 -5
- careamics/models/lvae/likelihoods.py +31 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/prediction_outputs.py +49 -3
- careamics/prediction_utils/stitch_prediction.py +83 -1
- careamics/transforms/xy_random_rotate90.py +1 -1
- careamics/utils/version.py +4 -4
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Filter patches based on Shannon entropy threshold."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from skimage.measure import shannon_entropy
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
|
|
9
|
+
from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
|
|
10
|
+
create_array_extractor,
|
|
11
|
+
)
|
|
12
|
+
from careamics.dataset_ng.patch_filter.patch_filter_protocol import PatchFilterProtocol
|
|
13
|
+
from careamics.dataset_ng.patching_strategies import TilingStrategy
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ShannonPatchFilter(PatchFilterProtocol):
|
|
17
|
+
"""
|
|
18
|
+
Filter patches based on Shannon entropy threshold.
|
|
19
|
+
|
|
20
|
+
Attributes
|
|
21
|
+
----------
|
|
22
|
+
threshold : float
|
|
23
|
+
Threshold for the Shannon entropy of the patch.
|
|
24
|
+
p : float
|
|
25
|
+
Probability of applying the filter to a patch.
|
|
26
|
+
rng : np.random.Generator
|
|
27
|
+
Random number generator for stochastic filtering.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self, threshold: float, p: float = 1.0, seed: int | None = None
|
|
32
|
+
) -> None:
|
|
33
|
+
"""
|
|
34
|
+
Create a ShannonEntropyFilter.
|
|
35
|
+
|
|
36
|
+
This filter removes patches whose Shannon entropy is below a specified
|
|
37
|
+
threshold.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
threshold : float
|
|
42
|
+
Threshold for the Shannon entropy of the patch.
|
|
43
|
+
p : float, default=1
|
|
44
|
+
Probability of applying the filter to a patch. Must be between 0 and 1.
|
|
45
|
+
seed : int | None, default=None
|
|
46
|
+
Seed for the random number generator for reproducibility.
|
|
47
|
+
|
|
48
|
+
Raises
|
|
49
|
+
------
|
|
50
|
+
ValueError
|
|
51
|
+
If threshold is negative.
|
|
52
|
+
ValueError
|
|
53
|
+
If p is not between 0 and 1.
|
|
54
|
+
"""
|
|
55
|
+
if threshold < 0:
|
|
56
|
+
raise ValueError("Threshold must be non-negative.")
|
|
57
|
+
if not (0 <= p <= 1):
|
|
58
|
+
raise ValueError("Probability p must be between 0 and 1.")
|
|
59
|
+
|
|
60
|
+
self.threshold = threshold
|
|
61
|
+
|
|
62
|
+
self.p = p
|
|
63
|
+
self.rng = np.random.default_rng(seed)
|
|
64
|
+
|
|
65
|
+
def filter_out(self, patch: np.ndarray) -> bool:
|
|
66
|
+
"""
|
|
67
|
+
Determine whether to filter out a patch based on its Shannon entropy.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
patch : numpy.NDArray
|
|
72
|
+
The patch to evaluate.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
bool
|
|
77
|
+
True if the patch should be filtered out, False otherwise.
|
|
78
|
+
"""
|
|
79
|
+
if self.rng.uniform(0, 1) < self.p:
|
|
80
|
+
return shannon_entropy(patch) < self.threshold
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def filter_map(
|
|
85
|
+
image: np.ndarray,
|
|
86
|
+
patch_size: Sequence[int],
|
|
87
|
+
) -> np.ndarray:
|
|
88
|
+
"""
|
|
89
|
+
Compute the Shannon entropy map of an image.
|
|
90
|
+
|
|
91
|
+
The entropy is computed over non-overlapping patches. This method can be used
|
|
92
|
+
to assess a useful threshold for the Shannon entropy filter.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
image : numpy.NDArray
|
|
97
|
+
The image for which to compute the entropy map, must be 2D or 3D.
|
|
98
|
+
patch_size : Sequence[int]
|
|
99
|
+
The size of the patches to compute the entropy over. Must be a sequence
|
|
100
|
+
of two integers.
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
numpy.NDArray
|
|
105
|
+
The Shannon entropy map of the patch.
|
|
106
|
+
|
|
107
|
+
Raises
|
|
108
|
+
------
|
|
109
|
+
ValueError
|
|
110
|
+
If the image is not 2D or 3D.
|
|
111
|
+
|
|
112
|
+
Example
|
|
113
|
+
-------
|
|
114
|
+
The `filter_map` method can be used to assess a useful threshold for the
|
|
115
|
+
Shannon entropy filter. Below is an example of how to compute and visualize
|
|
116
|
+
the Shannon entropy map of a random image and visualize thresholded versions
|
|
117
|
+
of the map.
|
|
118
|
+
>>> import numpy as np
|
|
119
|
+
>>> from matplotlib import pyplot as plt
|
|
120
|
+
>>> from careamics.dataset_ng.patch_filter import ShannonPatchFilter
|
|
121
|
+
>>> rng = np.random.default_rng(42)
|
|
122
|
+
>>> image = rng.binomial(20, 0.1, (256, 256)).astype(np.float32)
|
|
123
|
+
>>> image[64:192, 64:192] += rng.normal(50, 5, (128, 128))
|
|
124
|
+
>>> image[96:160, 96:160] = rng.poisson(image[96:160, 96:160])
|
|
125
|
+
>>> patch_size = (16, 16)
|
|
126
|
+
>>> entropy_map = ShannonPatchFilter.filter_map(image, patch_size)
|
|
127
|
+
>>> fig, ax = plt.subplots(1, 5, figsize=(20, 5)) # doctest: +SKIP
|
|
128
|
+
>>> for i, thresh in enumerate([2 + 1.5 * i for i in range(5)]):
|
|
129
|
+
... ax[i].imshow(entropy_map >= thresh, cmap="gray") #doctest: +SKIP
|
|
130
|
+
... ax[i].set_title(f"Threshold: {thresh}") #doctest: +SKIP
|
|
131
|
+
>>> plt.show() # doctest: +SKIP
|
|
132
|
+
"""
|
|
133
|
+
if len(image.shape) < 2 or len(image.shape) > 3:
|
|
134
|
+
raise ValueError("Image must be 2D or 3D.")
|
|
135
|
+
|
|
136
|
+
axes = "YX" if len(patch_size) == 2 else "ZYX"
|
|
137
|
+
|
|
138
|
+
shannon_img = np.zeros_like(image, dtype=float)
|
|
139
|
+
|
|
140
|
+
extractor = create_array_extractor(source=[image], axes=axes)
|
|
141
|
+
tiling = TilingStrategy(
|
|
142
|
+
data_shapes=[(1, 1, *image.shape)],
|
|
143
|
+
tile_size=patch_size,
|
|
144
|
+
overlaps=(0,) * len(patch_size), # no overlap
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
for idx in tqdm(range(tiling.n_patches), desc="Computing Shannon Entropy map"):
|
|
148
|
+
patch_spec = tiling.get_patch_spec(idx)
|
|
149
|
+
patch = extractor.extract_patch(
|
|
150
|
+
data_idx=0,
|
|
151
|
+
sample_idx=0,
|
|
152
|
+
coords=patch_spec["coords"],
|
|
153
|
+
patch_size=patch_size,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
coordinates = tuple(
|
|
157
|
+
slice(patch_spec["coords"][i], patch_spec["coords"][i] + p)
|
|
158
|
+
for i, p in enumerate(patch_size)
|
|
159
|
+
)
|
|
160
|
+
shannon_img[coordinates] = shannon_entropy(patch)
|
|
161
|
+
|
|
162
|
+
return shannon_img
|
|
163
|
+
|
|
164
|
+
@staticmethod
|
|
165
|
+
def apply_filter(
|
|
166
|
+
filter_map: np.ndarray,
|
|
167
|
+
threshold: float,
|
|
168
|
+
) -> np.ndarray:
|
|
169
|
+
"""
|
|
170
|
+
Apply the Shannon entropy filter to a precomputed filter map.
|
|
171
|
+
|
|
172
|
+
The filter map is the output of the `filter_map` method.
|
|
173
|
+
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
filter_map : numpy.NDArray
|
|
177
|
+
The precomputed Shannon entropy map of the image.
|
|
178
|
+
threshold : float
|
|
179
|
+
The Shannon entropy threshold for filtering.
|
|
180
|
+
|
|
181
|
+
Returns
|
|
182
|
+
-------
|
|
183
|
+
numpy.NDArray
|
|
184
|
+
A boolean array where True indicates that the patch should be kept
|
|
185
|
+
(not filtered out) and False indicates that the patch should be filtered
|
|
186
|
+
out.
|
|
187
|
+
"""
|
|
188
|
+
return filter_map >= threshold
|
careamics/lightning/__init__.py
CHANGED
|
@@ -1,18 +1,32 @@
|
|
|
1
1
|
"""CAREamics PyTorch Lightning modules."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
+
"DataStatsCallback",
|
|
4
5
|
"FCNModule",
|
|
5
6
|
"HyperParametersCallback",
|
|
7
|
+
"MicroSplitDataModule",
|
|
6
8
|
"PredictDataModule",
|
|
7
9
|
"ProgressBarCallback",
|
|
8
10
|
"TrainDataModule",
|
|
9
11
|
"VAEModule",
|
|
10
12
|
"create_careamics_module",
|
|
13
|
+
"create_microsplit_predict_datamodule",
|
|
14
|
+
"create_microsplit_train_datamodule",
|
|
11
15
|
"create_predict_datamodule",
|
|
12
16
|
"create_train_datamodule",
|
|
17
|
+
"create_unet_based_module",
|
|
18
|
+
"create_vae_based_module",
|
|
13
19
|
]
|
|
14
20
|
|
|
15
|
-
from .callbacks import HyperParametersCallback, ProgressBarCallback
|
|
21
|
+
from .callbacks import DataStatsCallback, HyperParametersCallback, ProgressBarCallback
|
|
16
22
|
from .lightning_module import FCNModule, VAEModule, create_careamics_module
|
|
23
|
+
from .microsplit_data_module import (
|
|
24
|
+
MicroSplitDataModule,
|
|
25
|
+
create_microsplit_predict_datamodule,
|
|
26
|
+
create_microsplit_train_datamodule,
|
|
27
|
+
)
|
|
17
28
|
from .predict_data_module import PredictDataModule, create_predict_datamodule
|
|
18
|
-
from .train_data_module import
|
|
29
|
+
from .train_data_module import (
|
|
30
|
+
TrainDataModule,
|
|
31
|
+
create_train_datamodule,
|
|
32
|
+
)
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
"""Callbacks module."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
+
"DataStatsCallback",
|
|
4
5
|
"HyperParametersCallback",
|
|
5
6
|
"PredictionWriterCallback",
|
|
6
7
|
"ProgressBarCallback",
|
|
7
8
|
]
|
|
8
9
|
|
|
10
|
+
from .data_stats_callback import DataStatsCallback
|
|
9
11
|
from .hyperparameters_callback import HyperParametersCallback
|
|
10
12
|
from .prediction_writer_callback import PredictionWriterCallback
|
|
11
13
|
from .progress_bar_callback import ProgressBarCallback
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Data statistics callback."""
|
|
2
|
+
|
|
3
|
+
import pytorch_lightning as L
|
|
4
|
+
from pytorch_lightning.callbacks import Callback
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DataStatsCallback(Callback):
|
|
8
|
+
"""Callback to update model's data statistics from datamodule.
|
|
9
|
+
|
|
10
|
+
This callback ensures that the model has access to the data statistics (mean and
|
|
11
|
+
std) calculated by the datamodule before training starts.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def setup(self, trainer: L.Trainer, module: L.LightningModule, stage: str) -> None:
|
|
15
|
+
"""Called when trainer is setting up.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
trainer : Lightning.Trainer
|
|
20
|
+
The trainer instance.
|
|
21
|
+
module : Lightning.LightningModule
|
|
22
|
+
The model being trained.
|
|
23
|
+
stage : str
|
|
24
|
+
The current stage of training (e.g., 'fit', 'validate', 'test', 'predict').
|
|
25
|
+
"""
|
|
26
|
+
if stage == "fit":
|
|
27
|
+
# Get data statistics from datamodule
|
|
28
|
+
(data_mean, data_std), _ = trainer.datamodule.get_data_stats()
|
|
29
|
+
|
|
30
|
+
# Set data statistics in the model's likelihood module
|
|
31
|
+
module.noise_model_likelihood.set_data_stats(
|
|
32
|
+
data_mean=data_mean["target"], data_std=data_std["target"]
|
|
33
|
+
)
|
|
@@ -39,6 +39,10 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
39
39
|
train_data_target : Optional[InputType]
|
|
40
40
|
Training data target, can be a path to a folder,
|
|
41
41
|
a list of paths, or a numpy array.
|
|
42
|
+
train_data_mask : InputType (when filtering is needed)
|
|
43
|
+
Training data mask, can be a path to a folder,
|
|
44
|
+
a list of paths, or a numpy array. Used for coordinate filtering.
|
|
45
|
+
Only required when using coordinate-based patch filtering.
|
|
42
46
|
val_data : Optional[InputType]
|
|
43
47
|
Validation data, can be a path to a folder,
|
|
44
48
|
a list of paths, or a numpy array.
|
|
@@ -99,6 +103,9 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
99
103
|
train_data_target : Optional[Any]
|
|
100
104
|
Training data target, can be a path to a folder, a list of paths, or a numpy
|
|
101
105
|
array.
|
|
106
|
+
train_data_mask : Optional[Any]
|
|
107
|
+
Training data mask, can be a path to a folder, a list of paths, or a numpy
|
|
108
|
+
array.
|
|
102
109
|
val_data : Optional[Any]
|
|
103
110
|
Validation data, can be a path to a folder, a list of paths, or a numpy array.
|
|
104
111
|
val_data_target : Optional[Any]
|
|
@@ -118,7 +125,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
118
125
|
If input and target data types are not consistent.
|
|
119
126
|
"""
|
|
120
127
|
|
|
121
|
-
# standard use
|
|
128
|
+
# standard use (no mask)
|
|
122
129
|
@overload
|
|
123
130
|
def __init__(
|
|
124
131
|
self,
|
|
@@ -136,7 +143,26 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
136
143
|
use_in_memory: bool = True,
|
|
137
144
|
) -> None: ...
|
|
138
145
|
|
|
139
|
-
#
|
|
146
|
+
# with training mask for filtering
|
|
147
|
+
@overload
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
data_config: NGDataConfig,
|
|
151
|
+
*,
|
|
152
|
+
train_data: InputType | None = None,
|
|
153
|
+
train_data_target: InputType | None = None,
|
|
154
|
+
train_data_mask: InputType,
|
|
155
|
+
val_data: InputType | None = None,
|
|
156
|
+
val_data_target: InputType | None = None,
|
|
157
|
+
pred_data: InputType | None = None,
|
|
158
|
+
pred_data_target: InputType | None = None,
|
|
159
|
+
extension_filter: str = "",
|
|
160
|
+
val_percentage: float | None = None,
|
|
161
|
+
val_minimum_split: int = 5,
|
|
162
|
+
use_in_memory: bool = True,
|
|
163
|
+
) -> None: ...
|
|
164
|
+
|
|
165
|
+
# custom read function (no mask)
|
|
140
166
|
@overload
|
|
141
167
|
def __init__(
|
|
142
168
|
self,
|
|
@@ -156,6 +182,48 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
156
182
|
use_in_memory: bool = True,
|
|
157
183
|
) -> None: ...
|
|
158
184
|
|
|
185
|
+
# custom read function with training mask
|
|
186
|
+
@overload
|
|
187
|
+
def __init__(
|
|
188
|
+
self,
|
|
189
|
+
data_config: NGDataConfig,
|
|
190
|
+
*,
|
|
191
|
+
train_data: InputType | None = None,
|
|
192
|
+
train_data_target: InputType | None = None,
|
|
193
|
+
train_data_mask: InputType,
|
|
194
|
+
val_data: InputType | None = None,
|
|
195
|
+
val_data_target: InputType | None = None,
|
|
196
|
+
pred_data: InputType | None = None,
|
|
197
|
+
pred_data_target: InputType | None = None,
|
|
198
|
+
read_source_func: Callable,
|
|
199
|
+
read_kwargs: dict[str, Any] | None = None,
|
|
200
|
+
extension_filter: str = "",
|
|
201
|
+
val_percentage: float | None = None,
|
|
202
|
+
val_minimum_split: int = 5,
|
|
203
|
+
use_in_memory: bool = True,
|
|
204
|
+
) -> None: ...
|
|
205
|
+
|
|
206
|
+
# image stack loader (no mask)
|
|
207
|
+
@overload
|
|
208
|
+
def __init__(
|
|
209
|
+
self,
|
|
210
|
+
data_config: NGDataConfig,
|
|
211
|
+
*,
|
|
212
|
+
train_data: Any | None = None,
|
|
213
|
+
train_data_target: Any | None = None,
|
|
214
|
+
val_data: Any | None = None,
|
|
215
|
+
val_data_target: Any | None = None,
|
|
216
|
+
pred_data: Any | None = None,
|
|
217
|
+
pred_data_target: Any | None = None,
|
|
218
|
+
image_stack_loader: ImageStackLoader,
|
|
219
|
+
image_stack_loader_kwargs: dict[str, Any] | None = None,
|
|
220
|
+
extension_filter: str = "",
|
|
221
|
+
val_percentage: float | None = None,
|
|
222
|
+
val_minimum_split: int = 5,
|
|
223
|
+
use_in_memory: bool = True,
|
|
224
|
+
) -> None: ...
|
|
225
|
+
|
|
226
|
+
# image stack loader with training mask
|
|
159
227
|
@overload
|
|
160
228
|
def __init__(
|
|
161
229
|
self,
|
|
@@ -163,6 +231,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
163
231
|
*,
|
|
164
232
|
train_data: Any | None = None,
|
|
165
233
|
train_data_target: Any | None = None,
|
|
234
|
+
train_data_mask: Any,
|
|
166
235
|
val_data: Any | None = None,
|
|
167
236
|
val_data_target: Any | None = None,
|
|
168
237
|
pred_data: Any | None = None,
|
|
@@ -181,6 +250,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
181
250
|
*,
|
|
182
251
|
train_data: Any | None = None,
|
|
183
252
|
train_data_target: Any | None = None,
|
|
253
|
+
train_data_mask: Any | None = None,
|
|
184
254
|
val_data: Any | None = None,
|
|
185
255
|
val_data_target: Any | None = None,
|
|
186
256
|
pred_data: Any | None = None,
|
|
@@ -209,6 +279,10 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
209
279
|
train_data_target : Optional[InputType]
|
|
210
280
|
Training data target, can be a path to a folder,
|
|
211
281
|
a list of paths, or a numpy array.
|
|
282
|
+
train_data_mask : InputType (when filtering is needed)
|
|
283
|
+
Training data mask, can be a path to a folder,
|
|
284
|
+
a list of paths, or a numpy array. Used for coordinate filtering.
|
|
285
|
+
Only required when using coordinate-based patch filtering.
|
|
212
286
|
val_data : Optional[InputType]
|
|
213
287
|
Validation data, can be a path to a folder,
|
|
214
288
|
a list of paths, or a numpy array.
|
|
@@ -268,6 +342,8 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
268
342
|
self.train_data, self.train_data_target = self._initialize_data_pair(
|
|
269
343
|
train_data, train_data_target
|
|
270
344
|
)
|
|
345
|
+
self.train_data_mask, _ = self._initialize_data_pair(train_data_mask, None)
|
|
346
|
+
|
|
271
347
|
self.val_data, self.val_data_target = self._initialize_data_pair(
|
|
272
348
|
val_data, val_data_target
|
|
273
349
|
)
|
|
@@ -574,6 +650,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
574
650
|
mode=Mode.TRAINING,
|
|
575
651
|
inputs=self.train_data,
|
|
576
652
|
targets=self.train_data_target,
|
|
653
|
+
masks=self.train_data_mask,
|
|
577
654
|
config=self.config,
|
|
578
655
|
in_memory=self.use_in_memory,
|
|
579
656
|
read_func=self.read_source_func,
|