careamics 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (73) hide show
  1. careamics/careamist.py +4 -3
  2. careamics/cli/utils.py +1 -1
  3. careamics/config/algorithms/n2v_algorithm_model.py +1 -1
  4. careamics/config/architectures/unet_model.py +3 -0
  5. careamics/config/callback_model.py +23 -34
  6. careamics/config/configuration.py +47 -1
  7. careamics/config/configuration_factories.py +288 -23
  8. careamics/config/data/__init__.py +2 -0
  9. careamics/config/data/data_model.py +3 -3
  10. careamics/config/data/ng_data_model.py +381 -0
  11. careamics/config/data/patching_strategies/__init__.py +14 -0
  12. careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
  13. careamics/config/data/patching_strategies/_patched_model.py +56 -0
  14. careamics/config/data/patching_strategies/random_patching_model.py +21 -0
  15. careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
  16. careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
  17. careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
  18. careamics/config/inference_model.py +6 -3
  19. careamics/config/support/supported_data.py +7 -0
  20. careamics/config/support/supported_patching_strategies.py +22 -0
  21. careamics/config/validators/validator_utils.py +4 -3
  22. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  23. careamics/dataset/in_memory_dataset.py +2 -1
  24. careamics/dataset/iterable_dataset.py +2 -2
  25. careamics/dataset/iterable_pred_dataset.py +2 -2
  26. careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
  27. careamics/dataset/patching/patching.py +3 -2
  28. careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
  29. careamics/dataset/tiling/tiled_patching.py +2 -1
  30. careamics/dataset_ng/dataset.py +46 -50
  31. careamics/dataset_ng/demos/bsd68_demo.ipynb +28 -23
  32. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +1 -1
  33. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +1 -1
  34. careamics/dataset_ng/demos/demo_datamodule.ipynb +50 -46
  35. careamics/dataset_ng/demos/demo_dataset.ipynb +32 -49
  36. careamics/dataset_ng/factory.py +58 -15
  37. careamics/dataset_ng/legacy_interoperability.py +3 -1
  38. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +1 -1
  39. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -0
  40. careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
  41. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
  42. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +43 -1
  43. careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
  44. careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
  45. careamics/dataset_ng/patching_strategies/tiling_strategy.py +2 -1
  46. careamics/file_io/read/get_func.py +2 -1
  47. careamics/lightning/dataset_ng/__init__.py +1 -0
  48. careamics/lightning/dataset_ng/data_module.py +218 -28
  49. careamics/lightning/dataset_ng/lightning_modules/care_module.py +44 -5
  50. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +42 -3
  51. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +73 -4
  52. careamics/lightning/lightning_module.py +2 -1
  53. careamics/lightning/predict_data_module.py +2 -1
  54. careamics/lightning/train_data_module.py +2 -1
  55. careamics/losses/loss_factory.py +2 -1
  56. careamics/lvae_training/dataset/multicrop_dset.py +1 -1
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +1 -1
  59. careamics/model_io/bmz_io.py +1 -1
  60. careamics/model_io/model_io_utils.py +2 -2
  61. careamics/models/activation.py +2 -1
  62. careamics/prediction_utils/prediction_outputs.py +1 -1
  63. careamics/prediction_utils/stitch_prediction.py +1 -1
  64. careamics/transforms/n2v_manipulate_torch.py +15 -9
  65. careamics/transforms/pixel_manipulation_torch.py +59 -92
  66. careamics/utils/lightning_utils.py +2 -2
  67. careamics/utils/metrics.py +2 -1
  68. careamics/utils/torch_utils.py +23 -0
  69. {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/METADATA +10 -9
  70. {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/RECORD +73 -62
  71. {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
  72. {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
  73. {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,8 @@
1
+ """Next-Generation CAREamics DataModule."""
2
+
3
+ from collections.abc import Callable
1
4
  from pathlib import Path
2
- from typing import Any, Callable, Optional, Union, overload
5
+ from typing import Any, Optional, Union, overload
3
6
 
4
7
  import numpy as np
5
8
  import pytorch_lightning as L
@@ -7,7 +10,7 @@ from numpy.typing import NDArray
7
10
  from torch.utils.data import DataLoader
8
11
  from torch.utils.data._utils.collate import default_collate
9
12
 
10
- from careamics.config.data import DataConfig
13
+ from careamics.config.data.ng_data_model import NGDataConfig
11
14
  from careamics.config.support import SupportedData
12
15
  from careamics.dataset.dataset_utils import list_files, validate_source_target_files
13
16
  from careamics.dataset_ng.dataset import Mode
@@ -18,16 +21,108 @@ from careamics.utils import get_logger
18
21
  logger = get_logger(__name__)
19
22
 
20
23
  ItemType = Union[Path, str, NDArray[Any]]
24
+ """Type of input items passed to the dataset."""
25
+
21
26
  InputType = Union[ItemType, list[ItemType], None]
27
+ """Type of input data passed to the dataset."""
22
28
 
23
29
 
24
30
  class CareamicsDataModule(L.LightningDataModule):
25
- """Data module for Careamics dataset."""
26
-
31
+ """Data module for Careamics dataset.
32
+
33
+ Parameters
34
+ ----------
35
+ data_config : DataConfig
36
+ Pydantic model for CAREamics data configuration.
37
+ train_data : Optional[InputType]
38
+ Training data, can be a path to a folder, a list of paths, or a numpy array.
39
+ train_data_target : Optional[InputType]
40
+ Training data target, can be a path to a folder,
41
+ a list of paths, or a numpy array.
42
+ val_data : Optional[InputType]
43
+ Validation data, can be a path to a folder,
44
+ a list of paths, or a numpy array.
45
+ val_data_target : Optional[InputType]
46
+ Validation data target, can be a path to a folder,
47
+ a list of paths, or a numpy array.
48
+ pred_data : Optional[InputType]
49
+ Prediction data, can be a path to a folder, a list of paths,
50
+ or a numpy array.
51
+ pred_data_target : Optional[InputType]
52
+ Prediction data target, can be a path to a folder,
53
+ a list of paths, or a numpy array.
54
+ read_source_func : Optional[Callable], default=None
55
+ Function to read the source data. Only used for `custom`
56
+ data type (see DataModel).
57
+ read_kwargs : Optional[dict[str, Any]]
58
+ The kwargs for the read source function.
59
+ image_stack_loader : Optional[ImageStackLoader]
60
+ The image stack loader.
61
+ image_stack_loader_kwargs : Optional[dict[str, Any]]
62
+ The image stack loader kwargs.
63
+ extension_filter : str, default=""
64
+ Filter for file extensions. Only used for `custom` data types
65
+ (see DataModel).
66
+ val_percentage : Optional[float]
67
+ Percentage of the training data to use for validation. Only
68
+ used if `val_data` is None.
69
+ val_minimum_split : int, default=5
70
+ Minimum number of patches or files to split from the training data for
71
+ validation. Only used if `val_data` is None.
72
+ use_in_memory : bool
73
+ Load data in memory dataset if possible, by default True.
74
+
75
+
76
+ Attributes
77
+ ----------
78
+ config : DataConfig
79
+ Pydantic model for CAREamics data configuration.
80
+ data_type : str
81
+ Type of data, one of SupportedData.
82
+ batch_size : int
83
+ Batch size for the dataloaders.
84
+ use_in_memory : bool
85
+ Whether to load data in memory if possible.
86
+ extension_filter : str
87
+ Filter for file extensions, by default "".
88
+ read_source_func : Optional[Callable], default=None
89
+ Function to read the source data.
90
+ read_kwargs : Optional[dict[str, Any]], default=None
91
+ The kwargs for the read source function.
92
+ val_percentage : Optional[float]
93
+ Percentage of the training data to use for validation.
94
+ val_minimum_split : int, default=5
95
+ Minimum number of patches or files to split from the training data for
96
+ validation.
97
+ train_data : Optional[Any]
98
+ Training data, can be a path to a folder, a list of paths, or a numpy array.
99
+ train_data_target : Optional[Any]
100
+ Training data target, can be a path to a folder, a list of paths, or a numpy
101
+ array.
102
+ val_data : Optional[Any]
103
+ Validation data, can be a path to a folder, a list of paths, or a numpy array.
104
+ val_data_target : Optional[Any]
105
+ Validation data target, can be a path to a folder, a list of paths, or a numpy
106
+ array.
107
+ pred_data : Optional[Any]
108
+ Prediction data, can be a path to a folder, a list of paths, or a numpy array.
109
+ pred_data_target : Optional[Any]
110
+ Prediction data target, can be a path to a folder, a list of paths, or a numpy
111
+ array.
112
+
113
+ Raises
114
+ ------
115
+ ValueError
116
+ If at least one of train_data, val_data or pred_data is not provided.
117
+ ValueError
118
+ If input and target data types are not consistent.
119
+ """
120
+
121
+ # standard use
27
122
  @overload
28
123
  def __init__(
29
124
  self,
30
- data_config: DataConfig,
125
+ data_config: NGDataConfig,
31
126
  *,
32
127
  train_data: Optional[InputType] = None,
33
128
  train_data_target: Optional[InputType] = None,
@@ -41,10 +136,11 @@ class CareamicsDataModule(L.LightningDataModule):
41
136
  use_in_memory: bool = True,
42
137
  ) -> None: ...
43
138
 
139
+ # custom read function
44
140
  @overload
45
141
  def __init__(
46
142
  self,
47
- data_config: DataConfig,
143
+ data_config: NGDataConfig,
48
144
  *,
49
145
  train_data: Optional[InputType] = None,
50
146
  train_data_target: Optional[InputType] = None,
@@ -63,7 +159,7 @@ class CareamicsDataModule(L.LightningDataModule):
63
159
  @overload
64
160
  def __init__(
65
161
  self,
66
- data_config: DataConfig,
162
+ data_config: NGDataConfig,
67
163
  *,
68
164
  train_data: Optional[Any] = None,
69
165
  train_data_target: Optional[Any] = None,
@@ -81,7 +177,7 @@ class CareamicsDataModule(L.LightningDataModule):
81
177
 
82
178
  def __init__(
83
179
  self,
84
- data_config: DataConfig,
180
+ data_config: NGDataConfig,
85
181
  *,
86
182
  train_data: Optional[Any] = None,
87
183
  train_data_target: Optional[Any] = None,
@@ -106,7 +202,7 @@ class CareamicsDataModule(L.LightningDataModule):
106
202
 
107
203
  Parameters
108
204
  ----------
109
- data_config : DataConfig
205
+ data_config : NGDataConfig
110
206
  Pydantic model for CAREamics data configuration.
111
207
  train_data : Optional[InputType]
112
208
  Training data, can be a path to a folder, a list of paths, or a numpy array.
@@ -153,7 +249,7 @@ class CareamicsDataModule(L.LightningDataModule):
153
249
  "At least one of train_data, val_data or pred_data must be provided."
154
250
  )
155
251
 
156
- self.config: DataConfig = data_config
252
+ self.config: NGDataConfig = data_config
157
253
  self.data_type: str = data_config.data_type
158
254
  self.batch_size: int = data_config.batch_size
159
255
  self.use_in_memory: bool = use_in_memory
@@ -186,7 +282,16 @@ class CareamicsDataModule(L.LightningDataModule):
186
282
  input_data: InputType,
187
283
  target_data: Optional[InputType],
188
284
  ) -> None:
189
- """Validate if the input and target data types are consistent."""
285
+ """Validate if the input and target data types are consistent.
286
+
287
+ Parameters
288
+ ----------
289
+ input_data : InputType
290
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
291
+ target_data : Optional[InputType]
292
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
293
+ array.
294
+ """
190
295
  if input_data is not None and target_data is not None:
191
296
  if not isinstance(input_data, type(target_data)):
192
297
  raise ValueError(
@@ -210,7 +315,22 @@ class CareamicsDataModule(L.LightningDataModule):
210
315
  input_data,
211
316
  target_data=None,
212
317
  ) -> tuple[list[Path], Optional[list[Path]]]:
213
- """List files from input and target directories."""
318
+ """List files from input and target directories.
319
+
320
+ Parameters
321
+ ----------
322
+ input_data : InputType
323
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
324
+ target_data : Optional[InputType]
325
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
326
+ array.
327
+
328
+ Returns
329
+ -------
330
+ (list[Path], Optional[list[Path]])
331
+ A tuple containing lists of file paths for input and target data.
332
+ If target_data is None, the second element will be None.
333
+ """
214
334
  input_data = Path(input_data)
215
335
  input_files = list_files(input_data, self.data_type, self.extension_filter)
216
336
  if target_data is None:
@@ -228,7 +348,22 @@ class CareamicsDataModule(L.LightningDataModule):
228
348
  input_data,
229
349
  target_data=None,
230
350
  ) -> tuple[list[Path], Optional[list[Path]]]:
231
- """Create a list of file paths from the input and target data."""
351
+ """Create a list of file paths from the input and target data.
352
+
353
+ Parameters
354
+ ----------
355
+ input_data : InputType
356
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
357
+ target_data : Optional[InputType]
358
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
359
+ array.
360
+
361
+ Returns
362
+ -------
363
+ (list[Path], Optional[list[Path]])
364
+ A tuple containing lists of file paths for input and target data.
365
+ If target_data is None, the second element will be None.
366
+ """
232
367
  input_files = [
233
368
  Path(item) if isinstance(item, str) else item for item in input_data
234
369
  ]
@@ -246,7 +381,21 @@ class CareamicsDataModule(L.LightningDataModule):
246
381
  input_data: InputType,
247
382
  target_data: Optional[InputType],
248
383
  ) -> tuple[Any, Any]:
249
- """Validate if the input data is a numpy array."""
384
+ """Validate if the input data is a numpy array.
385
+
386
+ Parameters
387
+ ----------
388
+ input_data : InputType
389
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
390
+ target_data : Optional[InputType]
391
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
392
+ array.
393
+
394
+ Returns
395
+ -------
396
+ (Any, Any)
397
+ A tuple containing the input and target.
398
+ """
250
399
  if isinstance(input_data, np.ndarray):
251
400
  input_array = [input_data]
252
401
  target_array = [target_data] if target_data is not None else None
@@ -261,9 +410,25 @@ class CareamicsDataModule(L.LightningDataModule):
261
410
  def _validate_path_input(
262
411
  self, input_data: InputType, target_data: Optional[InputType]
263
412
  ) -> tuple[list[Path], Optional[list[Path]]]:
264
- if isinstance(input_data, (str, Path)):
413
+ """Validate if the input data is a path or a list of paths.
414
+
415
+ Parameters
416
+ ----------
417
+ input_data : InputType
418
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
419
+ target_data : Optional[InputType]
420
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
421
+ array.
422
+
423
+ Returns
424
+ -------
425
+ (list[Path], Optional[list[Path]])
426
+ A tuple containing lists of file paths for input and target data.
427
+ If target_data is None, the second element will be None.
428
+ """
429
+ if isinstance(input_data, str | Path):
265
430
  if target_data is not None:
266
- assert isinstance(target_data, (str, Path))
431
+ assert isinstance(target_data, str | Path)
267
432
  input_list, target_list = self._list_files_in_directory(
268
433
  input_data, target_data
269
434
  )
@@ -281,17 +446,33 @@ class CareamicsDataModule(L.LightningDataModule):
281
446
  )
282
447
 
283
448
  def _validate_custom_input(self, input_data, target_data) -> tuple[Any, Any]:
449
+ """Convert custom input data to a list of file paths.
450
+
451
+ Parameters
452
+ ----------
453
+ input_data : InputType
454
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
455
+ target_data : Optional[InputType]
456
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
457
+ array.
458
+
459
+ Returns
460
+ -------
461
+ (Any, Any)
462
+ A tuple containing lists of file paths for input and target data.
463
+ If target_data is None, the second element will be None.
464
+ """
284
465
  if self.image_stack_loader is not None:
285
466
  return input_data, target_data
286
- elif isinstance(input_data, (str, Path)):
467
+ elif isinstance(input_data, str | Path):
287
468
  if target_data is not None:
288
- assert isinstance(target_data, (str, Path))
469
+ assert isinstance(target_data, str | Path)
289
470
  input_list, target_list = self._list_files_in_directory(
290
471
  input_data, target_data
291
472
  )
292
473
  return input_list, target_list
293
474
  elif isinstance(input_data, list):
294
- if isinstance(input_data[0], (str, Path)):
475
+ if isinstance(input_data[0], str | Path):
295
476
  if target_data is not None:
296
477
  assert isinstance(target_data, list)
297
478
  input_list, target_list = self._convert_paths_to_pathlib(
@@ -313,13 +494,22 @@ class CareamicsDataModule(L.LightningDataModule):
313
494
  """
314
495
  Initialize a pair of input and target data.
315
496
 
497
+ Parameters
498
+ ----------
499
+ input_data : InputType
500
+ Input data, can be None, a path to a folder, a list of paths, or a numpy
501
+ array.
502
+ target_data : Optional[InputType]
503
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
504
+ array.
505
+
316
506
  Returns
317
507
  -------
318
- tuple[Union[list[NDArray], list[Path]],
319
- Optional[Union[list[NDArray], list[Path]]]]
320
- A tuple containing the initialized input and target data.
321
- For file paths, returns lists of Path objects.
322
- For numpy arrays, returns the arrays directly.
508
+ (list of numpy.ndarray or list of pathlib.Path, None or list of numpy.ndarray or
509
+ list of pathlib.Path)
510
+ A tuple containing the initialized input and target data. For file paths,
511
+ returns lists of Path objects. For numpy arrays, returns the arrays
512
+ directly.
323
513
  """
324
514
  if input_data is None:
325
515
  return None, None
@@ -341,11 +531,11 @@ class CareamicsDataModule(L.LightningDataModule):
341
531
  raise ValueError(
342
532
  f"Unsupported input type for {self.data_type}: {type(input_data)}"
343
533
  )
344
- elif self.data_type == SupportedData.TIFF:
345
- if isinstance(input_data, (str, Path)):
534
+ elif self.data_type in (SupportedData.TIFF, SupportedData.CZI):
535
+ if isinstance(input_data, str | Path):
346
536
  return self._validate_path_input(input_data, target_data)
347
537
  elif isinstance(input_data, list):
348
- if isinstance(input_data[0], (Path, str)):
538
+ if isinstance(input_data[0], str | Path):
349
539
  return self._validate_path_input(input_data, target_data)
350
540
  else:
351
541
  raise ValueError(
@@ -484,5 +674,5 @@ class CareamicsDataModule(L.LightningDataModule):
484
674
  self.predict_dataset,
485
675
  batch_size=self.batch_size,
486
676
  collate_fn=default_collate,
487
- # TODO: set appropriate key for params once config changes are merged
677
+ **self.config.test_dataloader_params,
488
678
  )
@@ -1,4 +1,7 @@
1
- from typing import Any, Callable, Union
1
+ """CARE Lightning DataModule."""
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any, Union
2
5
 
3
6
  from careamics.config.algorithms.care_algorithm_model import CAREAlgorithm
4
7
  from careamics.config.algorithms.n2n_algorithm_model import N2NAlgorithm
@@ -13,12 +16,27 @@ logger = get_logger(__name__)
13
16
 
14
17
 
15
18
  class CAREModule(UnetModule):
16
- """CAREamics PyTorch Lightning module for CARE algorithm."""
19
+ """CAREamics PyTorch Lightning module for CARE algorithm.
20
+
21
+ Parameters
22
+ ----------
23
+ algorithm_config : CAREAlgorithm or dict
24
+ Configuration for the CARE algorithm, either as a CAREAlgorithm instance or a
25
+ dictionary.
26
+ """
17
27
 
18
28
  def __init__(self, algorithm_config: Union[CAREAlgorithm, dict]) -> None:
29
+ """Instantiate CARE DataModule.
30
+
31
+ Parameters
32
+ ----------
33
+ algorithm_config : CAREAlgorithm or dict
34
+ Configuration for the CARE algorithm, either as a CAREAlgorithm instance or
35
+ a dictionary.
36
+ """
19
37
  super().__init__(algorithm_config)
20
38
  assert isinstance(
21
- algorithm_config, (CAREAlgorithm, N2NAlgorithm)
39
+ algorithm_config, CAREAlgorithm | N2NAlgorithm
22
40
  ), "algorithm_config must be a CAREAlgorithm or a N2NAlgorithm"
23
41
  loss = algorithm_config.loss
24
42
  if loss == SupportedLoss.MAE:
@@ -33,7 +51,20 @@ class CAREModule(UnetModule):
33
51
  batch: tuple[ImageRegionData, ImageRegionData],
34
52
  batch_idx: Any,
35
53
  ) -> Any:
36
- """Training step for CARE module."""
54
+ """Training step for CARE module.
55
+
56
+ Parameters
57
+ ----------
58
+ batch : (ImageRegionData, ImageRegionData)
59
+ A tuple containing the input data and the target data.
60
+ batch_idx : Any
61
+ The index of the current batch in the training loop.
62
+
63
+ Returns
64
+ -------
65
+ Any
66
+ The loss value computed for the current batch.
67
+ """
37
68
  # TODO: add validation to determine if target is initialized
38
69
  x, target = batch[0], batch[1]
39
70
 
@@ -49,7 +80,15 @@ class CAREModule(UnetModule):
49
80
  batch: tuple[ImageRegionData, ImageRegionData],
50
81
  batch_idx: Any,
51
82
  ) -> None:
52
- """Validation step for CARE module."""
83
+ """Validation step for CARE module.
84
+
85
+ Parameters
86
+ ----------
87
+ batch : (ImageRegionData, ImageRegionData)
88
+ A tuple containing the input data and the target data.
89
+ batch_idx : Any
90
+ The index of the current batch in the training loop.
91
+ """
53
92
  x, target = batch[0], batch[1]
54
93
 
55
94
  prediction = self.model(x.data)
@@ -1,3 +1,5 @@
1
+ """Noise2Void Lightning DataModule."""
2
+
1
3
  from typing import Any, Union
2
4
 
3
5
  from careamics.config import (
@@ -14,9 +16,24 @@ logger = get_logger(__name__)
14
16
 
15
17
 
16
18
  class N2VModule(UnetModule):
17
- """CAREamics PyTorch Lightning module for N2V algorithm."""
19
+ """CAREamics PyTorch Lightning module for N2V algorithm.
20
+
21
+ Parameters
22
+ ----------
23
+ algorithm_config : N2VAlgorithm or dict
24
+ Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
25
+ dictionary.
26
+ """
18
27
 
19
28
  def __init__(self, algorithm_config: Union[N2VAlgorithm, dict]) -> None:
29
+ """Instantiate N2V DataModule.
30
+
31
+ Parameters
32
+ ----------
33
+ algorithm_config : N2VAlgorithm or dict
34
+ Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
35
+ dictionary.
36
+ """
20
37
  super().__init__(algorithm_config)
21
38
 
22
39
  assert isinstance(
@@ -29,6 +46,7 @@ class N2VModule(UnetModule):
29
46
  self.loss_func = n2v_loss
30
47
 
31
48
  def _load_best_checkpoint(self) -> None:
49
+ """Load the best checkpoint for N2V model."""
32
50
  logger.warning(
33
51
  "Loading best checkpoint for N2V model. Note that for N2V, "
34
52
  "the checkpoint with the best validation metrics may not necessarily "
@@ -41,7 +59,20 @@ class N2VModule(UnetModule):
41
59
  batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
42
60
  batch_idx: Any,
43
61
  ) -> Any:
44
- """Training step for N2V model."""
62
+ """Training step for N2V model.
63
+
64
+ Parameters
65
+ ----------
66
+ batch : ImageRegionData or (ImageRegionData, ImageRegionData)
67
+ A tuple containing the input data and the target data.
68
+ batch_idx : Any
69
+ The index of the current batch in the training loop.
70
+
71
+ Returns
72
+ -------
73
+ Any
74
+ The loss value for the current training step.
75
+ """
45
76
  x = batch[0]
46
77
  x_masked, x_original, mask = self.n2v_manipulate(x.data)
47
78
  prediction = self.model(x_masked)
@@ -56,7 +87,15 @@ class N2VModule(UnetModule):
56
87
  batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
57
88
  batch_idx: Any,
58
89
  ) -> None:
59
- """Validation step for N2V model."""
90
+ """Validation step for N2V model.
91
+
92
+ Parameters
93
+ ----------
94
+ batch : ImageRegionData or (ImageRegionData, ImageRegionData)
95
+ A tuple containing the input data and the target data.
96
+ batch_idx : Any
97
+ The index of the current batch in the training loop.
98
+ """
60
99
  x = batch[0]
61
100
 
62
101
  x_masked, x_original, mask = self.n2v_manipulate(x.data)
@@ -1,3 +1,5 @@
1
+ """Generic UNet Lightning DataModule."""
2
+
1
3
  from typing import Any, Union
2
4
 
3
5
  import pytorch_lightning as L
@@ -18,11 +20,27 @@ logger = get_logger(__name__)
18
20
 
19
21
 
20
22
  class UnetModule(L.LightningModule):
21
- """CAREamics PyTorch Lightning module for UNet based algorithms."""
23
+ """CAREamics PyTorch Lightning module for UNet based algorithms.
24
+
25
+ Parameters
26
+ ----------
27
+ algorithm_config : CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, or dict
28
+ Configuration for the algorithm, either as an instance of a specific algorithm
29
+ class or a dictionary that can be converted to an algorithm instance.
30
+ """
22
31
 
23
32
  def __init__(
24
33
  self, algorithm_config: Union[CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, dict]
25
34
  ) -> None:
35
+ """Instantiate UNet DataModule.
36
+
37
+ Parameters
38
+ ----------
39
+ algorithm_config : CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, or dict
40
+ Configuration for the algorithm, either as an instance of a specific
41
+ algorithm class or a dictionary that can be converted to an algorithm
42
+ instance.
43
+ """
26
44
  super().__init__()
27
45
 
28
46
  if isinstance(algorithm_config, dict):
@@ -37,10 +55,30 @@ class UnetModule(L.LightningModule):
37
55
  self.metrics = MetricCollection(PeakSignalNoiseRatio())
38
56
 
39
57
  def forward(self, x: Any) -> Any:
40
- """Default forward method."""
58
+ """Default forward method.
59
+
60
+ Parameters
61
+ ----------
62
+ x : Any
63
+ Input data.
64
+
65
+ Returns
66
+ -------
67
+ Any
68
+ Output from the model.
69
+ """
41
70
  return self.model(x)
42
71
 
43
72
  def _log_training_stats(self, loss: Any, batch_size: Any) -> None:
73
+ """Log training statistics.
74
+
75
+ Parameters
76
+ ----------
77
+ loss : Any
78
+ The loss value for the current training step.
79
+ batch_size : Any
80
+ The size of the batch used in the current training step.
81
+ """
44
82
  self.log(
45
83
  "train_loss",
46
84
  loss,
@@ -66,6 +104,15 @@ class UnetModule(L.LightningModule):
66
104
  )
67
105
 
68
106
  def _log_validation_stats(self, loss: Any, batch_size: Any) -> None:
107
+ """Log validation statistics.
108
+
109
+ Parameters
110
+ ----------
111
+ loss : Any
112
+ The loss value for the current validation step.
113
+ batch_size : Any
114
+ The size of the batch used in the current validation step.
115
+ """
69
116
  self.log(
70
117
  "val_loss",
71
118
  loss,
@@ -78,6 +125,7 @@ class UnetModule(L.LightningModule):
78
125
  self.log_dict(self.metrics, on_step=False, on_epoch=True, batch_size=batch_size)
79
126
 
80
127
  def _load_best_checkpoint(self) -> None:
128
+ """Load the best checkpoint from the trainer's checkpoint callback."""
81
129
  if (
82
130
  not hasattr(self.trainer, "checkpoint_callback")
83
131
  or self.trainer.checkpoint_callback is None
@@ -99,7 +147,22 @@ class UnetModule(L.LightningModule):
99
147
  batch_idx: Any,
100
148
  load_best_checkpoint=False,
101
149
  ) -> Any:
102
- """Default predict step."""
150
+ """Default predict step.
151
+
152
+ Parameters
153
+ ----------
154
+ batch : ImageRegionData or (ImageRegionData, ImageRegionData)
155
+ A tuple containing the input data and optionally the target data.
156
+ batch_idx : Any
157
+ The index of the current batch in the prediction loop.
158
+ load_best_checkpoint : bool, default=False
159
+ Whether to load the best checkpoint before making predictions.
160
+
161
+ Returns
162
+ -------
163
+ Any
164
+ The output batch containing the predictions.
165
+ """
103
166
  if self._best_checkpoint_loaded is False and load_best_checkpoint:
104
167
  self._load_best_checkpoint()
105
168
  self._best_checkpoint_loaded = True
@@ -127,7 +190,13 @@ class UnetModule(L.LightningModule):
127
190
  return output_batch
128
191
 
129
192
  def configure_optimizers(self) -> Any:
130
- """Configure optimizers."""
193
+ """Configure optimizers.
194
+
195
+ Returns
196
+ -------
197
+ Any
198
+ A dictionary containing the optimizer and learning rate scheduler.
199
+ """
131
200
  optimizer_func = get_optimizer(self.config.optimizer.name)
132
201
  optimizer = optimizer_func(
133
202
  self.model.parameters(), **self.config.optimizer.parameters
@@ -1,6 +1,7 @@
1
1
  """CAREamics Lightning module."""
2
2
 
3
- from typing import Any, Callable, Literal, Optional, Union
3
+ from collections.abc import Callable
4
+ from typing import Any, Literal, Optional, Union
4
5
 
5
6
  import numpy as np
6
7
  import pytorch_lightning as L
@@ -1,7 +1,8 @@
1
1
  """Prediction Lightning data modules."""
2
2
 
3
+ from collections.abc import Callable
3
4
  from pathlib import Path
4
- from typing import Any, Callable, Literal, Optional, Union
5
+ from typing import Any, Literal, Optional, Union
5
6
 
6
7
  import numpy as np
7
8
  import pytorch_lightning as L
@@ -1,7 +1,8 @@
1
1
  """Training and validation Lightning data modules."""
2
2
 
3
+ from collections.abc import Callable
3
4
  from pathlib import Path
4
- from typing import Any, Callable, Literal, Optional, Union
5
+ from typing import Any, Literal, Optional, Union
5
6
 
6
7
  import numpy as np
7
8
  import pytorch_lightning as L