careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +164 -231
  4. careamics/config/algorithm_model.py +5 -18
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +11 -4
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -5
  12. careamics/config/configuration_factory.py +27 -41
  13. careamics/config/configuration_model.py +11 -11
  14. careamics/config/data_model.py +89 -63
  15. careamics/config/inference_model.py +28 -81
  16. careamics/config/optimizer_models.py +11 -11
  17. careamics/config/support/__init__.py +0 -2
  18. careamics/config/support/supported_activations.py +2 -0
  19. careamics/config/support/supported_algorithms.py +3 -1
  20. careamics/config/support/supported_architectures.py +2 -0
  21. careamics/config/support/supported_data.py +2 -0
  22. careamics/config/support/supported_loggers.py +2 -0
  23. careamics/config/support/supported_losses.py +2 -0
  24. careamics/config/support/supported_optimizers.py +2 -0
  25. careamics/config/support/supported_pixel_manipulations.py +3 -3
  26. careamics/config/support/supported_struct_axis.py +2 -0
  27. careamics/config/support/supported_transforms.py +4 -16
  28. careamics/config/tile_information.py +28 -58
  29. careamics/config/transformations/__init__.py +3 -2
  30. careamics/config/transformations/normalize_model.py +32 -4
  31. careamics/config/transformations/xy_flip_model.py +43 -0
  32. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  33. careamics/config/validators/validator_utils.py +1 -1
  34. careamics/conftest.py +12 -0
  35. careamics/dataset/__init__.py +12 -1
  36. careamics/dataset/dataset_utils/__init__.py +8 -1
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  38. careamics/dataset/dataset_utils/file_utils.py +4 -3
  39. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  40. careamics/dataset/dataset_utils/read_tiff.py +6 -11
  41. careamics/dataset/dataset_utils/read_utils.py +2 -0
  42. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  43. careamics/dataset/dataset_utils/running_stats.py +186 -0
  44. careamics/dataset/in_memory_dataset.py +88 -154
  45. careamics/dataset/in_memory_pred_dataset.py +88 -0
  46. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  47. careamics/dataset/iterable_dataset.py +121 -191
  48. careamics/dataset/iterable_pred_dataset.py +121 -0
  49. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  50. careamics/dataset/patching/patching.py +109 -39
  51. careamics/dataset/patching/random_patching.py +17 -6
  52. careamics/dataset/patching/sequential_patching.py +14 -8
  53. careamics/dataset/patching/validate_patch_dimension.py +7 -3
  54. careamics/dataset/tiling/__init__.py +10 -0
  55. careamics/dataset/tiling/collate_tiles.py +33 -0
  56. careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
  57. careamics/dataset/zarr_dataset.py +2 -0
  58. careamics/lightning_datamodule.py +46 -25
  59. careamics/lightning_module.py +19 -9
  60. careamics/lightning_prediction_datamodule.py +54 -84
  61. careamics/losses/__init__.py +2 -3
  62. careamics/losses/loss_factory.py +1 -1
  63. careamics/losses/losses.py +11 -7
  64. careamics/lvae_training/__init__.py +0 -0
  65. careamics/lvae_training/data_modules.py +1220 -0
  66. careamics/lvae_training/data_utils.py +618 -0
  67. careamics/lvae_training/eval_utils.py +905 -0
  68. careamics/lvae_training/get_config.py +84 -0
  69. careamics/lvae_training/lightning_module.py +701 -0
  70. careamics/lvae_training/metrics.py +214 -0
  71. careamics/lvae_training/train_lvae.py +339 -0
  72. careamics/lvae_training/train_utils.py +121 -0
  73. careamics/model_io/bioimage/model_description.py +40 -32
  74. careamics/model_io/bmz_io.py +3 -3
  75. careamics/model_io/model_io_utils.py +5 -2
  76. careamics/models/activation.py +2 -0
  77. careamics/models/layers.py +121 -25
  78. careamics/models/lvae/__init__.py +0 -0
  79. careamics/models/lvae/layers.py +1998 -0
  80. careamics/models/lvae/likelihoods.py +312 -0
  81. careamics/models/lvae/lvae.py +985 -0
  82. careamics/models/lvae/noise_models.py +409 -0
  83. careamics/models/lvae/utils.py +395 -0
  84. careamics/models/model_factory.py +1 -1
  85. careamics/models/unet.py +35 -14
  86. careamics/prediction_utils/__init__.py +12 -0
  87. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  88. careamics/prediction_utils/prediction_outputs.py +165 -0
  89. careamics/prediction_utils/stitch_prediction.py +100 -0
  90. careamics/transforms/__init__.py +2 -2
  91. careamics/transforms/compose.py +33 -7
  92. careamics/transforms/n2v_manipulate.py +52 -14
  93. careamics/transforms/normalize.py +171 -48
  94. careamics/transforms/pixel_manipulation.py +35 -11
  95. careamics/transforms/struct_mask_parameters.py +3 -1
  96. careamics/transforms/transform.py +10 -19
  97. careamics/transforms/tta.py +43 -29
  98. careamics/transforms/xy_flip.py +123 -0
  99. careamics/transforms/xy_random_rotate90.py +38 -5
  100. careamics/utils/base_enum.py +28 -0
  101. careamics/utils/path_utils.py +2 -0
  102. careamics/utils/ram.py +4 -2
  103. careamics/utils/receptive_field.py +93 -87
  104. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
  105. careamics-0.1.0rc7.dist-info/RECORD +130 -0
  106. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  107. careamics/config/noise_models.py +0 -162
  108. careamics/config/support/supported_extraction_strategies.py +0 -25
  109. careamics/config/transformations/nd_flip_model.py +0 -27
  110. careamics/lightning_prediction_loop.py +0 -116
  111. careamics/losses/noise_model_factory.py +0 -40
  112. careamics/losses/noise_models.py +0 -524
  113. careamics/prediction/__init__.py +0 -7
  114. careamics/prediction/stitch_prediction.py +0 -74
  115. careamics/transforms/nd_flip.py +0 -67
  116. careamics/utils/running_stats.py +0 -43
  117. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  118. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -95,13 +95,13 @@ class CAREamicsTrainData(L.LightningDataModule):
95
95
  Batch size.
96
96
  use_in_memory : bool
97
97
  Whether to use in memory dataset if possible.
98
- train_data : Union[Path, str, np.ndarray]
98
+ train_data : Union[Path, np.ndarray]
99
99
  Training data.
100
- val_data : Optional[Union[Path, str, np.ndarray]]
100
+ val_data : Optional[Union[Path, np.ndarray]]
101
101
  Validation data.
102
- train_data_target : Optional[Union[Path, str, np.ndarray]]
102
+ train_data_target : Optional[Union[Path, np.ndarray]]
103
103
  Training target data.
104
- val_data_target : Optional[Union[Path, str, np.ndarray]]
104
+ val_data_target : Optional[Union[Path, np.ndarray]]
105
105
  Validation target data.
106
106
  val_percentage : float
107
107
  Percentage of the training data to use for validation, if no validation data is
@@ -217,17 +217,33 @@ class CAREamicsTrainData(L.LightningDataModule):
217
217
  )
218
218
 
219
219
  # configuration
220
- self.data_config = data_config
221
- self.data_type = data_config.data_type
222
- self.batch_size = data_config.batch_size
223
- self.use_in_memory = use_in_memory
220
+ self.data_config: DataConfig = data_config
221
+ self.data_type: str = data_config.data_type
222
+ self.batch_size: int = data_config.batch_size
223
+ self.use_in_memory: bool = use_in_memory
224
+
225
+ # data: make data Path or np.ndarray, use type annotations for mypy
226
+ self.train_data: Union[Path, np.ndarray] = (
227
+ Path(train_data) if isinstance(train_data, str) else train_data
228
+ )
229
+
230
+ self.val_data: Union[Path, np.ndarray] = (
231
+ Path(val_data) if isinstance(val_data, str) else val_data
232
+ )
224
233
 
225
- # data
226
- self.train_data = train_data
227
- self.val_data = val_data
234
+ self.train_data_target: Union[Path, np.ndarray] = (
235
+ Path(train_data_target)
236
+ if isinstance(train_data_target, str)
237
+ else train_data_target
238
+ )
228
239
 
229
- self.train_data_target = train_data_target
230
- self.val_data_target = val_data_target
240
+ self.val_data_target: Union[Path, np.ndarray] = (
241
+ Path(val_data_target)
242
+ if isinstance(val_data_target, str)
243
+ else val_data_target
244
+ )
245
+
246
+ # validation split
231
247
  self.val_percentage = val_percentage
232
248
  self.val_minimum_split = val_minimum_split
233
249
 
@@ -241,10 +257,10 @@ class CAREamicsTrainData(L.LightningDataModule):
241
257
  elif data_config.data_type != SupportedData.ARRAY:
242
258
  self.read_source_func = get_read_func(data_config.data_type)
243
259
 
244
- self.extension_filter = extension_filter
260
+ self.extension_filter: str = extension_filter
245
261
 
246
262
  # Pytorch dataloader parameters
247
- self.dataloader_params = (
263
+ self.dataloader_params: Dict[str, Any] = (
248
264
  data_config.dataloader_params if data_config.dataloader_params else {}
249
265
  )
250
266
 
@@ -309,20 +325,30 @@ class CAREamicsTrainData(L.LightningDataModule):
309
325
  """
310
326
  # if numpy array
311
327
  if self.data_type == SupportedData.ARRAY:
328
+ # mypy checks
329
+ assert isinstance(self.train_data, np.ndarray)
330
+ if self.train_data_target is not None:
331
+ assert isinstance(self.train_data_target, np.ndarray)
332
+
312
333
  # train dataset
313
334
  self.train_dataset: DatasetType = InMemoryDataset(
314
335
  data_config=self.data_config,
315
336
  inputs=self.train_data,
316
- data_target=self.train_data_target,
337
+ input_target=self.train_data_target,
317
338
  )
318
339
 
319
340
  # validation dataset
320
341
  if self.val_data is not None:
342
+ # mypy checks
343
+ assert isinstance(self.val_data, np.ndarray)
344
+ if self.val_data_target is not None:
345
+ assert isinstance(self.val_data_target, np.ndarray)
346
+
321
347
  # create its own dataset
322
348
  self.val_dataset: DatasetType = InMemoryDataset(
323
349
  data_config=self.data_config,
324
350
  inputs=self.val_data,
325
- data_target=self.val_data_target,
351
+ input_target=self.val_data_target,
326
352
  )
327
353
  else:
328
354
  # extract validation from the training patches
@@ -341,7 +367,7 @@ class CAREamicsTrainData(L.LightningDataModule):
341
367
  self.train_dataset = InMemoryDataset(
342
368
  data_config=self.data_config,
343
369
  inputs=self.train_files,
344
- data_target=(
370
+ input_target=(
345
371
  self.train_target_files if self.train_data_target else None
346
372
  ),
347
373
  read_source_func=self.read_source_func,
@@ -352,7 +378,7 @@ class CAREamicsTrainData(L.LightningDataModule):
352
378
  self.val_dataset = InMemoryDataset(
353
379
  data_config=self.data_config,
354
380
  inputs=self.val_files,
355
- data_target=(
381
+ input_target=(
356
382
  self.val_target_files if self.val_data_target else None
357
383
  ),
358
384
  read_source_func=self.read_source_func,
@@ -557,12 +583,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
557
583
  >>> my_array = np.arange(256).reshape(16, 16)
558
584
  >>> my_transforms = [
559
585
  ... {
560
- ... "name": SupportedTransform.NORMALIZE.value,
561
- ... "mean": 0,
562
- ... "std": 1,
563
- ... },
564
- ... {
565
- ... "name": SupportedTransform.N2V_MANIPULATE.value,
586
+ ... "name": SupportedTransform.XY_FLIP.value,
566
587
  ... }
567
588
  ... ]
568
589
  >>> data_module = TrainingDataWrapper(
@@ -1,3 +1,5 @@
1
+ """CAREamics Lightning module."""
2
+
1
3
  from typing import Any, Optional, Union
2
4
 
3
5
  import pytorch_lightning as L
@@ -24,6 +26,11 @@ class CAREamicsModule(L.LightningModule):
24
26
  This class encapsulates the a PyTorch model along with the training, validation,
25
27
  and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
26
28
 
29
+ Parameters
30
+ ----------
31
+ algorithm_config : Union[AlgorithmModel, dict]
32
+ Algorithm configuration.
33
+
27
34
  Attributes
28
35
  ----------
29
36
  model : nn.Module
@@ -39,8 +46,7 @@ class CAREamicsModule(L.LightningModule):
39
46
  """
40
47
 
41
48
  def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None:
42
- """
43
- CAREamics Lightning module.
49
+ """Lightning module for CAREamics.
44
50
 
45
51
  This class encapsulates the a PyTorch model along with the training, validation,
46
52
  and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
@@ -142,13 +148,17 @@ class CAREamicsModule(L.LightningModule):
142
148
  Any
143
149
  Model output.
144
150
  """
145
- x, *aux = batch
151
+ if self._trainer.datamodule.tiled:
152
+ x, *aux = batch
153
+ else:
154
+ x = batch
155
+ aux = []
146
156
 
147
157
  # apply test-time augmentation if available
148
158
  # TODO: probably wont work with batch size > 1
149
159
  if self._trainer.datamodule.prediction_config.tta_transforms:
150
160
  tta = ImageRestorationTTA()
151
- augmented_batch = tta.forward(batch[0]) # list of augmented tensors
161
+ augmented_batch = tta.forward(x) # list of augmented tensors
152
162
  augmented_output = []
153
163
  for augmented in augmented_batch:
154
164
  augmented_pred = self.model(augmented)
@@ -159,13 +169,13 @@ class CAREamicsModule(L.LightningModule):
159
169
 
160
170
  # Denormalize the output
161
171
  denorm = Denormalize(
162
- mean=self._trainer.datamodule.predict_dataset.mean,
163
- std=self._trainer.datamodule.predict_dataset.std,
172
+ image_means=self._trainer.datamodule.predict_dataset.image_means,
173
+ image_stds=self._trainer.datamodule.predict_dataset.image_stds,
164
174
  )
165
- denormalized_output, _ = denorm(patch=output)
175
+ denormalized_output = denorm(patch=output.cpu().numpy())
166
176
 
167
- if len(aux) > 0:
168
- return denormalized_output, aux
177
+ if len(aux) > 0: # aux can be tiling information
178
+ return denormalized_output, *aux
169
179
  else:
170
180
  return denormalized_output
171
181
 
@@ -1,68 +1,37 @@
1
1
  """Prediction Lightning data modules."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Any, Callable, List, Literal, Optional, Tuple, Union
4
+ from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union
5
5
 
6
6
  import numpy as np
7
7
  import pytorch_lightning as L
8
8
  from torch.utils.data import DataLoader
9
- from torch.utils.data.dataloader import default_collate
10
9
 
11
10
  from careamics.config import InferenceConfig
12
11
  from careamics.config.support import SupportedData
13
- from careamics.config.tile_information import TileInformation
12
+ from careamics.dataset import (
13
+ InMemoryPredDataset,
14
+ InMemoryTiledPredDataset,
15
+ IterablePredDataset,
16
+ IterableTiledPredDataset,
17
+ )
14
18
  from careamics.dataset.dataset_utils import (
15
19
  get_read_func,
16
20
  list_files,
17
21
  )
18
- from careamics.dataset.in_memory_dataset import (
19
- InMemoryPredictionDataset,
20
- )
21
- from careamics.dataset.iterable_dataset import (
22
- IterablePredictionDataset,
23
- )
22
+ from careamics.dataset.tiling.collate_tiles import collate_tiles
24
23
  from careamics.utils import get_logger
25
24
 
26
- PredictDatasetType = Union[InMemoryPredictionDataset, IterablePredictionDataset]
25
+ PredictDatasetType = Union[
26
+ InMemoryPredDataset,
27
+ InMemoryTiledPredDataset,
28
+ IterablePredDataset,
29
+ IterableTiledPredDataset,
30
+ ]
27
31
 
28
32
  logger = get_logger(__name__)
29
33
 
30
34
 
31
- def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
32
- """
33
- Collate tiles received from CAREamics prediction dataloader.
34
-
35
- CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
36
- case of non-tiled data, this function will return the arrays. In case of tiled data,
37
- it will return the arrays, the last tile flag, the overlap crop coordinates and the
38
- stitch coordinates.
39
-
40
- Parameters
41
- ----------
42
- batch : List[Tuple[np.ndarray, TileInformation], ...]
43
- Batch of tiles.
44
-
45
- Returns
46
- -------
47
- Any
48
- Collated batch.
49
- """
50
- first_tile_info: TileInformation = batch[0][1]
51
- # if not tiled, then return arrays
52
- if not first_tile_info.tiled:
53
- arrays, _ = zip(*batch)
54
-
55
- return default_collate(arrays)
56
- # else we explicit the last_tile flag and coordinates
57
- else:
58
- new_batch = [
59
- (tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords)
60
- for tile, t in batch
61
- ]
62
-
63
- return default_collate(new_batch)
64
-
65
-
66
35
  class CAREamicsPredictData(L.LightningDataModule):
67
36
  """
68
37
  CAREamics Lightning prediction data module.
@@ -182,6 +151,9 @@ class CAREamicsPredictData(L.LightningDataModule):
182
151
  self.tile_size = pred_config.tile_size
183
152
  self.tile_overlap = pred_config.tile_overlap
184
153
 
154
+ # check if it is tiled
155
+ self.tiled = self.tile_size is not None and self.tile_overlap is not None
156
+
185
157
  # read source function
186
158
  if pred_config.data_type == SupportedData.CUSTOM:
187
159
  # mypy check
@@ -212,17 +184,29 @@ class CAREamicsPredictData(L.LightningDataModule):
212
184
  """
213
185
  # if numpy array
214
186
  if self.data_type == SupportedData.ARRAY:
215
- # prediction dataset
216
- self.predict_dataset: PredictDatasetType = InMemoryPredictionDataset(
217
- prediction_config=self.prediction_config,
218
- inputs=self.pred_data,
219
- )
187
+ if self.tiled:
188
+ self.predict_dataset: PredictDatasetType = InMemoryTiledPredDataset(
189
+ prediction_config=self.prediction_config,
190
+ inputs=self.pred_data,
191
+ )
192
+ else:
193
+ self.predict_dataset = InMemoryPredDataset(
194
+ prediction_config=self.prediction_config,
195
+ inputs=self.pred_data,
196
+ )
220
197
  else:
221
- self.predict_dataset = IterablePredictionDataset(
222
- prediction_config=self.prediction_config,
223
- src_files=self.pred_files,
224
- read_source_func=self.read_source_func,
225
- )
198
+ if self.tiled:
199
+ self.predict_dataset = IterableTiledPredDataset(
200
+ prediction_config=self.prediction_config,
201
+ src_files=self.pred_files,
202
+ read_source_func=self.read_source_func,
203
+ )
204
+ else:
205
+ self.predict_dataset = IterablePredDataset(
206
+ prediction_config=self.prediction_config,
207
+ src_files=self.pred_files,
208
+ read_source_func=self.read_source_func,
209
+ )
226
210
 
227
211
  def predict_dataloader(self) -> DataLoader:
228
212
  """
@@ -236,7 +220,7 @@ class CAREamicsPredictData(L.LightningDataModule):
236
220
  return DataLoader(
237
221
  self.predict_dataset,
238
222
  batch_size=self.batch_size,
239
- collate_fn=_collate_tiles,
223
+ collate_fn=collate_tiles if self.tiled else None,
240
224
  **self.dataloader_params,
241
225
  ) # TODO check workers are used
242
226
 
@@ -287,12 +271,10 @@ class PredictDataWrapper(CAREamicsPredictData):
287
271
  Prediction data.
288
272
  data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
289
273
  Data type, see `SupportedData` for available options.
290
- mean : float
291
- Mean value for normalization, only used if Normalization is defined in the
292
- transforms.
293
- std : float
294
- Standard deviation value for normalization, only used if Normalization is
295
- defined in the transform.
274
+ image_means : list of float
275
+ Mean values for normalization, only used if Normalization is defined.
276
+ image_stds : list of float
277
+ Std values for normalization, only used if Normalization is defined.
296
278
  tile_size : Tuple[int, ...]
297
279
  Tile size, 2D or 3D tile size.
298
280
  tile_overlap : Tuple[int, ...]
@@ -303,9 +285,6 @@ class PredictDataWrapper(CAREamicsPredictData):
303
285
  Batch size.
304
286
  tta_transforms : bool, optional
305
287
  Use test time augmentation, by default True.
306
- transforms : List, optional
307
- List of transforms to apply to prediction patches. If None, default
308
- transforms are applied.
309
288
  read_source_func : Optional[Callable], optional
310
289
  Function to read the source data, used if `data_type` is `custom`, by
311
290
  default None.
@@ -319,14 +298,13 @@ class PredictDataWrapper(CAREamicsPredictData):
319
298
  self,
320
299
  pred_data: Union[str, Path, np.ndarray],
321
300
  data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
322
- mean: float,
323
- std: float,
301
+ image_means=list[float],
302
+ image_stds=list[float],
324
303
  tile_size: Optional[Tuple[int, ...]] = None,
325
304
  tile_overlap: Optional[Tuple[int, ...]] = None,
326
305
  axes: str = "YX",
327
306
  batch_size: int = 1,
328
307
  tta_transforms: bool = True,
329
- transforms: Optional[List] = None,
330
308
  read_source_func: Optional[Callable] = None,
331
309
  extension_filter: str = "",
332
310
  dataloader_params: Optional[dict] = None,
@@ -340,12 +318,10 @@ class PredictDataWrapper(CAREamicsPredictData):
340
318
  Prediction data.
341
319
  data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
342
320
  Data type, see `SupportedData` for available options.
343
- mean : float
344
- Mean value for normalization, only used if Normalization is defined in the
345
- transforms.
346
- std : float
347
- Standard deviation value for normalization, only used if Normalization is
348
- defined in the transform.
321
+ image_means : list of float
322
+ Mean values for normalization, only used if Normalization is defined.
323
+ image_stds : list of float
324
+ Std values for normalization, only used if Normalization is defined.
349
325
  tile_size : List[int]
350
326
  Tile size, 2D or 3D tile size.
351
327
  tile_overlap : List[int]
@@ -356,9 +332,6 @@ class PredictDataWrapper(CAREamicsPredictData):
356
332
  Batch size.
357
333
  tta_transforms : bool, optional
358
334
  Use test time augmentation, by default True.
359
- transforms : Optional[List], optional
360
- List of transforms to apply to prediction patches. If None, default
361
- transforms are applied.
362
335
  read_source_func : Optional[Callable], optional
363
336
  Function to read the source data, used if `data_type` is `custom`, by
364
337
  default None.
@@ -369,21 +342,18 @@ class PredictDataWrapper(CAREamicsPredictData):
369
342
  """
370
343
  if dataloader_params is None:
371
344
  dataloader_params = {}
372
- prediction_dict = {
345
+ prediction_dict: Dict[str, Any] = {
373
346
  "data_type": data_type,
374
347
  "tile_size": tile_size,
375
348
  "tile_overlap": tile_overlap,
376
349
  "axes": axes,
377
- "mean": mean,
378
- "std": std,
350
+ "image_means": image_means,
351
+ "image_stds": image_stds,
379
352
  "tta": tta_transforms,
380
353
  "batch_size": batch_size,
354
+ "transforms": [],
381
355
  }
382
356
 
383
- # if transforms are passed (otherwise it will use the default ones)
384
- if transforms is not None:
385
- prediction_dict["transforms"] = transforms
386
-
387
357
  # validate configuration
388
358
  self.prediction_config = InferenceConfig(**prediction_dict)
389
359
 
@@ -1,6 +1,5 @@
1
1
  """Losses module."""
2
2
 
3
- from .loss_factory import loss_factory
3
+ __all__ = ["loss_factory"]
4
4
 
5
- # from .noise_model_factory import noise_model_factory as noise_model_factory
6
- # from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel
5
+ from .loss_factory import loss_factory
@@ -17,7 +17,7 @@ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
17
17
 
18
18
  Parameters
19
19
  ----------
20
- loss: SupportedLoss
20
+ loss : Union[SupportedLoss, str]
21
21
  Requested loss.
22
22
 
23
23
  Returns
@@ -5,23 +5,27 @@ This submodule contains the various losses used in CAREamics.
5
5
  """
6
6
 
7
7
  import torch
8
-
9
- # TODO if we are only using the DiceLoss, can we just implement it?
10
- # from segmentation_models_pytorch.losses import DiceLoss
11
8
  from torch.nn import L1Loss, MSELoss
12
9
 
13
10
 
14
- def mse_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
11
+ def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
15
12
  """
16
13
  Mean squared error loss.
17
14
 
15
+ Parameters
16
+ ----------
17
+ source : torch.Tensor
18
+ Source patches.
19
+ target : torch.Tensor
20
+ Target patches.
21
+
18
22
  Returns
19
23
  -------
20
24
  torch.Tensor
21
25
  Loss value.
22
26
  """
23
27
  loss = MSELoss()
24
- return loss(samples, labels)
28
+ return loss(source, target)
25
29
 
26
30
 
27
31
  def n2v_loss(
@@ -34,9 +38,9 @@ def n2v_loss(
34
38
 
35
39
  Parameters
36
40
  ----------
37
- samples : torch.Tensor
41
+ manipulated_patches : torch.Tensor
38
42
  Patches with manipulated pixels.
39
- labels : torch.Tensor
43
+ original_patches : torch.Tensor
40
44
  Noisy patches.
41
45
  masks : torch.Tensor
42
46
  Array containing masked pixel locations.
File without changes