careamics 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +6 -12
- careamics/cli/conf.py +18 -3
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +51 -16
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +7 -3
- careamics/config/configuration.py +9 -8
- careamics/config/configuration_factories.py +843 -29
- careamics/config/data/data_model.py +1 -2
- careamics/config/data/ng_data_model.py +1 -2
- careamics/config/inference_model.py +1 -2
- careamics/config/likelihood_model.py +2 -2
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +26 -1
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +6 -36
- careamics/config/transformations/normalize_model.py +1 -2
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +23 -0
- careamics/lightning/lightning_module.py +161 -61
- careamics/lightning/microsplit_data_module.py +631 -0
- careamics/lightning/predict_data_module.py +8 -1
- careamics/lightning/train_data_module.py +19 -8
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +85 -0
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/model_io/bmz_io.py +9 -5
- careamics/models/lvae/likelihoods.py +30 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/prediction_outputs.py +48 -3
- careamics/prediction_utils/stitch_prediction.py +71 -0
- careamics/transforms/xy_random_rotate90.py +1 -1
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -15
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/RECORD +57 -55
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
careamics/losses/__init__.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
__all__ = [
|
|
4
4
|
"denoisplit_loss",
|
|
5
5
|
"denoisplit_musplit_loss",
|
|
6
|
+
"hdn_loss",
|
|
6
7
|
"loss_factory",
|
|
7
8
|
"mae_loss",
|
|
8
9
|
"mse_loss",
|
|
@@ -12,4 +13,9 @@ __all__ = [
|
|
|
12
13
|
|
|
13
14
|
from .fcn.losses import mae_loss, mse_loss, n2v_loss
|
|
14
15
|
from .loss_factory import loss_factory
|
|
15
|
-
from .lvae.losses import
|
|
16
|
+
from .lvae.losses import (
|
|
17
|
+
denoisplit_loss,
|
|
18
|
+
denoisplit_musplit_loss,
|
|
19
|
+
hdn_loss,
|
|
20
|
+
musplit_loss,
|
|
21
|
+
)
|
careamics/losses/loss_factory.py
CHANGED
|
@@ -14,7 +14,12 @@ from torch import Tensor as tensor
|
|
|
14
14
|
|
|
15
15
|
from ..config.support import SupportedLoss
|
|
16
16
|
from .fcn.losses import mae_loss, mse_loss, n2v_loss
|
|
17
|
-
from .lvae.losses import
|
|
17
|
+
from .lvae.losses import (
|
|
18
|
+
denoisplit_loss,
|
|
19
|
+
denoisplit_musplit_loss,
|
|
20
|
+
hdn_loss,
|
|
21
|
+
musplit_loss,
|
|
22
|
+
)
|
|
18
23
|
|
|
19
24
|
|
|
20
25
|
@dataclass
|
|
@@ -59,6 +64,9 @@ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
|
|
|
59
64
|
elif loss == SupportedLoss.MSE:
|
|
60
65
|
return mse_loss
|
|
61
66
|
|
|
67
|
+
elif loss == SupportedLoss.HDN:
|
|
68
|
+
return hdn_loss
|
|
69
|
+
|
|
62
70
|
elif loss == SupportedLoss.MUSPLIT:
|
|
63
71
|
return musplit_loss
|
|
64
72
|
|
careamics/losses/lvae/losses.py
CHANGED
|
@@ -89,6 +89,7 @@ def _reconstruction_loss_musplit_denoisplit(
|
|
|
89
89
|
if predictions.shape[1] == 2 * targets.shape[1]:
|
|
90
90
|
# predictions contain both mean and log-variance
|
|
91
91
|
pred_mean, _ = predictions.chunk(2, dim=1)
|
|
92
|
+
# TODO if this condition does not hold, everything breaks later!
|
|
92
93
|
else:
|
|
93
94
|
pred_mean = predictions
|
|
94
95
|
|
|
@@ -269,6 +270,90 @@ def _get_kl_divergence_loss_denoisplit(
|
|
|
269
270
|
# - `__init__` method initializes the loss parameters now contained in
|
|
270
271
|
# the `LVAELossParameters` class
|
|
271
272
|
# NOTE: same for the other loss functions
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def hdn_loss(
|
|
276
|
+
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
277
|
+
targets: torch.Tensor,
|
|
278
|
+
config: LVAELossConfig,
|
|
279
|
+
gaussian_likelihood: GaussianLikelihood | None,
|
|
280
|
+
noise_model_likelihood: NoiseModelLikelihood | None,
|
|
281
|
+
) -> dict[str, torch.Tensor] | None:
|
|
282
|
+
"""Loss function for HDN.
|
|
283
|
+
|
|
284
|
+
Parameters
|
|
285
|
+
----------
|
|
286
|
+
model_outputs : tuple[torch.Tensor, dict[str, Any]]
|
|
287
|
+
Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
|
|
288
|
+
and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
|
|
289
|
+
targets : torch.Tensor
|
|
290
|
+
The target image used to compute the reconstruction loss. In this case we use
|
|
291
|
+
the input patch itself as target. Shape is (B, `target_ch`, [Z], Y, X).
|
|
292
|
+
config : LVAELossConfig
|
|
293
|
+
The config for loss function containing all loss hyperparameters.
|
|
294
|
+
gaussian_likelihood : GaussianLikelihood
|
|
295
|
+
The Gaussian likelihood object.
|
|
296
|
+
noise_model_likelihood : NoiseModelLikelihood
|
|
297
|
+
The noise model likelihood object.
|
|
298
|
+
|
|
299
|
+
Returns
|
|
300
|
+
-------
|
|
301
|
+
output : Optional[dict[str, torch.Tensor]]
|
|
302
|
+
A dictionary containing the overall loss `["loss"]`, the reconstruction loss
|
|
303
|
+
`["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
|
|
304
|
+
"""
|
|
305
|
+
if gaussian_likelihood is not None:
|
|
306
|
+
likelihood = gaussian_likelihood
|
|
307
|
+
elif noise_model_likelihood is not None:
|
|
308
|
+
likelihood = noise_model_likelihood
|
|
309
|
+
else:
|
|
310
|
+
raise ValueError("Invalid likelihood object.")
|
|
311
|
+
# TODO refactor loss signature
|
|
312
|
+
predictions, td_data = model_outputs
|
|
313
|
+
|
|
314
|
+
# Reconstruction loss computation
|
|
315
|
+
recons_loss = config.reconstruction_weight * get_reconstruction_loss(
|
|
316
|
+
reconstruction=predictions,
|
|
317
|
+
target=targets,
|
|
318
|
+
likelihood_obj=likelihood,
|
|
319
|
+
)
|
|
320
|
+
if torch.isnan(recons_loss).any():
|
|
321
|
+
recons_loss = 0.0
|
|
322
|
+
|
|
323
|
+
# KL loss computation
|
|
324
|
+
kl_weight = get_kl_weight(
|
|
325
|
+
config.kl_params.annealing,
|
|
326
|
+
config.kl_params.start,
|
|
327
|
+
config.kl_params.annealtime,
|
|
328
|
+
config.kl_weight,
|
|
329
|
+
config.kl_params.current_epoch,
|
|
330
|
+
)
|
|
331
|
+
kl_loss = (
|
|
332
|
+
_get_kl_divergence_loss_denoisplit(
|
|
333
|
+
topdown_data=td_data,
|
|
334
|
+
img_shape=targets.shape[2:],
|
|
335
|
+
kl_type=config.kl_params.loss_type,
|
|
336
|
+
)
|
|
337
|
+
* kl_weight
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
net_loss = recons_loss + kl_loss # TODO add check that losses coefs sum to 1
|
|
341
|
+
output = {
|
|
342
|
+
"loss": net_loss,
|
|
343
|
+
"reconstruction_loss": (
|
|
344
|
+
recons_loss.detach()
|
|
345
|
+
if isinstance(recons_loss, torch.Tensor)
|
|
346
|
+
else recons_loss
|
|
347
|
+
),
|
|
348
|
+
"kl_loss": kl_loss.detach(),
|
|
349
|
+
}
|
|
350
|
+
# https://github.com/openai/vdvae/blob/main/train.py#L26
|
|
351
|
+
if torch.isnan(net_loss).any():
|
|
352
|
+
return None
|
|
353
|
+
|
|
354
|
+
return output
|
|
355
|
+
|
|
356
|
+
|
|
272
357
|
def musplit_loss(
|
|
273
358
|
model_outputs: tuple[torch.Tensor, dict[str, Any]],
|
|
274
359
|
targets: torch.Tensor,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from .config import
|
|
1
|
+
from .config import MicroSplitDataConfig
|
|
2
2
|
from .lc_dataset import LCMultiChDloader
|
|
3
3
|
from .ms_dataset_ref import MultiChDloaderRef
|
|
4
4
|
from .multich_dataset import MultiChDloader
|
|
@@ -7,14 +7,14 @@ from .multifile_dataset import MultiFileDset
|
|
|
7
7
|
from .types import DataSplitType, DataType, TilingMode
|
|
8
8
|
|
|
9
9
|
__all__ = [
|
|
10
|
-
"
|
|
11
|
-
"
|
|
10
|
+
"DataSplitType",
|
|
11
|
+
"DataType",
|
|
12
12
|
"LCMultiChDloader",
|
|
13
|
-
"MultiFileDset",
|
|
14
|
-
"MultiCropDset",
|
|
15
|
-
"MultiChDloaderRef",
|
|
16
13
|
"LCMultiChDloaderRef",
|
|
17
|
-
"
|
|
18
|
-
"
|
|
14
|
+
"MicroSplitDataConfig",
|
|
15
|
+
"MultiChDloader",
|
|
16
|
+
"MultiChDloaderRef",
|
|
17
|
+
"MultiCropDset",
|
|
18
|
+
"MultiFileDset",
|
|
19
19
|
"TilingMode",
|
|
20
20
|
]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any,
|
|
1
|
+
from typing import Any, Union
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel, ConfigDict
|
|
4
4
|
|
|
@@ -6,70 +6,70 @@ from .types import DataSplitType, DataType, TilingMode
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
# TODO: check if any bool logic can be removed
|
|
9
|
-
class
|
|
10
|
-
model_config = ConfigDict(validate_assignment=True, extra="
|
|
9
|
+
class MicroSplitDataConfig(BaseModel):
|
|
10
|
+
model_config = ConfigDict(validate_assignment=True, extra="allow")
|
|
11
11
|
|
|
12
|
-
data_type:
|
|
12
|
+
data_type: Union[DataType, str] | None # TODO remove or refactor!!
|
|
13
13
|
"""Type of the dataset, should be one of DataType"""
|
|
14
14
|
|
|
15
|
-
depth3D:
|
|
15
|
+
depth3D: int | None = 1
|
|
16
16
|
"""Number of slices in 3D. If data is 2D depth3D is equal to 1"""
|
|
17
17
|
|
|
18
|
-
datasplit_type:
|
|
19
|
-
"""Whether to return training, validation or test split, should be one of
|
|
18
|
+
datasplit_type: DataSplitType | None = None
|
|
19
|
+
"""Whether to return training, validation or test split, should be one of
|
|
20
20
|
DataSplitType"""
|
|
21
21
|
|
|
22
|
-
num_channels:
|
|
22
|
+
num_channels: int | None = 2
|
|
23
23
|
"""Number of channels in the input"""
|
|
24
24
|
|
|
25
25
|
# TODO: remove ch*_fname parameters, should be parsed automatically from a name list
|
|
26
|
-
ch1_fname:
|
|
27
|
-
ch2_fname:
|
|
28
|
-
ch_input_fname:
|
|
26
|
+
ch1_fname: str | None = None
|
|
27
|
+
ch2_fname: str | None = None
|
|
28
|
+
ch_input_fname: str | None = None
|
|
29
29
|
|
|
30
|
-
input_is_sum:
|
|
30
|
+
input_is_sum: bool | None = False
|
|
31
31
|
"""Whether the input is the sum or average of channels"""
|
|
32
32
|
|
|
33
|
-
input_idx:
|
|
33
|
+
input_idx: int | None = None
|
|
34
34
|
"""Index of the channel where the input is stored in the data"""
|
|
35
35
|
|
|
36
|
-
target_idx_list:
|
|
36
|
+
target_idx_list: list[int] | None = None
|
|
37
37
|
"""Indices of the channels where the targets are stored in the data"""
|
|
38
38
|
|
|
39
39
|
# TODO: where are there used?
|
|
40
|
-
start_alpha:
|
|
41
|
-
end_alpha:
|
|
40
|
+
start_alpha: Any | None = None
|
|
41
|
+
end_alpha: Any | None = None
|
|
42
42
|
|
|
43
43
|
image_size: tuple # TODO: revisit, new model_config uses tuple
|
|
44
44
|
"""Size of one patch of data"""
|
|
45
45
|
|
|
46
|
-
grid_size:
|
|
46
|
+
grid_size: Union[int, tuple[int, int, int]] | None = None
|
|
47
47
|
"""Frame is divided into square grids of this size. A patch centered on a grid
|
|
48
48
|
having size `image_size` is returned. Grid size not used in training,
|
|
49
49
|
used only during val / test, grid size controls the overlap of the patches"""
|
|
50
50
|
|
|
51
|
-
empty_patch_replacement_enabled:
|
|
51
|
+
empty_patch_replacement_enabled: bool | None = False
|
|
52
52
|
"""Whether to replace the content of one of the channels
|
|
53
53
|
with background with given probability"""
|
|
54
|
-
empty_patch_replacement_channel_idx:
|
|
55
|
-
empty_patch_replacement_probab:
|
|
56
|
-
empty_patch_max_val_threshold:
|
|
54
|
+
empty_patch_replacement_channel_idx: Any | None = None
|
|
55
|
+
empty_patch_replacement_probab: Any | None = None
|
|
56
|
+
empty_patch_max_val_threshold: Any | None = None
|
|
57
57
|
|
|
58
|
-
uncorrelated_channels:
|
|
59
|
-
"""Replace the content in one of the channels with given probability to make
|
|
58
|
+
uncorrelated_channels: bool | None = False
|
|
59
|
+
"""Replace the content in one of the channels with given probability to make
|
|
60
60
|
channel content 'uncorrelated'"""
|
|
61
|
-
uncorrelated_channel_probab:
|
|
61
|
+
uncorrelated_channel_probab: float | None = 0.5
|
|
62
62
|
|
|
63
|
-
poisson_noise_factor:
|
|
63
|
+
poisson_noise_factor: float | None = -1
|
|
64
64
|
"""The added poisson noise factor"""
|
|
65
65
|
|
|
66
|
-
synthetic_gaussian_scale:
|
|
66
|
+
synthetic_gaussian_scale: float | None = 0.1
|
|
67
67
|
|
|
68
68
|
# TODO: set to True in training code, recheck
|
|
69
|
-
input_has_dependant_noise:
|
|
69
|
+
input_has_dependant_noise: bool | None = False
|
|
70
70
|
|
|
71
71
|
# TODO: sometimes max_val differs between runs with fixed seeds with noise enabled
|
|
72
|
-
enable_gaussian_noise:
|
|
72
|
+
enable_gaussian_noise: bool | None = False
|
|
73
73
|
"""Whether to enable gaussian noise"""
|
|
74
74
|
|
|
75
75
|
# TODO: is this parameter used?
|
|
@@ -80,44 +80,56 @@ class DatasetConfig(BaseModel):
|
|
|
80
80
|
deterministic_grid: Any = None
|
|
81
81
|
|
|
82
82
|
# TODO: why is this not used?
|
|
83
|
-
enable_rotation_aug:
|
|
83
|
+
enable_rotation_aug: bool | None = False
|
|
84
84
|
|
|
85
|
-
max_val:
|
|
86
|
-
"""Maximum data in the dataset. Is calculated for train split, and should be
|
|
85
|
+
max_val: Union[float, tuple] | None = None
|
|
86
|
+
"""Maximum data in the dataset. Is calculated for train split, and should be
|
|
87
87
|
externally set for val and test splits."""
|
|
88
88
|
|
|
89
89
|
overlapping_padding_kwargs: Any = None
|
|
90
90
|
"""Parameters for np.pad method"""
|
|
91
91
|
|
|
92
92
|
# TODO: remove this parameter, controls debug print
|
|
93
|
-
print_vars:
|
|
93
|
+
print_vars: bool | None = False
|
|
94
94
|
|
|
95
95
|
# Hard-coded parameters (used to be in the config file)
|
|
96
96
|
normalized_input: bool = True
|
|
97
97
|
"""If this is set to true, then one mean and stdev is used
|
|
98
98
|
for both channels. Otherwise, two different mean and stdev are used."""
|
|
99
|
-
use_one_mu_std:
|
|
99
|
+
use_one_mu_std: bool | None = True
|
|
100
100
|
|
|
101
101
|
# TODO: is this parameter used?
|
|
102
|
-
train_aug_rotate:
|
|
103
|
-
enable_random_cropping:
|
|
102
|
+
train_aug_rotate: bool | None = False
|
|
103
|
+
enable_random_cropping: bool | None = True
|
|
104
104
|
|
|
105
|
-
multiscale_lowres_count:
|
|
105
|
+
multiscale_lowres_count: int | None = None
|
|
106
106
|
"""Number of LC scales"""
|
|
107
107
|
|
|
108
|
-
tiling_mode:
|
|
108
|
+
tiling_mode: TilingMode | None = TilingMode.ShiftBoundary
|
|
109
109
|
|
|
110
|
-
target_separate_normalization:
|
|
110
|
+
target_separate_normalization: bool | None = True
|
|
111
111
|
|
|
112
|
-
mode_3D:
|
|
112
|
+
mode_3D: bool | None = False
|
|
113
113
|
"""If training in 3D mode or not"""
|
|
114
114
|
|
|
115
|
-
trainig_datausage_fraction:
|
|
115
|
+
trainig_datausage_fraction: float | None = 1.0
|
|
116
116
|
|
|
117
|
-
validtarget_random_fraction:
|
|
117
|
+
validtarget_random_fraction: float | None = None
|
|
118
118
|
|
|
119
|
-
validation_datausage_fraction:
|
|
119
|
+
validation_datausage_fraction: float | None = 1.0
|
|
120
120
|
|
|
121
|
-
random_flip_z_3D:
|
|
121
|
+
random_flip_z_3D: bool | None = False
|
|
122
122
|
|
|
123
|
-
padding_kwargs:
|
|
123
|
+
padding_kwargs: dict = {"mode": "reflect"} # TODO remove !!
|
|
124
|
+
|
|
125
|
+
def __init__(self, **data):
|
|
126
|
+
# Convert string data_type to enum if needed
|
|
127
|
+
if "data_type" in data and isinstance(data["data_type"], str):
|
|
128
|
+
try:
|
|
129
|
+
data["data_type"] = DataType[data["data_type"]]
|
|
130
|
+
except KeyError:
|
|
131
|
+
# Keep original value to let validation handle the error
|
|
132
|
+
pass
|
|
133
|
+
super().__init__(**data)
|
|
134
|
+
|
|
135
|
+
# TODO add validators !
|
|
@@ -2,23 +2,29 @@
|
|
|
2
2
|
A place for Datasets and Dataloaders.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
import logging
|
|
6
|
+
import math
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Callable, Optional, Union
|
|
6
9
|
|
|
7
10
|
import numpy as np
|
|
8
11
|
from skimage.transform import resize
|
|
9
12
|
|
|
10
|
-
from .config import
|
|
13
|
+
from .config import MicroSplitDataConfig
|
|
11
14
|
from .multich_dataset import MultiChDloader
|
|
12
15
|
|
|
13
16
|
|
|
14
17
|
class LCMultiChDloader(MultiChDloader):
|
|
18
|
+
"""Multi-channel dataset loader for LC-style datasets."""
|
|
19
|
+
|
|
15
20
|
def __init__(
|
|
16
21
|
self,
|
|
17
|
-
data_config:
|
|
18
|
-
|
|
19
|
-
load_data_fn: Callable,
|
|
20
|
-
val_fraction=
|
|
21
|
-
test_fraction=
|
|
22
|
+
data_config: MicroSplitDataConfig,
|
|
23
|
+
datapath: Union[str, Path],
|
|
24
|
+
load_data_fn: Optional[Callable] = None,
|
|
25
|
+
val_fraction: float = 0.1,
|
|
26
|
+
test_fraction: float = 0.1,
|
|
27
|
+
allow_generation: bool = False,
|
|
22
28
|
):
|
|
23
29
|
self._padding_kwargs = (
|
|
24
30
|
data_config.padding_kwargs # mode=padding_mode, constant_values=constant_value
|
|
@@ -27,7 +33,7 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
27
33
|
|
|
28
34
|
super().__init__(
|
|
29
35
|
data_config,
|
|
30
|
-
|
|
36
|
+
datapath,
|
|
31
37
|
load_data_fn=load_data_fn,
|
|
32
38
|
val_fraction=val_fraction,
|
|
33
39
|
test_fraction=test_fraction,
|
|
@@ -111,8 +117,8 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
111
117
|
return msg
|
|
112
118
|
|
|
113
119
|
def _load_scaled_img(
|
|
114
|
-
self, scaled_index, index: Union[int,
|
|
115
|
-
) ->
|
|
120
|
+
self, scaled_index, index: Union[int, tuple[int, int]]
|
|
121
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
116
122
|
if isinstance(index, int):
|
|
117
123
|
idx = index
|
|
118
124
|
else:
|
|
@@ -131,7 +137,7 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
131
137
|
imgs = tuple([img + noise[0] * factor for img in imgs])
|
|
132
138
|
return imgs
|
|
133
139
|
|
|
134
|
-
def _crop_img(self, img: np.ndarray, patch_start_loc:
|
|
140
|
+
def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
|
|
135
141
|
"""
|
|
136
142
|
Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
|
|
137
143
|
the cropped image will be smaller than self._img_sz * self._img_sz
|
|
@@ -202,7 +208,7 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
202
208
|
)
|
|
203
209
|
return output_img_tuples, cropped_noise_tuples
|
|
204
210
|
|
|
205
|
-
def __getitem__(self, index: Union[int,
|
|
211
|
+
def __getitem__(self, index: Union[int, tuple[int, int]]):
|
|
206
212
|
img_tuples, noise_tuples = self._get_img(index)
|
|
207
213
|
if self._uncorrelated_channels:
|
|
208
214
|
assert (
|
|
@@ -10,7 +10,7 @@ from typing import Callable, Union
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
from skimage.transform import resize
|
|
12
12
|
|
|
13
|
-
from .config import
|
|
13
|
+
from .config import MicroSplitDataConfig
|
|
14
14
|
from .types import DataSplitType, TilingMode
|
|
15
15
|
from .utils.empty_patch_fetcher import EmptyPatchFetcher
|
|
16
16
|
from .utils.index_manager import GridIndexManagerRef
|
|
@@ -19,7 +19,7 @@ from .utils.index_manager import GridIndexManagerRef
|
|
|
19
19
|
class MultiChDloaderRef:
|
|
20
20
|
def __init__(
|
|
21
21
|
self,
|
|
22
|
-
data_config:
|
|
22
|
+
data_config: MicroSplitDataConfig,
|
|
23
23
|
fpath: str,
|
|
24
24
|
load_data_fn: Callable,
|
|
25
25
|
val_fraction: float = None,
|
|
@@ -171,8 +171,8 @@ class MultiChDloaderRef:
|
|
|
171
171
|
|
|
172
172
|
def load_data(
|
|
173
173
|
self,
|
|
174
|
-
data_config,
|
|
175
|
-
datasplit_type,
|
|
174
|
+
data_config: MicroSplitDataConfig,
|
|
175
|
+
datasplit_type: DataSplitType,
|
|
176
176
|
load_data_fn: Callable,
|
|
177
177
|
val_fraction=None,
|
|
178
178
|
test_fraction=None,
|
|
@@ -813,7 +813,7 @@ class MultiChDloaderRef:
|
|
|
813
813
|
class LCMultiChDloaderRef(MultiChDloaderRef):
|
|
814
814
|
def __init__(
|
|
815
815
|
self,
|
|
816
|
-
data_config:
|
|
816
|
+
data_config: MicroSplitDataConfig,
|
|
817
817
|
fpath: str,
|
|
818
818
|
load_data_fn: Callable,
|
|
819
819
|
val_fraction=None,
|
|
@@ -2,29 +2,35 @@
|
|
|
2
2
|
A place for Datasets and Dataloaders.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Optional, Union
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from torch.utils.data import Dataset
|
|
8
11
|
|
|
9
12
|
from .utils.empty_patch_fetcher import EmptyPatchFetcher
|
|
10
13
|
from .utils.index_manager import GridIndexManager
|
|
11
14
|
from .utils.index_switcher import IndexSwitcher
|
|
12
|
-
from .config import
|
|
15
|
+
from .config import MicroSplitDataConfig
|
|
13
16
|
from .types import DataSplitType, TilingMode
|
|
14
17
|
|
|
15
18
|
|
|
16
|
-
class MultiChDloader:
|
|
19
|
+
class MultiChDloader(Dataset):
|
|
20
|
+
"""Multi-channel dataset loader."""
|
|
21
|
+
|
|
17
22
|
def __init__(
|
|
18
23
|
self,
|
|
19
|
-
data_config:
|
|
20
|
-
|
|
21
|
-
load_data_fn: Callable,
|
|
22
|
-
val_fraction: float =
|
|
23
|
-
test_fraction: float =
|
|
24
|
+
data_config: MicroSplitDataConfig,
|
|
25
|
+
datapath: Union[str, Path],
|
|
26
|
+
load_data_fn: Optional[Callable] = None,
|
|
27
|
+
val_fraction: float = 0.1,
|
|
28
|
+
test_fraction: float = 0.1,
|
|
29
|
+
allow_generation: bool = False,
|
|
24
30
|
):
|
|
25
31
|
""" """
|
|
26
32
|
self._data_type = data_config.data_type
|
|
27
|
-
self._fpath =
|
|
33
|
+
self._fpath = datapath
|
|
28
34
|
self._data = self._noise_data = None
|
|
29
35
|
self.Z = 1
|
|
30
36
|
self._5Ddata = False
|
|
@@ -395,7 +401,7 @@ class MultiChDloader:
|
|
|
395
401
|
)
|
|
396
402
|
|
|
397
403
|
def get_idx_manager_shapes(
|
|
398
|
-
self, patch_size: int, grid_size: Union[int,
|
|
404
|
+
self, patch_size: int, grid_size: Union[int, tuple[int, int, int]]
|
|
399
405
|
):
|
|
400
406
|
numC = self._data.shape[-1]
|
|
401
407
|
if self._5Ddata:
|
|
@@ -415,7 +421,7 @@ class MultiChDloader:
|
|
|
415
421
|
|
|
416
422
|
return patch_shape, grid_shape
|
|
417
423
|
|
|
418
|
-
def set_img_sz(self, image_size, grid_size: Union[int,
|
|
424
|
+
def set_img_sz(self, image_size, grid_size: Union[int, tuple[int, int, int]]):
|
|
419
425
|
"""
|
|
420
426
|
If one wants to change the image size on the go, then this can be used.
|
|
421
427
|
Args:
|
|
@@ -519,7 +525,7 @@ class MultiChDloader:
|
|
|
519
525
|
},
|
|
520
526
|
)
|
|
521
527
|
|
|
522
|
-
def _crop_img(self, img: np.ndarray, patch_start_loc:
|
|
528
|
+
def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
|
|
523
529
|
if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
|
|
524
530
|
# In training, this is used.
|
|
525
531
|
# NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
|
|
@@ -600,7 +606,7 @@ class MultiChDloader:
|
|
|
600
606
|
return new_img
|
|
601
607
|
|
|
602
608
|
def _crop_flip_img(
|
|
603
|
-
self, img: np.ndarray, patch_start_loc:
|
|
609
|
+
self, img: np.ndarray, patch_start_loc: tuple, h_flip: bool, w_flip: bool
|
|
604
610
|
):
|
|
605
611
|
new_img = self._crop_img(img, patch_start_loc)
|
|
606
612
|
if h_flip:
|
|
@@ -611,8 +617,8 @@ class MultiChDloader:
|
|
|
611
617
|
return new_img.astype(np.float32)
|
|
612
618
|
|
|
613
619
|
def _load_img(
|
|
614
|
-
self, index: Union[int,
|
|
615
|
-
) ->
|
|
620
|
+
self, index: Union[int, tuple[int, int]]
|
|
621
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
616
622
|
"""
|
|
617
623
|
Returns the channels and also the respective noise channels.
|
|
618
624
|
"""
|
|
@@ -806,7 +812,7 @@ class MultiChDloader:
|
|
|
806
812
|
w_start = 0
|
|
807
813
|
return h_start, w_start
|
|
808
814
|
|
|
809
|
-
def _get_img(self, index: Union[int,
|
|
815
|
+
def _get_img(self, index: Union[int, tuple[int, int]]):
|
|
810
816
|
"""
|
|
811
817
|
Loads an image.
|
|
812
818
|
Crops the image such that cropped image has content.
|
|
@@ -1056,8 +1062,8 @@ class MultiChDloader:
|
|
|
1056
1062
|
return img_tuples, noise_tuples
|
|
1057
1063
|
|
|
1058
1064
|
def __getitem__(
|
|
1059
|
-
self, index: Union[int,
|
|
1060
|
-
) ->
|
|
1065
|
+
self, index: Union[int, tuple[int, int]]
|
|
1066
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
1061
1067
|
# Vera: input can be both real microscopic image and two separate channels that are summed in the code
|
|
1062
1068
|
|
|
1063
1069
|
if self._train_index_switcher is not None:
|
|
@@ -4,7 +4,7 @@ from typing import Callable, Union
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from numpy.typing import NDArray
|
|
6
6
|
|
|
7
|
-
from .config import
|
|
7
|
+
from .config import MicroSplitDataConfig
|
|
8
8
|
from .lc_dataset import LCMultiChDloader
|
|
9
9
|
from .multich_dataset import MultiChDloader
|
|
10
10
|
from .types import DataSplitType
|
|
@@ -82,7 +82,7 @@ class SingleFileLCDset(LCMultiChDloader):
|
|
|
82
82
|
def __init__(
|
|
83
83
|
self,
|
|
84
84
|
preloaded_data: NDArray,
|
|
85
|
-
data_config:
|
|
85
|
+
data_config: MicroSplitDataConfig,
|
|
86
86
|
fpath: str,
|
|
87
87
|
load_data_fn: Callable,
|
|
88
88
|
val_fraction=None,
|
|
@@ -106,7 +106,7 @@ class SingleFileLCDset(LCMultiChDloader):
|
|
|
106
106
|
|
|
107
107
|
def load_data(
|
|
108
108
|
self,
|
|
109
|
-
data_config:
|
|
109
|
+
data_config: MicroSplitDataConfig,
|
|
110
110
|
datasplit_type: DataSplitType,
|
|
111
111
|
load_data_fn: Callable,
|
|
112
112
|
val_fraction=None,
|
|
@@ -124,7 +124,7 @@ class SingleFileDset(MultiChDloader):
|
|
|
124
124
|
def __init__(
|
|
125
125
|
self,
|
|
126
126
|
preloaded_data: NDArray,
|
|
127
|
-
data_config:
|
|
127
|
+
data_config: MicroSplitDataConfig,
|
|
128
128
|
fpath: str,
|
|
129
129
|
load_data_fn: Callable,
|
|
130
130
|
val_fraction=None,
|
|
@@ -148,7 +148,7 @@ class SingleFileDset(MultiChDloader):
|
|
|
148
148
|
|
|
149
149
|
def load_data(
|
|
150
150
|
self,
|
|
151
|
-
data_config:
|
|
151
|
+
data_config: MicroSplitDataConfig,
|
|
152
152
|
datasplit_type: DataSplitType,
|
|
153
153
|
load_data_fn: Callable[..., NDArray],
|
|
154
154
|
val_fraction=None,
|
|
@@ -175,7 +175,7 @@ class MultiFileDset:
|
|
|
175
175
|
|
|
176
176
|
def __init__(
|
|
177
177
|
self,
|
|
178
|
-
data_config:
|
|
178
|
+
data_config: MicroSplitDataConfig,
|
|
179
179
|
fpath: str,
|
|
180
180
|
load_data_fn: Callable[..., Union[TwoChannelData, MultiChannelData]],
|
|
181
181
|
val_fraction=None,
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -186,11 +186,15 @@ def export_to_bmz(
|
|
|
186
186
|
)
|
|
187
187
|
|
|
188
188
|
# test model description
|
|
189
|
-
test_kwargs =
|
|
190
|
-
|
|
191
|
-
.
|
|
192
|
-
|
|
193
|
-
|
|
189
|
+
test_kwargs = {}
|
|
190
|
+
if hasattr(model_description, "config") and isinstance(
|
|
191
|
+
model_description.config, dict
|
|
192
|
+
):
|
|
193
|
+
bioimageio_config = model_description.config.get("bioimageio", {})
|
|
194
|
+
test_kwargs = bioimageio_config.get("test_kwargs", {}).get(
|
|
195
|
+
"pytorch_state_dict", {}
|
|
196
|
+
)
|
|
197
|
+
|
|
194
198
|
summary: ValidationSummary = test_model(model_description, **test_kwargs)
|
|
195
199
|
if summary.status == "failed":
|
|
196
200
|
raise ValueError(f"Model description test failed: {summary}")
|