careamics 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +55 -61
- careamics/cli/conf.py +24 -9
- careamics/cli/main.py +8 -8
- careamics/cli/utils.py +2 -4
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +53 -18
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +15 -11
- careamics/config/configuration.py +9 -8
- careamics/config/configuration_factories.py +892 -78
- careamics/config/data/data_model.py +7 -14
- careamics/config/data/ng_data_model.py +8 -15
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
- careamics/config/inference_model.py +6 -11
- careamics/config/likelihood_model.py +4 -4
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +30 -7
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +8 -38
- careamics/config/transformations/normalize_model.py +3 -4
- careamics/config/transformations/xy_flip_model.py +2 -2
- careamics/config/transformations/xy_random_rotate90_model.py +2 -2
- careamics/config/validators/validator_utils.py +1 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
- careamics/dataset/in_memory_dataset.py +2 -2
- careamics/dataset/iterable_dataset.py +1 -2
- careamics/dataset/patching/random_patching.py +6 -6
- careamics/dataset/patching/sequential_patching.py +4 -4
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
- careamics/dataset_ng/dataset.py +3 -3
- careamics/dataset_ng/factory.py +19 -19
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
- careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +23 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
- careamics/lightning/dataset_ng/data_module.py +43 -43
- careamics/lightning/lightning_module.py +166 -68
- careamics/lightning/microsplit_data_module.py +631 -0
- careamics/lightning/predict_data_module.py +16 -9
- careamics/lightning/train_data_module.py +29 -18
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +94 -9
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/model_io/bioimage/model_description.py +12 -11
- careamics/model_io/bmz_io.py +12 -8
- careamics/models/layers.py +5 -5
- careamics/models/lvae/likelihoods.py +30 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/lvae_prediction.py +5 -5
- careamics/prediction_utils/prediction_outputs.py +48 -3
- careamics/prediction_utils/stitch_prediction.py +71 -0
- careamics/transforms/compose.py +9 -9
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/n2v_manipulate_torch.py +4 -4
- careamics/transforms/normalize.py +4 -6
- careamics/transforms/pixel_manipulation.py +6 -8
- careamics/transforms/pixel_manipulation_torch.py +5 -7
- careamics/transforms/xy_flip.py +3 -5
- careamics/transforms/xy_random_rotate90.py +4 -6
- careamics/utils/logging.py +8 -8
- careamics/utils/metrics.py +2 -2
- careamics/utils/plotting.py +1 -3
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -16
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/RECORD +90 -88
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Any, Literal,
|
|
5
|
+
from typing import Any, Literal, Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pytorch_lightning as L
|
|
@@ -121,10 +121,10 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
121
121
|
self,
|
|
122
122
|
data_config: DataConfig,
|
|
123
123
|
train_data: Union[Path, str, NDArray],
|
|
124
|
-
val_data:
|
|
125
|
-
train_data_target:
|
|
126
|
-
val_data_target:
|
|
127
|
-
read_source_func:
|
|
124
|
+
val_data: Union[Path, str, NDArray] | None = None,
|
|
125
|
+
train_data_target: Union[Path, str, NDArray] | None = None,
|
|
126
|
+
val_data_target: Union[Path, str, NDArray] | None = None,
|
|
127
|
+
read_source_func: Callable | None = None,
|
|
128
128
|
extension_filter: str = "",
|
|
129
129
|
val_percentage: float = 0.1,
|
|
130
130
|
val_minimum_split: int = 5,
|
|
@@ -477,15 +477,16 @@ def create_train_datamodule(
|
|
|
477
477
|
patch_size: list[int],
|
|
478
478
|
axes: str,
|
|
479
479
|
batch_size: int,
|
|
480
|
-
val_data:
|
|
481
|
-
transforms:
|
|
482
|
-
train_target_data:
|
|
483
|
-
val_target_data:
|
|
484
|
-
read_source_func:
|
|
480
|
+
val_data: Union[str, Path, NDArray] | None = None,
|
|
481
|
+
transforms: list[TransformModel] | None = None,
|
|
482
|
+
train_target_data: Union[str, Path, NDArray] | None = None,
|
|
483
|
+
val_target_data: Union[str, Path, NDArray] | None = None,
|
|
484
|
+
read_source_func: Callable | None = None,
|
|
485
485
|
extension_filter: str = "",
|
|
486
486
|
val_percentage: float = 0.1,
|
|
487
487
|
val_minimum_patches: int = 5,
|
|
488
|
-
|
|
488
|
+
train_dataloader_params: dict | None = None,
|
|
489
|
+
val_dataloader_params: dict | None = None,
|
|
489
490
|
use_in_memory: bool = True,
|
|
490
491
|
) -> TrainDataModule:
|
|
491
492
|
"""Create a TrainDataModule.
|
|
@@ -556,8 +557,10 @@ def create_train_datamodule(
|
|
|
556
557
|
val_minimum_patches : int, optional
|
|
557
558
|
Minimum number of patches to split from the training data for validation if
|
|
558
559
|
no validation data is given, by default 5.
|
|
559
|
-
|
|
560
|
-
Pytorch dataloader parameters, by default {}.
|
|
560
|
+
train_dataloader_params : dict, optional
|
|
561
|
+
Pytorch dataloader parameters for the training data, by default {}.
|
|
562
|
+
val_dataloader_params : dict, optional
|
|
563
|
+
Pytorch dataloader parameters for the validation data, by default {}.
|
|
561
564
|
use_in_memory : bool, optional
|
|
562
565
|
Use in memory dataset if possible, by default True.
|
|
563
566
|
|
|
@@ -617,8 +620,11 @@ def create_train_datamodule(
|
|
|
617
620
|
... transforms=my_transforms,
|
|
618
621
|
... )
|
|
619
622
|
"""
|
|
620
|
-
if
|
|
621
|
-
|
|
623
|
+
if train_dataloader_params is None:
|
|
624
|
+
train_dataloader_params = {"shuffle": True}
|
|
625
|
+
|
|
626
|
+
if val_dataloader_params is None:
|
|
627
|
+
val_dataloader_params = {"shuffle": False}
|
|
622
628
|
|
|
623
629
|
data_dict: dict[str, Any] = {
|
|
624
630
|
"mode": "train",
|
|
@@ -626,7 +632,8 @@ def create_train_datamodule(
|
|
|
626
632
|
"patch_size": patch_size,
|
|
627
633
|
"axes": axes,
|
|
628
634
|
"batch_size": batch_size,
|
|
629
|
-
"
|
|
635
|
+
"train_dataloader_params": train_dataloader_params,
|
|
636
|
+
"val_dataloader_params": val_dataloader_params,
|
|
630
637
|
}
|
|
631
638
|
|
|
632
639
|
# if transforms are passed (otherwise it will use the default ones)
|
|
@@ -637,9 +644,13 @@ def create_train_datamodule(
|
|
|
637
644
|
data_config = DataConfig(**data_dict)
|
|
638
645
|
|
|
639
646
|
# sanity check on the dataloader parameters
|
|
640
|
-
if "batch_size" in
|
|
647
|
+
if "batch_size" in train_dataloader_params:
|
|
648
|
+
# remove it
|
|
649
|
+
del train_dataloader_params["batch_size"]
|
|
650
|
+
|
|
651
|
+
if "batch_size" in val_dataloader_params:
|
|
641
652
|
# remove it
|
|
642
|
-
del
|
|
653
|
+
del val_dataloader_params["batch_size"]
|
|
643
654
|
|
|
644
655
|
return TrainDataModule(
|
|
645
656
|
data_config=data_config,
|
careamics/losses/__init__.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
__all__ = [
|
|
4
4
|
"denoisplit_loss",
|
|
5
5
|
"denoisplit_musplit_loss",
|
|
6
|
+
"hdn_loss",
|
|
6
7
|
"loss_factory",
|
|
7
8
|
"mae_loss",
|
|
8
9
|
"mse_loss",
|
|
@@ -12,4 +13,9 @@ __all__ = [
|
|
|
12
13
|
|
|
13
14
|
from .fcn.losses import mae_loss, mse_loss, n2v_loss
|
|
14
15
|
from .loss_factory import loss_factory
|
|
15
|
-
from .lvae.losses import
|
|
16
|
+
from .lvae.losses import (
|
|
17
|
+
denoisplit_loss,
|
|
18
|
+
denoisplit_musplit_loss,
|
|
19
|
+
hdn_loss,
|
|
20
|
+
musplit_loss,
|
|
21
|
+
)
|
careamics/losses/loss_factory.py
CHANGED
|
@@ -14,7 +14,12 @@ from torch import Tensor as tensor
|
|
|
14
14
|
|
|
15
15
|
from ..config.support import SupportedLoss
|
|
16
16
|
from .fcn.losses import mae_loss, mse_loss, n2v_loss
|
|
17
|
-
from .lvae.losses import
|
|
17
|
+
from .lvae.losses import (
|
|
18
|
+
denoisplit_loss,
|
|
19
|
+
denoisplit_musplit_loss,
|
|
20
|
+
hdn_loss,
|
|
21
|
+
musplit_loss,
|
|
22
|
+
)
|
|
18
23
|
|
|
19
24
|
|
|
20
25
|
@dataclass
|
|
@@ -59,6 +64,9 @@ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
|
|
|
59
64
|
elif loss == SupportedLoss.MSE:
|
|
60
65
|
return mse_loss
|
|
61
66
|
|
|
67
|
+
elif loss == SupportedLoss.HDN:
|
|
68
|
+
return hdn_loss
|
|
69
|
+
|
|
62
70
|
elif loss == SupportedLoss.MUSPLIT:
|
|
63
71
|
return musplit_loss
|
|
64
72
|
|
careamics/losses/lvae/losses.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Literal,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import torch
|
|
@@ -89,6 +89,7 @@ def _reconstruction_loss_musplit_denoisplit(
|
|
|
89
89
|
if predictions.shape[1] == 2 * targets.shape[1]:
|
|
90
90
|
# predictions contain both mean and log-variance
|
|
91
91
|
pred_mean, _ = predictions.chunk(2, dim=1)
|
|
92
|
+
# TODO if this condition does not hold, everything breaks later!
|
|
92
93
|
else:
|
|
93
94
|
pred_mean = predictions
|
|
94
95
|
|
|
@@ -112,7 +113,7 @@ def get_kl_divergence_loss(
|
|
|
112
113
|
rescaling: Literal["latent_dim", "image_dim"],
|
|
113
114
|
aggregation: Literal["mean", "sum"],
|
|
114
115
|
free_bits_coeff: float,
|
|
115
|
-
img_shape:
|
|
116
|
+
img_shape: tuple[int] | None = None,
|
|
116
117
|
) -> torch.Tensor:
|
|
117
118
|
"""Compute the KL divergence loss.
|
|
118
119
|
|
|
@@ -269,13 +270,97 @@ def _get_kl_divergence_loss_denoisplit(
|
|
|
269
270
|
# - `__init__` method initializes the loss parameters now contained in
|
|
270
271
|
# the `LVAELossParameters` class
|
|
271
272
|
# NOTE: same for the other loss functions
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def hdn_loss(
|
|
276
|
+
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
277
|
+
targets: torch.Tensor,
|
|
278
|
+
config: LVAELossConfig,
|
|
279
|
+
gaussian_likelihood: GaussianLikelihood | None,
|
|
280
|
+
noise_model_likelihood: NoiseModelLikelihood | None,
|
|
281
|
+
) -> dict[str, torch.Tensor] | None:
|
|
282
|
+
"""Loss function for HDN.
|
|
283
|
+
|
|
284
|
+
Parameters
|
|
285
|
+
----------
|
|
286
|
+
model_outputs : tuple[torch.Tensor, dict[str, Any]]
|
|
287
|
+
Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
|
|
288
|
+
and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
|
|
289
|
+
targets : torch.Tensor
|
|
290
|
+
The target image used to compute the reconstruction loss. In this case we use
|
|
291
|
+
the input patch itself as target. Shape is (B, `target_ch`, [Z], Y, X).
|
|
292
|
+
config : LVAELossConfig
|
|
293
|
+
The config for loss function containing all loss hyperparameters.
|
|
294
|
+
gaussian_likelihood : GaussianLikelihood
|
|
295
|
+
The Gaussian likelihood object.
|
|
296
|
+
noise_model_likelihood : NoiseModelLikelihood
|
|
297
|
+
The noise model likelihood object.
|
|
298
|
+
|
|
299
|
+
Returns
|
|
300
|
+
-------
|
|
301
|
+
output : Optional[dict[str, torch.Tensor]]
|
|
302
|
+
A dictionary containing the overall loss `["loss"]`, the reconstruction loss
|
|
303
|
+
`["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
|
|
304
|
+
"""
|
|
305
|
+
if gaussian_likelihood is not None:
|
|
306
|
+
likelihood = gaussian_likelihood
|
|
307
|
+
elif noise_model_likelihood is not None:
|
|
308
|
+
likelihood = noise_model_likelihood
|
|
309
|
+
else:
|
|
310
|
+
raise ValueError("Invalid likelihood object.")
|
|
311
|
+
# TODO refactor loss signature
|
|
312
|
+
predictions, td_data = model_outputs
|
|
313
|
+
|
|
314
|
+
# Reconstruction loss computation
|
|
315
|
+
recons_loss = config.reconstruction_weight * get_reconstruction_loss(
|
|
316
|
+
reconstruction=predictions,
|
|
317
|
+
target=targets,
|
|
318
|
+
likelihood_obj=likelihood,
|
|
319
|
+
)
|
|
320
|
+
if torch.isnan(recons_loss).any():
|
|
321
|
+
recons_loss = 0.0
|
|
322
|
+
|
|
323
|
+
# KL loss computation
|
|
324
|
+
kl_weight = get_kl_weight(
|
|
325
|
+
config.kl_params.annealing,
|
|
326
|
+
config.kl_params.start,
|
|
327
|
+
config.kl_params.annealtime,
|
|
328
|
+
config.kl_weight,
|
|
329
|
+
config.kl_params.current_epoch,
|
|
330
|
+
)
|
|
331
|
+
kl_loss = (
|
|
332
|
+
_get_kl_divergence_loss_denoisplit(
|
|
333
|
+
topdown_data=td_data,
|
|
334
|
+
img_shape=targets.shape[2:],
|
|
335
|
+
kl_type=config.kl_params.loss_type,
|
|
336
|
+
)
|
|
337
|
+
* kl_weight
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
net_loss = recons_loss + kl_loss # TODO add check that losses coefs sum to 1
|
|
341
|
+
output = {
|
|
342
|
+
"loss": net_loss,
|
|
343
|
+
"reconstruction_loss": (
|
|
344
|
+
recons_loss.detach()
|
|
345
|
+
if isinstance(recons_loss, torch.Tensor)
|
|
346
|
+
else recons_loss
|
|
347
|
+
),
|
|
348
|
+
"kl_loss": kl_loss.detach(),
|
|
349
|
+
}
|
|
350
|
+
# https://github.com/openai/vdvae/blob/main/train.py#L26
|
|
351
|
+
if torch.isnan(net_loss).any():
|
|
352
|
+
return None
|
|
353
|
+
|
|
354
|
+
return output
|
|
355
|
+
|
|
356
|
+
|
|
272
357
|
def musplit_loss(
|
|
273
358
|
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
274
359
|
targets: torch.Tensor,
|
|
275
360
|
config: LVAELossConfig,
|
|
276
|
-
gaussian_likelihood:
|
|
277
|
-
noise_model_likelihood:
|
|
278
|
-
) ->
|
|
361
|
+
gaussian_likelihood: GaussianLikelihood | None,
|
|
362
|
+
noise_model_likelihood: NoiseModelLikelihood | None = None, # TODO: ugly
|
|
363
|
+
) -> dict[str, torch.Tensor] | None:
|
|
279
364
|
"""Loss function for muSplit.
|
|
280
365
|
|
|
281
366
|
Parameters
|
|
@@ -351,9 +436,9 @@ def denoisplit_loss(
|
|
|
351
436
|
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
352
437
|
targets: torch.Tensor,
|
|
353
438
|
config: LVAELossConfig,
|
|
354
|
-
gaussian_likelihood:
|
|
355
|
-
noise_model_likelihood:
|
|
356
|
-
) ->
|
|
439
|
+
gaussian_likelihood: GaussianLikelihood | None = None,
|
|
440
|
+
noise_model_likelihood: NoiseModelLikelihood | None = None,
|
|
441
|
+
) -> dict[str, torch.Tensor] | None:
|
|
357
442
|
"""Loss function for DenoiSplit.
|
|
358
443
|
|
|
359
444
|
Parameters
|
|
@@ -430,7 +515,7 @@ def denoisplit_musplit_loss(
|
|
|
430
515
|
config: LVAELossConfig,
|
|
431
516
|
gaussian_likelihood: GaussianLikelihood,
|
|
432
517
|
noise_model_likelihood: NoiseModelLikelihood,
|
|
433
|
-
) ->
|
|
518
|
+
) -> dict[str, torch.Tensor] | None:
|
|
434
519
|
"""Loss function for DenoiSplit.
|
|
435
520
|
|
|
436
521
|
Parameters
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from .config import
|
|
1
|
+
from .config import MicroSplitDataConfig
|
|
2
2
|
from .lc_dataset import LCMultiChDloader
|
|
3
3
|
from .ms_dataset_ref import MultiChDloaderRef
|
|
4
4
|
from .multich_dataset import MultiChDloader
|
|
@@ -7,14 +7,14 @@ from .multifile_dataset import MultiFileDset
|
|
|
7
7
|
from .types import DataSplitType, DataType, TilingMode
|
|
8
8
|
|
|
9
9
|
__all__ = [
|
|
10
|
-
"
|
|
11
|
-
"
|
|
10
|
+
"DataSplitType",
|
|
11
|
+
"DataType",
|
|
12
12
|
"LCMultiChDloader",
|
|
13
|
-
"MultiFileDset",
|
|
14
|
-
"MultiCropDset",
|
|
15
|
-
"MultiChDloaderRef",
|
|
16
13
|
"LCMultiChDloaderRef",
|
|
17
|
-
"
|
|
18
|
-
"
|
|
14
|
+
"MicroSplitDataConfig",
|
|
15
|
+
"MultiChDloader",
|
|
16
|
+
"MultiChDloaderRef",
|
|
17
|
+
"MultiCropDset",
|
|
18
|
+
"MultiFileDset",
|
|
19
19
|
"TilingMode",
|
|
20
20
|
]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any,
|
|
1
|
+
from typing import Any, Union
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel, ConfigDict
|
|
4
4
|
|
|
@@ -6,70 +6,70 @@ from .types import DataSplitType, DataType, TilingMode
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
# TODO: check if any bool logic can be removed
|
|
9
|
-
class
|
|
10
|
-
model_config = ConfigDict(validate_assignment=True, extra="
|
|
9
|
+
class MicroSplitDataConfig(BaseModel):
|
|
10
|
+
model_config = ConfigDict(validate_assignment=True, extra="allow")
|
|
11
11
|
|
|
12
|
-
data_type:
|
|
12
|
+
data_type: Union[DataType, str] | None # TODO remove or refactor!!
|
|
13
13
|
"""Type of the dataset, should be one of DataType"""
|
|
14
14
|
|
|
15
|
-
depth3D:
|
|
15
|
+
depth3D: int | None = 1
|
|
16
16
|
"""Number of slices in 3D. If data is 2D depth3D is equal to 1"""
|
|
17
17
|
|
|
18
|
-
datasplit_type:
|
|
19
|
-
"""Whether to return training, validation or test split, should be one of
|
|
18
|
+
datasplit_type: DataSplitType | None = None
|
|
19
|
+
"""Whether to return training, validation or test split, should be one of
|
|
20
20
|
DataSplitType"""
|
|
21
21
|
|
|
22
|
-
num_channels:
|
|
22
|
+
num_channels: int | None = 2
|
|
23
23
|
"""Number of channels in the input"""
|
|
24
24
|
|
|
25
25
|
# TODO: remove ch*_fname parameters, should be parsed automatically from a name list
|
|
26
|
-
ch1_fname:
|
|
27
|
-
ch2_fname:
|
|
28
|
-
ch_input_fname:
|
|
26
|
+
ch1_fname: str | None = None
|
|
27
|
+
ch2_fname: str | None = None
|
|
28
|
+
ch_input_fname: str | None = None
|
|
29
29
|
|
|
30
|
-
input_is_sum:
|
|
30
|
+
input_is_sum: bool | None = False
|
|
31
31
|
"""Whether the input is the sum or average of channels"""
|
|
32
32
|
|
|
33
|
-
input_idx:
|
|
33
|
+
input_idx: int | None = None
|
|
34
34
|
"""Index of the channel where the input is stored in the data"""
|
|
35
35
|
|
|
36
|
-
target_idx_list:
|
|
36
|
+
target_idx_list: list[int] | None = None
|
|
37
37
|
"""Indices of the channels where the targets are stored in the data"""
|
|
38
38
|
|
|
39
39
|
# TODO: where are there used?
|
|
40
|
-
start_alpha:
|
|
41
|
-
end_alpha:
|
|
40
|
+
start_alpha: Any | None = None
|
|
41
|
+
end_alpha: Any | None = None
|
|
42
42
|
|
|
43
43
|
image_size: tuple # TODO: revisit, new model_config uses tuple
|
|
44
44
|
"""Size of one patch of data"""
|
|
45
45
|
|
|
46
|
-
grid_size:
|
|
46
|
+
grid_size: Union[int, tuple[int, int, int]] | None = None
|
|
47
47
|
"""Frame is divided into square grids of this size. A patch centered on a grid
|
|
48
48
|
having size `image_size` is returned. Grid size not used in training,
|
|
49
49
|
used only during val / test, grid size controls the overlap of the patches"""
|
|
50
50
|
|
|
51
|
-
empty_patch_replacement_enabled:
|
|
51
|
+
empty_patch_replacement_enabled: bool | None = False
|
|
52
52
|
"""Whether to replace the content of one of the channels
|
|
53
53
|
with background with given probability"""
|
|
54
|
-
empty_patch_replacement_channel_idx:
|
|
55
|
-
empty_patch_replacement_probab:
|
|
56
|
-
empty_patch_max_val_threshold:
|
|
54
|
+
empty_patch_replacement_channel_idx: Any | None = None
|
|
55
|
+
empty_patch_replacement_probab: Any | None = None
|
|
56
|
+
empty_patch_max_val_threshold: Any | None = None
|
|
57
57
|
|
|
58
|
-
uncorrelated_channels:
|
|
59
|
-
"""Replace the content in one of the channels with given probability to make
|
|
58
|
+
uncorrelated_channels: bool | None = False
|
|
59
|
+
"""Replace the content in one of the channels with given probability to make
|
|
60
60
|
channel content 'uncorrelated'"""
|
|
61
|
-
uncorrelated_channel_probab:
|
|
61
|
+
uncorrelated_channel_probab: float | None = 0.5
|
|
62
62
|
|
|
63
|
-
poisson_noise_factor:
|
|
63
|
+
poisson_noise_factor: float | None = -1
|
|
64
64
|
"""The added poisson noise factor"""
|
|
65
65
|
|
|
66
|
-
synthetic_gaussian_scale:
|
|
66
|
+
synthetic_gaussian_scale: float | None = 0.1
|
|
67
67
|
|
|
68
68
|
# TODO: set to True in training code, recheck
|
|
69
|
-
input_has_dependant_noise:
|
|
69
|
+
input_has_dependant_noise: bool | None = False
|
|
70
70
|
|
|
71
71
|
# TODO: sometimes max_val differs between runs with fixed seeds with noise enabled
|
|
72
|
-
enable_gaussian_noise:
|
|
72
|
+
enable_gaussian_noise: bool | None = False
|
|
73
73
|
"""Whether to enable gaussian noise"""
|
|
74
74
|
|
|
75
75
|
# TODO: is this parameter used?
|
|
@@ -80,44 +80,56 @@ class DatasetConfig(BaseModel):
|
|
|
80
80
|
deterministic_grid: Any = None
|
|
81
81
|
|
|
82
82
|
# TODO: why is this not used?
|
|
83
|
-
enable_rotation_aug:
|
|
83
|
+
enable_rotation_aug: bool | None = False
|
|
84
84
|
|
|
85
|
-
max_val:
|
|
86
|
-
"""Maximum data in the dataset. Is calculated for train split, and should be
|
|
85
|
+
max_val: Union[float, tuple] | None = None
|
|
86
|
+
"""Maximum data in the dataset. Is calculated for train split, and should be
|
|
87
87
|
externally set for val and test splits."""
|
|
88
88
|
|
|
89
89
|
overlapping_padding_kwargs: Any = None
|
|
90
90
|
"""Parameters for np.pad method"""
|
|
91
91
|
|
|
92
92
|
# TODO: remove this parameter, controls debug print
|
|
93
|
-
print_vars:
|
|
93
|
+
print_vars: bool | None = False
|
|
94
94
|
|
|
95
95
|
# Hard-coded parameters (used to be in the config file)
|
|
96
96
|
normalized_input: bool = True
|
|
97
97
|
"""If this is set to true, then one mean and stdev is used
|
|
98
98
|
for both channels. Otherwise, two different mean and stdev are used."""
|
|
99
|
-
use_one_mu_std:
|
|
99
|
+
use_one_mu_std: bool | None = True
|
|
100
100
|
|
|
101
101
|
# TODO: is this parameter used?
|
|
102
|
-
train_aug_rotate:
|
|
103
|
-
enable_random_cropping:
|
|
102
|
+
train_aug_rotate: bool | None = False
|
|
103
|
+
enable_random_cropping: bool | None = True
|
|
104
104
|
|
|
105
|
-
multiscale_lowres_count:
|
|
105
|
+
multiscale_lowres_count: int | None = None
|
|
106
106
|
"""Number of LC scales"""
|
|
107
107
|
|
|
108
|
-
tiling_mode:
|
|
108
|
+
tiling_mode: TilingMode | None = TilingMode.ShiftBoundary
|
|
109
109
|
|
|
110
|
-
target_separate_normalization:
|
|
110
|
+
target_separate_normalization: bool | None = True
|
|
111
111
|
|
|
112
|
-
mode_3D:
|
|
112
|
+
mode_3D: bool | None = False
|
|
113
113
|
"""If training in 3D mode or not"""
|
|
114
114
|
|
|
115
|
-
trainig_datausage_fraction:
|
|
115
|
+
trainig_datausage_fraction: float | None = 1.0
|
|
116
116
|
|
|
117
|
-
validtarget_random_fraction:
|
|
117
|
+
validtarget_random_fraction: float | None = None
|
|
118
118
|
|
|
119
|
-
validation_datausage_fraction:
|
|
119
|
+
validation_datausage_fraction: float | None = 1.0
|
|
120
120
|
|
|
121
|
-
random_flip_z_3D:
|
|
121
|
+
random_flip_z_3D: bool | None = False
|
|
122
122
|
|
|
123
|
-
padding_kwargs:
|
|
123
|
+
padding_kwargs: dict = {"mode": "reflect"} # TODO remove !!
|
|
124
|
+
|
|
125
|
+
def __init__(self, **data):
|
|
126
|
+
# Convert string data_type to enum if needed
|
|
127
|
+
if "data_type" in data and isinstance(data["data_type"], str):
|
|
128
|
+
try:
|
|
129
|
+
data["data_type"] = DataType[data["data_type"]]
|
|
130
|
+
except KeyError:
|
|
131
|
+
# Keep original value to let validation handle the error
|
|
132
|
+
pass
|
|
133
|
+
super().__init__(**data)
|
|
134
|
+
|
|
135
|
+
# TODO add validators !
|
|
@@ -2,23 +2,29 @@
|
|
|
2
2
|
A place for Datasets and Dataloaders.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
import logging
|
|
6
|
+
import math
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Callable, Optional, Union
|
|
6
9
|
|
|
7
10
|
import numpy as np
|
|
8
11
|
from skimage.transform import resize
|
|
9
12
|
|
|
10
|
-
from .config import
|
|
13
|
+
from .config import MicroSplitDataConfig
|
|
11
14
|
from .multich_dataset import MultiChDloader
|
|
12
15
|
|
|
13
16
|
|
|
14
17
|
class LCMultiChDloader(MultiChDloader):
|
|
18
|
+
"""Multi-channel dataset loader for LC-style datasets."""
|
|
19
|
+
|
|
15
20
|
def __init__(
|
|
16
21
|
self,
|
|
17
|
-
data_config:
|
|
18
|
-
|
|
19
|
-
load_data_fn: Callable,
|
|
20
|
-
val_fraction=
|
|
21
|
-
test_fraction=
|
|
22
|
+
data_config: MicroSplitDataConfig,
|
|
23
|
+
datapath: Union[str, Path],
|
|
24
|
+
load_data_fn: Optional[Callable] = None,
|
|
25
|
+
val_fraction: float = 0.1,
|
|
26
|
+
test_fraction: float = 0.1,
|
|
27
|
+
allow_generation: bool = False,
|
|
22
28
|
):
|
|
23
29
|
self._padding_kwargs = (
|
|
24
30
|
data_config.padding_kwargs # mode=padding_mode, constant_values=constant_value
|
|
@@ -27,7 +33,7 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
27
33
|
|
|
28
34
|
super().__init__(
|
|
29
35
|
data_config,
|
|
30
|
-
|
|
36
|
+
datapath,
|
|
31
37
|
load_data_fn=load_data_fn,
|
|
32
38
|
val_fraction=val_fraction,
|
|
33
39
|
test_fraction=test_fraction,
|
|
@@ -111,8 +117,8 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
111
117
|
return msg
|
|
112
118
|
|
|
113
119
|
def _load_scaled_img(
|
|
114
|
-
self, scaled_index, index: Union[int,
|
|
115
|
-
) ->
|
|
120
|
+
self, scaled_index, index: Union[int, tuple[int, int]]
|
|
121
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
116
122
|
if isinstance(index, int):
|
|
117
123
|
idx = index
|
|
118
124
|
else:
|
|
@@ -131,7 +137,7 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
131
137
|
imgs = tuple([img + noise[0] * factor for img in imgs])
|
|
132
138
|
return imgs
|
|
133
139
|
|
|
134
|
-
def _crop_img(self, img: np.ndarray, patch_start_loc:
|
|
140
|
+
def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
|
|
135
141
|
"""
|
|
136
142
|
Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
|
|
137
143
|
the cropped image will be smaller than self._img_sz * self._img_sz
|
|
@@ -202,7 +208,7 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
202
208
|
)
|
|
203
209
|
return output_img_tuples, cropped_noise_tuples
|
|
204
210
|
|
|
205
|
-
def __getitem__(self, index: Union[int,
|
|
211
|
+
def __getitem__(self, index: Union[int, tuple[int, int]]):
|
|
206
212
|
img_tuples, noise_tuples = self._get_img(index)
|
|
207
213
|
if self._uncorrelated_channels:
|
|
208
214
|
assert (
|
|
@@ -10,7 +10,7 @@ from typing import Callable, Union
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
from skimage.transform import resize
|
|
12
12
|
|
|
13
|
-
from .config import
|
|
13
|
+
from .config import MicroSplitDataConfig
|
|
14
14
|
from .types import DataSplitType, TilingMode
|
|
15
15
|
from .utils.empty_patch_fetcher import EmptyPatchFetcher
|
|
16
16
|
from .utils.index_manager import GridIndexManagerRef
|
|
@@ -19,7 +19,7 @@ from .utils.index_manager import GridIndexManagerRef
|
|
|
19
19
|
class MultiChDloaderRef:
|
|
20
20
|
def __init__(
|
|
21
21
|
self,
|
|
22
|
-
data_config:
|
|
22
|
+
data_config: MicroSplitDataConfig,
|
|
23
23
|
fpath: str,
|
|
24
24
|
load_data_fn: Callable,
|
|
25
25
|
val_fraction: float = None,
|
|
@@ -171,8 +171,8 @@ class MultiChDloaderRef:
|
|
|
171
171
|
|
|
172
172
|
def load_data(
|
|
173
173
|
self,
|
|
174
|
-
data_config,
|
|
175
|
-
datasplit_type,
|
|
174
|
+
data_config: MicroSplitDataConfig,
|
|
175
|
+
datasplit_type: DataSplitType,
|
|
176
176
|
load_data_fn: Callable,
|
|
177
177
|
val_fraction=None,
|
|
178
178
|
test_fraction=None,
|
|
@@ -813,7 +813,7 @@ class MultiChDloaderRef:
|
|
|
813
813
|
class LCMultiChDloaderRef(MultiChDloaderRef):
|
|
814
814
|
def __init__(
|
|
815
815
|
self,
|
|
816
|
-
data_config:
|
|
816
|
+
data_config: MicroSplitDataConfig,
|
|
817
817
|
fpath: str,
|
|
818
818
|
load_data_fn: Callable,
|
|
819
819
|
val_fraction=None,
|