careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (79) hide show
  1. careamics/careamist.py +11 -14
  2. careamics/cli/conf.py +18 -3
  3. careamics/config/__init__.py +8 -0
  4. careamics/config/algorithms/__init__.py +4 -0
  5. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  6. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  7. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  8. careamics/config/algorithms/vae_algorithm_model.py +51 -16
  9. careamics/config/architectures/lvae_model.py +12 -8
  10. careamics/config/callback_model.py +7 -3
  11. careamics/config/configuration.py +15 -63
  12. careamics/config/configuration_factories.py +853 -29
  13. careamics/config/data/data_model.py +50 -11
  14. careamics/config/data/ng_data_model.py +168 -4
  15. careamics/config/data/patch_filter/__init__.py +15 -0
  16. careamics/config/data/patch_filter/filter_model.py +16 -0
  17. careamics/config/data/patch_filter/mask_filter_model.py +17 -0
  18. careamics/config/data/patch_filter/max_filter_model.py +15 -0
  19. careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
  20. careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
  21. careamics/config/inference_model.py +1 -2
  22. careamics/config/likelihood_model.py +2 -2
  23. careamics/config/loss_model.py +6 -2
  24. careamics/config/nm_model.py +26 -1
  25. careamics/config/optimizer_models.py +1 -2
  26. careamics/config/support/supported_algorithms.py +5 -3
  27. careamics/config/support/supported_filters.py +17 -0
  28. careamics/config/support/supported_losses.py +5 -2
  29. careamics/config/training_model.py +6 -36
  30. careamics/config/transformations/normalize_model.py +1 -2
  31. careamics/dataset_ng/dataset.py +57 -5
  32. careamics/dataset_ng/factory.py +101 -18
  33. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  34. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  35. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  36. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  37. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  38. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  39. careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
  40. careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
  41. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  42. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  43. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  44. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  45. careamics/file_io/read/__init__.py +0 -1
  46. careamics/lightning/__init__.py +16 -2
  47. careamics/lightning/callbacks/__init__.py +2 -0
  48. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  49. careamics/lightning/dataset_ng/data_module.py +79 -2
  50. careamics/lightning/lightning_module.py +162 -61
  51. careamics/lightning/microsplit_data_module.py +636 -0
  52. careamics/lightning/predict_data_module.py +8 -1
  53. careamics/lightning/train_data_module.py +19 -8
  54. careamics/losses/__init__.py +7 -1
  55. careamics/losses/loss_factory.py +9 -1
  56. careamics/losses/lvae/losses.py +85 -0
  57. careamics/lvae_training/dataset/__init__.py +8 -8
  58. careamics/lvae_training/dataset/config.py +56 -44
  59. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  60. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  61. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  62. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  63. careamics/lvae_training/eval_utils.py +46 -24
  64. careamics/model_io/bmz_io.py +9 -5
  65. careamics/models/lvae/likelihoods.py +31 -14
  66. careamics/models/lvae/lvae.py +2 -2
  67. careamics/models/lvae/noise_models.py +20 -14
  68. careamics/prediction_utils/__init__.py +8 -2
  69. careamics/prediction_utils/prediction_outputs.py +49 -3
  70. careamics/prediction_utils/stitch_prediction.py +83 -1
  71. careamics/transforms/xy_random_rotate90.py +1 -1
  72. careamics/utils/version.py +4 -4
  73. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
  74. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
  75. careamics/dataset/zarr_dataset.py +0 -151
  76. careamics/file_io/read/zarr.py +0 -60
  77. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
  78. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
  79. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
@@ -485,7 +485,8 @@ def create_train_datamodule(
485
485
  extension_filter: str = "",
486
486
  val_percentage: float = 0.1,
487
487
  val_minimum_patches: int = 5,
488
- dataloader_params: dict | None = None,
488
+ train_dataloader_params: dict | None = None,
489
+ val_dataloader_params: dict | None = None,
489
490
  use_in_memory: bool = True,
490
491
  ) -> TrainDataModule:
491
492
  """Create a TrainDataModule.
@@ -556,8 +557,10 @@ def create_train_datamodule(
556
557
  val_minimum_patches : int, optional
557
558
  Minimum number of patches to split from the training data for validation if
558
559
  no validation data is given, by default 5.
559
- dataloader_params : dict, optional
560
- Pytorch dataloader parameters, by default {}.
560
+ train_dataloader_params : dict, optional
561
+ Pytorch dataloader parameters for the training data, by default {}.
562
+ val_dataloader_params : dict, optional
563
+ Pytorch dataloader parameters for the validation data, by default {}.
561
564
  use_in_memory : bool, optional
562
565
  Use in memory dataset if possible, by default True.
563
566
 
@@ -617,8 +620,11 @@ def create_train_datamodule(
617
620
  ... transforms=my_transforms,
618
621
  ... )
619
622
  """
620
- if dataloader_params is None:
621
- dataloader_params = {}
623
+ if train_dataloader_params is None:
624
+ train_dataloader_params = {"shuffle": True}
625
+
626
+ if val_dataloader_params is None:
627
+ val_dataloader_params = {"shuffle": False}
622
628
 
623
629
  data_dict: dict[str, Any] = {
624
630
  "mode": "train",
@@ -626,7 +632,8 @@ def create_train_datamodule(
626
632
  "patch_size": patch_size,
627
633
  "axes": axes,
628
634
  "batch_size": batch_size,
629
- "dataloader_params": dataloader_params,
635
+ "train_dataloader_params": train_dataloader_params,
636
+ "val_dataloader_params": val_dataloader_params,
630
637
  }
631
638
 
632
639
  # if transforms are passed (otherwise it will use the default ones)
@@ -637,9 +644,13 @@ def create_train_datamodule(
637
644
  data_config = DataConfig(**data_dict)
638
645
 
639
646
  # sanity check on the dataloader parameters
640
- if "batch_size" in dataloader_params:
647
+ if "batch_size" in train_dataloader_params:
648
+ # remove it
649
+ del train_dataloader_params["batch_size"]
650
+
651
+ if "batch_size" in val_dataloader_params:
641
652
  # remove it
642
- del dataloader_params["batch_size"]
653
+ del val_dataloader_params["batch_size"]
643
654
 
644
655
  return TrainDataModule(
645
656
  data_config=data_config,
@@ -3,6 +3,7 @@
3
3
  __all__ = [
4
4
  "denoisplit_loss",
5
5
  "denoisplit_musplit_loss",
6
+ "hdn_loss",
6
7
  "loss_factory",
7
8
  "mae_loss",
8
9
  "mse_loss",
@@ -12,4 +13,9 @@ __all__ = [
12
13
 
13
14
  from .fcn.losses import mae_loss, mse_loss, n2v_loss
14
15
  from .loss_factory import loss_factory
15
- from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
16
+ from .lvae.losses import (
17
+ denoisplit_loss,
18
+ denoisplit_musplit_loss,
19
+ hdn_loss,
20
+ musplit_loss,
21
+ )
@@ -14,7 +14,12 @@ from torch import Tensor as tensor
14
14
 
15
15
  from ..config.support import SupportedLoss
16
16
  from .fcn.losses import mae_loss, mse_loss, n2v_loss
17
- from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
17
+ from .lvae.losses import (
18
+ denoisplit_loss,
19
+ denoisplit_musplit_loss,
20
+ hdn_loss,
21
+ musplit_loss,
22
+ )
18
23
 
19
24
 
20
25
  @dataclass
@@ -59,6 +64,9 @@ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
59
64
  elif loss == SupportedLoss.MSE:
60
65
  return mse_loss
61
66
 
67
+ elif loss == SupportedLoss.HDN:
68
+ return hdn_loss
69
+
62
70
  elif loss == SupportedLoss.MUSPLIT:
63
71
  return musplit_loss
64
72
 
@@ -89,6 +89,7 @@ def _reconstruction_loss_musplit_denoisplit(
89
89
  if predictions.shape[1] == 2 * targets.shape[1]:
90
90
  # predictions contain both mean and log-variance
91
91
  pred_mean, _ = predictions.chunk(2, dim=1)
92
+ # TODO if this condition does not hold, everything breaks later!
92
93
  else:
93
94
  pred_mean = predictions
94
95
 
@@ -269,6 +270,90 @@ def _get_kl_divergence_loss_denoisplit(
269
270
  # - `__init__` method initializes the loss parameters now contained in
270
271
  # the `LVAELossParameters` class
271
272
  # NOTE: same for the other loss functions
273
+
274
+
275
+ def hdn_loss(
276
+ model_outputs: tuple[torch.Tensor, dict[str, Any]],
277
+ targets: torch.Tensor,
278
+ config: LVAELossConfig,
279
+ gaussian_likelihood: GaussianLikelihood | None,
280
+ noise_model_likelihood: NoiseModelLikelihood | None,
281
+ ) -> dict[str, torch.Tensor] | None:
282
+ """Loss function for HDN.
283
+
284
+ Parameters
285
+ ----------
286
+ model_outputs : tuple[torch.Tensor, dict[str, Any]]
287
+ Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
288
+ and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
289
+ targets : torch.Tensor
290
+ The target image used to compute the reconstruction loss. In this case we use
291
+ the input patch itself as target. Shape is (B, `target_ch`, [Z], Y, X).
292
+ config : LVAELossConfig
293
+ The config for loss function containing all loss hyperparameters.
294
+ gaussian_likelihood : GaussianLikelihood
295
+ The Gaussian likelihood object.
296
+ noise_model_likelihood : NoiseModelLikelihood
297
+ The noise model likelihood object.
298
+
299
+ Returns
300
+ -------
301
+ output : Optional[dict[str, torch.Tensor]]
302
+ A dictionary containing the overall loss `["loss"]`, the reconstruction loss
303
+ `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
304
+ """
305
+ if gaussian_likelihood is not None:
306
+ likelihood = gaussian_likelihood
307
+ elif noise_model_likelihood is not None:
308
+ likelihood = noise_model_likelihood
309
+ else:
310
+ raise ValueError("Invalid likelihood object.")
311
+ # TODO refactor loss signature
312
+ predictions, td_data = model_outputs
313
+
314
+ # Reconstruction loss computation
315
+ recons_loss = config.reconstruction_weight * get_reconstruction_loss(
316
+ reconstruction=predictions,
317
+ target=targets,
318
+ likelihood_obj=likelihood,
319
+ )
320
+ if torch.isnan(recons_loss).any():
321
+ recons_loss = 0.0
322
+
323
+ # KL loss computation
324
+ kl_weight = get_kl_weight(
325
+ config.kl_params.annealing,
326
+ config.kl_params.start,
327
+ config.kl_params.annealtime,
328
+ config.kl_weight,
329
+ config.kl_params.current_epoch,
330
+ )
331
+ kl_loss = (
332
+ _get_kl_divergence_loss_denoisplit(
333
+ topdown_data=td_data,
334
+ img_shape=targets.shape[2:],
335
+ kl_type=config.kl_params.loss_type,
336
+ )
337
+ * kl_weight
338
+ )
339
+
340
+ net_loss = recons_loss + kl_loss # TODO add check that losses coefs sum to 1
341
+ output = {
342
+ "loss": net_loss,
343
+ "reconstruction_loss": (
344
+ recons_loss.detach()
345
+ if isinstance(recons_loss, torch.Tensor)
346
+ else recons_loss
347
+ ),
348
+ "kl_loss": kl_loss.detach(),
349
+ }
350
+ # https://github.com/openai/vdvae/blob/main/train.py#L26
351
+ if torch.isnan(net_loss).any():
352
+ return None
353
+
354
+ return output
355
+
356
+
272
357
  def musplit_loss(
273
358
  model_outputs: tuple[torch.Tensor, dict[str, Any]],
274
359
  targets: torch.Tensor,
@@ -1,4 +1,4 @@
1
- from .config import DatasetConfig
1
+ from .config import MicroSplitDataConfig
2
2
  from .lc_dataset import LCMultiChDloader
3
3
  from .ms_dataset_ref import MultiChDloaderRef
4
4
  from .multich_dataset import MultiChDloader
@@ -7,14 +7,14 @@ from .multifile_dataset import MultiFileDset
7
7
  from .types import DataSplitType, DataType, TilingMode
8
8
 
9
9
  __all__ = [
10
- "DatasetConfig",
11
- "MultiChDloader",
10
+ "DataSplitType",
11
+ "DataType",
12
12
  "LCMultiChDloader",
13
- "MultiFileDset",
14
- "MultiCropDset",
15
- "MultiChDloaderRef",
16
13
  "LCMultiChDloaderRef",
17
- "DataType",
18
- "DataSplitType",
14
+ "MicroSplitDataConfig",
15
+ "MultiChDloader",
16
+ "MultiChDloaderRef",
17
+ "MultiCropDset",
18
+ "MultiFileDset",
19
19
  "TilingMode",
20
20
  ]
@@ -1,4 +1,4 @@
1
- from typing import Any, Optional, Union
1
+ from typing import Any, Union
2
2
 
3
3
  from pydantic import BaseModel, ConfigDict
4
4
 
@@ -6,70 +6,70 @@ from .types import DataSplitType, DataType, TilingMode
6
6
 
7
7
 
8
8
  # TODO: check if any bool logic can be removed
9
- class DatasetConfig(BaseModel):
10
- model_config = ConfigDict(validate_assignment=True, extra="forbid")
9
+ class MicroSplitDataConfig(BaseModel):
10
+ model_config = ConfigDict(validate_assignment=True, extra="allow")
11
11
 
12
- data_type: Optional[DataType]
12
+ data_type: Union[DataType, str] | None # TODO remove or refactor!!
13
13
  """Type of the dataset, should be one of DataType"""
14
14
 
15
- depth3D: Optional[int] = 1
15
+ depth3D: int | None = 1
16
16
  """Number of slices in 3D. If data is 2D depth3D is equal to 1"""
17
17
 
18
- datasplit_type: Optional[DataSplitType] = None
19
- """Whether to return training, validation or test split, should be one of
18
+ datasplit_type: DataSplitType | None = None
19
+ """Whether to return training, validation or test split, should be one of
20
20
  DataSplitType"""
21
21
 
22
- num_channels: Optional[int] = 2
22
+ num_channels: int | None = 2
23
23
  """Number of channels in the input"""
24
24
 
25
25
  # TODO: remove ch*_fname parameters, should be parsed automatically from a name list
26
- ch1_fname: Optional[str] = None
27
- ch2_fname: Optional[str] = None
28
- ch_input_fname: Optional[str] = None
26
+ ch1_fname: str | None = None
27
+ ch2_fname: str | None = None
28
+ ch_input_fname: str | None = None
29
29
 
30
- input_is_sum: Optional[bool] = False
30
+ input_is_sum: bool | None = False
31
31
  """Whether the input is the sum or average of channels"""
32
32
 
33
- input_idx: Optional[int] = None
33
+ input_idx: int | None = None
34
34
  """Index of the channel where the input is stored in the data"""
35
35
 
36
- target_idx_list: Optional[list[int]] = None
36
+ target_idx_list: list[int] | None = None
37
37
  """Indices of the channels where the targets are stored in the data"""
38
38
 
39
39
  # TODO: where are there used?
40
- start_alpha: Optional[Any] = None
41
- end_alpha: Optional[Any] = None
40
+ start_alpha: Any | None = None
41
+ end_alpha: Any | None = None
42
42
 
43
43
  image_size: tuple # TODO: revisit, new model_config uses tuple
44
44
  """Size of one patch of data"""
45
45
 
46
- grid_size: Optional[Union[int, tuple[int, int, int]]] = None
46
+ grid_size: Union[int, tuple[int, int, int]] | None = None
47
47
  """Frame is divided into square grids of this size. A patch centered on a grid
48
48
  having size `image_size` is returned. Grid size not used in training,
49
49
  used only during val / test, grid size controls the overlap of the patches"""
50
50
 
51
- empty_patch_replacement_enabled: Optional[bool] = False
51
+ empty_patch_replacement_enabled: bool | None = False
52
52
  """Whether to replace the content of one of the channels
53
53
  with background with given probability"""
54
- empty_patch_replacement_channel_idx: Optional[Any] = None
55
- empty_patch_replacement_probab: Optional[Any] = None
56
- empty_patch_max_val_threshold: Optional[Any] = None
54
+ empty_patch_replacement_channel_idx: Any | None = None
55
+ empty_patch_replacement_probab: Any | None = None
56
+ empty_patch_max_val_threshold: Any | None = None
57
57
 
58
- uncorrelated_channels: Optional[bool] = False
59
- """Replace the content in one of the channels with given probability to make
58
+ uncorrelated_channels: bool | None = False
59
+ """Replace the content in one of the channels with given probability to make
60
60
  channel content 'uncorrelated'"""
61
- uncorrelated_channel_probab: Optional[float] = 0.5
61
+ uncorrelated_channel_probab: float | None = 0.5
62
62
 
63
- poisson_noise_factor: Optional[float] = -1
63
+ poisson_noise_factor: float | None = -1
64
64
  """The added poisson noise factor"""
65
65
 
66
- synthetic_gaussian_scale: Optional[float] = 0.1
66
+ synthetic_gaussian_scale: float | None = 0.1
67
67
 
68
68
  # TODO: set to True in training code, recheck
69
- input_has_dependant_noise: Optional[bool] = False
69
+ input_has_dependant_noise: bool | None = False
70
70
 
71
71
  # TODO: sometimes max_val differs between runs with fixed seeds with noise enabled
72
- enable_gaussian_noise: Optional[bool] = False
72
+ enable_gaussian_noise: bool | None = False
73
73
  """Whether to enable gaussian noise"""
74
74
 
75
75
  # TODO: is this parameter used?
@@ -80,44 +80,56 @@ class DatasetConfig(BaseModel):
80
80
  deterministic_grid: Any = None
81
81
 
82
82
  # TODO: why is this not used?
83
- enable_rotation_aug: Optional[bool] = False
83
+ enable_rotation_aug: bool | None = False
84
84
 
85
- max_val: Optional[Union[float, tuple]] = None
86
- """Maximum data in the dataset. Is calculated for train split, and should be
85
+ max_val: Union[float, tuple] | None = None
86
+ """Maximum data in the dataset. Is calculated for train split, and should be
87
87
  externally set for val and test splits."""
88
88
 
89
89
  overlapping_padding_kwargs: Any = None
90
90
  """Parameters for np.pad method"""
91
91
 
92
92
  # TODO: remove this parameter, controls debug print
93
- print_vars: Optional[bool] = False
93
+ print_vars: bool | None = False
94
94
 
95
95
  # Hard-coded parameters (used to be in the config file)
96
96
  normalized_input: bool = True
97
97
  """If this is set to true, then one mean and stdev is used
98
98
  for both channels. Otherwise, two different mean and stdev are used."""
99
- use_one_mu_std: Optional[bool] = True
99
+ use_one_mu_std: bool | None = True
100
100
 
101
101
  # TODO: is this parameter used?
102
- train_aug_rotate: Optional[bool] = False
103
- enable_random_cropping: Optional[bool] = True
102
+ train_aug_rotate: bool | None = False
103
+ enable_random_cropping: bool | None = True
104
104
 
105
- multiscale_lowres_count: Optional[int] = None
105
+ multiscale_lowres_count: int | None = None
106
106
  """Number of LC scales"""
107
107
 
108
- tiling_mode: Optional[TilingMode] = TilingMode.ShiftBoundary
108
+ tiling_mode: TilingMode | None = TilingMode.ShiftBoundary
109
109
 
110
- target_separate_normalization: Optional[bool] = True
110
+ target_separate_normalization: bool | None = True
111
111
 
112
- mode_3D: Optional[bool] = False
112
+ mode_3D: bool | None = False
113
113
  """If training in 3D mode or not"""
114
114
 
115
- trainig_datausage_fraction: Optional[float] = 1.0
115
+ trainig_datausage_fraction: float | None = 1.0
116
116
 
117
- validtarget_random_fraction: Optional[float] = None
117
+ validtarget_random_fraction: float | None = None
118
118
 
119
- validation_datausage_fraction: Optional[float] = 1.0
119
+ validation_datausage_fraction: float | None = 1.0
120
120
 
121
- random_flip_z_3D: Optional[bool] = False
121
+ random_flip_z_3D: bool | None = False
122
122
 
123
- padding_kwargs: Optional[dict] = None
123
+ padding_kwargs: dict = {"mode": "reflect"} # TODO remove !!
124
+
125
+ def __init__(self, **data):
126
+ # Convert string data_type to enum if needed
127
+ if "data_type" in data and isinstance(data["data_type"], str):
128
+ try:
129
+ data["data_type"] = DataType[data["data_type"]]
130
+ except KeyError:
131
+ # Keep original value to let validation handle the error
132
+ pass
133
+ super().__init__(**data)
134
+
135
+ # TODO add validators !
@@ -2,23 +2,29 @@
2
2
  A place for Datasets and Dataloaders.
3
3
  """
4
4
 
5
- from typing import Tuple, Union, Callable
5
+ import logging
6
+ import math
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Optional, Union
6
9
 
7
10
  import numpy as np
8
11
  from skimage.transform import resize
9
12
 
10
- from .config import DatasetConfig
13
+ from .config import MicroSplitDataConfig
11
14
  from .multich_dataset import MultiChDloader
12
15
 
13
16
 
14
17
  class LCMultiChDloader(MultiChDloader):
18
+ """Multi-channel dataset loader for LC-style datasets."""
19
+
15
20
  def __init__(
16
21
  self,
17
- data_config: DatasetConfig,
18
- fpath: str,
19
- load_data_fn: Callable,
20
- val_fraction=None,
21
- test_fraction=None,
22
+ data_config: MicroSplitDataConfig,
23
+ datapath: Union[str, Path],
24
+ load_data_fn: Optional[Callable] = None,
25
+ val_fraction: float = 0.1,
26
+ test_fraction: float = 0.1,
27
+ allow_generation: bool = False,
22
28
  ):
23
29
  self._padding_kwargs = (
24
30
  data_config.padding_kwargs # mode=padding_mode, constant_values=constant_value
@@ -27,7 +33,7 @@ class LCMultiChDloader(MultiChDloader):
27
33
 
28
34
  super().__init__(
29
35
  data_config,
30
- fpath,
36
+ datapath,
31
37
  load_data_fn=load_data_fn,
32
38
  val_fraction=val_fraction,
33
39
  test_fraction=test_fraction,
@@ -111,8 +117,8 @@ class LCMultiChDloader(MultiChDloader):
111
117
  return msg
112
118
 
113
119
  def _load_scaled_img(
114
- self, scaled_index, index: Union[int, Tuple[int, int]]
115
- ) -> Tuple[np.ndarray, np.ndarray]:
120
+ self, scaled_index, index: Union[int, tuple[int, int]]
121
+ ) -> tuple[np.ndarray, np.ndarray]:
116
122
  if isinstance(index, int):
117
123
  idx = index
118
124
  else:
@@ -131,7 +137,7 @@ class LCMultiChDloader(MultiChDloader):
131
137
  imgs = tuple([img + noise[0] * factor for img in imgs])
132
138
  return imgs
133
139
 
134
- def _crop_img(self, img: np.ndarray, patch_start_loc: Tuple):
140
+ def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
135
141
  """
136
142
  Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
137
143
  the cropped image will be smaller than self._img_sz * self._img_sz
@@ -202,7 +208,7 @@ class LCMultiChDloader(MultiChDloader):
202
208
  )
203
209
  return output_img_tuples, cropped_noise_tuples
204
210
 
205
- def __getitem__(self, index: Union[int, Tuple[int, int]]):
211
+ def __getitem__(self, index: Union[int, tuple[int, int]]):
206
212
  img_tuples, noise_tuples = self._get_img(index)
207
213
  if self._uncorrelated_channels:
208
214
  assert (
@@ -10,7 +10,7 @@ from typing import Callable, Union
10
10
  import numpy as np
11
11
  from skimage.transform import resize
12
12
 
13
- from .config import DatasetConfig
13
+ from .config import MicroSplitDataConfig
14
14
  from .types import DataSplitType, TilingMode
15
15
  from .utils.empty_patch_fetcher import EmptyPatchFetcher
16
16
  from .utils.index_manager import GridIndexManagerRef
@@ -19,7 +19,7 @@ from .utils.index_manager import GridIndexManagerRef
19
19
  class MultiChDloaderRef:
20
20
  def __init__(
21
21
  self,
22
- data_config: DatasetConfig,
22
+ data_config: MicroSplitDataConfig,
23
23
  fpath: str,
24
24
  load_data_fn: Callable,
25
25
  val_fraction: float = None,
@@ -171,8 +171,8 @@ class MultiChDloaderRef:
171
171
 
172
172
  def load_data(
173
173
  self,
174
- data_config,
175
- datasplit_type,
174
+ data_config: MicroSplitDataConfig,
175
+ datasplit_type: DataSplitType,
176
176
  load_data_fn: Callable,
177
177
  val_fraction=None,
178
178
  test_fraction=None,
@@ -813,7 +813,7 @@ class MultiChDloaderRef:
813
813
  class LCMultiChDloaderRef(MultiChDloaderRef):
814
814
  def __init__(
815
815
  self,
816
- data_config: DatasetConfig,
816
+ data_config: MicroSplitDataConfig,
817
817
  fpath: str,
818
818
  load_data_fn: Callable,
819
819
  val_fraction=None,
@@ -2,29 +2,35 @@
2
2
  A place for Datasets and Dataloaders.
3
3
  """
4
4
 
5
- from typing import Tuple, Union, Callable
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Optional, Union
6
7
 
7
8
  import numpy as np
9
+ import torch
10
+ from torch.utils.data import Dataset
8
11
 
9
12
  from .utils.empty_patch_fetcher import EmptyPatchFetcher
10
13
  from .utils.index_manager import GridIndexManager
11
14
  from .utils.index_switcher import IndexSwitcher
12
- from .config import DatasetConfig
15
+ from .config import MicroSplitDataConfig
13
16
  from .types import DataSplitType, TilingMode
14
17
 
15
18
 
16
- class MultiChDloader:
19
+ class MultiChDloader(Dataset):
20
+ """Multi-channel dataset loader."""
21
+
17
22
  def __init__(
18
23
  self,
19
- data_config: DatasetConfig,
20
- fpath: str,
21
- load_data_fn: Callable,
22
- val_fraction: float = None,
23
- test_fraction: float = None,
24
+ data_config: MicroSplitDataConfig,
25
+ datapath: Union[str, Path],
26
+ load_data_fn: Optional[Callable] = None,
27
+ val_fraction: float = 0.1,
28
+ test_fraction: float = 0.1,
29
+ allow_generation: bool = False,
24
30
  ):
25
31
  """ """
26
32
  self._data_type = data_config.data_type
27
- self._fpath = fpath
33
+ self._fpath = datapath
28
34
  self._data = self._noise_data = None
29
35
  self.Z = 1
30
36
  self._5Ddata = False
@@ -395,7 +401,7 @@ class MultiChDloader:
395
401
  )
396
402
 
397
403
  def get_idx_manager_shapes(
398
- self, patch_size: int, grid_size: Union[int, Tuple[int, int, int]]
404
+ self, patch_size: int, grid_size: Union[int, tuple[int, int, int]]
399
405
  ):
400
406
  numC = self._data.shape[-1]
401
407
  if self._5Ddata:
@@ -415,7 +421,7 @@ class MultiChDloader:
415
421
 
416
422
  return patch_shape, grid_shape
417
423
 
418
- def set_img_sz(self, image_size, grid_size: Union[int, Tuple[int, int, int]]):
424
+ def set_img_sz(self, image_size, grid_size: Union[int, tuple[int, int, int]]):
419
425
  """
420
426
  If one wants to change the image size on the go, then this can be used.
421
427
  Args:
@@ -519,7 +525,7 @@ class MultiChDloader:
519
525
  },
520
526
  )
521
527
 
522
- def _crop_img(self, img: np.ndarray, patch_start_loc: Tuple):
528
+ def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
523
529
  if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
524
530
  # In training, this is used.
525
531
  # NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
@@ -600,7 +606,7 @@ class MultiChDloader:
600
606
  return new_img
601
607
 
602
608
  def _crop_flip_img(
603
- self, img: np.ndarray, patch_start_loc: Tuple, h_flip: bool, w_flip: bool
609
+ self, img: np.ndarray, patch_start_loc: tuple, h_flip: bool, w_flip: bool
604
610
  ):
605
611
  new_img = self._crop_img(img, patch_start_loc)
606
612
  if h_flip:
@@ -611,8 +617,8 @@ class MultiChDloader:
611
617
  return new_img.astype(np.float32)
612
618
 
613
619
  def _load_img(
614
- self, index: Union[int, Tuple[int, int]]
615
- ) -> Tuple[np.ndarray, np.ndarray]:
620
+ self, index: Union[int, tuple[int, int]]
621
+ ) -> tuple[np.ndarray, np.ndarray]:
616
622
  """
617
623
  Returns the channels and also the respective noise channels.
618
624
  """
@@ -806,7 +812,7 @@ class MultiChDloader:
806
812
  w_start = 0
807
813
  return h_start, w_start
808
814
 
809
- def _get_img(self, index: Union[int, Tuple[int, int]]):
815
+ def _get_img(self, index: Union[int, tuple[int, int]]):
810
816
  """
811
817
  Loads an image.
812
818
  Crops the image such that cropped image has content.
@@ -1056,8 +1062,8 @@ class MultiChDloader:
1056
1062
  return img_tuples, noise_tuples
1057
1063
 
1058
1064
  def __getitem__(
1059
- self, index: Union[int, Tuple[int, int]]
1060
- ) -> Tuple[np.ndarray, np.ndarray]:
1065
+ self, index: Union[int, tuple[int, int]]
1066
+ ) -> tuple[np.ndarray, np.ndarray]:
1061
1067
  # Vera: input can be both real microscopic image and two separate channels that are summed in the code
1062
1068
 
1063
1069
  if self._train_index_switcher is not None: