careamics 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +20 -4
- careamics/config/configuration.py +10 -5
- careamics/config/data/data_model.py +38 -1
- careamics/config/optimizer_models.py +1 -3
- careamics/config/training_model.py +0 -2
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/dataset.py +233 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
- careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
- careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
- careamics/dataset_ng/factory.py +408 -0
- careamics/dataset_ng/legacy_interoperability.py +168 -0
- careamics/dataset_ng/patch_extractor/__init__.py +3 -8
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
- careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
- careamics/dataset_ng/patching_strategies/__init__.py +6 -1
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
- careamics/lightning/dataset_ng/data_module.py +488 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
- careamics/lightning/lightning_module.py +3 -0
- careamics/lvae_training/dataset/__init__.py +8 -3
- careamics/lvae_training/dataset/config.py +3 -3
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +46 -17
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/types.py +3 -3
- careamics/lvae_training/dataset/utils/index_manager.py +259 -0
- careamics/lvae_training/eval_utils.py +93 -3
- careamics/transforms/compose.py +1 -0
- careamics/transforms/normalize.py +18 -7
- careamics/utils/lightning_utils.py +25 -11
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/RECORD +50 -35
- careamics/dataset_ng/dataset/__init__.py +0 -3
- careamics/dataset_ng/dataset/dataset.py +0 -184
- careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,488 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Any, Callable, Optional, Union, overload
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pytorch_lightning as L
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
from torch.utils.data._utils.collate import default_collate
|
|
9
|
+
|
|
10
|
+
from careamics.config.data import DataConfig
|
|
11
|
+
from careamics.config.support import SupportedData
|
|
12
|
+
from careamics.dataset.dataset_utils import list_files, validate_source_target_files
|
|
13
|
+
from careamics.dataset_ng.dataset import Mode
|
|
14
|
+
from careamics.dataset_ng.factory import create_dataset
|
|
15
|
+
from careamics.dataset_ng.patch_extractor import ImageStackLoader
|
|
16
|
+
from careamics.utils import get_logger
|
|
17
|
+
|
|
18
|
+
logger = get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
ItemType = Union[Path, str, NDArray[Any]]
|
|
21
|
+
InputType = Union[ItemType, list[ItemType], None]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CareamicsDataModule(L.LightningDataModule):
|
|
25
|
+
"""Data module for Careamics dataset."""
|
|
26
|
+
|
|
27
|
+
@overload
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
data_config: DataConfig,
|
|
31
|
+
*,
|
|
32
|
+
train_data: Optional[InputType] = None,
|
|
33
|
+
train_data_target: Optional[InputType] = None,
|
|
34
|
+
val_data: Optional[InputType] = None,
|
|
35
|
+
val_data_target: Optional[InputType] = None,
|
|
36
|
+
pred_data: Optional[InputType] = None,
|
|
37
|
+
pred_data_target: Optional[InputType] = None,
|
|
38
|
+
extension_filter: str = "",
|
|
39
|
+
val_percentage: Optional[float] = None,
|
|
40
|
+
val_minimum_split: int = 5,
|
|
41
|
+
use_in_memory: bool = True,
|
|
42
|
+
) -> None: ...
|
|
43
|
+
|
|
44
|
+
@overload
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
data_config: DataConfig,
|
|
48
|
+
*,
|
|
49
|
+
train_data: Optional[InputType] = None,
|
|
50
|
+
train_data_target: Optional[InputType] = None,
|
|
51
|
+
val_data: Optional[InputType] = None,
|
|
52
|
+
val_data_target: Optional[InputType] = None,
|
|
53
|
+
pred_data: Optional[InputType] = None,
|
|
54
|
+
pred_data_target: Optional[InputType] = None,
|
|
55
|
+
read_source_func: Callable,
|
|
56
|
+
read_kwargs: Optional[dict[str, Any]] = None,
|
|
57
|
+
extension_filter: str = "",
|
|
58
|
+
val_percentage: Optional[float] = None,
|
|
59
|
+
val_minimum_split: int = 5,
|
|
60
|
+
use_in_memory: bool = True,
|
|
61
|
+
) -> None: ...
|
|
62
|
+
|
|
63
|
+
@overload
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
data_config: DataConfig,
|
|
67
|
+
*,
|
|
68
|
+
train_data: Optional[Any] = None,
|
|
69
|
+
train_data_target: Optional[Any] = None,
|
|
70
|
+
val_data: Optional[Any] = None,
|
|
71
|
+
val_data_target: Optional[Any] = None,
|
|
72
|
+
pred_data: Optional[Any] = None,
|
|
73
|
+
pred_data_target: Optional[Any] = None,
|
|
74
|
+
image_stack_loader: ImageStackLoader,
|
|
75
|
+
image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
|
|
76
|
+
extension_filter: str = "",
|
|
77
|
+
val_percentage: Optional[float] = None,
|
|
78
|
+
val_minimum_split: int = 5,
|
|
79
|
+
use_in_memory: bool = True,
|
|
80
|
+
) -> None: ...
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
data_config: DataConfig,
|
|
85
|
+
*,
|
|
86
|
+
train_data: Optional[Any] = None,
|
|
87
|
+
train_data_target: Optional[Any] = None,
|
|
88
|
+
val_data: Optional[Any] = None,
|
|
89
|
+
val_data_target: Optional[Any] = None,
|
|
90
|
+
pred_data: Optional[Any] = None,
|
|
91
|
+
pred_data_target: Optional[Any] = None,
|
|
92
|
+
read_source_func: Optional[Callable] = None,
|
|
93
|
+
read_kwargs: Optional[dict[str, Any]] = None,
|
|
94
|
+
image_stack_loader: Optional[ImageStackLoader] = None,
|
|
95
|
+
image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
|
|
96
|
+
extension_filter: str = "",
|
|
97
|
+
val_percentage: Optional[float] = None,
|
|
98
|
+
val_minimum_split: int = 5,
|
|
99
|
+
use_in_memory: bool = True,
|
|
100
|
+
) -> None:
|
|
101
|
+
"""
|
|
102
|
+
Data module for Careamics dataset initialization.
|
|
103
|
+
|
|
104
|
+
Create a lightning datamodule that handles creating datasets for training,
|
|
105
|
+
validation, and prediction.
|
|
106
|
+
|
|
107
|
+
Parameters
|
|
108
|
+
----------
|
|
109
|
+
data_config : DataConfig
|
|
110
|
+
Pydantic model for CAREamics data configuration.
|
|
111
|
+
train_data : Optional[InputType]
|
|
112
|
+
Training data, can be a path to a folder, a list of paths, or a numpy array.
|
|
113
|
+
train_data_target : Optional[InputType]
|
|
114
|
+
Training data target, can be a path to a folder,
|
|
115
|
+
a list of paths, or a numpy array.
|
|
116
|
+
val_data : Optional[InputType]
|
|
117
|
+
Validation data, can be a path to a folder,
|
|
118
|
+
a list of paths, or a numpy array.
|
|
119
|
+
val_data_target : Optional[InputType]
|
|
120
|
+
Validation data target, can be a path to a folder,
|
|
121
|
+
a list of paths, or a numpy array.
|
|
122
|
+
pred_data : Optional[InputType]
|
|
123
|
+
Prediction data, can be a path to a folder, a list of paths,
|
|
124
|
+
or a numpy array.
|
|
125
|
+
pred_data_target : Optional[InputType]
|
|
126
|
+
Prediction data target, can be a path to a folder,
|
|
127
|
+
a list of paths, or a numpy array.
|
|
128
|
+
read_source_func : Optional[Callable]
|
|
129
|
+
Function to read the source data, by default None. Only used for `custom`
|
|
130
|
+
data type (see DataModel).
|
|
131
|
+
read_kwargs : Optional[dict[str, Any]]
|
|
132
|
+
The kwargs for the read source function.
|
|
133
|
+
image_stack_loader : Optional[ImageStackLoader]
|
|
134
|
+
The image stack loader.
|
|
135
|
+
image_stack_loader_kwargs : Optional[dict[str, Any]]
|
|
136
|
+
The image stack loader kwargs.
|
|
137
|
+
extension_filter : str
|
|
138
|
+
Filter for file extensions, by default "". Only used for `custom` data types
|
|
139
|
+
(see DataModel).
|
|
140
|
+
val_percentage : Optional[float]
|
|
141
|
+
Percentage of the training data to use for validation. Only
|
|
142
|
+
used if `val_data` is None.
|
|
143
|
+
val_minimum_split : int
|
|
144
|
+
Minimum number of patches or files to split from the training data for
|
|
145
|
+
validation, by default 5. Only used if `val_data` is None.
|
|
146
|
+
use_in_memory : bool
|
|
147
|
+
Load data in memory dataset if possible, by default True.
|
|
148
|
+
"""
|
|
149
|
+
super().__init__()
|
|
150
|
+
|
|
151
|
+
if train_data is None and val_data is None and pred_data is None:
|
|
152
|
+
raise ValueError(
|
|
153
|
+
"At least one of train_data, val_data or pred_data must be provided."
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
self.config: DataConfig = data_config
|
|
157
|
+
self.data_type: str = data_config.data_type
|
|
158
|
+
self.batch_size: int = data_config.batch_size
|
|
159
|
+
self.use_in_memory: bool = use_in_memory
|
|
160
|
+
self.extension_filter: str = extension_filter
|
|
161
|
+
self.read_source_func = read_source_func
|
|
162
|
+
self.read_kwargs = read_kwargs
|
|
163
|
+
self.image_stack_loader = image_stack_loader
|
|
164
|
+
self.image_stack_loader_kwargs = image_stack_loader_kwargs
|
|
165
|
+
|
|
166
|
+
# TODO: implement the validation split logic
|
|
167
|
+
self.val_percentage = val_percentage
|
|
168
|
+
self.val_minimum_split = val_minimum_split
|
|
169
|
+
if self.val_percentage is not None:
|
|
170
|
+
raise NotImplementedError("Validation split not implemented")
|
|
171
|
+
|
|
172
|
+
self.train_data, self.train_data_target = self._initialize_data_pair(
|
|
173
|
+
train_data, train_data_target
|
|
174
|
+
)
|
|
175
|
+
self.val_data, self.val_data_target = self._initialize_data_pair(
|
|
176
|
+
val_data, val_data_target
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# The pred_data_target can be needed to count metrics on the prediction
|
|
180
|
+
self.pred_data, self.pred_data_target = self._initialize_data_pair(
|
|
181
|
+
pred_data, pred_data_target
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def _validate_input_target_type_consistency(
|
|
185
|
+
self,
|
|
186
|
+
input_data: InputType,
|
|
187
|
+
target_data: Optional[InputType],
|
|
188
|
+
) -> None:
|
|
189
|
+
"""Validate if the input and target data types are consistent."""
|
|
190
|
+
if input_data is not None and target_data is not None:
|
|
191
|
+
if not isinstance(input_data, type(target_data)):
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f"Inputs for input and target must be of the same type or None. "
|
|
194
|
+
f"Got {type(input_data)} and {type(target_data)}."
|
|
195
|
+
)
|
|
196
|
+
if isinstance(input_data, list) and isinstance(target_data, list):
|
|
197
|
+
if len(input_data) != len(target_data):
|
|
198
|
+
raise ValueError(
|
|
199
|
+
f"Inputs and targets must have the same length. "
|
|
200
|
+
f"Got {len(input_data)} and {len(target_data)}."
|
|
201
|
+
)
|
|
202
|
+
if not isinstance(input_data[0], type(target_data[0])):
|
|
203
|
+
raise ValueError(
|
|
204
|
+
f"Inputs and targets must have the same type. "
|
|
205
|
+
f"Got {type(input_data[0])} and {type(target_data[0])}."
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def _list_files_in_directory(
|
|
209
|
+
self,
|
|
210
|
+
input_data,
|
|
211
|
+
target_data=None,
|
|
212
|
+
) -> tuple[list[Path], Optional[list[Path]]]:
|
|
213
|
+
"""List files from input and target directories."""
|
|
214
|
+
input_data = Path(input_data)
|
|
215
|
+
input_files = list_files(input_data, self.data_type, self.extension_filter)
|
|
216
|
+
if target_data is None:
|
|
217
|
+
return input_files, None
|
|
218
|
+
else:
|
|
219
|
+
target_data = Path(target_data)
|
|
220
|
+
target_files = list_files(
|
|
221
|
+
target_data, self.data_type, self.extension_filter
|
|
222
|
+
)
|
|
223
|
+
validate_source_target_files(input_files, target_files)
|
|
224
|
+
return input_files, target_files
|
|
225
|
+
|
|
226
|
+
def _convert_paths_to_pathlib(
|
|
227
|
+
self,
|
|
228
|
+
input_data,
|
|
229
|
+
target_data=None,
|
|
230
|
+
) -> tuple[list[Path], Optional[list[Path]]]:
|
|
231
|
+
"""Create a list of file paths from the input and target data."""
|
|
232
|
+
input_files = [
|
|
233
|
+
Path(item) if isinstance(item, str) else item for item in input_data
|
|
234
|
+
]
|
|
235
|
+
if target_data is None:
|
|
236
|
+
return input_files, None
|
|
237
|
+
else:
|
|
238
|
+
target_files = [
|
|
239
|
+
Path(item) if isinstance(item, str) else item for item in target_data
|
|
240
|
+
]
|
|
241
|
+
validate_source_target_files(input_files, target_files)
|
|
242
|
+
return input_files, target_files
|
|
243
|
+
|
|
244
|
+
def _validate_array_input(
|
|
245
|
+
self,
|
|
246
|
+
input_data: InputType,
|
|
247
|
+
target_data: Optional[InputType],
|
|
248
|
+
) -> tuple[Any, Any]:
|
|
249
|
+
"""Validate if the input data is a numpy array."""
|
|
250
|
+
if isinstance(input_data, np.ndarray):
|
|
251
|
+
input_array = [input_data]
|
|
252
|
+
target_array = [target_data] if target_data is not None else None
|
|
253
|
+
return input_array, target_array
|
|
254
|
+
elif isinstance(input_data, list):
|
|
255
|
+
return input_data, target_data
|
|
256
|
+
else:
|
|
257
|
+
raise ValueError(
|
|
258
|
+
f"Unsupported input type for {self.data_type}: {type(input_data)}"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def _validate_path_input(
|
|
262
|
+
self, input_data: InputType, target_data: Optional[InputType]
|
|
263
|
+
) -> tuple[list[Path], Optional[list[Path]]]:
|
|
264
|
+
if isinstance(input_data, (str, Path)):
|
|
265
|
+
if target_data is not None:
|
|
266
|
+
assert isinstance(target_data, (str, Path))
|
|
267
|
+
input_list, target_list = self._list_files_in_directory(
|
|
268
|
+
input_data, target_data
|
|
269
|
+
)
|
|
270
|
+
return input_list, target_list
|
|
271
|
+
elif isinstance(input_data, list):
|
|
272
|
+
if target_data is not None:
|
|
273
|
+
assert isinstance(target_data, list)
|
|
274
|
+
input_list, target_list = self._convert_paths_to_pathlib(
|
|
275
|
+
input_data, target_data
|
|
276
|
+
)
|
|
277
|
+
return input_list, target_list
|
|
278
|
+
else:
|
|
279
|
+
raise ValueError(
|
|
280
|
+
f"Unsupported input type for {self.data_type}: {type(input_data)}"
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def _validate_custom_input(self, input_data, target_data) -> tuple[Any, Any]:
|
|
284
|
+
if self.image_stack_loader is not None:
|
|
285
|
+
return input_data, target_data
|
|
286
|
+
elif isinstance(input_data, (str, Path)):
|
|
287
|
+
if target_data is not None:
|
|
288
|
+
assert isinstance(target_data, (str, Path))
|
|
289
|
+
input_list, target_list = self._list_files_in_directory(
|
|
290
|
+
input_data, target_data
|
|
291
|
+
)
|
|
292
|
+
return input_list, target_list
|
|
293
|
+
elif isinstance(input_data, list):
|
|
294
|
+
if isinstance(input_data[0], (str, Path)):
|
|
295
|
+
if target_data is not None:
|
|
296
|
+
assert isinstance(target_data, list)
|
|
297
|
+
input_list, target_list = self._convert_paths_to_pathlib(
|
|
298
|
+
input_data, target_data
|
|
299
|
+
)
|
|
300
|
+
return input_list, target_list
|
|
301
|
+
else:
|
|
302
|
+
raise ValueError(
|
|
303
|
+
f"If using {self.data_type}, pass a custom "
|
|
304
|
+
f"image_stack_loader or read_source_func"
|
|
305
|
+
)
|
|
306
|
+
return input_data, target_data
|
|
307
|
+
|
|
308
|
+
def _initialize_data_pair(
|
|
309
|
+
self,
|
|
310
|
+
input_data: Optional[InputType],
|
|
311
|
+
target_data: Optional[InputType],
|
|
312
|
+
) -> tuple[Any, Any]:
|
|
313
|
+
"""
|
|
314
|
+
Initialize a pair of input and target data.
|
|
315
|
+
|
|
316
|
+
Returns
|
|
317
|
+
-------
|
|
318
|
+
tuple[Union[list[NDArray], list[Path]],
|
|
319
|
+
Optional[Union[list[NDArray], list[Path]]]]
|
|
320
|
+
A tuple containing the initialized input and target data.
|
|
321
|
+
For file paths, returns lists of Path objects.
|
|
322
|
+
For numpy arrays, returns the arrays directly.
|
|
323
|
+
"""
|
|
324
|
+
if input_data is None:
|
|
325
|
+
return None, None
|
|
326
|
+
|
|
327
|
+
self._validate_input_target_type_consistency(input_data, target_data)
|
|
328
|
+
|
|
329
|
+
if self.data_type == SupportedData.ARRAY:
|
|
330
|
+
if isinstance(input_data, np.ndarray):
|
|
331
|
+
return self._validate_array_input(input_data, target_data)
|
|
332
|
+
elif isinstance(input_data, list):
|
|
333
|
+
if isinstance(input_data[0], np.ndarray):
|
|
334
|
+
return self._validate_array_input(input_data, target_data)
|
|
335
|
+
else:
|
|
336
|
+
raise ValueError(
|
|
337
|
+
f"Unsupported input type for {self.data_type}: "
|
|
338
|
+
f"{type(input_data[0])}"
|
|
339
|
+
)
|
|
340
|
+
else:
|
|
341
|
+
raise ValueError(
|
|
342
|
+
f"Unsupported input type for {self.data_type}: {type(input_data)}"
|
|
343
|
+
)
|
|
344
|
+
elif self.data_type == SupportedData.TIFF:
|
|
345
|
+
if isinstance(input_data, (str, Path)):
|
|
346
|
+
return self._validate_path_input(input_data, target_data)
|
|
347
|
+
elif isinstance(input_data, list):
|
|
348
|
+
if isinstance(input_data[0], (Path, str)):
|
|
349
|
+
return self._validate_path_input(input_data, target_data)
|
|
350
|
+
else:
|
|
351
|
+
raise ValueError(
|
|
352
|
+
f"Unsupported input type for {self.data_type}: "
|
|
353
|
+
f"{type(input_data[0])}"
|
|
354
|
+
)
|
|
355
|
+
else:
|
|
356
|
+
raise ValueError(
|
|
357
|
+
f"Unsupported input type for {self.data_type}: {type(input_data)}"
|
|
358
|
+
)
|
|
359
|
+
elif self.data_type == SupportedData.CUSTOM:
|
|
360
|
+
return self._validate_custom_input(input_data, target_data)
|
|
361
|
+
else:
|
|
362
|
+
raise NotImplementedError(f"Unsupported data type: {self.data_type}")
|
|
363
|
+
|
|
364
|
+
def setup(self, stage: str) -> None:
|
|
365
|
+
"""
|
|
366
|
+
Setup datasets.
|
|
367
|
+
|
|
368
|
+
Lightning hook that is called at the beginning of fit (train + validate),
|
|
369
|
+
validate, test, or predict. Creates the datasets for a given stage.
|
|
370
|
+
|
|
371
|
+
Parameters
|
|
372
|
+
----------
|
|
373
|
+
stage : str
|
|
374
|
+
The stage to set up datasets for.
|
|
375
|
+
Is either 'fit', 'validate', 'test', or 'predict'.
|
|
376
|
+
|
|
377
|
+
Raises
|
|
378
|
+
------
|
|
379
|
+
NotImplementedError
|
|
380
|
+
If stage is not one of "fit", "validate" or "predict".
|
|
381
|
+
"""
|
|
382
|
+
if stage == "fit":
|
|
383
|
+
self.train_dataset = create_dataset(
|
|
384
|
+
mode=Mode.TRAINING,
|
|
385
|
+
inputs=self.train_data,
|
|
386
|
+
targets=self.train_data_target,
|
|
387
|
+
config=self.config,
|
|
388
|
+
in_memory=self.use_in_memory,
|
|
389
|
+
read_func=self.read_source_func,
|
|
390
|
+
read_kwargs=self.read_kwargs,
|
|
391
|
+
image_stack_loader=self.image_stack_loader,
|
|
392
|
+
image_stack_loader_kwargs=self.image_stack_loader_kwargs,
|
|
393
|
+
)
|
|
394
|
+
# TODO: ugly, need to find a better solution
|
|
395
|
+
self.stats = self.train_dataset.input_stats
|
|
396
|
+
self.config.set_means_and_stds(
|
|
397
|
+
self.train_dataset.input_stats.means,
|
|
398
|
+
self.train_dataset.input_stats.stds,
|
|
399
|
+
self.train_dataset.target_stats.means,
|
|
400
|
+
self.train_dataset.target_stats.stds,
|
|
401
|
+
)
|
|
402
|
+
self.val_dataset = create_dataset(
|
|
403
|
+
mode=Mode.VALIDATING,
|
|
404
|
+
inputs=self.val_data,
|
|
405
|
+
targets=self.val_data_target,
|
|
406
|
+
config=self.config,
|
|
407
|
+
in_memory=self.use_in_memory,
|
|
408
|
+
read_func=self.read_source_func,
|
|
409
|
+
read_kwargs=self.read_kwargs,
|
|
410
|
+
image_stack_loader=self.image_stack_loader,
|
|
411
|
+
image_stack_loader_kwargs=self.image_stack_loader_kwargs,
|
|
412
|
+
)
|
|
413
|
+
elif stage == "validate":
|
|
414
|
+
self.val_dataset = create_dataset(
|
|
415
|
+
mode=Mode.VALIDATING,
|
|
416
|
+
inputs=self.val_data,
|
|
417
|
+
targets=self.val_data_target,
|
|
418
|
+
config=self.config,
|
|
419
|
+
in_memory=self.use_in_memory,
|
|
420
|
+
read_func=self.read_source_func,
|
|
421
|
+
read_kwargs=self.read_kwargs,
|
|
422
|
+
image_stack_loader=self.image_stack_loader,
|
|
423
|
+
image_stack_loader_kwargs=self.image_stack_loader_kwargs,
|
|
424
|
+
)
|
|
425
|
+
self.stats = self.val_dataset.input_stats
|
|
426
|
+
elif stage == "predict":
|
|
427
|
+
self.predict_dataset = create_dataset(
|
|
428
|
+
mode=Mode.PREDICTING,
|
|
429
|
+
inputs=self.pred_data,
|
|
430
|
+
targets=self.pred_data_target,
|
|
431
|
+
config=self.config,
|
|
432
|
+
in_memory=self.use_in_memory,
|
|
433
|
+
read_func=self.read_source_func,
|
|
434
|
+
read_kwargs=self.read_kwargs,
|
|
435
|
+
image_stack_loader=self.image_stack_loader,
|
|
436
|
+
image_stack_loader_kwargs=self.image_stack_loader_kwargs,
|
|
437
|
+
)
|
|
438
|
+
self.stats = self.predict_dataset.input_stats
|
|
439
|
+
else:
|
|
440
|
+
raise NotImplementedError(f"Stage {stage} not implemented")
|
|
441
|
+
|
|
442
|
+
def train_dataloader(self) -> DataLoader:
|
|
443
|
+
"""
|
|
444
|
+
Create a dataloader for training.
|
|
445
|
+
|
|
446
|
+
Returns
|
|
447
|
+
-------
|
|
448
|
+
DataLoader
|
|
449
|
+
Training dataloader.
|
|
450
|
+
"""
|
|
451
|
+
return DataLoader(
|
|
452
|
+
self.train_dataset,
|
|
453
|
+
batch_size=self.batch_size,
|
|
454
|
+
collate_fn=default_collate,
|
|
455
|
+
**self.config.train_dataloader_params,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
def val_dataloader(self) -> DataLoader:
|
|
459
|
+
"""
|
|
460
|
+
Create a dataloader for validation.
|
|
461
|
+
|
|
462
|
+
Returns
|
|
463
|
+
-------
|
|
464
|
+
DataLoader
|
|
465
|
+
Validation dataloader.
|
|
466
|
+
"""
|
|
467
|
+
return DataLoader(
|
|
468
|
+
self.val_dataset,
|
|
469
|
+
batch_size=self.batch_size,
|
|
470
|
+
collate_fn=default_collate,
|
|
471
|
+
**self.config.val_dataloader_params,
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
def predict_dataloader(self) -> DataLoader:
|
|
475
|
+
"""
|
|
476
|
+
Create a dataloader for prediction.
|
|
477
|
+
|
|
478
|
+
Returns
|
|
479
|
+
-------
|
|
480
|
+
DataLoader
|
|
481
|
+
Prediction dataloader.
|
|
482
|
+
"""
|
|
483
|
+
return DataLoader(
|
|
484
|
+
self.predict_dataset,
|
|
485
|
+
batch_size=self.batch_size,
|
|
486
|
+
collate_fn=default_collate,
|
|
487
|
+
# TODO: set appropriate key for params once config changes are merged
|
|
488
|
+
)
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import Any, Callable, Union
|
|
2
|
+
|
|
3
|
+
from careamics.config.algorithms.care_algorithm_model import CAREAlgorithm
|
|
4
|
+
from careamics.config.algorithms.n2n_algorithm_model import N2NAlgorithm
|
|
5
|
+
from careamics.config.support import SupportedLoss
|
|
6
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
7
|
+
from careamics.losses import mae_loss, mse_loss
|
|
8
|
+
from careamics.utils.logging import get_logger
|
|
9
|
+
|
|
10
|
+
from .unet_module import UnetModule
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CAREModule(UnetModule):
|
|
16
|
+
"""CAREamics PyTorch Lightning module for CARE algorithm."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, algorithm_config: Union[CAREAlgorithm, dict]) -> None:
|
|
19
|
+
super().__init__(algorithm_config)
|
|
20
|
+
assert isinstance(
|
|
21
|
+
algorithm_config, (CAREAlgorithm, N2NAlgorithm)
|
|
22
|
+
), "algorithm_config must be a CAREAlgorithm or a N2NAlgorithm"
|
|
23
|
+
loss = algorithm_config.loss
|
|
24
|
+
if loss == SupportedLoss.MAE:
|
|
25
|
+
self.loss_func: Callable = mae_loss
|
|
26
|
+
elif loss == SupportedLoss.MSE:
|
|
27
|
+
self.loss_func = mse_loss
|
|
28
|
+
else:
|
|
29
|
+
raise ValueError(f"Unsupported loss for Care: {loss}")
|
|
30
|
+
|
|
31
|
+
def training_step(
|
|
32
|
+
self,
|
|
33
|
+
batch: tuple[ImageRegionData, ImageRegionData],
|
|
34
|
+
batch_idx: Any,
|
|
35
|
+
) -> Any:
|
|
36
|
+
"""Training step for CARE module."""
|
|
37
|
+
# TODO: add validation to determine if target is initialized
|
|
38
|
+
x, target = batch[0], batch[1]
|
|
39
|
+
|
|
40
|
+
prediction = self.model(x.data)
|
|
41
|
+
loss = self.loss_func(prediction, target.data)
|
|
42
|
+
|
|
43
|
+
self._log_training_stats(loss, batch_size=x.data.shape[0])
|
|
44
|
+
|
|
45
|
+
return loss
|
|
46
|
+
|
|
47
|
+
def validation_step(
|
|
48
|
+
self,
|
|
49
|
+
batch: tuple[ImageRegionData, ImageRegionData],
|
|
50
|
+
batch_idx: Any,
|
|
51
|
+
) -> None:
|
|
52
|
+
"""Validation step for CARE module."""
|
|
53
|
+
x, target = batch[0], batch[1]
|
|
54
|
+
|
|
55
|
+
prediction = self.model(x.data)
|
|
56
|
+
val_loss = self.loss_func(prediction, target.data)
|
|
57
|
+
self.metrics(prediction, target.data)
|
|
58
|
+
self._log_validation_stats(val_loss, batch_size=x.data.shape[0])
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from typing import Any, Union
|
|
2
|
+
|
|
3
|
+
from careamics.config import (
|
|
4
|
+
N2VAlgorithm,
|
|
5
|
+
)
|
|
6
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
7
|
+
from careamics.losses import n2v_loss
|
|
8
|
+
from careamics.transforms import N2VManipulateTorch
|
|
9
|
+
from careamics.utils.logging import get_logger
|
|
10
|
+
|
|
11
|
+
from .unet_module import UnetModule
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class N2VModule(UnetModule):
|
|
17
|
+
"""CAREamics PyTorch Lightning module for N2V algorithm."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, algorithm_config: Union[N2VAlgorithm, dict]) -> None:
|
|
20
|
+
super().__init__(algorithm_config)
|
|
21
|
+
|
|
22
|
+
assert isinstance(
|
|
23
|
+
algorithm_config, N2VAlgorithm
|
|
24
|
+
), "algorithm_config must be a N2VAlgorithm"
|
|
25
|
+
|
|
26
|
+
self.n2v_manipulate = N2VManipulateTorch(
|
|
27
|
+
n2v_manipulate_config=algorithm_config.n2v_config
|
|
28
|
+
)
|
|
29
|
+
self.loss_func = n2v_loss
|
|
30
|
+
|
|
31
|
+
def _load_best_checkpoint(self) -> None:
|
|
32
|
+
logger.warning(
|
|
33
|
+
"Loading best checkpoint for N2V model. Note that for N2V, "
|
|
34
|
+
"the checkpoint with the best validation metrics may not necessarily "
|
|
35
|
+
"have the best denoising performance."
|
|
36
|
+
)
|
|
37
|
+
super()._load_best_checkpoint()
|
|
38
|
+
|
|
39
|
+
def training_step(
|
|
40
|
+
self,
|
|
41
|
+
batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
|
|
42
|
+
batch_idx: Any,
|
|
43
|
+
) -> Any:
|
|
44
|
+
"""Training step for N2V model."""
|
|
45
|
+
x = batch[0]
|
|
46
|
+
x_masked, x_original, mask = self.n2v_manipulate(x.data)
|
|
47
|
+
prediction = self.model(x_masked)
|
|
48
|
+
loss = self.loss_func(prediction, x_original, mask)
|
|
49
|
+
|
|
50
|
+
self._log_training_stats(loss, batch_size=x.data.shape[0])
|
|
51
|
+
|
|
52
|
+
return loss
|
|
53
|
+
|
|
54
|
+
def validation_step(
|
|
55
|
+
self,
|
|
56
|
+
batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
|
|
57
|
+
batch_idx: Any,
|
|
58
|
+
) -> None:
|
|
59
|
+
"""Validation step for N2V model."""
|
|
60
|
+
x = batch[0]
|
|
61
|
+
|
|
62
|
+
x_masked, x_original, mask = self.n2v_manipulate(x.data)
|
|
63
|
+
prediction = self.model(x_masked)
|
|
64
|
+
|
|
65
|
+
val_loss = self.loss_func(prediction, x_original, mask)
|
|
66
|
+
self.metrics(prediction, x_original)
|
|
67
|
+
self._log_validation_stats(val_loss, batch_size=x.data.shape[0])
|