careamics 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (92) hide show
  1. careamics/careamist.py +55 -61
  2. careamics/cli/conf.py +24 -9
  3. careamics/cli/main.py +8 -8
  4. careamics/cli/utils.py +2 -4
  5. careamics/config/__init__.py +8 -0
  6. careamics/config/algorithms/__init__.py +4 -0
  7. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  8. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  9. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  10. careamics/config/algorithms/vae_algorithm_model.py +53 -18
  11. careamics/config/architectures/lvae_model.py +12 -8
  12. careamics/config/callback_model.py +15 -11
  13. careamics/config/configuration.py +9 -8
  14. careamics/config/configuration_factories.py +892 -78
  15. careamics/config/data/data_model.py +7 -14
  16. careamics/config/data/ng_data_model.py +8 -15
  17. careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
  18. careamics/config/inference_model.py +6 -11
  19. careamics/config/likelihood_model.py +4 -4
  20. careamics/config/loss_model.py +6 -2
  21. careamics/config/nm_model.py +30 -7
  22. careamics/config/optimizer_models.py +1 -2
  23. careamics/config/support/supported_algorithms.py +5 -3
  24. careamics/config/support/supported_losses.py +5 -2
  25. careamics/config/training_model.py +8 -38
  26. careamics/config/transformations/normalize_model.py +3 -4
  27. careamics/config/transformations/xy_flip_model.py +2 -2
  28. careamics/config/transformations/xy_random_rotate90_model.py +2 -2
  29. careamics/config/validators/validator_utils.py +1 -2
  30. careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
  31. careamics/dataset/in_memory_dataset.py +2 -2
  32. careamics/dataset/iterable_dataset.py +1 -2
  33. careamics/dataset/patching/random_patching.py +6 -6
  34. careamics/dataset/patching/sequential_patching.py +4 -4
  35. careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
  36. careamics/dataset_ng/dataset.py +3 -3
  37. careamics/dataset_ng/factory.py +19 -19
  38. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  39. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  40. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  41. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  42. careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
  43. careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
  44. careamics/file_io/read/__init__.py +0 -1
  45. careamics/lightning/__init__.py +16 -2
  46. careamics/lightning/callbacks/__init__.py +2 -0
  47. careamics/lightning/callbacks/data_stats_callback.py +23 -0
  48. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
  49. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
  50. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
  51. careamics/lightning/dataset_ng/data_module.py +43 -43
  52. careamics/lightning/lightning_module.py +166 -68
  53. careamics/lightning/microsplit_data_module.py +631 -0
  54. careamics/lightning/predict_data_module.py +16 -9
  55. careamics/lightning/train_data_module.py +29 -18
  56. careamics/losses/__init__.py +7 -1
  57. careamics/losses/loss_factory.py +9 -1
  58. careamics/losses/lvae/losses.py +94 -9
  59. careamics/lvae_training/dataset/__init__.py +8 -8
  60. careamics/lvae_training/dataset/config.py +56 -44
  61. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  62. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  63. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  64. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  65. careamics/model_io/bioimage/model_description.py +12 -11
  66. careamics/model_io/bmz_io.py +12 -8
  67. careamics/models/layers.py +5 -5
  68. careamics/models/lvae/likelihoods.py +30 -14
  69. careamics/models/lvae/lvae.py +2 -2
  70. careamics/models/lvae/noise_models.py +20 -14
  71. careamics/prediction_utils/__init__.py +8 -2
  72. careamics/prediction_utils/lvae_prediction.py +5 -5
  73. careamics/prediction_utils/prediction_outputs.py +48 -3
  74. careamics/prediction_utils/stitch_prediction.py +71 -0
  75. careamics/transforms/compose.py +9 -9
  76. careamics/transforms/n2v_manipulate.py +3 -3
  77. careamics/transforms/n2v_manipulate_torch.py +4 -4
  78. careamics/transforms/normalize.py +4 -6
  79. careamics/transforms/pixel_manipulation.py +6 -8
  80. careamics/transforms/pixel_manipulation_torch.py +5 -7
  81. careamics/transforms/xy_flip.py +3 -5
  82. careamics/transforms/xy_random_rotate90.py +4 -6
  83. careamics/utils/logging.py +8 -8
  84. careamics/utils/metrics.py +2 -2
  85. careamics/utils/plotting.py +1 -3
  86. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -16
  87. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/RECORD +90 -88
  88. careamics/dataset/zarr_dataset.py +0 -151
  89. careamics/file_io/read/zarr.py +0 -60
  90. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
  91. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
  92. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  from collections.abc import Callable
4
4
  from pathlib import Path
5
- from typing import Any, Literal, Optional, Union
5
+ from typing import Any, Literal, Union
6
6
 
7
7
  import numpy as np
8
8
  import pytorch_lightning as L
@@ -121,10 +121,10 @@ class TrainDataModule(L.LightningDataModule):
121
121
  self,
122
122
  data_config: DataConfig,
123
123
  train_data: Union[Path, str, NDArray],
124
- val_data: Optional[Union[Path, str, NDArray]] = None,
125
- train_data_target: Optional[Union[Path, str, NDArray]] = None,
126
- val_data_target: Optional[Union[Path, str, NDArray]] = None,
127
- read_source_func: Optional[Callable] = None,
124
+ val_data: Union[Path, str, NDArray] | None = None,
125
+ train_data_target: Union[Path, str, NDArray] | None = None,
126
+ val_data_target: Union[Path, str, NDArray] | None = None,
127
+ read_source_func: Callable | None = None,
128
128
  extension_filter: str = "",
129
129
  val_percentage: float = 0.1,
130
130
  val_minimum_split: int = 5,
@@ -477,15 +477,16 @@ def create_train_datamodule(
477
477
  patch_size: list[int],
478
478
  axes: str,
479
479
  batch_size: int,
480
- val_data: Optional[Union[str, Path, NDArray]] = None,
481
- transforms: Optional[list[TransformModel]] = None,
482
- train_target_data: Optional[Union[str, Path, NDArray]] = None,
483
- val_target_data: Optional[Union[str, Path, NDArray]] = None,
484
- read_source_func: Optional[Callable] = None,
480
+ val_data: Union[str, Path, NDArray] | None = None,
481
+ transforms: list[TransformModel] | None = None,
482
+ train_target_data: Union[str, Path, NDArray] | None = None,
483
+ val_target_data: Union[str, Path, NDArray] | None = None,
484
+ read_source_func: Callable | None = None,
485
485
  extension_filter: str = "",
486
486
  val_percentage: float = 0.1,
487
487
  val_minimum_patches: int = 5,
488
- dataloader_params: Optional[dict] = None,
488
+ train_dataloader_params: dict | None = None,
489
+ val_dataloader_params: dict | None = None,
489
490
  use_in_memory: bool = True,
490
491
  ) -> TrainDataModule:
491
492
  """Create a TrainDataModule.
@@ -556,8 +557,10 @@ def create_train_datamodule(
556
557
  val_minimum_patches : int, optional
557
558
  Minimum number of patches to split from the training data for validation if
558
559
  no validation data is given, by default 5.
559
- dataloader_params : dict, optional
560
- Pytorch dataloader parameters, by default {}.
560
+ train_dataloader_params : dict, optional
561
+ Pytorch dataloader parameters for the training data, by default {}.
562
+ val_dataloader_params : dict, optional
563
+ Pytorch dataloader parameters for the validation data, by default {}.
561
564
  use_in_memory : bool, optional
562
565
  Use in memory dataset if possible, by default True.
563
566
 
@@ -617,8 +620,11 @@ def create_train_datamodule(
617
620
  ... transforms=my_transforms,
618
621
  ... )
619
622
  """
620
- if dataloader_params is None:
621
- dataloader_params = {}
623
+ if train_dataloader_params is None:
624
+ train_dataloader_params = {"shuffle": True}
625
+
626
+ if val_dataloader_params is None:
627
+ val_dataloader_params = {"shuffle": False}
622
628
 
623
629
  data_dict: dict[str, Any] = {
624
630
  "mode": "train",
@@ -626,7 +632,8 @@ def create_train_datamodule(
626
632
  "patch_size": patch_size,
627
633
  "axes": axes,
628
634
  "batch_size": batch_size,
629
- "dataloader_params": dataloader_params,
635
+ "train_dataloader_params": train_dataloader_params,
636
+ "val_dataloader_params": val_dataloader_params,
630
637
  }
631
638
 
632
639
  # if transforms are passed (otherwise it will use the default ones)
@@ -637,9 +644,13 @@ def create_train_datamodule(
637
644
  data_config = DataConfig(**data_dict)
638
645
 
639
646
  # sanity check on the dataloader parameters
640
- if "batch_size" in dataloader_params:
647
+ if "batch_size" in train_dataloader_params:
648
+ # remove it
649
+ del train_dataloader_params["batch_size"]
650
+
651
+ if "batch_size" in val_dataloader_params:
641
652
  # remove it
642
- del dataloader_params["batch_size"]
653
+ del val_dataloader_params["batch_size"]
643
654
 
644
655
  return TrainDataModule(
645
656
  data_config=data_config,
@@ -3,6 +3,7 @@
3
3
  __all__ = [
4
4
  "denoisplit_loss",
5
5
  "denoisplit_musplit_loss",
6
+ "hdn_loss",
6
7
  "loss_factory",
7
8
  "mae_loss",
8
9
  "mse_loss",
@@ -12,4 +13,9 @@ __all__ = [
12
13
 
13
14
  from .fcn.losses import mae_loss, mse_loss, n2v_loss
14
15
  from .loss_factory import loss_factory
15
- from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
16
+ from .lvae.losses import (
17
+ denoisplit_loss,
18
+ denoisplit_musplit_loss,
19
+ hdn_loss,
20
+ musplit_loss,
21
+ )
@@ -14,7 +14,12 @@ from torch import Tensor as tensor
14
14
 
15
15
  from ..config.support import SupportedLoss
16
16
  from .fcn.losses import mae_loss, mse_loss, n2v_loss
17
- from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
17
+ from .lvae.losses import (
18
+ denoisplit_loss,
19
+ denoisplit_musplit_loss,
20
+ hdn_loss,
21
+ musplit_loss,
22
+ )
18
23
 
19
24
 
20
25
  @dataclass
@@ -59,6 +64,9 @@ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
59
64
  elif loss == SupportedLoss.MSE:
60
65
  return mse_loss
61
66
 
67
+ elif loss == SupportedLoss.HDN:
68
+ return hdn_loss
69
+
62
70
  elif loss == SupportedLoss.MUSPLIT:
63
71
  return musplit_loss
64
72
 
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Any, Literal, Optional, Union
5
+ from typing import TYPE_CHECKING, Any, Literal, Union
6
6
 
7
7
  import numpy as np
8
8
  import torch
@@ -89,6 +89,7 @@ def _reconstruction_loss_musplit_denoisplit(
89
89
  if predictions.shape[1] == 2 * targets.shape[1]:
90
90
  # predictions contain both mean and log-variance
91
91
  pred_mean, _ = predictions.chunk(2, dim=1)
92
+ # TODO if this condition does not hold, everything breaks later!
92
93
  else:
93
94
  pred_mean = predictions
94
95
 
@@ -112,7 +113,7 @@ def get_kl_divergence_loss(
112
113
  rescaling: Literal["latent_dim", "image_dim"],
113
114
  aggregation: Literal["mean", "sum"],
114
115
  free_bits_coeff: float,
115
- img_shape: Optional[tuple[int]] = None,
116
+ img_shape: tuple[int] | None = None,
116
117
  ) -> torch.Tensor:
117
118
  """Compute the KL divergence loss.
118
119
 
@@ -269,13 +270,97 @@ def _get_kl_divergence_loss_denoisplit(
269
270
  # - `__init__` method initializes the loss parameters now contained in
270
271
  # the `LVAELossParameters` class
271
272
  # NOTE: same for the other loss functions
273
+
274
+
275
+ def hdn_loss(
276
+ model_outputs: tuple[torch.Tensor, dict[str, Any]],
277
+ targets: torch.Tensor,
278
+ config: LVAELossConfig,
279
+ gaussian_likelihood: GaussianLikelihood | None,
280
+ noise_model_likelihood: NoiseModelLikelihood | None,
281
+ ) -> dict[str, torch.Tensor] | None:
282
+ """Loss function for HDN.
283
+
284
+ Parameters
285
+ ----------
286
+ model_outputs : tuple[torch.Tensor, dict[str, Any]]
287
+ Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
288
+ and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
289
+ targets : torch.Tensor
290
+ The target image used to compute the reconstruction loss. In this case we use
291
+ the input patch itself as target. Shape is (B, `target_ch`, [Z], Y, X).
292
+ config : LVAELossConfig
293
+ The config for loss function containing all loss hyperparameters.
294
+ gaussian_likelihood : GaussianLikelihood
295
+ The Gaussian likelihood object.
296
+ noise_model_likelihood : NoiseModelLikelihood
297
+ The noise model likelihood object.
298
+
299
+ Returns
300
+ -------
301
+ output : Optional[dict[str, torch.Tensor]]
302
+ A dictionary containing the overall loss `["loss"]`, the reconstruction loss
303
+ `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
304
+ """
305
+ if gaussian_likelihood is not None:
306
+ likelihood = gaussian_likelihood
307
+ elif noise_model_likelihood is not None:
308
+ likelihood = noise_model_likelihood
309
+ else:
310
+ raise ValueError("Invalid likelihood object.")
311
+ # TODO refactor loss signature
312
+ predictions, td_data = model_outputs
313
+
314
+ # Reconstruction loss computation
315
+ recons_loss = config.reconstruction_weight * get_reconstruction_loss(
316
+ reconstruction=predictions,
317
+ target=targets,
318
+ likelihood_obj=likelihood,
319
+ )
320
+ if torch.isnan(recons_loss).any():
321
+ recons_loss = 0.0
322
+
323
+ # KL loss computation
324
+ kl_weight = get_kl_weight(
325
+ config.kl_params.annealing,
326
+ config.kl_params.start,
327
+ config.kl_params.annealtime,
328
+ config.kl_weight,
329
+ config.kl_params.current_epoch,
330
+ )
331
+ kl_loss = (
332
+ _get_kl_divergence_loss_denoisplit(
333
+ topdown_data=td_data,
334
+ img_shape=targets.shape[2:],
335
+ kl_type=config.kl_params.loss_type,
336
+ )
337
+ * kl_weight
338
+ )
339
+
340
+ net_loss = recons_loss + kl_loss # TODO add check that losses coefs sum to 1
341
+ output = {
342
+ "loss": net_loss,
343
+ "reconstruction_loss": (
344
+ recons_loss.detach()
345
+ if isinstance(recons_loss, torch.Tensor)
346
+ else recons_loss
347
+ ),
348
+ "kl_loss": kl_loss.detach(),
349
+ }
350
+ # https://github.com/openai/vdvae/blob/main/train.py#L26
351
+ if torch.isnan(net_loss).any():
352
+ return None
353
+
354
+ return output
355
+
356
+
272
357
  def musplit_loss(
273
358
  model_outputs: tuple[torch.Tensor, dict[str, Any]],
274
359
  targets: torch.Tensor,
275
360
  config: LVAELossConfig,
276
- gaussian_likelihood: Optional[GaussianLikelihood],
277
- noise_model_likelihood: Optional[NoiseModelLikelihood] = None, # TODO: ugly
278
- ) -> Optional[dict[str, torch.Tensor]]:
361
+ gaussian_likelihood: GaussianLikelihood | None,
362
+ noise_model_likelihood: NoiseModelLikelihood | None = None, # TODO: ugly
363
+ ) -> dict[str, torch.Tensor] | None:
279
364
  """Loss function for muSplit.
280
365
 
281
366
  Parameters
@@ -351,9 +436,9 @@ def denoisplit_loss(
351
436
  model_outputs: tuple[torch.Tensor, dict[str, Any]],
352
437
  targets: torch.Tensor,
353
438
  config: LVAELossConfig,
354
- gaussian_likelihood: Optional[GaussianLikelihood] = None,
355
- noise_model_likelihood: Optional[NoiseModelLikelihood] = None,
356
- ) -> Optional[dict[str, torch.Tensor]]:
439
+ gaussian_likelihood: GaussianLikelihood | None = None,
440
+ noise_model_likelihood: NoiseModelLikelihood | None = None,
441
+ ) -> dict[str, torch.Tensor] | None:
357
442
  """Loss function for DenoiSplit.
358
443
 
359
444
  Parameters
@@ -430,7 +515,7 @@ def denoisplit_musplit_loss(
430
515
  config: LVAELossConfig,
431
516
  gaussian_likelihood: GaussianLikelihood,
432
517
  noise_model_likelihood: NoiseModelLikelihood,
433
- ) -> Optional[dict[str, torch.Tensor]]:
518
+ ) -> dict[str, torch.Tensor] | None:
434
519
  """Loss function for DenoiSplit.
435
520
 
436
521
  Parameters
@@ -1,4 +1,4 @@
1
- from .config import DatasetConfig
1
+ from .config import MicroSplitDataConfig
2
2
  from .lc_dataset import LCMultiChDloader
3
3
  from .ms_dataset_ref import MultiChDloaderRef
4
4
  from .multich_dataset import MultiChDloader
@@ -7,14 +7,14 @@ from .multifile_dataset import MultiFileDset
7
7
  from .types import DataSplitType, DataType, TilingMode
8
8
 
9
9
  __all__ = [
10
- "DatasetConfig",
11
- "MultiChDloader",
10
+ "DataSplitType",
11
+ "DataType",
12
12
  "LCMultiChDloader",
13
- "MultiFileDset",
14
- "MultiCropDset",
15
- "MultiChDloaderRef",
16
13
  "LCMultiChDloaderRef",
17
- "DataType",
18
- "DataSplitType",
14
+ "MicroSplitDataConfig",
15
+ "MultiChDloader",
16
+ "MultiChDloaderRef",
17
+ "MultiCropDset",
18
+ "MultiFileDset",
19
19
  "TilingMode",
20
20
  ]
@@ -1,4 +1,4 @@
1
- from typing import Any, Optional, Union
1
+ from typing import Any, Union
2
2
 
3
3
  from pydantic import BaseModel, ConfigDict
4
4
 
@@ -6,70 +6,70 @@ from .types import DataSplitType, DataType, TilingMode
6
6
 
7
7
 
8
8
  # TODO: check if any bool logic can be removed
9
- class DatasetConfig(BaseModel):
10
- model_config = ConfigDict(validate_assignment=True, extra="forbid")
9
+ class MicroSplitDataConfig(BaseModel):
10
+ model_config = ConfigDict(validate_assignment=True, extra="allow")
11
11
 
12
- data_type: Optional[DataType]
12
+ data_type: Union[DataType, str] | None # TODO remove or refactor!!
13
13
  """Type of the dataset, should be one of DataType"""
14
14
 
15
- depth3D: Optional[int] = 1
15
+ depth3D: int | None = 1
16
16
  """Number of slices in 3D. If data is 2D depth3D is equal to 1"""
17
17
 
18
- datasplit_type: Optional[DataSplitType] = None
19
- """Whether to return training, validation or test split, should be one of
18
+ datasplit_type: DataSplitType | None = None
19
+ """Whether to return training, validation or test split, should be one of
20
20
  DataSplitType"""
21
21
 
22
- num_channels: Optional[int] = 2
22
+ num_channels: int | None = 2
23
23
  """Number of channels in the input"""
24
24
 
25
25
  # TODO: remove ch*_fname parameters, should be parsed automatically from a name list
26
- ch1_fname: Optional[str] = None
27
- ch2_fname: Optional[str] = None
28
- ch_input_fname: Optional[str] = None
26
+ ch1_fname: str | None = None
27
+ ch2_fname: str | None = None
28
+ ch_input_fname: str | None = None
29
29
 
30
- input_is_sum: Optional[bool] = False
30
+ input_is_sum: bool | None = False
31
31
  """Whether the input is the sum or average of channels"""
32
32
 
33
- input_idx: Optional[int] = None
33
+ input_idx: int | None = None
34
34
  """Index of the channel where the input is stored in the data"""
35
35
 
36
- target_idx_list: Optional[list[int]] = None
36
+ target_idx_list: list[int] | None = None
37
37
  """Indices of the channels where the targets are stored in the data"""
38
38
 
39
39
  # TODO: where are there used?
40
- start_alpha: Optional[Any] = None
41
- end_alpha: Optional[Any] = None
40
+ start_alpha: Any | None = None
41
+ end_alpha: Any | None = None
42
42
 
43
43
  image_size: tuple # TODO: revisit, new model_config uses tuple
44
44
  """Size of one patch of data"""
45
45
 
46
- grid_size: Optional[Union[int, tuple[int, int, int]]] = None
46
+ grid_size: Union[int, tuple[int, int, int]] | None = None
47
47
  """Frame is divided into square grids of this size. A patch centered on a grid
48
48
  having size `image_size` is returned. Grid size not used in training,
49
49
  used only during val / test, grid size controls the overlap of the patches"""
50
50
 
51
- empty_patch_replacement_enabled: Optional[bool] = False
51
+ empty_patch_replacement_enabled: bool | None = False
52
52
  """Whether to replace the content of one of the channels
53
53
  with background with given probability"""
54
- empty_patch_replacement_channel_idx: Optional[Any] = None
55
- empty_patch_replacement_probab: Optional[Any] = None
56
- empty_patch_max_val_threshold: Optional[Any] = None
54
+ empty_patch_replacement_channel_idx: Any | None = None
55
+ empty_patch_replacement_probab: Any | None = None
56
+ empty_patch_max_val_threshold: Any | None = None
57
57
 
58
- uncorrelated_channels: Optional[bool] = False
59
- """Replace the content in one of the channels with given probability to make
58
+ uncorrelated_channels: bool | None = False
59
+ """Replace the content in one of the channels with given probability to make
60
60
  channel content 'uncorrelated'"""
61
- uncorrelated_channel_probab: Optional[float] = 0.5
61
+ uncorrelated_channel_probab: float | None = 0.5
62
62
 
63
- poisson_noise_factor: Optional[float] = -1
63
+ poisson_noise_factor: float | None = -1
64
64
  """The added poisson noise factor"""
65
65
 
66
- synthetic_gaussian_scale: Optional[float] = 0.1
66
+ synthetic_gaussian_scale: float | None = 0.1
67
67
 
68
68
  # TODO: set to True in training code, recheck
69
- input_has_dependant_noise: Optional[bool] = False
69
+ input_has_dependant_noise: bool | None = False
70
70
 
71
71
  # TODO: sometimes max_val differs between runs with fixed seeds with noise enabled
72
- enable_gaussian_noise: Optional[bool] = False
72
+ enable_gaussian_noise: bool | None = False
73
73
  """Whether to enable gaussian noise"""
74
74
 
75
75
  # TODO: is this parameter used?
@@ -80,44 +80,56 @@ class DatasetConfig(BaseModel):
80
80
  deterministic_grid: Any = None
81
81
 
82
82
  # TODO: why is this not used?
83
- enable_rotation_aug: Optional[bool] = False
83
+ enable_rotation_aug: bool | None = False
84
84
 
85
- max_val: Optional[Union[float, tuple]] = None
86
- """Maximum data in the dataset. Is calculated for train split, and should be
85
+ max_val: Union[float, tuple] | None = None
86
+ """Maximum data in the dataset. Is calculated for train split, and should be
87
87
  externally set for val and test splits."""
88
88
 
89
89
  overlapping_padding_kwargs: Any = None
90
90
  """Parameters for np.pad method"""
91
91
 
92
92
  # TODO: remove this parameter, controls debug print
93
- print_vars: Optional[bool] = False
93
+ print_vars: bool | None = False
94
94
 
95
95
  # Hard-coded parameters (used to be in the config file)
96
96
  normalized_input: bool = True
97
97
  """If this is set to true, then one mean and stdev is used
98
98
  for both channels. Otherwise, two different mean and stdev are used."""
99
- use_one_mu_std: Optional[bool] = True
99
+ use_one_mu_std: bool | None = True
100
100
 
101
101
  # TODO: is this parameter used?
102
- train_aug_rotate: Optional[bool] = False
103
- enable_random_cropping: Optional[bool] = True
102
+ train_aug_rotate: bool | None = False
103
+ enable_random_cropping: bool | None = True
104
104
 
105
- multiscale_lowres_count: Optional[int] = None
105
+ multiscale_lowres_count: int | None = None
106
106
  """Number of LC scales"""
107
107
 
108
- tiling_mode: Optional[TilingMode] = TilingMode.ShiftBoundary
108
+ tiling_mode: TilingMode | None = TilingMode.ShiftBoundary
109
109
 
110
- target_separate_normalization: Optional[bool] = True
110
+ target_separate_normalization: bool | None = True
111
111
 
112
- mode_3D: Optional[bool] = False
112
+ mode_3D: bool | None = False
113
113
  """If training in 3D mode or not"""
114
114
 
115
- trainig_datausage_fraction: Optional[float] = 1.0
115
+ trainig_datausage_fraction: float | None = 1.0
116
116
 
117
- validtarget_random_fraction: Optional[float] = None
117
+ validtarget_random_fraction: float | None = None
118
118
 
119
- validation_datausage_fraction: Optional[float] = 1.0
119
+ validation_datausage_fraction: float | None = 1.0
120
120
 
121
- random_flip_z_3D: Optional[bool] = False
121
+ random_flip_z_3D: bool | None = False
122
122
 
123
- padding_kwargs: Optional[dict] = None
123
+ padding_kwargs: dict = {"mode": "reflect"} # TODO remove !!
124
+
125
+ def __init__(self, **data):
126
+ # Convert string data_type to enum if needed
127
+ if "data_type" in data and isinstance(data["data_type"], str):
128
+ try:
129
+ data["data_type"] = DataType[data["data_type"]]
130
+ except KeyError:
131
+ # Keep original value to let validation handle the error
132
+ pass
133
+ super().__init__(**data)
134
+
135
+ # TODO add validators !
@@ -2,23 +2,29 @@
2
2
  A place for Datasets and Dataloaders.
3
3
  """
4
4
 
5
- from typing import Tuple, Union, Callable
5
+ import logging
6
+ import math
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Optional, Union
6
9
 
7
10
  import numpy as np
8
11
  from skimage.transform import resize
9
12
 
10
- from .config import DatasetConfig
13
+ from .config import MicroSplitDataConfig
11
14
  from .multich_dataset import MultiChDloader
12
15
 
13
16
 
14
17
  class LCMultiChDloader(MultiChDloader):
18
+ """Multi-channel dataset loader for LC-style datasets."""
19
+
15
20
  def __init__(
16
21
  self,
17
- data_config: DatasetConfig,
18
- fpath: str,
19
- load_data_fn: Callable,
20
- val_fraction=None,
21
- test_fraction=None,
22
+ data_config: MicroSplitDataConfig,
23
+ datapath: Union[str, Path],
24
+ load_data_fn: Optional[Callable] = None,
25
+ val_fraction: float = 0.1,
26
+ test_fraction: float = 0.1,
27
+ allow_generation: bool = False,
22
28
  ):
23
29
  self._padding_kwargs = (
24
30
  data_config.padding_kwargs # mode=padding_mode, constant_values=constant_value
@@ -27,7 +33,7 @@ class LCMultiChDloader(MultiChDloader):
27
33
 
28
34
  super().__init__(
29
35
  data_config,
30
- fpath,
36
+ datapath,
31
37
  load_data_fn=load_data_fn,
32
38
  val_fraction=val_fraction,
33
39
  test_fraction=test_fraction,
@@ -111,8 +117,8 @@ class LCMultiChDloader(MultiChDloader):
111
117
  return msg
112
118
 
113
119
  def _load_scaled_img(
114
- self, scaled_index, index: Union[int, Tuple[int, int]]
115
- ) -> Tuple[np.ndarray, np.ndarray]:
120
+ self, scaled_index, index: Union[int, tuple[int, int]]
121
+ ) -> tuple[np.ndarray, np.ndarray]:
116
122
  if isinstance(index, int):
117
123
  idx = index
118
124
  else:
@@ -131,7 +137,7 @@ class LCMultiChDloader(MultiChDloader):
131
137
  imgs = tuple([img + noise[0] * factor for img in imgs])
132
138
  return imgs
133
139
 
134
- def _crop_img(self, img: np.ndarray, patch_start_loc: Tuple):
140
+ def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
135
141
  """
136
142
  Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
137
143
  the cropped image will be smaller than self._img_sz * self._img_sz
@@ -202,7 +208,7 @@ class LCMultiChDloader(MultiChDloader):
202
208
  )
203
209
  return output_img_tuples, cropped_noise_tuples
204
210
 
205
- def __getitem__(self, index: Union[int, Tuple[int, int]]):
211
+ def __getitem__(self, index: Union[int, tuple[int, int]]):
206
212
  img_tuples, noise_tuples = self._get_img(index)
207
213
  if self._uncorrelated_channels:
208
214
  assert (
@@ -10,7 +10,7 @@ from typing import Callable, Union
10
10
  import numpy as np
11
11
  from skimage.transform import resize
12
12
 
13
- from .config import DatasetConfig
13
+ from .config import MicroSplitDataConfig
14
14
  from .types import DataSplitType, TilingMode
15
15
  from .utils.empty_patch_fetcher import EmptyPatchFetcher
16
16
  from .utils.index_manager import GridIndexManagerRef
@@ -19,7 +19,7 @@ from .utils.index_manager import GridIndexManagerRef
19
19
  class MultiChDloaderRef:
20
20
  def __init__(
21
21
  self,
22
- data_config: DatasetConfig,
22
+ data_config: MicroSplitDataConfig,
23
23
  fpath: str,
24
24
  load_data_fn: Callable,
25
25
  val_fraction: float = None,
@@ -171,8 +171,8 @@ class MultiChDloaderRef:
171
171
 
172
172
  def load_data(
173
173
  self,
174
- data_config,
175
- datasplit_type,
174
+ data_config: MicroSplitDataConfig,
175
+ datasplit_type: DataSplitType,
176
176
  load_data_fn: Callable,
177
177
  val_fraction=None,
178
178
  test_fraction=None,
@@ -813,7 +813,7 @@ class MultiChDloaderRef:
813
813
  class LCMultiChDloaderRef(MultiChDloaderRef):
814
814
  def __init__(
815
815
  self,
816
- data_config: DatasetConfig,
816
+ data_config: MicroSplitDataConfig,
817
817
  fpath: str,
818
818
  load_data_fn: Callable,
819
819
  val_fraction=None,