careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -44,7 +44,9 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
44
44
  ValueError
45
45
  If the axes length is incorrect.
46
46
  """
47
- if fnmatch(file_path.suffix, SupportedData.get_extension(SupportedData.TIFF)):
47
+ if fnmatch(
48
+ file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
49
+ ):
48
50
  try:
49
51
  array = tifffile.imread(file_path)
50
52
  except (ValueError, OSError) as e:
@@ -53,13 +55,4 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
53
55
  else:
54
56
  raise ValueError(f"File {file_path} is not a valid tiff.")
55
57
 
56
- # check dimensions
57
- # TODO or should this really be done here? probably in the LightningDataModule
58
- # TODO this should also be centralized somewhere else (validate_dimensions)
59
- if len(array.shape) < 2 or len(array.shape) > 6:
60
- raise ValueError(
61
- f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape} for"
62
- f"file {file_path})."
63
- )
64
-
65
58
  return array
@@ -0,0 +1,9 @@
1
+ """Functions relating to writing image files of different formats."""
2
+
3
+ __all__ = [
4
+ "get_write_func",
5
+ "write_tiff",
6
+ ]
7
+
8
+ from .get_func import get_write_func
9
+ from .tiff import write_tiff
@@ -0,0 +1,59 @@
1
+ """Module to get write functions."""
2
+
3
+ from pathlib import Path
4
+ from typing import Protocol, Union
5
+
6
+ from numpy.typing import NDArray
7
+
8
+ from careamics.config.support import SupportedData
9
+
10
+ from .tiff import write_tiff
11
+
12
+
13
+ # This is very strict, arguments have to be called file_path & img
14
+ # Alternative? - doesn't capture *args & **kwargs
15
+ # WriteFunc = Callable[[Path, NDArray], None]
16
+ class WriteFunc(Protocol):
17
+ """Protocol for type hinting write functions."""
18
+
19
+ def __call__(self, file_path: Path, img: NDArray, *args, **kwargs) -> None:
20
+ """
21
+ Type hinted callables must match this function signature (not including self).
22
+
23
+ Parameters
24
+ ----------
25
+ file_path : pathlib.Path
26
+ Path to file.
27
+ img : numpy.ndarray
28
+ Image data to save.
29
+ *args
30
+ Other positional arguments.
31
+ **kwargs
32
+ Other keyword arguments.
33
+ """
34
+
35
+
36
+ WRITE_FUNCS: dict[SupportedData, WriteFunc] = {
37
+ SupportedData.TIFF: write_tiff,
38
+ }
39
+
40
+
41
+ def get_write_func(data_type: Union[str, SupportedData]) -> WriteFunc:
42
+ """
43
+ Get the write function for the data type.
44
+
45
+ Parameters
46
+ ----------
47
+ data_type : SupportedData
48
+ Data type.
49
+
50
+ Returns
51
+ -------
52
+ callable
53
+ Write function.
54
+ """
55
+ if data_type in WRITE_FUNCS:
56
+ data_type = SupportedData(data_type) # mypy complaining about dict key type
57
+ return WRITE_FUNCS[data_type]
58
+ else:
59
+ raise NotImplementedError(f"Data type {data_type} is not supported.")
@@ -0,0 +1,39 @@
1
+ """Write tiff function."""
2
+
3
+ from fnmatch import fnmatch
4
+ from pathlib import Path
5
+
6
+ import tifffile
7
+ from numpy.typing import NDArray
8
+
9
+ from careamics.config.support import SupportedData
10
+
11
+
12
+ def write_tiff(file_path: Path, img: NDArray, *args, **kwargs) -> None:
13
+ """
14
+ Write tiff files.
15
+
16
+ Parameters
17
+ ----------
18
+ file_path : pathlib.Path
19
+ Path to file.
20
+ img : numpy.ndarray
21
+ Image data to save.
22
+ *args
23
+ Positional arguments passed to `tifffile.imwrite`.
24
+ **kwargs
25
+ Keyword arguments passed to `tifffile.imwrite`.
26
+
27
+ Raises
28
+ ------
29
+ ValueError
30
+ When the file extension of `file_path` does not match the Unix shell-style
31
+ pattern '*.tif*'.
32
+ """
33
+ if not fnmatch(
34
+ file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
35
+ ):
36
+ raise ValueError(
37
+ f"Unexpected extension '{file_path.suffix}' for save file type 'tiff'."
38
+ )
39
+ tifffile.imwrite(file_path, img, *args, **kwargs)
@@ -0,0 +1,17 @@
1
+ """CAREamics PyTorch Lightning modules."""
2
+
3
+ __all__ = [
4
+ "CAREamicsModule",
5
+ "create_careamics_module",
6
+ "TrainDataModule",
7
+ "create_train_datamodule",
8
+ "PredictDataModule",
9
+ "create_predict_datamodule",
10
+ "HyperParametersCallback",
11
+ "ProgressBarCallback",
12
+ ]
13
+
14
+ from .callbacks import HyperParametersCallback, ProgressBarCallback
15
+ from .lightning_module import CAREamicsModule, create_careamics_module
16
+ from .predict_data_module import PredictDataModule, create_predict_datamodule
17
+ from .train_data_module import TrainDataModule, create_train_datamodule
@@ -23,19 +23,19 @@ class CAREamicsModule(L.LightningModule):
23
23
  """
24
24
  CAREamics Lightning module.
25
25
 
26
- This class encapsulates the a PyTorch model along with the training, validation,
26
+ This class encapsulates the PyTorch model along with the training, validation,
27
27
  and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
28
28
 
29
29
  Parameters
30
30
  ----------
31
- algorithm_config : Union[AlgorithmModel, dict]
31
+ algorithm_config : AlgorithmModel or dict
32
32
  Algorithm configuration.
33
33
 
34
34
  Attributes
35
35
  ----------
36
- model : nn.Module
36
+ model : torch.nn.Module
37
37
  PyTorch model.
38
- loss_func : nn.Module
38
+ loss_func : torch.nn.Module
39
39
  Loss function.
40
40
  optimizer_name : str
41
41
  Optimizer name.
@@ -53,7 +53,7 @@ class CAREamicsModule(L.LightningModule):
53
53
 
54
54
  Parameters
55
55
  ----------
56
- algorithm_config : Union[AlgorithmModel, dict]
56
+ algorithm_config : AlgorithmModel or dict
57
57
  Algorithm configuration.
58
58
  """
59
59
  super().__init__()
@@ -91,7 +91,7 @@ class CAREamicsModule(L.LightningModule):
91
91
 
92
92
  Parameters
93
93
  ----------
94
- batch : Tensor
94
+ batch : torch.Tensor
95
95
  Input batch.
96
96
  batch_idx : Any
97
97
  Batch index.
@@ -114,7 +114,7 @@ class CAREamicsModule(L.LightningModule):
114
114
 
115
115
  Parameters
116
116
  ----------
117
- batch : Tensor
117
+ batch : torch.Tensor
118
118
  Input batch.
119
119
  batch_idx : Any
120
120
  Batch index.
@@ -138,7 +138,7 @@ class CAREamicsModule(L.LightningModule):
138
138
 
139
139
  Parameters
140
140
  ----------
141
- batch : Tensor
141
+ batch : torch.Tensor
142
142
  Input batch.
143
143
  batch_idx : Any
144
144
  Batch index.
@@ -148,13 +148,17 @@ class CAREamicsModule(L.LightningModule):
148
148
  Any
149
149
  Model output.
150
150
  """
151
- x, *aux = batch
151
+ if self._trainer.datamodule.tiled:
152
+ x, *aux = batch
153
+ else:
154
+ x = batch
155
+ aux = []
152
156
 
153
157
  # apply test-time augmentation if available
154
158
  # TODO: probably wont work with batch size > 1
155
159
  if self._trainer.datamodule.prediction_config.tta_transforms:
156
160
  tta = ImageRestorationTTA()
157
- augmented_batch = tta.forward(batch[0]) # list of augmented tensors
161
+ augmented_batch = tta.forward(x) # list of augmented tensors
158
162
  augmented_output = []
159
163
  for augmented in augmented_batch:
160
164
  augmented_pred = self.model(augmented)
@@ -165,13 +169,13 @@ class CAREamicsModule(L.LightningModule):
165
169
 
166
170
  # Denormalize the output
167
171
  denorm = Denormalize(
168
- mean=self._trainer.datamodule.predict_dataset.mean,
169
- std=self._trainer.datamodule.predict_dataset.std,
172
+ image_means=self._trainer.datamodule.predict_dataset.image_means,
173
+ image_stds=self._trainer.datamodule.predict_dataset.image_stds,
170
174
  )
171
- denormalized_output, _ = denorm(patch=output)
175
+ denormalized_output = denorm(patch=output.cpu().numpy())
172
176
 
173
- if len(aux) > 0:
174
- return denormalized_output, aux
177
+ if len(aux) > 0: # aux can be tiling information
178
+ return denormalized_output, *aux
175
179
  else:
176
180
  return denormalized_output
177
181
 
@@ -198,101 +202,74 @@ class CAREamicsModule(L.LightningModule):
198
202
  }
199
203
 
200
204
 
201
- class CAREamicsModuleWrapper(CAREamicsModule):
202
- """Class defining the API for CAREamics Lightning layer.
205
+ def create_careamics_module(
206
+ algorithm: Union[SupportedAlgorithm, str],
207
+ loss: Union[SupportedLoss, str],
208
+ architecture: Union[SupportedArchitecture, str],
209
+ model_parameters: Optional[dict] = None,
210
+ optimizer: Union[SupportedOptimizer, str] = "Adam",
211
+ optimizer_parameters: Optional[dict] = None,
212
+ lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
213
+ lr_scheduler_parameters: Optional[dict] = None,
214
+ ) -> CAREamicsModule:
215
+ """Create a CAREamics Lithgning module.
203
216
 
204
- This class exposes parameters used to create an AlgorithmModel instance, triggering
205
- parameters validation.
217
+ This function exposes parameters used to create an AlgorithmModel instance,
218
+ triggering parameters validation.
206
219
 
207
220
  Parameters
208
221
  ----------
209
- algorithm : Union[SupportedAlgorithm, str]
222
+ algorithm : SupportedAlgorithm or str
210
223
  Algorithm to use for training (see SupportedAlgorithm).
211
- loss : Union[SupportedLoss, str]
224
+ loss : SupportedLoss or str
212
225
  Loss function to use for training (see SupportedLoss).
213
- architecture : Union[SupportedArchitecture, str]
226
+ architecture : SupportedArchitecture or str
214
227
  Model architecture to use for training (see SupportedArchitecture).
215
228
  model_parameters : dict, optional
216
229
  Model parameters to use for training, by default {}. Model parameters are
217
230
  defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
218
231
  `careamics.config.architectures`).
219
- optimizer : Union[SupportedOptimizer, str], optional
232
+ optimizer : SupportedOptimizer or str, optional
220
233
  Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
221
234
  optimizer_parameters : dict, optional
222
235
  Optimizer parameters to use for training, as defined in `torch.optim`, by
223
236
  default {}.
224
- lr_scheduler : Union[SupportedScheduler, str], optional
237
+ lr_scheduler : SupportedScheduler or str, optional
225
238
  Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
226
239
  (see SupportedScheduler).
227
240
  lr_scheduler_parameters : dict, optional
228
241
  Learning rate scheduler parameters to use for training, as defined in
229
242
  `torch.optim`, by default {}.
230
- """
231
243
 
232
- def __init__(
233
- self,
234
- algorithm: Union[SupportedAlgorithm, str],
235
- loss: Union[SupportedLoss, str],
236
- architecture: Union[SupportedArchitecture, str],
237
- model_parameters: Optional[dict] = None,
238
- optimizer: Union[SupportedOptimizer, str] = "Adam",
239
- optimizer_parameters: Optional[dict] = None,
240
- lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
241
- lr_scheduler_parameters: Optional[dict] = None,
242
- ) -> None:
243
- """
244
- Wrapper for the CAREamics model, exposing all algorithm configuration arguments.
245
-
246
- Parameters
247
- ----------
248
- algorithm : Union[SupportedAlgorithm, str]
249
- Algorithm to use for training (see SupportedAlgorithm).
250
- loss : Union[SupportedLoss, str]
251
- Loss function to use for training (see SupportedLoss).
252
- architecture : Union[SupportedArchitecture, str]
253
- Model architecture to use for training (see SupportedArchitecture).
254
- model_parameters : dict, optional
255
- Model parameters to use for training, by default {}. Model parameters are
256
- defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
257
- `careamics.config.architectures`).
258
- optimizer : Union[SupportedOptimizer, str], optional
259
- Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
260
- optimizer_parameters : dict, optional
261
- Optimizer parameters to use for training, as defined in `torch.optim`, by
262
- default {}.
263
- lr_scheduler : Union[SupportedScheduler, str], optional
264
- Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
265
- (see SupportedScheduler).
266
- lr_scheduler_parameters : dict, optional
267
- Learning rate scheduler parameters to use for training, as defined in
268
- `torch.optim`, by default {}.
269
- """
270
- # create a AlgorithmModel compatible dictionary
271
- if lr_scheduler_parameters is None:
272
- lr_scheduler_parameters = {}
273
- if optimizer_parameters is None:
274
- optimizer_parameters = {}
275
- if model_parameters is None:
276
- model_parameters = {}
277
- algorithm_configuration = {
278
- "algorithm": algorithm,
279
- "loss": loss,
280
- "optimizer": {
281
- "name": optimizer,
282
- "parameters": optimizer_parameters,
283
- },
284
- "lr_scheduler": {
285
- "name": lr_scheduler,
286
- "parameters": lr_scheduler_parameters,
287
- },
288
- }
289
- model_configuration = {"architecture": architecture}
290
- model_configuration.update(model_parameters)
291
-
292
- # add model parameters to algorithm configuration
293
- algorithm_configuration["model"] = model_configuration
294
-
295
- # call the parent init using an AlgorithmModel instance
296
- super().__init__(AlgorithmConfig(**algorithm_configuration))
297
-
298
- # TODO add load_from_checkpoint wrapper
244
+ Returns
245
+ -------
246
+ CAREamicsModule
247
+ CAREamics Lightning module.
248
+ """
249
+ # create a AlgorithmModel compatible dictionary
250
+ if lr_scheduler_parameters is None:
251
+ lr_scheduler_parameters = {}
252
+ if optimizer_parameters is None:
253
+ optimizer_parameters = {}
254
+ if model_parameters is None:
255
+ model_parameters = {}
256
+ algorithm_configuration = {
257
+ "algorithm": algorithm,
258
+ "loss": loss,
259
+ "optimizer": {
260
+ "name": optimizer,
261
+ "parameters": optimizer_parameters,
262
+ },
263
+ "lr_scheduler": {
264
+ "name": lr_scheduler,
265
+ "parameters": lr_scheduler_parameters,
266
+ },
267
+ }
268
+ model_configuration = {"architecture": architecture}
269
+ model_configuration.update(model_parameters)
270
+
271
+ # add model parameters to algorithm configuration
272
+ algorithm_configuration["model"] = model_configuration
273
+
274
+ # call the parent init using an AlgorithmModel instance
275
+ return CAREamicsModule(AlgorithmConfig(**algorithm_configuration))