careamics 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (98) hide show
  1. careamics/careamist.py +24 -7
  2. careamics/cli/utils.py +1 -1
  3. careamics/config/algorithms/n2v_algorithm_model.py +1 -1
  4. careamics/config/architectures/unet_model.py +3 -0
  5. careamics/config/callback_model.py +23 -34
  6. careamics/config/configuration.py +55 -4
  7. careamics/config/configuration_factories.py +288 -23
  8. careamics/config/data/__init__.py +2 -0
  9. careamics/config/data/data_model.py +41 -4
  10. careamics/config/data/ng_data_model.py +381 -0
  11. careamics/config/data/patching_strategies/__init__.py +14 -0
  12. careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
  13. careamics/config/data/patching_strategies/_patched_model.py +56 -0
  14. careamics/config/data/patching_strategies/random_patching_model.py +21 -0
  15. careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
  16. careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
  17. careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
  18. careamics/config/inference_model.py +6 -3
  19. careamics/config/optimizer_models.py +1 -3
  20. careamics/config/support/supported_data.py +7 -0
  21. careamics/config/support/supported_patching_strategies.py +22 -0
  22. careamics/config/training_model.py +0 -2
  23. careamics/config/validators/validator_utils.py +4 -3
  24. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  25. careamics/dataset/in_memory_dataset.py +2 -1
  26. careamics/dataset/iterable_dataset.py +2 -2
  27. careamics/dataset/iterable_pred_dataset.py +2 -2
  28. careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
  29. careamics/dataset/patching/patching.py +3 -2
  30. careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
  31. careamics/dataset/tiling/tiled_patching.py +2 -1
  32. careamics/dataset_ng/README.md +212 -0
  33. careamics/dataset_ng/dataset.py +229 -0
  34. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  35. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  36. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
  37. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  38. careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +60 -53
  39. careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
  40. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
  41. careamics/dataset_ng/factory.py +451 -0
  42. careamics/dataset_ng/legacy_interoperability.py +170 -0
  43. careamics/dataset_ng/patch_extractor/__init__.py +3 -8
  44. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +7 -5
  45. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +4 -1
  46. careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
  47. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
  48. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
  49. careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
  50. careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
  51. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +114 -105
  52. careamics/dataset_ng/patching_strategies/__init__.py +6 -1
  53. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
  54. careamics/dataset_ng/patching_strategies/random_patching.py +5 -1
  55. careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
  56. careamics/dataset_ng/patching_strategies/tiling_strategy.py +172 -0
  57. careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
  58. careamics/file_io/read/get_func.py +2 -1
  59. careamics/lightning/dataset_ng/__init__.py +1 -0
  60. careamics/lightning/dataset_ng/data_module.py +678 -0
  61. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  62. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  63. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  64. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +212 -0
  65. careamics/lightning/lightning_module.py +5 -1
  66. careamics/lightning/predict_data_module.py +2 -1
  67. careamics/lightning/train_data_module.py +2 -1
  68. careamics/losses/loss_factory.py +2 -1
  69. careamics/lvae_training/dataset/__init__.py +8 -3
  70. careamics/lvae_training/dataset/config.py +3 -3
  71. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  72. careamics/lvae_training/dataset/multich_dataset.py +46 -17
  73. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  74. careamics/lvae_training/dataset/types.py +3 -3
  75. careamics/lvae_training/dataset/utils/index_manager.py +259 -0
  76. careamics/lvae_training/eval_utils.py +93 -3
  77. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  78. careamics/model_io/bioimage/model_description.py +1 -1
  79. careamics/model_io/bmz_io.py +1 -1
  80. careamics/model_io/model_io_utils.py +2 -2
  81. careamics/models/activation.py +2 -1
  82. careamics/prediction_utils/prediction_outputs.py +1 -1
  83. careamics/prediction_utils/stitch_prediction.py +1 -1
  84. careamics/transforms/compose.py +1 -0
  85. careamics/transforms/n2v_manipulate_torch.py +15 -9
  86. careamics/transforms/normalize.py +18 -7
  87. careamics/transforms/pixel_manipulation_torch.py +59 -92
  88. careamics/utils/lightning_utils.py +25 -11
  89. careamics/utils/metrics.py +2 -1
  90. careamics/utils/torch_utils.py +23 -0
  91. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/METADATA +12 -11
  92. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/RECORD +95 -69
  93. careamics/dataset_ng/dataset/__init__.py +0 -3
  94. careamics/dataset_ng/dataset/dataset.py +0 -184
  95. careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
  96. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
  97. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
  98. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,678 @@
1
+ """Next-Generation CAREamics DataModule."""
2
+
3
+ from collections.abc import Callable
4
+ from pathlib import Path
5
+ from typing import Any, Optional, Union, overload
6
+
7
+ import numpy as np
8
+ import pytorch_lightning as L
9
+ from numpy.typing import NDArray
10
+ from torch.utils.data import DataLoader
11
+ from torch.utils.data._utils.collate import default_collate
12
+
13
+ from careamics.config.data.ng_data_model import NGDataConfig
14
+ from careamics.config.support import SupportedData
15
+ from careamics.dataset.dataset_utils import list_files, validate_source_target_files
16
+ from careamics.dataset_ng.dataset import Mode
17
+ from careamics.dataset_ng.factory import create_dataset
18
+ from careamics.dataset_ng.patch_extractor import ImageStackLoader
19
+ from careamics.utils import get_logger
20
+
21
+ logger = get_logger(__name__)
22
+
23
+ ItemType = Union[Path, str, NDArray[Any]]
24
+ """Type of input items passed to the dataset."""
25
+
26
+ InputType = Union[ItemType, list[ItemType], None]
27
+ """Type of input data passed to the dataset."""
28
+
29
+
30
+ class CareamicsDataModule(L.LightningDataModule):
31
+ """Data module for Careamics dataset.
32
+
33
+ Parameters
34
+ ----------
35
+ data_config : DataConfig
36
+ Pydantic model for CAREamics data configuration.
37
+ train_data : Optional[InputType]
38
+ Training data, can be a path to a folder, a list of paths, or a numpy array.
39
+ train_data_target : Optional[InputType]
40
+ Training data target, can be a path to a folder,
41
+ a list of paths, or a numpy array.
42
+ val_data : Optional[InputType]
43
+ Validation data, can be a path to a folder,
44
+ a list of paths, or a numpy array.
45
+ val_data_target : Optional[InputType]
46
+ Validation data target, can be a path to a folder,
47
+ a list of paths, or a numpy array.
48
+ pred_data : Optional[InputType]
49
+ Prediction data, can be a path to a folder, a list of paths,
50
+ or a numpy array.
51
+ pred_data_target : Optional[InputType]
52
+ Prediction data target, can be a path to a folder,
53
+ a list of paths, or a numpy array.
54
+ read_source_func : Optional[Callable], default=None
55
+ Function to read the source data. Only used for `custom`
56
+ data type (see DataModel).
57
+ read_kwargs : Optional[dict[str, Any]]
58
+ The kwargs for the read source function.
59
+ image_stack_loader : Optional[ImageStackLoader]
60
+ The image stack loader.
61
+ image_stack_loader_kwargs : Optional[dict[str, Any]]
62
+ The image stack loader kwargs.
63
+ extension_filter : str, default=""
64
+ Filter for file extensions. Only used for `custom` data types
65
+ (see DataModel).
66
+ val_percentage : Optional[float]
67
+ Percentage of the training data to use for validation. Only
68
+ used if `val_data` is None.
69
+ val_minimum_split : int, default=5
70
+ Minimum number of patches or files to split from the training data for
71
+ validation. Only used if `val_data` is None.
72
+ use_in_memory : bool
73
+ Load data in memory dataset if possible, by default True.
74
+
75
+
76
+ Attributes
77
+ ----------
78
+ config : DataConfig
79
+ Pydantic model for CAREamics data configuration.
80
+ data_type : str
81
+ Type of data, one of SupportedData.
82
+ batch_size : int
83
+ Batch size for the dataloaders.
84
+ use_in_memory : bool
85
+ Whether to load data in memory if possible.
86
+ extension_filter : str
87
+ Filter for file extensions, by default "".
88
+ read_source_func : Optional[Callable], default=None
89
+ Function to read the source data.
90
+ read_kwargs : Optional[dict[str, Any]], default=None
91
+ The kwargs for the read source function.
92
+ val_percentage : Optional[float]
93
+ Percentage of the training data to use for validation.
94
+ val_minimum_split : int, default=5
95
+ Minimum number of patches or files to split from the training data for
96
+ validation.
97
+ train_data : Optional[Any]
98
+ Training data, can be a path to a folder, a list of paths, or a numpy array.
99
+ train_data_target : Optional[Any]
100
+ Training data target, can be a path to a folder, a list of paths, or a numpy
101
+ array.
102
+ val_data : Optional[Any]
103
+ Validation data, can be a path to a folder, a list of paths, or a numpy array.
104
+ val_data_target : Optional[Any]
105
+ Validation data target, can be a path to a folder, a list of paths, or a numpy
106
+ array.
107
+ pred_data : Optional[Any]
108
+ Prediction data, can be a path to a folder, a list of paths, or a numpy array.
109
+ pred_data_target : Optional[Any]
110
+ Prediction data target, can be a path to a folder, a list of paths, or a numpy
111
+ array.
112
+
113
+ Raises
114
+ ------
115
+ ValueError
116
+ If at least one of train_data, val_data or pred_data is not provided.
117
+ ValueError
118
+ If input and target data types are not consistent.
119
+ """
120
+
121
+ # standard use
122
+ @overload
123
+ def __init__(
124
+ self,
125
+ data_config: NGDataConfig,
126
+ *,
127
+ train_data: Optional[InputType] = None,
128
+ train_data_target: Optional[InputType] = None,
129
+ val_data: Optional[InputType] = None,
130
+ val_data_target: Optional[InputType] = None,
131
+ pred_data: Optional[InputType] = None,
132
+ pred_data_target: Optional[InputType] = None,
133
+ extension_filter: str = "",
134
+ val_percentage: Optional[float] = None,
135
+ val_minimum_split: int = 5,
136
+ use_in_memory: bool = True,
137
+ ) -> None: ...
138
+
139
+ # custom read function
140
+ @overload
141
+ def __init__(
142
+ self,
143
+ data_config: NGDataConfig,
144
+ *,
145
+ train_data: Optional[InputType] = None,
146
+ train_data_target: Optional[InputType] = None,
147
+ val_data: Optional[InputType] = None,
148
+ val_data_target: Optional[InputType] = None,
149
+ pred_data: Optional[InputType] = None,
150
+ pred_data_target: Optional[InputType] = None,
151
+ read_source_func: Callable,
152
+ read_kwargs: Optional[dict[str, Any]] = None,
153
+ extension_filter: str = "",
154
+ val_percentage: Optional[float] = None,
155
+ val_minimum_split: int = 5,
156
+ use_in_memory: bool = True,
157
+ ) -> None: ...
158
+
159
+ @overload
160
+ def __init__(
161
+ self,
162
+ data_config: NGDataConfig,
163
+ *,
164
+ train_data: Optional[Any] = None,
165
+ train_data_target: Optional[Any] = None,
166
+ val_data: Optional[Any] = None,
167
+ val_data_target: Optional[Any] = None,
168
+ pred_data: Optional[Any] = None,
169
+ pred_data_target: Optional[Any] = None,
170
+ image_stack_loader: ImageStackLoader,
171
+ image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
172
+ extension_filter: str = "",
173
+ val_percentage: Optional[float] = None,
174
+ val_minimum_split: int = 5,
175
+ use_in_memory: bool = True,
176
+ ) -> None: ...
177
+
178
+ def __init__(
179
+ self,
180
+ data_config: NGDataConfig,
181
+ *,
182
+ train_data: Optional[Any] = None,
183
+ train_data_target: Optional[Any] = None,
184
+ val_data: Optional[Any] = None,
185
+ val_data_target: Optional[Any] = None,
186
+ pred_data: Optional[Any] = None,
187
+ pred_data_target: Optional[Any] = None,
188
+ read_source_func: Optional[Callable] = None,
189
+ read_kwargs: Optional[dict[str, Any]] = None,
190
+ image_stack_loader: Optional[ImageStackLoader] = None,
191
+ image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
192
+ extension_filter: str = "",
193
+ val_percentage: Optional[float] = None,
194
+ val_minimum_split: int = 5,
195
+ use_in_memory: bool = True,
196
+ ) -> None:
197
+ """
198
+ Data module for Careamics dataset initialization.
199
+
200
+ Create a lightning datamodule that handles creating datasets for training,
201
+ validation, and prediction.
202
+
203
+ Parameters
204
+ ----------
205
+ data_config : NGDataConfig
206
+ Pydantic model for CAREamics data configuration.
207
+ train_data : Optional[InputType]
208
+ Training data, can be a path to a folder, a list of paths, or a numpy array.
209
+ train_data_target : Optional[InputType]
210
+ Training data target, can be a path to a folder,
211
+ a list of paths, or a numpy array.
212
+ val_data : Optional[InputType]
213
+ Validation data, can be a path to a folder,
214
+ a list of paths, or a numpy array.
215
+ val_data_target : Optional[InputType]
216
+ Validation data target, can be a path to a folder,
217
+ a list of paths, or a numpy array.
218
+ pred_data : Optional[InputType]
219
+ Prediction data, can be a path to a folder, a list of paths,
220
+ or a numpy array.
221
+ pred_data_target : Optional[InputType]
222
+ Prediction data target, can be a path to a folder,
223
+ a list of paths, or a numpy array.
224
+ read_source_func : Optional[Callable]
225
+ Function to read the source data, by default None. Only used for `custom`
226
+ data type (see DataModel).
227
+ read_kwargs : Optional[dict[str, Any]]
228
+ The kwargs for the read source function.
229
+ image_stack_loader : Optional[ImageStackLoader]
230
+ The image stack loader.
231
+ image_stack_loader_kwargs : Optional[dict[str, Any]]
232
+ The image stack loader kwargs.
233
+ extension_filter : str
234
+ Filter for file extensions, by default "". Only used for `custom` data types
235
+ (see DataModel).
236
+ val_percentage : Optional[float]
237
+ Percentage of the training data to use for validation. Only
238
+ used if `val_data` is None.
239
+ val_minimum_split : int
240
+ Minimum number of patches or files to split from the training data for
241
+ validation, by default 5. Only used if `val_data` is None.
242
+ use_in_memory : bool
243
+ Load data in memory dataset if possible, by default True.
244
+ """
245
+ super().__init__()
246
+
247
+ if train_data is None and val_data is None and pred_data is None:
248
+ raise ValueError(
249
+ "At least one of train_data, val_data or pred_data must be provided."
250
+ )
251
+
252
+ self.config: NGDataConfig = data_config
253
+ self.data_type: str = data_config.data_type
254
+ self.batch_size: int = data_config.batch_size
255
+ self.use_in_memory: bool = use_in_memory
256
+ self.extension_filter: str = extension_filter
257
+ self.read_source_func = read_source_func
258
+ self.read_kwargs = read_kwargs
259
+ self.image_stack_loader = image_stack_loader
260
+ self.image_stack_loader_kwargs = image_stack_loader_kwargs
261
+
262
+ # TODO: implement the validation split logic
263
+ self.val_percentage = val_percentage
264
+ self.val_minimum_split = val_minimum_split
265
+ if self.val_percentage is not None:
266
+ raise NotImplementedError("Validation split not implemented")
267
+
268
+ self.train_data, self.train_data_target = self._initialize_data_pair(
269
+ train_data, train_data_target
270
+ )
271
+ self.val_data, self.val_data_target = self._initialize_data_pair(
272
+ val_data, val_data_target
273
+ )
274
+
275
+ # The pred_data_target can be needed to count metrics on the prediction
276
+ self.pred_data, self.pred_data_target = self._initialize_data_pair(
277
+ pred_data, pred_data_target
278
+ )
279
+
280
+ def _validate_input_target_type_consistency(
281
+ self,
282
+ input_data: InputType,
283
+ target_data: Optional[InputType],
284
+ ) -> None:
285
+ """Validate if the input and target data types are consistent.
286
+
287
+ Parameters
288
+ ----------
289
+ input_data : InputType
290
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
291
+ target_data : Optional[InputType]
292
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
293
+ array.
294
+ """
295
+ if input_data is not None and target_data is not None:
296
+ if not isinstance(input_data, type(target_data)):
297
+ raise ValueError(
298
+ f"Inputs for input and target must be of the same type or None. "
299
+ f"Got {type(input_data)} and {type(target_data)}."
300
+ )
301
+ if isinstance(input_data, list) and isinstance(target_data, list):
302
+ if len(input_data) != len(target_data):
303
+ raise ValueError(
304
+ f"Inputs and targets must have the same length. "
305
+ f"Got {len(input_data)} and {len(target_data)}."
306
+ )
307
+ if not isinstance(input_data[0], type(target_data[0])):
308
+ raise ValueError(
309
+ f"Inputs and targets must have the same type. "
310
+ f"Got {type(input_data[0])} and {type(target_data[0])}."
311
+ )
312
+
313
+ def _list_files_in_directory(
314
+ self,
315
+ input_data,
316
+ target_data=None,
317
+ ) -> tuple[list[Path], Optional[list[Path]]]:
318
+ """List files from input and target directories.
319
+
320
+ Parameters
321
+ ----------
322
+ input_data : InputType
323
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
324
+ target_data : Optional[InputType]
325
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
326
+ array.
327
+
328
+ Returns
329
+ -------
330
+ (list[Path], Optional[list[Path]])
331
+ A tuple containing lists of file paths for input and target data.
332
+ If target_data is None, the second element will be None.
333
+ """
334
+ input_data = Path(input_data)
335
+ input_files = list_files(input_data, self.data_type, self.extension_filter)
336
+ if target_data is None:
337
+ return input_files, None
338
+ else:
339
+ target_data = Path(target_data)
340
+ target_files = list_files(
341
+ target_data, self.data_type, self.extension_filter
342
+ )
343
+ validate_source_target_files(input_files, target_files)
344
+ return input_files, target_files
345
+
346
+ def _convert_paths_to_pathlib(
347
+ self,
348
+ input_data,
349
+ target_data=None,
350
+ ) -> tuple[list[Path], Optional[list[Path]]]:
351
+ """Create a list of file paths from the input and target data.
352
+
353
+ Parameters
354
+ ----------
355
+ input_data : InputType
356
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
357
+ target_data : Optional[InputType]
358
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
359
+ array.
360
+
361
+ Returns
362
+ -------
363
+ (list[Path], Optional[list[Path]])
364
+ A tuple containing lists of file paths for input and target data.
365
+ If target_data is None, the second element will be None.
366
+ """
367
+ input_files = [
368
+ Path(item) if isinstance(item, str) else item for item in input_data
369
+ ]
370
+ if target_data is None:
371
+ return input_files, None
372
+ else:
373
+ target_files = [
374
+ Path(item) if isinstance(item, str) else item for item in target_data
375
+ ]
376
+ validate_source_target_files(input_files, target_files)
377
+ return input_files, target_files
378
+
379
+ def _validate_array_input(
380
+ self,
381
+ input_data: InputType,
382
+ target_data: Optional[InputType],
383
+ ) -> tuple[Any, Any]:
384
+ """Validate if the input data is a numpy array.
385
+
386
+ Parameters
387
+ ----------
388
+ input_data : InputType
389
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
390
+ target_data : Optional[InputType]
391
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
392
+ array.
393
+
394
+ Returns
395
+ -------
396
+ (Any, Any)
397
+ A tuple containing the input and target.
398
+ """
399
+ if isinstance(input_data, np.ndarray):
400
+ input_array = [input_data]
401
+ target_array = [target_data] if target_data is not None else None
402
+ return input_array, target_array
403
+ elif isinstance(input_data, list):
404
+ return input_data, target_data
405
+ else:
406
+ raise ValueError(
407
+ f"Unsupported input type for {self.data_type}: {type(input_data)}"
408
+ )
409
+
410
+ def _validate_path_input(
411
+ self, input_data: InputType, target_data: Optional[InputType]
412
+ ) -> tuple[list[Path], Optional[list[Path]]]:
413
+ """Validate if the input data is a path or a list of paths.
414
+
415
+ Parameters
416
+ ----------
417
+ input_data : InputType
418
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
419
+ target_data : Optional[InputType]
420
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
421
+ array.
422
+
423
+ Returns
424
+ -------
425
+ (list[Path], Optional[list[Path]])
426
+ A tuple containing lists of file paths for input and target data.
427
+ If target_data is None, the second element will be None.
428
+ """
429
+ if isinstance(input_data, str | Path):
430
+ if target_data is not None:
431
+ assert isinstance(target_data, str | Path)
432
+ input_list, target_list = self._list_files_in_directory(
433
+ input_data, target_data
434
+ )
435
+ return input_list, target_list
436
+ elif isinstance(input_data, list):
437
+ if target_data is not None:
438
+ assert isinstance(target_data, list)
439
+ input_list, target_list = self._convert_paths_to_pathlib(
440
+ input_data, target_data
441
+ )
442
+ return input_list, target_list
443
+ else:
444
+ raise ValueError(
445
+ f"Unsupported input type for {self.data_type}: {type(input_data)}"
446
+ )
447
+
448
+ def _validate_custom_input(self, input_data, target_data) -> tuple[Any, Any]:
449
+ """Convert custom input data to a list of file paths.
450
+
451
+ Parameters
452
+ ----------
453
+ input_data : InputType
454
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
455
+ target_data : Optional[InputType]
456
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
457
+ array.
458
+
459
+ Returns
460
+ -------
461
+ (Any, Any)
462
+ A tuple containing lists of file paths for input and target data.
463
+ If target_data is None, the second element will be None.
464
+ """
465
+ if self.image_stack_loader is not None:
466
+ return input_data, target_data
467
+ elif isinstance(input_data, str | Path):
468
+ if target_data is not None:
469
+ assert isinstance(target_data, str | Path)
470
+ input_list, target_list = self._list_files_in_directory(
471
+ input_data, target_data
472
+ )
473
+ return input_list, target_list
474
+ elif isinstance(input_data, list):
475
+ if isinstance(input_data[0], str | Path):
476
+ if target_data is not None:
477
+ assert isinstance(target_data, list)
478
+ input_list, target_list = self._convert_paths_to_pathlib(
479
+ input_data, target_data
480
+ )
481
+ return input_list, target_list
482
+ else:
483
+ raise ValueError(
484
+ f"If using {self.data_type}, pass a custom "
485
+ f"image_stack_loader or read_source_func"
486
+ )
487
+ return input_data, target_data
488
+
489
+ def _initialize_data_pair(
490
+ self,
491
+ input_data: Optional[InputType],
492
+ target_data: Optional[InputType],
493
+ ) -> tuple[Any, Any]:
494
+ """
495
+ Initialize a pair of input and target data.
496
+
497
+ Parameters
498
+ ----------
499
+ input_data : InputType
500
+ Input data, can be None, a path to a folder, a list of paths, or a numpy
501
+ array.
502
+ target_data : Optional[InputType]
503
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
504
+ array.
505
+
506
+ Returns
507
+ -------
508
+ (list of numpy.ndarray or list of pathlib.Path, None or list of numpy.ndarray or
509
+ list of pathlib.Path)
510
+ A tuple containing the initialized input and target data. For file paths,
511
+ returns lists of Path objects. For numpy arrays, returns the arrays
512
+ directly.
513
+ """
514
+ if input_data is None:
515
+ return None, None
516
+
517
+ self._validate_input_target_type_consistency(input_data, target_data)
518
+
519
+ if self.data_type == SupportedData.ARRAY:
520
+ if isinstance(input_data, np.ndarray):
521
+ return self._validate_array_input(input_data, target_data)
522
+ elif isinstance(input_data, list):
523
+ if isinstance(input_data[0], np.ndarray):
524
+ return self._validate_array_input(input_data, target_data)
525
+ else:
526
+ raise ValueError(
527
+ f"Unsupported input type for {self.data_type}: "
528
+ f"{type(input_data[0])}"
529
+ )
530
+ else:
531
+ raise ValueError(
532
+ f"Unsupported input type for {self.data_type}: {type(input_data)}"
533
+ )
534
+ elif self.data_type in (SupportedData.TIFF, SupportedData.CZI):
535
+ if isinstance(input_data, str | Path):
536
+ return self._validate_path_input(input_data, target_data)
537
+ elif isinstance(input_data, list):
538
+ if isinstance(input_data[0], str | Path):
539
+ return self._validate_path_input(input_data, target_data)
540
+ else:
541
+ raise ValueError(
542
+ f"Unsupported input type for {self.data_type}: "
543
+ f"{type(input_data[0])}"
544
+ )
545
+ else:
546
+ raise ValueError(
547
+ f"Unsupported input type for {self.data_type}: {type(input_data)}"
548
+ )
549
+ elif self.data_type == SupportedData.CUSTOM:
550
+ return self._validate_custom_input(input_data, target_data)
551
+ else:
552
+ raise NotImplementedError(f"Unsupported data type: {self.data_type}")
553
+
554
+ def setup(self, stage: str) -> None:
555
+ """
556
+ Setup datasets.
557
+
558
+ Lightning hook that is called at the beginning of fit (train + validate),
559
+ validate, test, or predict. Creates the datasets for a given stage.
560
+
561
+ Parameters
562
+ ----------
563
+ stage : str
564
+ The stage to set up datasets for.
565
+ Is either 'fit', 'validate', 'test', or 'predict'.
566
+
567
+ Raises
568
+ ------
569
+ NotImplementedError
570
+ If stage is not one of "fit", "validate" or "predict".
571
+ """
572
+ if stage == "fit":
573
+ self.train_dataset = create_dataset(
574
+ mode=Mode.TRAINING,
575
+ inputs=self.train_data,
576
+ targets=self.train_data_target,
577
+ config=self.config,
578
+ in_memory=self.use_in_memory,
579
+ read_func=self.read_source_func,
580
+ read_kwargs=self.read_kwargs,
581
+ image_stack_loader=self.image_stack_loader,
582
+ image_stack_loader_kwargs=self.image_stack_loader_kwargs,
583
+ )
584
+ # TODO: ugly, need to find a better solution
585
+ self.stats = self.train_dataset.input_stats
586
+ self.config.set_means_and_stds(
587
+ self.train_dataset.input_stats.means,
588
+ self.train_dataset.input_stats.stds,
589
+ self.train_dataset.target_stats.means,
590
+ self.train_dataset.target_stats.stds,
591
+ )
592
+ self.val_dataset = create_dataset(
593
+ mode=Mode.VALIDATING,
594
+ inputs=self.val_data,
595
+ targets=self.val_data_target,
596
+ config=self.config,
597
+ in_memory=self.use_in_memory,
598
+ read_func=self.read_source_func,
599
+ read_kwargs=self.read_kwargs,
600
+ image_stack_loader=self.image_stack_loader,
601
+ image_stack_loader_kwargs=self.image_stack_loader_kwargs,
602
+ )
603
+ elif stage == "validate":
604
+ self.val_dataset = create_dataset(
605
+ mode=Mode.VALIDATING,
606
+ inputs=self.val_data,
607
+ targets=self.val_data_target,
608
+ config=self.config,
609
+ in_memory=self.use_in_memory,
610
+ read_func=self.read_source_func,
611
+ read_kwargs=self.read_kwargs,
612
+ image_stack_loader=self.image_stack_loader,
613
+ image_stack_loader_kwargs=self.image_stack_loader_kwargs,
614
+ )
615
+ self.stats = self.val_dataset.input_stats
616
+ elif stage == "predict":
617
+ self.predict_dataset = create_dataset(
618
+ mode=Mode.PREDICTING,
619
+ inputs=self.pred_data,
620
+ targets=self.pred_data_target,
621
+ config=self.config,
622
+ in_memory=self.use_in_memory,
623
+ read_func=self.read_source_func,
624
+ read_kwargs=self.read_kwargs,
625
+ image_stack_loader=self.image_stack_loader,
626
+ image_stack_loader_kwargs=self.image_stack_loader_kwargs,
627
+ )
628
+ self.stats = self.predict_dataset.input_stats
629
+ else:
630
+ raise NotImplementedError(f"Stage {stage} not implemented")
631
+
632
+ def train_dataloader(self) -> DataLoader:
633
+ """
634
+ Create a dataloader for training.
635
+
636
+ Returns
637
+ -------
638
+ DataLoader
639
+ Training dataloader.
640
+ """
641
+ return DataLoader(
642
+ self.train_dataset,
643
+ batch_size=self.batch_size,
644
+ collate_fn=default_collate,
645
+ **self.config.train_dataloader_params,
646
+ )
647
+
648
+ def val_dataloader(self) -> DataLoader:
649
+ """
650
+ Create a dataloader for validation.
651
+
652
+ Returns
653
+ -------
654
+ DataLoader
655
+ Validation dataloader.
656
+ """
657
+ return DataLoader(
658
+ self.val_dataset,
659
+ batch_size=self.batch_size,
660
+ collate_fn=default_collate,
661
+ **self.config.val_dataloader_params,
662
+ )
663
+
664
+ def predict_dataloader(self) -> DataLoader:
665
+ """
666
+ Create a dataloader for prediction.
667
+
668
+ Returns
669
+ -------
670
+ DataLoader
671
+ Prediction dataloader.
672
+ """
673
+ return DataLoader(
674
+ self.predict_dataset,
675
+ batch_size=self.batch_size,
676
+ collate_fn=default_collate,
677
+ **self.config.test_dataloader_params,
678
+ )
@@ -0,0 +1,9 @@
1
+ """CAREamics PyTorch Lightning modules."""
2
+
3
+ from .care_module import CAREModule
4
+ from .n2v_module import N2VModule
5
+
6
+ __all__ = [
7
+ "CAREModule",
8
+ "N2VModule",
9
+ ]