careamics 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (98) hide show
  1. careamics/careamist.py +24 -7
  2. careamics/cli/utils.py +1 -1
  3. careamics/config/algorithms/n2v_algorithm_model.py +1 -1
  4. careamics/config/architectures/unet_model.py +3 -0
  5. careamics/config/callback_model.py +23 -34
  6. careamics/config/configuration.py +55 -4
  7. careamics/config/configuration_factories.py +288 -23
  8. careamics/config/data/__init__.py +2 -0
  9. careamics/config/data/data_model.py +41 -4
  10. careamics/config/data/ng_data_model.py +381 -0
  11. careamics/config/data/patching_strategies/__init__.py +14 -0
  12. careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
  13. careamics/config/data/patching_strategies/_patched_model.py +56 -0
  14. careamics/config/data/patching_strategies/random_patching_model.py +21 -0
  15. careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
  16. careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
  17. careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
  18. careamics/config/inference_model.py +6 -3
  19. careamics/config/optimizer_models.py +1 -3
  20. careamics/config/support/supported_data.py +7 -0
  21. careamics/config/support/supported_patching_strategies.py +22 -0
  22. careamics/config/training_model.py +0 -2
  23. careamics/config/validators/validator_utils.py +4 -3
  24. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  25. careamics/dataset/in_memory_dataset.py +2 -1
  26. careamics/dataset/iterable_dataset.py +2 -2
  27. careamics/dataset/iterable_pred_dataset.py +2 -2
  28. careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
  29. careamics/dataset/patching/patching.py +3 -2
  30. careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
  31. careamics/dataset/tiling/tiled_patching.py +2 -1
  32. careamics/dataset_ng/README.md +212 -0
  33. careamics/dataset_ng/dataset.py +229 -0
  34. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  35. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  36. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
  37. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  38. careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +60 -53
  39. careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
  40. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
  41. careamics/dataset_ng/factory.py +451 -0
  42. careamics/dataset_ng/legacy_interoperability.py +170 -0
  43. careamics/dataset_ng/patch_extractor/__init__.py +3 -8
  44. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +7 -5
  45. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +4 -1
  46. careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
  47. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
  48. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
  49. careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
  50. careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
  51. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +114 -105
  52. careamics/dataset_ng/patching_strategies/__init__.py +6 -1
  53. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
  54. careamics/dataset_ng/patching_strategies/random_patching.py +5 -1
  55. careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
  56. careamics/dataset_ng/patching_strategies/tiling_strategy.py +172 -0
  57. careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
  58. careamics/file_io/read/get_func.py +2 -1
  59. careamics/lightning/dataset_ng/__init__.py +1 -0
  60. careamics/lightning/dataset_ng/data_module.py +678 -0
  61. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  62. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  63. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  64. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +212 -0
  65. careamics/lightning/lightning_module.py +5 -1
  66. careamics/lightning/predict_data_module.py +2 -1
  67. careamics/lightning/train_data_module.py +2 -1
  68. careamics/losses/loss_factory.py +2 -1
  69. careamics/lvae_training/dataset/__init__.py +8 -3
  70. careamics/lvae_training/dataset/config.py +3 -3
  71. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  72. careamics/lvae_training/dataset/multich_dataset.py +46 -17
  73. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  74. careamics/lvae_training/dataset/types.py +3 -3
  75. careamics/lvae_training/dataset/utils/index_manager.py +259 -0
  76. careamics/lvae_training/eval_utils.py +93 -3
  77. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  78. careamics/model_io/bioimage/model_description.py +1 -1
  79. careamics/model_io/bmz_io.py +1 -1
  80. careamics/model_io/model_io_utils.py +2 -2
  81. careamics/models/activation.py +2 -1
  82. careamics/prediction_utils/prediction_outputs.py +1 -1
  83. careamics/prediction_utils/stitch_prediction.py +1 -1
  84. careamics/transforms/compose.py +1 -0
  85. careamics/transforms/n2v_manipulate_torch.py +15 -9
  86. careamics/transforms/normalize.py +18 -7
  87. careamics/transforms/pixel_manipulation_torch.py +59 -92
  88. careamics/utils/lightning_utils.py +25 -11
  89. careamics/utils/metrics.py +2 -1
  90. careamics/utils/torch_utils.py +23 -0
  91. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/METADATA +12 -11
  92. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/RECORD +95 -69
  93. careamics/dataset_ng/dataset/__init__.py +0 -3
  94. careamics/dataset_ng/dataset/dataset.py +0 -184
  95. careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
  96. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
  97. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
  98. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,97 @@
1
+ """CARE Lightning DataModule."""
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any, Union
5
+
6
+ from careamics.config.algorithms.care_algorithm_model import CAREAlgorithm
7
+ from careamics.config.algorithms.n2n_algorithm_model import N2NAlgorithm
8
+ from careamics.config.support import SupportedLoss
9
+ from careamics.dataset_ng.dataset import ImageRegionData
10
+ from careamics.losses import mae_loss, mse_loss
11
+ from careamics.utils.logging import get_logger
12
+
13
+ from .unet_module import UnetModule
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class CAREModule(UnetModule):
19
+ """CAREamics PyTorch Lightning module for CARE algorithm.
20
+
21
+ Parameters
22
+ ----------
23
+ algorithm_config : CAREAlgorithm or dict
24
+ Configuration for the CARE algorithm, either as a CAREAlgorithm instance or a
25
+ dictionary.
26
+ """
27
+
28
+ def __init__(self, algorithm_config: Union[CAREAlgorithm, dict]) -> None:
29
+ """Instantiate CARE DataModule.
30
+
31
+ Parameters
32
+ ----------
33
+ algorithm_config : CAREAlgorithm or dict
34
+ Configuration for the CARE algorithm, either as a CAREAlgorithm instance or
35
+ a dictionary.
36
+ """
37
+ super().__init__(algorithm_config)
38
+ assert isinstance(
39
+ algorithm_config, CAREAlgorithm | N2NAlgorithm
40
+ ), "algorithm_config must be a CAREAlgorithm or a N2NAlgorithm"
41
+ loss = algorithm_config.loss
42
+ if loss == SupportedLoss.MAE:
43
+ self.loss_func: Callable = mae_loss
44
+ elif loss == SupportedLoss.MSE:
45
+ self.loss_func = mse_loss
46
+ else:
47
+ raise ValueError(f"Unsupported loss for Care: {loss}")
48
+
49
+ def training_step(
50
+ self,
51
+ batch: tuple[ImageRegionData, ImageRegionData],
52
+ batch_idx: Any,
53
+ ) -> Any:
54
+ """Training step for CARE module.
55
+
56
+ Parameters
57
+ ----------
58
+ batch : (ImageRegionData, ImageRegionData)
59
+ A tuple containing the input data and the target data.
60
+ batch_idx : Any
61
+ The index of the current batch in the training loop.
62
+
63
+ Returns
64
+ -------
65
+ Any
66
+ The loss value computed for the current batch.
67
+ """
68
+ # TODO: add validation to determine if target is initialized
69
+ x, target = batch[0], batch[1]
70
+
71
+ prediction = self.model(x.data)
72
+ loss = self.loss_func(prediction, target.data)
73
+
74
+ self._log_training_stats(loss, batch_size=x.data.shape[0])
75
+
76
+ return loss
77
+
78
+ def validation_step(
79
+ self,
80
+ batch: tuple[ImageRegionData, ImageRegionData],
81
+ batch_idx: Any,
82
+ ) -> None:
83
+ """Validation step for CARE module.
84
+
85
+ Parameters
86
+ ----------
87
+ batch : (ImageRegionData, ImageRegionData)
88
+ A tuple containing the input data and the target data.
89
+ batch_idx : Any
90
+ The index of the current batch in the training loop.
91
+ """
92
+ x, target = batch[0], batch[1]
93
+
94
+ prediction = self.model(x.data)
95
+ val_loss = self.loss_func(prediction, target.data)
96
+ self.metrics(prediction, target.data)
97
+ self._log_validation_stats(val_loss, batch_size=x.data.shape[0])
@@ -0,0 +1,106 @@
1
+ """Noise2Void Lightning DataModule."""
2
+
3
+ from typing import Any, Union
4
+
5
+ from careamics.config import (
6
+ N2VAlgorithm,
7
+ )
8
+ from careamics.dataset_ng.dataset import ImageRegionData
9
+ from careamics.losses import n2v_loss
10
+ from careamics.transforms import N2VManipulateTorch
11
+ from careamics.utils.logging import get_logger
12
+
13
+ from .unet_module import UnetModule
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class N2VModule(UnetModule):
19
+ """CAREamics PyTorch Lightning module for N2V algorithm.
20
+
21
+ Parameters
22
+ ----------
23
+ algorithm_config : N2VAlgorithm or dict
24
+ Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
25
+ dictionary.
26
+ """
27
+
28
+ def __init__(self, algorithm_config: Union[N2VAlgorithm, dict]) -> None:
29
+ """Instantiate N2V DataModule.
30
+
31
+ Parameters
32
+ ----------
33
+ algorithm_config : N2VAlgorithm or dict
34
+ Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
35
+ dictionary.
36
+ """
37
+ super().__init__(algorithm_config)
38
+
39
+ assert isinstance(
40
+ algorithm_config, N2VAlgorithm
41
+ ), "algorithm_config must be a N2VAlgorithm"
42
+
43
+ self.n2v_manipulate = N2VManipulateTorch(
44
+ n2v_manipulate_config=algorithm_config.n2v_config
45
+ )
46
+ self.loss_func = n2v_loss
47
+
48
+ def _load_best_checkpoint(self) -> None:
49
+ """Load the best checkpoint for N2V model."""
50
+ logger.warning(
51
+ "Loading best checkpoint for N2V model. Note that for N2V, "
52
+ "the checkpoint with the best validation metrics may not necessarily "
53
+ "have the best denoising performance."
54
+ )
55
+ super()._load_best_checkpoint()
56
+
57
+ def training_step(
58
+ self,
59
+ batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
60
+ batch_idx: Any,
61
+ ) -> Any:
62
+ """Training step for N2V model.
63
+
64
+ Parameters
65
+ ----------
66
+ batch : ImageRegionData or (ImageRegionData, ImageRegionData)
67
+ A tuple containing the input data and the target data.
68
+ batch_idx : Any
69
+ The index of the current batch in the training loop.
70
+
71
+ Returns
72
+ -------
73
+ Any
74
+ The loss value for the current training step.
75
+ """
76
+ x = batch[0]
77
+ x_masked, x_original, mask = self.n2v_manipulate(x.data)
78
+ prediction = self.model(x_masked)
79
+ loss = self.loss_func(prediction, x_original, mask)
80
+
81
+ self._log_training_stats(loss, batch_size=x.data.shape[0])
82
+
83
+ return loss
84
+
85
+ def validation_step(
86
+ self,
87
+ batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
88
+ batch_idx: Any,
89
+ ) -> None:
90
+ """Validation step for N2V model.
91
+
92
+ Parameters
93
+ ----------
94
+ batch : ImageRegionData or (ImageRegionData, ImageRegionData)
95
+ A tuple containing the input data and the target data.
96
+ batch_idx : Any
97
+ The index of the current batch in the training loop.
98
+ """
99
+ x = batch[0]
100
+
101
+ x_masked, x_original, mask = self.n2v_manipulate(x.data)
102
+ prediction = self.model(x_masked)
103
+
104
+ val_loss = self.loss_func(prediction, x_original, mask)
105
+ self.metrics(prediction, x_original)
106
+ self._log_validation_stats(val_loss, batch_size=x.data.shape[0])
@@ -0,0 +1,212 @@
1
+ """Generic UNet Lightning DataModule."""
2
+
3
+ from typing import Any, Union
4
+
5
+ import pytorch_lightning as L
6
+ import torch
7
+ from torch import nn
8
+ from torchmetrics import MetricCollection
9
+ from torchmetrics.image import PeakSignalNoiseRatio
10
+
11
+ from careamics.config import algorithm_factory
12
+ from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
13
+ from careamics.dataset_ng.dataset import ImageRegionData
14
+ from careamics.models.unet import UNet
15
+ from careamics.transforms import Denormalize
16
+ from careamics.utils.logging import get_logger
17
+ from careamics.utils.torch_utils import get_optimizer, get_scheduler
18
+
19
+ logger = get_logger(__name__)
20
+
21
+
22
+ class UnetModule(L.LightningModule):
23
+ """CAREamics PyTorch Lightning module for UNet based algorithms.
24
+
25
+ Parameters
26
+ ----------
27
+ algorithm_config : CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, or dict
28
+ Configuration for the algorithm, either as an instance of a specific algorithm
29
+ class or a dictionary that can be converted to an algorithm instance.
30
+ """
31
+
32
+ def __init__(
33
+ self, algorithm_config: Union[CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, dict]
34
+ ) -> None:
35
+ """Instantiate UNet DataModule.
36
+
37
+ Parameters
38
+ ----------
39
+ algorithm_config : CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, or dict
40
+ Configuration for the algorithm, either as an instance of a specific
41
+ algorithm class or a dictionary that can be converted to an algorithm
42
+ instance.
43
+ """
44
+ super().__init__()
45
+
46
+ if isinstance(algorithm_config, dict):
47
+ algorithm_config = algorithm_factory(algorithm_config)
48
+
49
+ self.config = algorithm_config
50
+ self.model: nn.Module = UNet(**algorithm_config.model.model_dump())
51
+
52
+ self._best_checkpoint_loaded = False
53
+
54
+ # TODO: how to support metric evaluation better
55
+ self.metrics = MetricCollection(PeakSignalNoiseRatio())
56
+
57
+ def forward(self, x: Any) -> Any:
58
+ """Default forward method.
59
+
60
+ Parameters
61
+ ----------
62
+ x : Any
63
+ Input data.
64
+
65
+ Returns
66
+ -------
67
+ Any
68
+ Output from the model.
69
+ """
70
+ return self.model(x)
71
+
72
+ def _log_training_stats(self, loss: Any, batch_size: Any) -> None:
73
+ """Log training statistics.
74
+
75
+ Parameters
76
+ ----------
77
+ loss : Any
78
+ The loss value for the current training step.
79
+ batch_size : Any
80
+ The size of the batch used in the current training step.
81
+ """
82
+ self.log(
83
+ "train_loss",
84
+ loss,
85
+ on_step=True,
86
+ on_epoch=True,
87
+ prog_bar=True,
88
+ logger=True,
89
+ batch_size=batch_size,
90
+ )
91
+
92
+ optimizer = self.optimizers()
93
+ if isinstance(optimizer, list):
94
+ current_lr = optimizer[0].param_groups[0]["lr"]
95
+ else:
96
+ current_lr = optimizer.param_groups[0]["lr"]
97
+ self.log(
98
+ "learning_rate",
99
+ current_lr,
100
+ on_step=False,
101
+ on_epoch=True,
102
+ logger=True,
103
+ batch_size=batch_size,
104
+ )
105
+
106
+ def _log_validation_stats(self, loss: Any, batch_size: Any) -> None:
107
+ """Log validation statistics.
108
+
109
+ Parameters
110
+ ----------
111
+ loss : Any
112
+ The loss value for the current validation step.
113
+ batch_size : Any
114
+ The size of the batch used in the current validation step.
115
+ """
116
+ self.log(
117
+ "val_loss",
118
+ loss,
119
+ on_step=False,
120
+ on_epoch=True,
121
+ prog_bar=True,
122
+ logger=True,
123
+ batch_size=batch_size,
124
+ )
125
+ self.log_dict(self.metrics, on_step=False, on_epoch=True, batch_size=batch_size)
126
+
127
+ def _load_best_checkpoint(self) -> None:
128
+ """Load the best checkpoint from the trainer's checkpoint callback."""
129
+ if (
130
+ not hasattr(self.trainer, "checkpoint_callback")
131
+ or self.trainer.checkpoint_callback is None
132
+ ):
133
+ logger.warning("No checkpoint callback found, cannot load best checkpoint.")
134
+ return
135
+
136
+ best_model_path = self.trainer.checkpoint_callback.best_model_path
137
+ if best_model_path and best_model_path != "":
138
+ logger.info(f"Loading best checkpoint from: {best_model_path}")
139
+ model_state = torch.load(best_model_path, weights_only=True)["state_dict"]
140
+ self.load_state_dict(model_state)
141
+ else:
142
+ logger.warning("No best checkpoint found.")
143
+
144
+ def predict_step(
145
+ self,
146
+ batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
147
+ batch_idx: Any,
148
+ load_best_checkpoint=False,
149
+ ) -> Any:
150
+ """Default predict step.
151
+
152
+ Parameters
153
+ ----------
154
+ batch : ImageRegionData or (ImageRegionData, ImageRegionData)
155
+ A tuple containing the input data and optionally the target data.
156
+ batch_idx : Any
157
+ The index of the current batch in the prediction loop.
158
+ load_best_checkpoint : bool, default=False
159
+ Whether to load the best checkpoint before making predictions.
160
+
161
+ Returns
162
+ -------
163
+ Any
164
+ The output batch containing the predictions.
165
+ """
166
+ if self._best_checkpoint_loaded is False and load_best_checkpoint:
167
+ self._load_best_checkpoint()
168
+ self._best_checkpoint_loaded = True
169
+
170
+ x = batch[0]
171
+ # TODO: add TTA
172
+ prediction = self.model(x.data).cpu().numpy()
173
+
174
+ means = self._trainer.datamodule.stats.means
175
+ stds = self._trainer.datamodule.stats.stds
176
+ denormalize = Denormalize(
177
+ image_means=means,
178
+ image_stds=stds,
179
+ )
180
+ denormalized_output = denormalize(prediction)
181
+
182
+ output_batch = ImageRegionData(
183
+ data=denormalized_output,
184
+ source=x.source,
185
+ data_shape=x.data_shape,
186
+ dtype=x.dtype,
187
+ axes=x.axes,
188
+ region_spec=x.region_spec,
189
+ )
190
+ return output_batch
191
+
192
+ def configure_optimizers(self) -> Any:
193
+ """Configure optimizers.
194
+
195
+ Returns
196
+ -------
197
+ Any
198
+ A dictionary containing the optimizer and learning rate scheduler.
199
+ """
200
+ optimizer_func = get_optimizer(self.config.optimizer.name)
201
+ optimizer = optimizer_func(
202
+ self.model.parameters(), **self.config.optimizer.parameters
203
+ )
204
+
205
+ scheduler_func = get_scheduler(self.config.lr_scheduler.name)
206
+ scheduler = scheduler_func(optimizer, **self.config.lr_scheduler.parameters)
207
+
208
+ return {
209
+ "optimizer": optimizer,
210
+ "lr_scheduler": scheduler,
211
+ "monitor": "val_loss", # otherwise triggers MisconfigurationException
212
+ }
@@ -1,6 +1,7 @@
1
1
  """CAREamics Lightning module."""
2
2
 
3
- from typing import Any, Callable, Literal, Optional, Union
3
+ from collections.abc import Callable
4
+ from typing import Any, Literal, Optional, Union
4
5
 
5
6
  import numpy as np
6
7
  import pytorch_lightning as L
@@ -148,6 +149,9 @@ class FCNModule(L.LightningModule):
148
149
  self.log(
149
150
  "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
150
151
  )
152
+ optimizer = self.optimizers()
153
+ current_lr = optimizer.param_groups[0]["lr"]
154
+ self.log("learning_rate", current_lr, on_step=False, on_epoch=True, logger=True)
151
155
  return loss
152
156
 
153
157
  def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
@@ -1,7 +1,8 @@
1
1
  """Prediction Lightning data modules."""
2
2
 
3
+ from collections.abc import Callable
3
4
  from pathlib import Path
4
- from typing import Any, Callable, Literal, Optional, Union
5
+ from typing import Any, Literal, Optional, Union
5
6
 
6
7
  import numpy as np
7
8
  import pytorch_lightning as L
@@ -1,7 +1,8 @@
1
1
  """Training and validation Lightning data modules."""
2
2
 
3
+ from collections.abc import Callable
3
4
  from pathlib import Path
4
- from typing import Any, Callable, Literal, Optional, Union
5
+ from typing import Any, Literal, Optional, Union
5
6
 
6
7
  import numpy as np
7
8
  import pytorch_lightning as L
@@ -6,8 +6,9 @@ This module contains a factory function for creating loss functions.
6
6
 
7
7
  from __future__ import annotations
8
8
 
9
+ from collections.abc import Callable
9
10
  from dataclasses import dataclass
10
- from typing import Callable, Union
11
+ from typing import Union
11
12
 
12
13
  from torch import Tensor as tensor
13
14
 
@@ -1,14 +1,19 @@
1
- from .multich_dataset import MultiChDloader
1
+ from .config import DatasetConfig
2
2
  from .lc_dataset import LCMultiChDloader
3
+ from .ms_dataset_ref import MultiChDloaderRef
4
+ from .multich_dataset import MultiChDloader
5
+ from .multicrop_dset import MultiCropDset
3
6
  from .multifile_dataset import MultiFileDset
4
- from .config import DatasetConfig
5
- from .types import DataType, DataSplitType, TilingMode
7
+ from .types import DataSplitType, DataType, TilingMode
6
8
 
7
9
  __all__ = [
8
10
  "DatasetConfig",
9
11
  "MultiChDloader",
10
12
  "LCMultiChDloader",
11
13
  "MultiFileDset",
14
+ "MultiCropDset",
15
+ "MultiChDloaderRef",
16
+ "LCMultiChDloaderRef",
12
17
  "DataType",
13
18
  "DataSplitType",
14
19
  "TilingMode",
@@ -1,4 +1,4 @@
1
- from typing import Any, Optional
1
+ from typing import Any, Optional, Union
2
2
 
3
3
  from pydantic import BaseModel, ConfigDict
4
4
 
@@ -43,7 +43,7 @@ class DatasetConfig(BaseModel):
43
43
  image_size: tuple # TODO: revisit, new model_config uses tuple
44
44
  """Size of one patch of data"""
45
45
 
46
- grid_size: Optional[int] = None
46
+ grid_size: Optional[Union[int, tuple[int, int, int]]] = None
47
47
  """Frame is divided into square grids of this size. A patch centered on a grid
48
48
  having size `image_size` is returned. Grid size not used in training,
49
49
  used only during val / test, grid size controls the overlap of the patches"""
@@ -82,7 +82,7 @@ class DatasetConfig(BaseModel):
82
82
  # TODO: why is this not used?
83
83
  enable_rotation_aug: Optional[bool] = False
84
84
 
85
- max_val: Optional[float] = None
85
+ max_val: Optional[Union[float, tuple]] = None
86
86
  """Maximum data in the dataset. Is calculated for train split, and should be
87
87
  externally set for val and test splits."""
88
88