careamics 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (53) hide show
  1. careamics/careamist.py +20 -4
  2. careamics/config/configuration.py +10 -5
  3. careamics/config/data/data_model.py +38 -1
  4. careamics/config/optimizer_models.py +1 -3
  5. careamics/config/training_model.py +0 -2
  6. careamics/dataset_ng/README.md +212 -0
  7. careamics/dataset_ng/dataset.py +233 -0
  8. careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
  9. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  10. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
  11. careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
  12. careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
  13. careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
  14. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
  15. careamics/dataset_ng/factory.py +408 -0
  16. careamics/dataset_ng/legacy_interoperability.py +168 -0
  17. careamics/dataset_ng/patch_extractor/__init__.py +3 -8
  18. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
  19. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
  20. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
  21. careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
  22. careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
  23. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
  24. careamics/dataset_ng/patching_strategies/__init__.py +6 -1
  25. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
  26. careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
  27. careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
  28. careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
  29. careamics/lightning/dataset_ng/data_module.py +488 -0
  30. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  31. careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
  32. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
  33. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
  34. careamics/lightning/lightning_module.py +3 -0
  35. careamics/lvae_training/dataset/__init__.py +8 -3
  36. careamics/lvae_training/dataset/config.py +3 -3
  37. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  38. careamics/lvae_training/dataset/multich_dataset.py +46 -17
  39. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  40. careamics/lvae_training/dataset/types.py +3 -3
  41. careamics/lvae_training/dataset/utils/index_manager.py +259 -0
  42. careamics/lvae_training/eval_utils.py +93 -3
  43. careamics/transforms/compose.py +1 -0
  44. careamics/transforms/normalize.py +18 -7
  45. careamics/utils/lightning_utils.py +25 -11
  46. {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
  47. {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/RECORD +50 -35
  48. careamics/dataset_ng/dataset/__init__.py +0 -3
  49. careamics/dataset_ng/dataset/dataset.py +0 -184
  50. careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
  51. {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
  52. {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
  53. {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,488 @@
1
+ from pathlib import Path
2
+ from typing import Any, Callable, Optional, Union, overload
3
+
4
+ import numpy as np
5
+ import pytorch_lightning as L
6
+ from numpy.typing import NDArray
7
+ from torch.utils.data import DataLoader
8
+ from torch.utils.data._utils.collate import default_collate
9
+
10
+ from careamics.config.data import DataConfig
11
+ from careamics.config.support import SupportedData
12
+ from careamics.dataset.dataset_utils import list_files, validate_source_target_files
13
+ from careamics.dataset_ng.dataset import Mode
14
+ from careamics.dataset_ng.factory import create_dataset
15
+ from careamics.dataset_ng.patch_extractor import ImageStackLoader
16
+ from careamics.utils import get_logger
17
+
18
+ logger = get_logger(__name__)
19
+
20
+ ItemType = Union[Path, str, NDArray[Any]]
21
+ InputType = Union[ItemType, list[ItemType], None]
22
+
23
+
24
+ class CareamicsDataModule(L.LightningDataModule):
25
+ """Data module for Careamics dataset."""
26
+
27
+ @overload
28
+ def __init__(
29
+ self,
30
+ data_config: DataConfig,
31
+ *,
32
+ train_data: Optional[InputType] = None,
33
+ train_data_target: Optional[InputType] = None,
34
+ val_data: Optional[InputType] = None,
35
+ val_data_target: Optional[InputType] = None,
36
+ pred_data: Optional[InputType] = None,
37
+ pred_data_target: Optional[InputType] = None,
38
+ extension_filter: str = "",
39
+ val_percentage: Optional[float] = None,
40
+ val_minimum_split: int = 5,
41
+ use_in_memory: bool = True,
42
+ ) -> None: ...
43
+
44
+ @overload
45
+ def __init__(
46
+ self,
47
+ data_config: DataConfig,
48
+ *,
49
+ train_data: Optional[InputType] = None,
50
+ train_data_target: Optional[InputType] = None,
51
+ val_data: Optional[InputType] = None,
52
+ val_data_target: Optional[InputType] = None,
53
+ pred_data: Optional[InputType] = None,
54
+ pred_data_target: Optional[InputType] = None,
55
+ read_source_func: Callable,
56
+ read_kwargs: Optional[dict[str, Any]] = None,
57
+ extension_filter: str = "",
58
+ val_percentage: Optional[float] = None,
59
+ val_minimum_split: int = 5,
60
+ use_in_memory: bool = True,
61
+ ) -> None: ...
62
+
63
+ @overload
64
+ def __init__(
65
+ self,
66
+ data_config: DataConfig,
67
+ *,
68
+ train_data: Optional[Any] = None,
69
+ train_data_target: Optional[Any] = None,
70
+ val_data: Optional[Any] = None,
71
+ val_data_target: Optional[Any] = None,
72
+ pred_data: Optional[Any] = None,
73
+ pred_data_target: Optional[Any] = None,
74
+ image_stack_loader: ImageStackLoader,
75
+ image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
76
+ extension_filter: str = "",
77
+ val_percentage: Optional[float] = None,
78
+ val_minimum_split: int = 5,
79
+ use_in_memory: bool = True,
80
+ ) -> None: ...
81
+
82
+ def __init__(
83
+ self,
84
+ data_config: DataConfig,
85
+ *,
86
+ train_data: Optional[Any] = None,
87
+ train_data_target: Optional[Any] = None,
88
+ val_data: Optional[Any] = None,
89
+ val_data_target: Optional[Any] = None,
90
+ pred_data: Optional[Any] = None,
91
+ pred_data_target: Optional[Any] = None,
92
+ read_source_func: Optional[Callable] = None,
93
+ read_kwargs: Optional[dict[str, Any]] = None,
94
+ image_stack_loader: Optional[ImageStackLoader] = None,
95
+ image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
96
+ extension_filter: str = "",
97
+ val_percentage: Optional[float] = None,
98
+ val_minimum_split: int = 5,
99
+ use_in_memory: bool = True,
100
+ ) -> None:
101
+ """
102
+ Data module for Careamics dataset initialization.
103
+
104
+ Create a lightning datamodule that handles creating datasets for training,
105
+ validation, and prediction.
106
+
107
+ Parameters
108
+ ----------
109
+ data_config : DataConfig
110
+ Pydantic model for CAREamics data configuration.
111
+ train_data : Optional[InputType]
112
+ Training data, can be a path to a folder, a list of paths, or a numpy array.
113
+ train_data_target : Optional[InputType]
114
+ Training data target, can be a path to a folder,
115
+ a list of paths, or a numpy array.
116
+ val_data : Optional[InputType]
117
+ Validation data, can be a path to a folder,
118
+ a list of paths, or a numpy array.
119
+ val_data_target : Optional[InputType]
120
+ Validation data target, can be a path to a folder,
121
+ a list of paths, or a numpy array.
122
+ pred_data : Optional[InputType]
123
+ Prediction data, can be a path to a folder, a list of paths,
124
+ or a numpy array.
125
+ pred_data_target : Optional[InputType]
126
+ Prediction data target, can be a path to a folder,
127
+ a list of paths, or a numpy array.
128
+ read_source_func : Optional[Callable]
129
+ Function to read the source data, by default None. Only used for `custom`
130
+ data type (see DataModel).
131
+ read_kwargs : Optional[dict[str, Any]]
132
+ The kwargs for the read source function.
133
+ image_stack_loader : Optional[ImageStackLoader]
134
+ The image stack loader.
135
+ image_stack_loader_kwargs : Optional[dict[str, Any]]
136
+ The image stack loader kwargs.
137
+ extension_filter : str
138
+ Filter for file extensions, by default "". Only used for `custom` data types
139
+ (see DataModel).
140
+ val_percentage : Optional[float]
141
+ Percentage of the training data to use for validation. Only
142
+ used if `val_data` is None.
143
+ val_minimum_split : int
144
+ Minimum number of patches or files to split from the training data for
145
+ validation, by default 5. Only used if `val_data` is None.
146
+ use_in_memory : bool
147
+ Load data in memory dataset if possible, by default True.
148
+ """
149
+ super().__init__()
150
+
151
+ if train_data is None and val_data is None and pred_data is None:
152
+ raise ValueError(
153
+ "At least one of train_data, val_data or pred_data must be provided."
154
+ )
155
+
156
+ self.config: DataConfig = data_config
157
+ self.data_type: str = data_config.data_type
158
+ self.batch_size: int = data_config.batch_size
159
+ self.use_in_memory: bool = use_in_memory
160
+ self.extension_filter: str = extension_filter
161
+ self.read_source_func = read_source_func
162
+ self.read_kwargs = read_kwargs
163
+ self.image_stack_loader = image_stack_loader
164
+ self.image_stack_loader_kwargs = image_stack_loader_kwargs
165
+
166
+ # TODO: implement the validation split logic
167
+ self.val_percentage = val_percentage
168
+ self.val_minimum_split = val_minimum_split
169
+ if self.val_percentage is not None:
170
+ raise NotImplementedError("Validation split not implemented")
171
+
172
+ self.train_data, self.train_data_target = self._initialize_data_pair(
173
+ train_data, train_data_target
174
+ )
175
+ self.val_data, self.val_data_target = self._initialize_data_pair(
176
+ val_data, val_data_target
177
+ )
178
+
179
+ # The pred_data_target can be needed to count metrics on the prediction
180
+ self.pred_data, self.pred_data_target = self._initialize_data_pair(
181
+ pred_data, pred_data_target
182
+ )
183
+
184
+ def _validate_input_target_type_consistency(
185
+ self,
186
+ input_data: InputType,
187
+ target_data: Optional[InputType],
188
+ ) -> None:
189
+ """Validate if the input and target data types are consistent."""
190
+ if input_data is not None and target_data is not None:
191
+ if not isinstance(input_data, type(target_data)):
192
+ raise ValueError(
193
+ f"Inputs for input and target must be of the same type or None. "
194
+ f"Got {type(input_data)} and {type(target_data)}."
195
+ )
196
+ if isinstance(input_data, list) and isinstance(target_data, list):
197
+ if len(input_data) != len(target_data):
198
+ raise ValueError(
199
+ f"Inputs and targets must have the same length. "
200
+ f"Got {len(input_data)} and {len(target_data)}."
201
+ )
202
+ if not isinstance(input_data[0], type(target_data[0])):
203
+ raise ValueError(
204
+ f"Inputs and targets must have the same type. "
205
+ f"Got {type(input_data[0])} and {type(target_data[0])}."
206
+ )
207
+
208
+ def _list_files_in_directory(
209
+ self,
210
+ input_data,
211
+ target_data=None,
212
+ ) -> tuple[list[Path], Optional[list[Path]]]:
213
+ """List files from input and target directories."""
214
+ input_data = Path(input_data)
215
+ input_files = list_files(input_data, self.data_type, self.extension_filter)
216
+ if target_data is None:
217
+ return input_files, None
218
+ else:
219
+ target_data = Path(target_data)
220
+ target_files = list_files(
221
+ target_data, self.data_type, self.extension_filter
222
+ )
223
+ validate_source_target_files(input_files, target_files)
224
+ return input_files, target_files
225
+
226
+ def _convert_paths_to_pathlib(
227
+ self,
228
+ input_data,
229
+ target_data=None,
230
+ ) -> tuple[list[Path], Optional[list[Path]]]:
231
+ """Create a list of file paths from the input and target data."""
232
+ input_files = [
233
+ Path(item) if isinstance(item, str) else item for item in input_data
234
+ ]
235
+ if target_data is None:
236
+ return input_files, None
237
+ else:
238
+ target_files = [
239
+ Path(item) if isinstance(item, str) else item for item in target_data
240
+ ]
241
+ validate_source_target_files(input_files, target_files)
242
+ return input_files, target_files
243
+
244
+ def _validate_array_input(
245
+ self,
246
+ input_data: InputType,
247
+ target_data: Optional[InputType],
248
+ ) -> tuple[Any, Any]:
249
+ """Validate if the input data is a numpy array."""
250
+ if isinstance(input_data, np.ndarray):
251
+ input_array = [input_data]
252
+ target_array = [target_data] if target_data is not None else None
253
+ return input_array, target_array
254
+ elif isinstance(input_data, list):
255
+ return input_data, target_data
256
+ else:
257
+ raise ValueError(
258
+ f"Unsupported input type for {self.data_type}: {type(input_data)}"
259
+ )
260
+
261
+ def _validate_path_input(
262
+ self, input_data: InputType, target_data: Optional[InputType]
263
+ ) -> tuple[list[Path], Optional[list[Path]]]:
264
+ if isinstance(input_data, (str, Path)):
265
+ if target_data is not None:
266
+ assert isinstance(target_data, (str, Path))
267
+ input_list, target_list = self._list_files_in_directory(
268
+ input_data, target_data
269
+ )
270
+ return input_list, target_list
271
+ elif isinstance(input_data, list):
272
+ if target_data is not None:
273
+ assert isinstance(target_data, list)
274
+ input_list, target_list = self._convert_paths_to_pathlib(
275
+ input_data, target_data
276
+ )
277
+ return input_list, target_list
278
+ else:
279
+ raise ValueError(
280
+ f"Unsupported input type for {self.data_type}: {type(input_data)}"
281
+ )
282
+
283
+ def _validate_custom_input(self, input_data, target_data) -> tuple[Any, Any]:
284
+ if self.image_stack_loader is not None:
285
+ return input_data, target_data
286
+ elif isinstance(input_data, (str, Path)):
287
+ if target_data is not None:
288
+ assert isinstance(target_data, (str, Path))
289
+ input_list, target_list = self._list_files_in_directory(
290
+ input_data, target_data
291
+ )
292
+ return input_list, target_list
293
+ elif isinstance(input_data, list):
294
+ if isinstance(input_data[0], (str, Path)):
295
+ if target_data is not None:
296
+ assert isinstance(target_data, list)
297
+ input_list, target_list = self._convert_paths_to_pathlib(
298
+ input_data, target_data
299
+ )
300
+ return input_list, target_list
301
+ else:
302
+ raise ValueError(
303
+ f"If using {self.data_type}, pass a custom "
304
+ f"image_stack_loader or read_source_func"
305
+ )
306
+ return input_data, target_data
307
+
308
+ def _initialize_data_pair(
309
+ self,
310
+ input_data: Optional[InputType],
311
+ target_data: Optional[InputType],
312
+ ) -> tuple[Any, Any]:
313
+ """
314
+ Initialize a pair of input and target data.
315
+
316
+ Returns
317
+ -------
318
+ tuple[Union[list[NDArray], list[Path]],
319
+ Optional[Union[list[NDArray], list[Path]]]]
320
+ A tuple containing the initialized input and target data.
321
+ For file paths, returns lists of Path objects.
322
+ For numpy arrays, returns the arrays directly.
323
+ """
324
+ if input_data is None:
325
+ return None, None
326
+
327
+ self._validate_input_target_type_consistency(input_data, target_data)
328
+
329
+ if self.data_type == SupportedData.ARRAY:
330
+ if isinstance(input_data, np.ndarray):
331
+ return self._validate_array_input(input_data, target_data)
332
+ elif isinstance(input_data, list):
333
+ if isinstance(input_data[0], np.ndarray):
334
+ return self._validate_array_input(input_data, target_data)
335
+ else:
336
+ raise ValueError(
337
+ f"Unsupported input type for {self.data_type}: "
338
+ f"{type(input_data[0])}"
339
+ )
340
+ else:
341
+ raise ValueError(
342
+ f"Unsupported input type for {self.data_type}: {type(input_data)}"
343
+ )
344
+ elif self.data_type == SupportedData.TIFF:
345
+ if isinstance(input_data, (str, Path)):
346
+ return self._validate_path_input(input_data, target_data)
347
+ elif isinstance(input_data, list):
348
+ if isinstance(input_data[0], (Path, str)):
349
+ return self._validate_path_input(input_data, target_data)
350
+ else:
351
+ raise ValueError(
352
+ f"Unsupported input type for {self.data_type}: "
353
+ f"{type(input_data[0])}"
354
+ )
355
+ else:
356
+ raise ValueError(
357
+ f"Unsupported input type for {self.data_type}: {type(input_data)}"
358
+ )
359
+ elif self.data_type == SupportedData.CUSTOM:
360
+ return self._validate_custom_input(input_data, target_data)
361
+ else:
362
+ raise NotImplementedError(f"Unsupported data type: {self.data_type}")
363
+
364
+ def setup(self, stage: str) -> None:
365
+ """
366
+ Setup datasets.
367
+
368
+ Lightning hook that is called at the beginning of fit (train + validate),
369
+ validate, test, or predict. Creates the datasets for a given stage.
370
+
371
+ Parameters
372
+ ----------
373
+ stage : str
374
+ The stage to set up datasets for.
375
+ Is either 'fit', 'validate', 'test', or 'predict'.
376
+
377
+ Raises
378
+ ------
379
+ NotImplementedError
380
+ If stage is not one of "fit", "validate" or "predict".
381
+ """
382
+ if stage == "fit":
383
+ self.train_dataset = create_dataset(
384
+ mode=Mode.TRAINING,
385
+ inputs=self.train_data,
386
+ targets=self.train_data_target,
387
+ config=self.config,
388
+ in_memory=self.use_in_memory,
389
+ read_func=self.read_source_func,
390
+ read_kwargs=self.read_kwargs,
391
+ image_stack_loader=self.image_stack_loader,
392
+ image_stack_loader_kwargs=self.image_stack_loader_kwargs,
393
+ )
394
+ # TODO: ugly, need to find a better solution
395
+ self.stats = self.train_dataset.input_stats
396
+ self.config.set_means_and_stds(
397
+ self.train_dataset.input_stats.means,
398
+ self.train_dataset.input_stats.stds,
399
+ self.train_dataset.target_stats.means,
400
+ self.train_dataset.target_stats.stds,
401
+ )
402
+ self.val_dataset = create_dataset(
403
+ mode=Mode.VALIDATING,
404
+ inputs=self.val_data,
405
+ targets=self.val_data_target,
406
+ config=self.config,
407
+ in_memory=self.use_in_memory,
408
+ read_func=self.read_source_func,
409
+ read_kwargs=self.read_kwargs,
410
+ image_stack_loader=self.image_stack_loader,
411
+ image_stack_loader_kwargs=self.image_stack_loader_kwargs,
412
+ )
413
+ elif stage == "validate":
414
+ self.val_dataset = create_dataset(
415
+ mode=Mode.VALIDATING,
416
+ inputs=self.val_data,
417
+ targets=self.val_data_target,
418
+ config=self.config,
419
+ in_memory=self.use_in_memory,
420
+ read_func=self.read_source_func,
421
+ read_kwargs=self.read_kwargs,
422
+ image_stack_loader=self.image_stack_loader,
423
+ image_stack_loader_kwargs=self.image_stack_loader_kwargs,
424
+ )
425
+ self.stats = self.val_dataset.input_stats
426
+ elif stage == "predict":
427
+ self.predict_dataset = create_dataset(
428
+ mode=Mode.PREDICTING,
429
+ inputs=self.pred_data,
430
+ targets=self.pred_data_target,
431
+ config=self.config,
432
+ in_memory=self.use_in_memory,
433
+ read_func=self.read_source_func,
434
+ read_kwargs=self.read_kwargs,
435
+ image_stack_loader=self.image_stack_loader,
436
+ image_stack_loader_kwargs=self.image_stack_loader_kwargs,
437
+ )
438
+ self.stats = self.predict_dataset.input_stats
439
+ else:
440
+ raise NotImplementedError(f"Stage {stage} not implemented")
441
+
442
+ def train_dataloader(self) -> DataLoader:
443
+ """
444
+ Create a dataloader for training.
445
+
446
+ Returns
447
+ -------
448
+ DataLoader
449
+ Training dataloader.
450
+ """
451
+ return DataLoader(
452
+ self.train_dataset,
453
+ batch_size=self.batch_size,
454
+ collate_fn=default_collate,
455
+ **self.config.train_dataloader_params,
456
+ )
457
+
458
+ def val_dataloader(self) -> DataLoader:
459
+ """
460
+ Create a dataloader for validation.
461
+
462
+ Returns
463
+ -------
464
+ DataLoader
465
+ Validation dataloader.
466
+ """
467
+ return DataLoader(
468
+ self.val_dataset,
469
+ batch_size=self.batch_size,
470
+ collate_fn=default_collate,
471
+ **self.config.val_dataloader_params,
472
+ )
473
+
474
+ def predict_dataloader(self) -> DataLoader:
475
+ """
476
+ Create a dataloader for prediction.
477
+
478
+ Returns
479
+ -------
480
+ DataLoader
481
+ Prediction dataloader.
482
+ """
483
+ return DataLoader(
484
+ self.predict_dataset,
485
+ batch_size=self.batch_size,
486
+ collate_fn=default_collate,
487
+ # TODO: set appropriate key for params once config changes are merged
488
+ )
@@ -0,0 +1,9 @@
1
+ """CAREamics PyTorch Lightning modules."""
2
+
3
+ from .care_module import CAREModule
4
+ from .n2v_module import N2VModule
5
+
6
+ __all__ = [
7
+ "CAREModule",
8
+ "N2VModule",
9
+ ]
@@ -0,0 +1,58 @@
1
+ from typing import Any, Callable, Union
2
+
3
+ from careamics.config.algorithms.care_algorithm_model import CAREAlgorithm
4
+ from careamics.config.algorithms.n2n_algorithm_model import N2NAlgorithm
5
+ from careamics.config.support import SupportedLoss
6
+ from careamics.dataset_ng.dataset import ImageRegionData
7
+ from careamics.losses import mae_loss, mse_loss
8
+ from careamics.utils.logging import get_logger
9
+
10
+ from .unet_module import UnetModule
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ class CAREModule(UnetModule):
16
+ """CAREamics PyTorch Lightning module for CARE algorithm."""
17
+
18
+ def __init__(self, algorithm_config: Union[CAREAlgorithm, dict]) -> None:
19
+ super().__init__(algorithm_config)
20
+ assert isinstance(
21
+ algorithm_config, (CAREAlgorithm, N2NAlgorithm)
22
+ ), "algorithm_config must be a CAREAlgorithm or a N2NAlgorithm"
23
+ loss = algorithm_config.loss
24
+ if loss == SupportedLoss.MAE:
25
+ self.loss_func: Callable = mae_loss
26
+ elif loss == SupportedLoss.MSE:
27
+ self.loss_func = mse_loss
28
+ else:
29
+ raise ValueError(f"Unsupported loss for Care: {loss}")
30
+
31
+ def training_step(
32
+ self,
33
+ batch: tuple[ImageRegionData, ImageRegionData],
34
+ batch_idx: Any,
35
+ ) -> Any:
36
+ """Training step for CARE module."""
37
+ # TODO: add validation to determine if target is initialized
38
+ x, target = batch[0], batch[1]
39
+
40
+ prediction = self.model(x.data)
41
+ loss = self.loss_func(prediction, target.data)
42
+
43
+ self._log_training_stats(loss, batch_size=x.data.shape[0])
44
+
45
+ return loss
46
+
47
+ def validation_step(
48
+ self,
49
+ batch: tuple[ImageRegionData, ImageRegionData],
50
+ batch_idx: Any,
51
+ ) -> None:
52
+ """Validation step for CARE module."""
53
+ x, target = batch[0], batch[1]
54
+
55
+ prediction = self.model(x.data)
56
+ val_loss = self.loss_func(prediction, target.data)
57
+ self.metrics(prediction, target.data)
58
+ self._log_validation_stats(val_loss, batch_size=x.data.shape[0])
@@ -0,0 +1,67 @@
1
+ from typing import Any, Union
2
+
3
+ from careamics.config import (
4
+ N2VAlgorithm,
5
+ )
6
+ from careamics.dataset_ng.dataset import ImageRegionData
7
+ from careamics.losses import n2v_loss
8
+ from careamics.transforms import N2VManipulateTorch
9
+ from careamics.utils.logging import get_logger
10
+
11
+ from .unet_module import UnetModule
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ class N2VModule(UnetModule):
17
+ """CAREamics PyTorch Lightning module for N2V algorithm."""
18
+
19
+ def __init__(self, algorithm_config: Union[N2VAlgorithm, dict]) -> None:
20
+ super().__init__(algorithm_config)
21
+
22
+ assert isinstance(
23
+ algorithm_config, N2VAlgorithm
24
+ ), "algorithm_config must be a N2VAlgorithm"
25
+
26
+ self.n2v_manipulate = N2VManipulateTorch(
27
+ n2v_manipulate_config=algorithm_config.n2v_config
28
+ )
29
+ self.loss_func = n2v_loss
30
+
31
+ def _load_best_checkpoint(self) -> None:
32
+ logger.warning(
33
+ "Loading best checkpoint for N2V model. Note that for N2V, "
34
+ "the checkpoint with the best validation metrics may not necessarily "
35
+ "have the best denoising performance."
36
+ )
37
+ super()._load_best_checkpoint()
38
+
39
+ def training_step(
40
+ self,
41
+ batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
42
+ batch_idx: Any,
43
+ ) -> Any:
44
+ """Training step for N2V model."""
45
+ x = batch[0]
46
+ x_masked, x_original, mask = self.n2v_manipulate(x.data)
47
+ prediction = self.model(x_masked)
48
+ loss = self.loss_func(prediction, x_original, mask)
49
+
50
+ self._log_training_stats(loss, batch_size=x.data.shape[0])
51
+
52
+ return loss
53
+
54
+ def validation_step(
55
+ self,
56
+ batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
57
+ batch_idx: Any,
58
+ ) -> None:
59
+ """Validation step for N2V model."""
60
+ x = batch[0]
61
+
62
+ x_masked, x_original, mask = self.n2v_manipulate(x.data)
63
+ prediction = self.model(x_masked)
64
+
65
+ val_loss = self.loss_func(prediction, x_original, mask)
66
+ self.metrics(prediction, x_original)
67
+ self._log_validation_stats(val_loss, batch_size=x.data.shape[0])