careamics 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +55 -61
- careamics/cli/conf.py +24 -9
- careamics/cli/main.py +8 -8
- careamics/cli/utils.py +2 -4
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +53 -18
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +15 -11
- careamics/config/configuration.py +9 -8
- careamics/config/configuration_factories.py +892 -78
- careamics/config/data/data_model.py +7 -14
- careamics/config/data/ng_data_model.py +8 -15
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
- careamics/config/inference_model.py +6 -11
- careamics/config/likelihood_model.py +4 -4
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +30 -7
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +8 -38
- careamics/config/transformations/normalize_model.py +3 -4
- careamics/config/transformations/xy_flip_model.py +2 -2
- careamics/config/transformations/xy_random_rotate90_model.py +2 -2
- careamics/config/validators/validator_utils.py +1 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
- careamics/dataset/in_memory_dataset.py +2 -2
- careamics/dataset/iterable_dataset.py +1 -2
- careamics/dataset/patching/random_patching.py +6 -6
- careamics/dataset/patching/sequential_patching.py +4 -4
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
- careamics/dataset_ng/dataset.py +3 -3
- careamics/dataset_ng/factory.py +19 -19
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
- careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +23 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
- careamics/lightning/dataset_ng/data_module.py +43 -43
- careamics/lightning/lightning_module.py +166 -68
- careamics/lightning/microsplit_data_module.py +631 -0
- careamics/lightning/predict_data_module.py +16 -9
- careamics/lightning/train_data_module.py +29 -18
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +94 -9
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/model_io/bioimage/model_description.py +12 -11
- careamics/model_io/bmz_io.py +12 -8
- careamics/models/layers.py +5 -5
- careamics/models/lvae/likelihoods.py +30 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/lvae_prediction.py +5 -5
- careamics/prediction_utils/prediction_outputs.py +48 -3
- careamics/prediction_utils/stitch_prediction.py +71 -0
- careamics/transforms/compose.py +9 -9
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/n2v_manipulate_torch.py +4 -4
- careamics/transforms/normalize.py +4 -6
- careamics/transforms/pixel_manipulation.py +6 -8
- careamics/transforms/pixel_manipulation_torch.py +5 -7
- careamics/transforms/xy_flip.py +3 -5
- careamics/transforms/xy_random_rotate90.py +4 -6
- careamics/utils/logging.py +8 -8
- careamics/utils/metrics.py +2 -2
- careamics/utils/plotting.py +1 -3
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -16
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/RECORD +90 -88
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,29 +2,35 @@
|
|
|
2
2
|
A place for Datasets and Dataloaders.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Optional, Union
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from torch.utils.data import Dataset
|
|
8
11
|
|
|
9
12
|
from .utils.empty_patch_fetcher import EmptyPatchFetcher
|
|
10
13
|
from .utils.index_manager import GridIndexManager
|
|
11
14
|
from .utils.index_switcher import IndexSwitcher
|
|
12
|
-
from .config import
|
|
15
|
+
from .config import MicroSplitDataConfig
|
|
13
16
|
from .types import DataSplitType, TilingMode
|
|
14
17
|
|
|
15
18
|
|
|
16
|
-
class MultiChDloader:
|
|
19
|
+
class MultiChDloader(Dataset):
|
|
20
|
+
"""Multi-channel dataset loader."""
|
|
21
|
+
|
|
17
22
|
def __init__(
|
|
18
23
|
self,
|
|
19
|
-
data_config:
|
|
20
|
-
|
|
21
|
-
load_data_fn: Callable,
|
|
22
|
-
val_fraction: float =
|
|
23
|
-
test_fraction: float =
|
|
24
|
+
data_config: MicroSplitDataConfig,
|
|
25
|
+
datapath: Union[str, Path],
|
|
26
|
+
load_data_fn: Optional[Callable] = None,
|
|
27
|
+
val_fraction: float = 0.1,
|
|
28
|
+
test_fraction: float = 0.1,
|
|
29
|
+
allow_generation: bool = False,
|
|
24
30
|
):
|
|
25
31
|
""" """
|
|
26
32
|
self._data_type = data_config.data_type
|
|
27
|
-
self._fpath =
|
|
33
|
+
self._fpath = datapath
|
|
28
34
|
self._data = self._noise_data = None
|
|
29
35
|
self.Z = 1
|
|
30
36
|
self._5Ddata = False
|
|
@@ -395,7 +401,7 @@ class MultiChDloader:
|
|
|
395
401
|
)
|
|
396
402
|
|
|
397
403
|
def get_idx_manager_shapes(
|
|
398
|
-
self, patch_size: int, grid_size: Union[int,
|
|
404
|
+
self, patch_size: int, grid_size: Union[int, tuple[int, int, int]]
|
|
399
405
|
):
|
|
400
406
|
numC = self._data.shape[-1]
|
|
401
407
|
if self._5Ddata:
|
|
@@ -415,7 +421,7 @@ class MultiChDloader:
|
|
|
415
421
|
|
|
416
422
|
return patch_shape, grid_shape
|
|
417
423
|
|
|
418
|
-
def set_img_sz(self, image_size, grid_size: Union[int,
|
|
424
|
+
def set_img_sz(self, image_size, grid_size: Union[int, tuple[int, int, int]]):
|
|
419
425
|
"""
|
|
420
426
|
If one wants to change the image size on the go, then this can be used.
|
|
421
427
|
Args:
|
|
@@ -519,7 +525,7 @@ class MultiChDloader:
|
|
|
519
525
|
},
|
|
520
526
|
)
|
|
521
527
|
|
|
522
|
-
def _crop_img(self, img: np.ndarray, patch_start_loc:
|
|
528
|
+
def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
|
|
523
529
|
if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
|
|
524
530
|
# In training, this is used.
|
|
525
531
|
# NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
|
|
@@ -600,7 +606,7 @@ class MultiChDloader:
|
|
|
600
606
|
return new_img
|
|
601
607
|
|
|
602
608
|
def _crop_flip_img(
|
|
603
|
-
self, img: np.ndarray, patch_start_loc:
|
|
609
|
+
self, img: np.ndarray, patch_start_loc: tuple, h_flip: bool, w_flip: bool
|
|
604
610
|
):
|
|
605
611
|
new_img = self._crop_img(img, patch_start_loc)
|
|
606
612
|
if h_flip:
|
|
@@ -611,8 +617,8 @@ class MultiChDloader:
|
|
|
611
617
|
return new_img.astype(np.float32)
|
|
612
618
|
|
|
613
619
|
def _load_img(
|
|
614
|
-
self, index: Union[int,
|
|
615
|
-
) ->
|
|
620
|
+
self, index: Union[int, tuple[int, int]]
|
|
621
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
616
622
|
"""
|
|
617
623
|
Returns the channels and also the respective noise channels.
|
|
618
624
|
"""
|
|
@@ -806,7 +812,7 @@ class MultiChDloader:
|
|
|
806
812
|
w_start = 0
|
|
807
813
|
return h_start, w_start
|
|
808
814
|
|
|
809
|
-
def _get_img(self, index: Union[int,
|
|
815
|
+
def _get_img(self, index: Union[int, tuple[int, int]]):
|
|
810
816
|
"""
|
|
811
817
|
Loads an image.
|
|
812
818
|
Crops the image such that cropped image has content.
|
|
@@ -1056,8 +1062,8 @@ class MultiChDloader:
|
|
|
1056
1062
|
return img_tuples, noise_tuples
|
|
1057
1063
|
|
|
1058
1064
|
def __getitem__(
|
|
1059
|
-
self, index: Union[int,
|
|
1060
|
-
) ->
|
|
1065
|
+
self, index: Union[int, tuple[int, int]]
|
|
1066
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
1061
1067
|
# Vera: input can be both real microscopic image and two separate channels that are summed in the code
|
|
1062
1068
|
|
|
1063
1069
|
if self._train_index_switcher is not None:
|
|
@@ -4,7 +4,7 @@ from typing import Callable, Union
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from numpy.typing import NDArray
|
|
6
6
|
|
|
7
|
-
from .config import
|
|
7
|
+
from .config import MicroSplitDataConfig
|
|
8
8
|
from .lc_dataset import LCMultiChDloader
|
|
9
9
|
from .multich_dataset import MultiChDloader
|
|
10
10
|
from .types import DataSplitType
|
|
@@ -82,7 +82,7 @@ class SingleFileLCDset(LCMultiChDloader):
|
|
|
82
82
|
def __init__(
|
|
83
83
|
self,
|
|
84
84
|
preloaded_data: NDArray,
|
|
85
|
-
data_config:
|
|
85
|
+
data_config: MicroSplitDataConfig,
|
|
86
86
|
fpath: str,
|
|
87
87
|
load_data_fn: Callable,
|
|
88
88
|
val_fraction=None,
|
|
@@ -106,7 +106,7 @@ class SingleFileLCDset(LCMultiChDloader):
|
|
|
106
106
|
|
|
107
107
|
def load_data(
|
|
108
108
|
self,
|
|
109
|
-
data_config:
|
|
109
|
+
data_config: MicroSplitDataConfig,
|
|
110
110
|
datasplit_type: DataSplitType,
|
|
111
111
|
load_data_fn: Callable,
|
|
112
112
|
val_fraction=None,
|
|
@@ -124,7 +124,7 @@ class SingleFileDset(MultiChDloader):
|
|
|
124
124
|
def __init__(
|
|
125
125
|
self,
|
|
126
126
|
preloaded_data: NDArray,
|
|
127
|
-
data_config:
|
|
127
|
+
data_config: MicroSplitDataConfig,
|
|
128
128
|
fpath: str,
|
|
129
129
|
load_data_fn: Callable,
|
|
130
130
|
val_fraction=None,
|
|
@@ -148,7 +148,7 @@ class SingleFileDset(MultiChDloader):
|
|
|
148
148
|
|
|
149
149
|
def load_data(
|
|
150
150
|
self,
|
|
151
|
-
data_config:
|
|
151
|
+
data_config: MicroSplitDataConfig,
|
|
152
152
|
datasplit_type: DataSplitType,
|
|
153
153
|
load_data_fn: Callable[..., NDArray],
|
|
154
154
|
val_fraction=None,
|
|
@@ -175,7 +175,7 @@ class MultiFileDset:
|
|
|
175
175
|
|
|
176
176
|
def __init__(
|
|
177
177
|
self,
|
|
178
|
-
data_config:
|
|
178
|
+
data_config: MicroSplitDataConfig,
|
|
179
179
|
fpath: str,
|
|
180
180
|
load_data_fn: Callable[..., Union[TwoChannelData, MultiChannelData]],
|
|
181
181
|
val_fraction=None,
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""Module use to build BMZ model description."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
-
from bioimageio.spec._internal.io import
|
|
7
|
+
from bioimageio.spec._internal.io import extract
|
|
8
8
|
from bioimageio.spec.model.v0_5 import (
|
|
9
9
|
ArchitectureFromLibraryDescr,
|
|
10
10
|
Author,
|
|
@@ -12,7 +12,6 @@ from bioimageio.spec.model.v0_5 import (
|
|
|
12
12
|
AxisId,
|
|
13
13
|
BatchAxis,
|
|
14
14
|
ChannelAxis,
|
|
15
|
-
EnvironmentFileDescr,
|
|
16
15
|
FileDescr,
|
|
17
16
|
FixedZeroMeanUnitVarianceAlongAxisKwargs,
|
|
18
17
|
FixedZeroMeanUnitVarianceDescr,
|
|
@@ -36,7 +35,7 @@ from ._readme_factory import readme_factory
|
|
|
36
35
|
def _create_axes(
|
|
37
36
|
array: np.ndarray,
|
|
38
37
|
data_config: DataConfig,
|
|
39
|
-
channel_names:
|
|
38
|
+
channel_names: list[str] | None = None,
|
|
40
39
|
is_input: bool = True,
|
|
41
40
|
) -> list[AxisBase]:
|
|
42
41
|
"""Create axes description.
|
|
@@ -105,7 +104,7 @@ def _create_inputs_ouputs(
|
|
|
105
104
|
data_config: DataConfig,
|
|
106
105
|
input_path: Union[Path, str],
|
|
107
106
|
output_path: Union[Path, str],
|
|
108
|
-
channel_names:
|
|
107
|
+
channel_names: list[str] | None = None,
|
|
109
108
|
) -> tuple[InputTensorDescr, OutputTensorDescr]:
|
|
110
109
|
"""Create input and output tensor description.
|
|
111
110
|
|
|
@@ -197,7 +196,7 @@ def create_model_description(
|
|
|
197
196
|
config_path: Union[Path, str],
|
|
198
197
|
env_path: Union[Path, str],
|
|
199
198
|
covers: list[Union[Path, str]],
|
|
200
|
-
channel_names:
|
|
199
|
+
channel_names: list[str] | None = None,
|
|
201
200
|
model_version: str = "0.1.0",
|
|
202
201
|
) -> ModelDescr:
|
|
203
202
|
"""Create model description.
|
|
@@ -269,7 +268,7 @@ def create_model_description(
|
|
|
269
268
|
source=weights_path,
|
|
270
269
|
architecture=architecture_descr,
|
|
271
270
|
pytorch_version=Version(torch_version),
|
|
272
|
-
dependencies=
|
|
271
|
+
dependencies=FileDescr(source=Path(env_path)),
|
|
273
272
|
),
|
|
274
273
|
)
|
|
275
274
|
|
|
@@ -322,9 +321,11 @@ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]:
|
|
|
322
321
|
"""
|
|
323
322
|
if model_desc.weights.pytorch_state_dict is None:
|
|
324
323
|
raise ValueError("No model weights found in model description.")
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
)
|
|
324
|
+
|
|
325
|
+
# extract the zip model and return the directory
|
|
326
|
+
model_dir = extract(model_desc.root)
|
|
327
|
+
|
|
328
|
+
weights_path = model_dir.joinpath(model_desc.weights.pytorch_state_dict.source.path)
|
|
328
329
|
|
|
329
330
|
for file in model_desc.attachments:
|
|
330
331
|
file_path = file.source if isinstance(file.source, Path) else file.source.path
|
|
@@ -332,7 +333,7 @@ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]:
|
|
|
332
333
|
continue
|
|
333
334
|
file_path = Path(file_path)
|
|
334
335
|
if file_path.name == "careamics.yaml":
|
|
335
|
-
config_path =
|
|
336
|
+
config_path = model_dir.joinpath(file.source.path)
|
|
336
337
|
break
|
|
337
338
|
else:
|
|
338
339
|
raise ValueError("Configuration file not found.")
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import tempfile
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
from bioimageio.core import load_model_description, test_model
|
|
@@ -90,8 +90,8 @@ def export_to_bmz(
|
|
|
90
90
|
authors: list[dict],
|
|
91
91
|
input_array: np.ndarray,
|
|
92
92
|
output_array: np.ndarray,
|
|
93
|
-
covers:
|
|
94
|
-
channel_names:
|
|
93
|
+
covers: list[Union[Path, str]] | None = None,
|
|
94
|
+
channel_names: list[str] | None = None,
|
|
95
95
|
model_version: str = "0.1.0",
|
|
96
96
|
) -> None:
|
|
97
97
|
"""Export the model to BioImage Model Zoo format.
|
|
@@ -186,11 +186,15 @@ def export_to_bmz(
|
|
|
186
186
|
)
|
|
187
187
|
|
|
188
188
|
# test model description
|
|
189
|
-
test_kwargs =
|
|
190
|
-
|
|
191
|
-
.
|
|
192
|
-
|
|
193
|
-
|
|
189
|
+
test_kwargs = {}
|
|
190
|
+
if hasattr(model_description, "config") and isinstance(
|
|
191
|
+
model_description.config, dict
|
|
192
|
+
):
|
|
193
|
+
bioimageio_config = model_description.config.get("bioimageio", {})
|
|
194
|
+
test_kwargs = bioimageio_config.get("test_kwargs", {}).get(
|
|
195
|
+
"pytorch_state_dict", {}
|
|
196
|
+
)
|
|
197
|
+
|
|
194
198
|
summary: ValidationSummary = test_model(model_description, **test_kwargs)
|
|
195
199
|
if summary.status == "failed":
|
|
196
200
|
raise ValueError(f"Model description test failed: {summary}")
|
careamics/models/layers.py
CHANGED
|
@@ -4,7 +4,7 @@ Layer module.
|
|
|
4
4
|
This submodule contains layers used in the CAREamics models.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import Union
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
import torch.nn as nn
|
|
@@ -207,8 +207,8 @@ def get_pascal_kernel_1d(
|
|
|
207
207
|
kernel_size: int,
|
|
208
208
|
norm: bool = False,
|
|
209
209
|
*,
|
|
210
|
-
device:
|
|
211
|
-
dtype:
|
|
210
|
+
device: torch.device | None = None,
|
|
211
|
+
dtype: torch.dtype | None = None,
|
|
212
212
|
) -> torch.Tensor:
|
|
213
213
|
"""Generate Yang Hui triangle (Pascal's triangle) for a given number.
|
|
214
214
|
|
|
@@ -270,8 +270,8 @@ def _get_pascal_kernel_nd(
|
|
|
270
270
|
norm: bool = True,
|
|
271
271
|
dim: int = 2,
|
|
272
272
|
*,
|
|
273
|
-
device:
|
|
274
|
-
dtype:
|
|
273
|
+
device: torch.device | None = None,
|
|
274
|
+
dtype: torch.dtype | None = None,
|
|
275
275
|
) -> torch.Tensor:
|
|
276
276
|
"""Generate pascal filter kernel by kernel size.
|
|
277
277
|
|
|
@@ -54,12 +54,8 @@ def likelihood_factory(
|
|
|
54
54
|
)
|
|
55
55
|
elif isinstance(config, NMLikelihoodConfig):
|
|
56
56
|
return NoiseModelLikelihood(
|
|
57
|
-
data_mean=config.data_mean,
|
|
58
|
-
data_std=config.data_std,
|
|
59
57
|
noise_model=noise_model,
|
|
60
58
|
)
|
|
61
|
-
else:
|
|
62
|
-
raise ValueError(f"Invalid likelihood model type: {config.model_type}")
|
|
63
59
|
|
|
64
60
|
|
|
65
61
|
# TODO: is it really worth to have this class? Or it just adds complexity? --> REFACTOR
|
|
@@ -290,27 +286,40 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
290
286
|
|
|
291
287
|
def __init__(
|
|
292
288
|
self,
|
|
293
|
-
data_mean: Union[np.ndarray, torch.Tensor],
|
|
294
|
-
data_std: Union[np.ndarray, torch.Tensor],
|
|
295
289
|
noise_model: NoiseModel,
|
|
296
290
|
):
|
|
297
291
|
"""Constructor.
|
|
298
292
|
|
|
299
293
|
Parameters
|
|
300
294
|
----------
|
|
301
|
-
data_mean: Union[np.ndarray, torch.Tensor]
|
|
302
|
-
The mean of the data, used to unnormalize data for noise model evaluation.
|
|
303
|
-
data_std: Union[np.ndarray, torch.Tensor]
|
|
304
|
-
The standard deviation of the data, used to unnormalize data for noise
|
|
305
|
-
model evaluation.
|
|
306
295
|
noiseModel: NoiseModel
|
|
307
296
|
The noise model instance used to compute the likelihood.
|
|
308
297
|
"""
|
|
309
298
|
super().__init__()
|
|
310
|
-
self.data_mean =
|
|
311
|
-
self.data_std =
|
|
299
|
+
self.data_mean = None
|
|
300
|
+
self.data_std = None
|
|
312
301
|
self.noiseModel = noise_model
|
|
313
302
|
|
|
303
|
+
def set_data_stats(
|
|
304
|
+
self,
|
|
305
|
+
data_mean: Union[np.ndarray, torch.Tensor],
|
|
306
|
+
data_std: Union[np.ndarray, torch.Tensor],
|
|
307
|
+
) -> None:
|
|
308
|
+
"""Set the data mean and std for denormalization.
|
|
309
|
+
# TODO check this !!
|
|
310
|
+
Parameters
|
|
311
|
+
----------
|
|
312
|
+
data_mean : Union[np.ndarray, torch.Tensor]
|
|
313
|
+
Mean values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting.
|
|
314
|
+
data_std : Union[np.ndarray, torch.Tensor]
|
|
315
|
+
Standard deviation values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting.
|
|
316
|
+
"""
|
|
317
|
+
# Convert to tensor if needed
|
|
318
|
+
self.data_mean = torch.as_tensor(data_mean, dtype=torch.float32)
|
|
319
|
+
self.data_std = torch.as_tensor(data_std, dtype=torch.float32)
|
|
320
|
+
|
|
321
|
+
# TODO add extra dim for 3D ?
|
|
322
|
+
|
|
314
323
|
def _set_params_to_same_device_as(
|
|
315
324
|
self, correct_device_tensor: torch.Tensor
|
|
316
325
|
) -> None:
|
|
@@ -321,7 +330,10 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
321
330
|
correct_device_tensor: torch.Tensor
|
|
322
331
|
The tensor whose device is used to set the parameters.
|
|
323
332
|
"""
|
|
324
|
-
if
|
|
333
|
+
if (
|
|
334
|
+
self.data_mean is not None
|
|
335
|
+
and self.data_mean.device != correct_device_tensor.device
|
|
336
|
+
):
|
|
325
337
|
self.data_mean = self.data_mean.to(correct_device_tensor.device)
|
|
326
338
|
self.data_std = self.data_std.to(correct_device_tensor.device)
|
|
327
339
|
if correct_device_tensor.device != self.noiseModel.device:
|
|
@@ -367,6 +379,10 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
367
379
|
torch.Tensor
|
|
368
380
|
The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
|
|
369
381
|
"""
|
|
382
|
+
if self.data_mean is None or self.data_std is None:
|
|
383
|
+
raise RuntimeError(
|
|
384
|
+
"NoiseModelLikelihood: data_mean and data_std must be set before calling log_likelihood."
|
|
385
|
+
)
|
|
370
386
|
self._set_params_to_same_device_as(x)
|
|
371
387
|
predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
|
|
372
388
|
x_denormalized = x * self.data_std + self.data_mean
|
careamics/models/lvae/lvae.py
CHANGED
|
@@ -6,7 +6,7 @@ and Artefact Removal, Prakash et al."
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from collections.abc import Iterable
|
|
9
|
-
from typing import
|
|
9
|
+
from typing import Union
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import torch
|
|
@@ -835,7 +835,7 @@ class LadderVAE(nn.Module):
|
|
|
835
835
|
top_layer_shape = (n_imgs, mu_logvar, self._model_3D_depth, h, w)
|
|
836
836
|
return top_layer_shape
|
|
837
837
|
|
|
838
|
-
def reset_for_inference(self, tile_size:
|
|
838
|
+
def reset_for_inference(self, tile_size: tuple[int, int] | None = None):
|
|
839
839
|
"""Should be called if we want to predict for a different input/output size."""
|
|
840
840
|
self.mode_pred = True
|
|
841
841
|
if tile_size is None:
|
|
@@ -3,10 +3,10 @@ from __future__ import annotations
|
|
|
3
3
|
import os
|
|
4
4
|
from typing import TYPE_CHECKING, Optional
|
|
5
5
|
|
|
6
|
-
from numpy.typing import NDArray
|
|
7
6
|
import numpy as np
|
|
8
7
|
import torch
|
|
9
8
|
import torch.nn as nn
|
|
9
|
+
from numpy.typing import NDArray
|
|
10
10
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
12
|
from careamics.config import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
@@ -355,16 +355,16 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
355
355
|
|
|
356
356
|
Parameters
|
|
357
357
|
----------
|
|
358
|
-
x: Tensor
|
|
359
|
-
|
|
360
|
-
mean: Tensor
|
|
361
|
-
|
|
362
|
-
std: Tensor
|
|
363
|
-
|
|
358
|
+
x: torch.Tensor
|
|
359
|
+
The ground-truth tensor. Shape is (batch, 1, dim1, dim2).
|
|
360
|
+
mean: torch.Tensor
|
|
361
|
+
The inferred mean of distribution. Shape is (batch, 1, dim1, dim2).
|
|
362
|
+
std: torch.Tensor
|
|
363
|
+
The inferred standard deviation of distribution. Shape is (batch, 1, dim1, dim2).
|
|
364
364
|
|
|
365
365
|
Returns
|
|
366
366
|
-------
|
|
367
|
-
tmp: Tensor
|
|
367
|
+
tmp: torch.Tensor
|
|
368
368
|
Normal probability density of `x` given `mean` and `std`
|
|
369
369
|
"""
|
|
370
370
|
tmp = -((x - mean) ** 2)
|
|
@@ -382,9 +382,9 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
382
382
|
Parameters
|
|
383
383
|
----------
|
|
384
384
|
observations : Tensor
|
|
385
|
-
Noisy observations
|
|
385
|
+
Noisy observations. Shape is (batch, 1, dim1, dim2).
|
|
386
386
|
signals : Tensor
|
|
387
|
-
Underlying signals
|
|
387
|
+
Underlying signals. Shape is (batch, 1, dim1, dim2).
|
|
388
388
|
|
|
389
389
|
Returns
|
|
390
390
|
-------
|
|
@@ -392,15 +392,21 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
392
392
|
Likelihood of observations given the signals and the GMM noise model
|
|
393
393
|
"""
|
|
394
394
|
gaussian_parameters: list[torch.Tensor] = self.get_gaussian_parameters(signals)
|
|
395
|
-
p =
|
|
395
|
+
p = torch.zeros_like(observations)
|
|
396
396
|
for gaussian in range(self.n_gaussian):
|
|
397
|
+
# Ensure all tensors have compatible shapes
|
|
398
|
+
mean = gaussian_parameters[gaussian]
|
|
399
|
+
std = gaussian_parameters[self.n_gaussian + gaussian]
|
|
400
|
+
weight = gaussian_parameters[2 * self.n_gaussian + gaussian]
|
|
401
|
+
|
|
402
|
+
# Compute normal density
|
|
397
403
|
p += (
|
|
398
404
|
self.normal_density(
|
|
399
405
|
observations,
|
|
400
|
-
|
|
401
|
-
|
|
406
|
+
mean,
|
|
407
|
+
std,
|
|
402
408
|
)
|
|
403
|
-
*
|
|
409
|
+
* weight
|
|
404
410
|
)
|
|
405
411
|
return p + self.tolerance
|
|
406
412
|
|
|
@@ -2,9 +2,15 @@
|
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
4
|
"convert_outputs",
|
|
5
|
+
"convert_outputs_microsplit",
|
|
5
6
|
"stitch_prediction",
|
|
6
7
|
"stitch_prediction_single",
|
|
8
|
+
"stitch_prediction_vae",
|
|
7
9
|
]
|
|
8
10
|
|
|
9
|
-
from .prediction_outputs import convert_outputs
|
|
10
|
-
from .stitch_prediction import
|
|
11
|
+
from .prediction_outputs import convert_outputs, convert_outputs_microsplit
|
|
12
|
+
from .stitch_prediction import (
|
|
13
|
+
stitch_prediction,
|
|
14
|
+
stitch_prediction_single,
|
|
15
|
+
stitch_prediction_vae,
|
|
16
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Module containing pytorch implementations for obtaining predictions from an LVAE."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
@@ -18,7 +18,7 @@ def lvae_predict_single_sample(
|
|
|
18
18
|
model: LVAE,
|
|
19
19
|
likelihood_obj: LikelihoodModule,
|
|
20
20
|
input: torch.Tensor,
|
|
21
|
-
) -> tuple[torch.Tensor,
|
|
21
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
22
22
|
"""
|
|
23
23
|
Generate a single sample prediction from an LVAE model, for a given input.
|
|
24
24
|
|
|
@@ -57,7 +57,7 @@ def lvae_predict_tiled_batch(
|
|
|
57
57
|
model: LVAE,
|
|
58
58
|
likelihood_obj: LikelihoodModule,
|
|
59
59
|
input: tuple[Any],
|
|
60
|
-
) -> tuple[tuple[Any],
|
|
60
|
+
) -> tuple[tuple[Any], tuple[Any] | None]:
|
|
61
61
|
# TODO: fix docstring return types, ... too many output options
|
|
62
62
|
"""
|
|
63
63
|
Generate a single sample prediction from an LVAE model, for a given input.
|
|
@@ -98,7 +98,7 @@ def lvae_predict_mmse_tiled_batch(
|
|
|
98
98
|
likelihood_obj: LikelihoodModule,
|
|
99
99
|
input: tuple[Any],
|
|
100
100
|
mmse_count: int,
|
|
101
|
-
) -> tuple[tuple[Any], tuple[Any],
|
|
101
|
+
) -> tuple[tuple[Any], tuple[Any], tuple[Any] | None]:
|
|
102
102
|
# TODO: fix docstring return types, ... hard to make readable
|
|
103
103
|
"""
|
|
104
104
|
Generate the MMSE (minimum mean squared error) prediction, for a given input.
|
|
@@ -137,7 +137,7 @@ def lvae_predict_mmse_tiled_batch(
|
|
|
137
137
|
|
|
138
138
|
input_shape = x.shape
|
|
139
139
|
output_shape = (input_shape[0], model.target_ch, *input_shape[2:])
|
|
140
|
-
log_var:
|
|
140
|
+
log_var: torch.Tensor | None = None
|
|
141
141
|
# pre-declare empty array to fill with individual sample predictions
|
|
142
142
|
sample_predictions = torch.zeros(size=(mmse_count, *output_shape))
|
|
143
143
|
for mmse_idx in range(mmse_count):
|
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
|
6
6
|
from numpy.typing import NDArray
|
|
7
7
|
|
|
8
8
|
from ..config.tile_information import TileInformation
|
|
9
|
-
from .stitch_prediction import stitch_prediction
|
|
9
|
+
from .stitch_prediction import stitch_prediction, stitch_prediction_vae
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def convert_outputs(predictions: list[Any], tiled: bool) -> list[NDArray]:
|
|
@@ -41,6 +41,48 @@ def convert_outputs(predictions: list[Any], tiled: bool) -> list[NDArray]:
|
|
|
41
41
|
return predictions_output
|
|
42
42
|
|
|
43
43
|
|
|
44
|
+
def convert_outputs_microsplit(
|
|
45
|
+
predictions: list[tuple[NDArray, NDArray]], dataset
|
|
46
|
+
) -> tuple[NDArray, NDArray]:
|
|
47
|
+
"""
|
|
48
|
+
Convert microsplit Lightning trainer outputs using eval_utils stitching functions.
|
|
49
|
+
|
|
50
|
+
This function processes microsplit predictions that return (tile_prediction, tile_std) tuples
|
|
51
|
+
and stitches them back together using the same logic as get_single_file_mmse.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
predictions : list of tuple[NDArray, NDArray]
|
|
56
|
+
Predictions from Lightning trainer for microsplit. Each element is a tuple of
|
|
57
|
+
(tile_prediction, tile_std) where both are numpy arrays from predict_step.
|
|
58
|
+
dataset : Dataset
|
|
59
|
+
The dataset object used for prediction, needed for stitching function selection
|
|
60
|
+
and stitching process.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
tuple[NDArray, NDArray]
|
|
65
|
+
A tuple of (stitched_predictions, stitched_stds) representing the full
|
|
66
|
+
stitched predictions and standard deviations.
|
|
67
|
+
"""
|
|
68
|
+
if len(predictions) == 0:
|
|
69
|
+
raise ValueError("No predictions provided")
|
|
70
|
+
|
|
71
|
+
# Separate predictions and stds from the list of tuples
|
|
72
|
+
tile_predictions = [pred for pred, _ in predictions]
|
|
73
|
+
tile_stds = [std for _, std in predictions]
|
|
74
|
+
|
|
75
|
+
# Concatenate all tiles exactly like get_single_file_mmse
|
|
76
|
+
tiles_arr = np.concatenate(tile_predictions, axis=0)
|
|
77
|
+
tile_stds_arr = np.concatenate(tile_stds, axis=0)
|
|
78
|
+
|
|
79
|
+
# Apply stitching using stitch_predictions_new
|
|
80
|
+
stitched_predictions = stitch_prediction_vae(tiles_arr, dataset)
|
|
81
|
+
stitched_stds = stitch_prediction_vae(tile_stds_arr, dataset)
|
|
82
|
+
|
|
83
|
+
return stitched_predictions, stitched_stds
|
|
84
|
+
|
|
85
|
+
|
|
44
86
|
# for mypy
|
|
45
87
|
@overload
|
|
46
88
|
def combine_batches( # numpydoc ignore=GL08
|
|
@@ -68,6 +110,8 @@ def combine_batches(
|
|
|
68
110
|
"""
|
|
69
111
|
If predictions are in batches, they will be combined.
|
|
70
112
|
|
|
113
|
+
# TODO improve description!
|
|
114
|
+
|
|
71
115
|
Parameters
|
|
72
116
|
----------
|
|
73
117
|
predictions : list
|
|
@@ -107,11 +151,12 @@ def _combine_tiled_batches(
|
|
|
107
151
|
"""
|
|
108
152
|
# turn list of lists into single list
|
|
109
153
|
tile_infos = [
|
|
110
|
-
tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
|
|
154
|
+
tile_info for *_, tile_info_list in predictions for tile_info in tile_info_list
|
|
111
155
|
]
|
|
112
156
|
prediction_tiles: list[NDArray] = _combine_array_batches(
|
|
113
|
-
[preds for preds, _ in predictions]
|
|
157
|
+
[preds for preds, *_ in predictions]
|
|
114
158
|
)
|
|
159
|
+
|
|
115
160
|
return prediction_tiles, tile_infos
|
|
116
161
|
|
|
117
162
|
|