careamics 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +4 -3
- careamics/cli/conf.py +1 -2
- careamics/cli/main.py +1 -2
- careamics/cli/utils.py +3 -3
- careamics/config/__init__.py +47 -25
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +38 -0
- careamics/config/algorithms/n2n_algorithm_model.py +30 -0
- careamics/config/algorithms/n2v_algorithm_model.py +29 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +6 -1
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +185 -57
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +91 -186
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +1 -2
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +1 -2
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +5 -4
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/__init__.py +12 -1
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +3 -3
- careamics/lightning/lightning_module.py +11 -7
- careamics/lightning/train_data_module.py +36 -45
- careamics/losses/__init__.py +3 -3
- careamics/lvae_training/calibration.py +64 -57
- careamics/lvae_training/dataset/lc_dataset.py +2 -1
- careamics/lvae_training/dataset/multich_dataset.py +2 -2
- careamics/lvae_training/dataset/types.py +1 -1
- careamics/lvae_training/eval_utils.py +123 -128
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +1 -1
- careamics/model_io/bioimage/model_description.py +17 -17
- careamics/model_io/bmz_io.py +6 -17
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +16 -16
- careamics/models/lvae/likelihoods.py +2 -0
- careamics/models/lvae/lvae.py +13 -4
- careamics/models/lvae/noise_models.py +280 -217
- careamics/models/lvae/stochastic.py +1 -0
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/logging.py +11 -10
- careamics/utils/metrics.py +25 -0
- careamics/utils/plotting.py +78 -0
- careamics/utils/torch_utils.py +7 -7
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/METADATA +13 -11
- careamics-0.0.7.dist-info/RECORD +178 -0
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.5.dist-info/RECORD +0 -171
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/WHEEL +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,8 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
from collections.abc import Sequence
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Optional,
|
|
7
|
+
from typing import Any, Optional, Union
|
|
7
8
|
|
|
8
9
|
from pytorch_lightning import LightningModule, Trainer
|
|
9
10
|
from pytorch_lightning.callbacks import BasePredictionWriter
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Module containing different strategies for writing predictions."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Sequence
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Optional, Protocol,
|
|
5
|
+
from typing import Any, Optional, Protocol, Union
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
from numpy.typing import NDArray
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
"""Progressbar callback."""
|
|
2
2
|
|
|
3
3
|
import sys
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Union
|
|
5
5
|
|
|
6
6
|
from pytorch_lightning import LightningModule, Trainer
|
|
7
7
|
from pytorch_lightning.callbacks import TQDMProgressBar
|
|
8
|
-
from tqdm import tqdm
|
|
8
|
+
from tqdm.auto import tqdm
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class ProgressBarCallback(TQDMProgressBar):
|
|
@@ -71,7 +71,7 @@ class ProgressBarCallback(TQDMProgressBar):
|
|
|
71
71
|
|
|
72
72
|
def get_metrics(
|
|
73
73
|
self, trainer: Trainer, pl_module: LightningModule
|
|
74
|
-
) ->
|
|
74
|
+
) -> dict[str, Union[int, str, float, dict[str, float]]]:
|
|
75
75
|
"""Override this to customize the metrics displayed in the progress bar.
|
|
76
76
|
|
|
77
77
|
Parameters
|
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
|
6
6
|
import pytorch_lightning as L
|
|
7
7
|
from torch import Tensor, nn
|
|
8
8
|
|
|
9
|
-
from careamics.config import
|
|
9
|
+
from careamics.config import UNetBasedAlgorithm, VAEBasedAlgorithm
|
|
10
10
|
from careamics.config.support import (
|
|
11
11
|
SupportedAlgorithm,
|
|
12
12
|
SupportedArchitecture,
|
|
@@ -34,6 +34,7 @@ from careamics.utils.torch_utils import get_optimizer, get_scheduler
|
|
|
34
34
|
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
35
35
|
|
|
36
36
|
|
|
37
|
+
# TODO rename to UNetModule
|
|
37
38
|
class FCNModule(L.LightningModule):
|
|
38
39
|
"""
|
|
39
40
|
CAREamics Lightning module.
|
|
@@ -60,7 +61,7 @@ class FCNModule(L.LightningModule):
|
|
|
60
61
|
Learning rate scheduler name.
|
|
61
62
|
"""
|
|
62
63
|
|
|
63
|
-
def __init__(self, algorithm_config: Union[
|
|
64
|
+
def __init__(self, algorithm_config: Union[UNetBasedAlgorithm, dict]) -> None:
|
|
64
65
|
"""Lightning module for CAREamics.
|
|
65
66
|
|
|
66
67
|
This class encapsulates the a PyTorch model along with the training, validation,
|
|
@@ -74,7 +75,9 @@ class FCNModule(L.LightningModule):
|
|
|
74
75
|
super().__init__()
|
|
75
76
|
# if loading from a checkpoint, AlgorithmModel needs to be instantiated
|
|
76
77
|
if isinstance(algorithm_config, dict):
|
|
77
|
-
algorithm_config =
|
|
78
|
+
algorithm_config = UNetBasedAlgorithm(
|
|
79
|
+
**algorithm_config
|
|
80
|
+
) # TODO this needs to be updated using the algorithm-specific class
|
|
78
81
|
|
|
79
82
|
# create model and loss function
|
|
80
83
|
self.model: nn.Module = model_factory(algorithm_config.model)
|
|
@@ -266,7 +269,7 @@ class VAEModule(L.LightningModule):
|
|
|
266
269
|
Learning rate scheduler name.
|
|
267
270
|
"""
|
|
268
271
|
|
|
269
|
-
def __init__(self, algorithm_config: Union[
|
|
272
|
+
def __init__(self, algorithm_config: Union[VAEBasedAlgorithm, dict]) -> None:
|
|
270
273
|
"""Lightning module for CAREamics.
|
|
271
274
|
|
|
272
275
|
This class encapsulates the a PyTorch model along with the training, validation,
|
|
@@ -280,7 +283,7 @@ class VAEModule(L.LightningModule):
|
|
|
280
283
|
super().__init__()
|
|
281
284
|
# if loading from a checkpoint, AlgorithmModel needs to be instantiated
|
|
282
285
|
self.algorithm_config = (
|
|
283
|
-
|
|
286
|
+
VAEBasedAlgorithm(**algorithm_config)
|
|
284
287
|
if isinstance(algorithm_config, dict)
|
|
285
288
|
else algorithm_config
|
|
286
289
|
)
|
|
@@ -656,9 +659,10 @@ def create_careamics_module(
|
|
|
656
659
|
algorithm_configuration["model"] = model_configuration
|
|
657
660
|
|
|
658
661
|
# call the parent init using an AlgorithmModel instance
|
|
662
|
+
# TODO broken by new configutations!
|
|
659
663
|
algorithm_str = algorithm_configuration["algorithm"]
|
|
660
|
-
if algorithm_str in
|
|
661
|
-
return FCNModule(
|
|
664
|
+
if algorithm_str in UNetBasedAlgorithm.get_compatible_algorithms():
|
|
665
|
+
return FCNModule(UNetBasedAlgorithm(**algorithm_configuration))
|
|
662
666
|
else:
|
|
663
667
|
raise NotImplementedError(
|
|
664
668
|
f"Model {algorithm_str} is not implemented or unknown."
|
|
@@ -2,14 +2,13 @@
|
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Any, Callable, Literal, Optional, Union
|
|
5
|
-
from warnings import warn
|
|
6
5
|
|
|
7
6
|
import numpy as np
|
|
8
7
|
import pytorch_lightning as L
|
|
9
8
|
from numpy.typing import NDArray
|
|
10
9
|
from torch.utils.data import DataLoader, IterableDataset
|
|
11
10
|
|
|
12
|
-
from careamics.config import DataConfig
|
|
11
|
+
from careamics.config.data import DataConfig, GeneralDataConfig, N2VDataConfig
|
|
13
12
|
from careamics.config.support import SupportedData
|
|
14
13
|
from careamics.config.transformations import TransformModel
|
|
15
14
|
from careamics.dataset.dataset_utils import (
|
|
@@ -119,7 +118,7 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
119
118
|
|
|
120
119
|
def __init__(
|
|
121
120
|
self,
|
|
122
|
-
data_config:
|
|
121
|
+
data_config: GeneralDataConfig,
|
|
123
122
|
train_data: Union[Path, str, NDArray],
|
|
124
123
|
val_data: Optional[Union[Path, str, NDArray]] = None,
|
|
125
124
|
train_data_target: Optional[Union[Path, str, NDArray]] = None,
|
|
@@ -219,7 +218,7 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
219
218
|
)
|
|
220
219
|
|
|
221
220
|
# configuration
|
|
222
|
-
self.data_config:
|
|
221
|
+
self.data_config: GeneralDataConfig = data_config
|
|
223
222
|
self.data_type: str = data_config.data_type
|
|
224
223
|
self.batch_size: int = data_config.batch_size
|
|
225
224
|
self.use_in_memory: bool = use_in_memory
|
|
@@ -261,11 +260,6 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
261
260
|
|
|
262
261
|
self.extension_filter: str = extension_filter
|
|
263
262
|
|
|
264
|
-
# Pytorch dataloader parameters
|
|
265
|
-
self.dataloader_params: dict[str, Any] = (
|
|
266
|
-
data_config.dataloader_params if data_config.dataloader_params else {}
|
|
267
|
-
)
|
|
268
|
-
|
|
269
263
|
def prepare_data(self) -> None:
|
|
270
264
|
"""
|
|
271
265
|
Hook used to prepare the data before calling `setup`.
|
|
@@ -447,21 +441,17 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
447
441
|
Any
|
|
448
442
|
Training dataloader.
|
|
449
443
|
"""
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
"Dataloader parameters include `shuffle=False`, this will be "
|
|
457
|
-
"passed to the training dataloader and may result in bad results.",
|
|
458
|
-
stacklevel=1,
|
|
459
|
-
)
|
|
460
|
-
else:
|
|
461
|
-
self.dataloader_params["shuffle"] = True
|
|
444
|
+
train_dataloader_params = self.data_config.train_dataloader_params.copy()
|
|
445
|
+
|
|
446
|
+
# NOTE: When next-gen datasets are completed this can be removed
|
|
447
|
+
# iterable dataset cannot be shuffled
|
|
448
|
+
if isinstance(self.train_dataset, IterableDataset):
|
|
449
|
+
del train_dataloader_params["shuffle"]
|
|
462
450
|
|
|
463
451
|
return DataLoader(
|
|
464
|
-
self.train_dataset,
|
|
452
|
+
self.train_dataset,
|
|
453
|
+
batch_size=self.batch_size,
|
|
454
|
+
**train_dataloader_params,
|
|
465
455
|
)
|
|
466
456
|
|
|
467
457
|
def val_dataloader(self) -> Any:
|
|
@@ -476,6 +466,7 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
476
466
|
return DataLoader(
|
|
477
467
|
self.val_dataset,
|
|
478
468
|
batch_size=self.batch_size,
|
|
469
|
+
**self.data_config.val_dataloader_params,
|
|
479
470
|
)
|
|
480
471
|
|
|
481
472
|
|
|
@@ -502,12 +493,23 @@ def create_train_datamodule(
|
|
|
502
493
|
"""Create a TrainDataModule.
|
|
503
494
|
|
|
504
495
|
This function is used to explicitly pass the parameters usually contained in a
|
|
505
|
-
`
|
|
496
|
+
`GenericDataConfig` to a TrainDataModule.
|
|
506
497
|
|
|
507
498
|
Since the lightning datamodule has no access to the model, make sure that the
|
|
508
499
|
parameters passed to the datamodule are consistent with the model's requirements and
|
|
509
500
|
are coherent.
|
|
510
501
|
|
|
502
|
+
By default, the train DataModule will be set for Noise2Void if no target data is
|
|
503
|
+
provided. That means that it will add a `N2VManipulateModel` transformation to the
|
|
504
|
+
list of augmentations. The default augmentations are XY flip, XY rotation, and N2V
|
|
505
|
+
pixel manipulation. If you pass a training target data, the default behaviour is to
|
|
506
|
+
train a supervised model. It will use the default XY flip and rotation
|
|
507
|
+
augmentations.
|
|
508
|
+
|
|
509
|
+
To use a different set of transformations, you can pass a list of transforms to
|
|
510
|
+
`transforms`. Note that if you intend to use Noise2Void, you should add
|
|
511
|
+
`N2VManipulateModel` as the last transform in the list of transformations.
|
|
512
|
+
|
|
511
513
|
The data module can be used with Path, str or numpy arrays. In the case of
|
|
512
514
|
numpy arrays, it loads and computes all the patches in memory. For Path and str
|
|
513
515
|
inputs, it calculates the total file size and estimate whether it can fit in
|
|
@@ -518,11 +520,6 @@ def create_train_datamodule(
|
|
|
518
520
|
To use array data, set `data_type` to `array` and pass a numpy array to
|
|
519
521
|
`train_data`.
|
|
520
522
|
|
|
521
|
-
In particular, N2V requires a specific transformation (N2V manipulates), which is
|
|
522
|
-
not compatible with supervised training. The default transformations applied to the
|
|
523
|
-
training patches are defined in `careamics.config.data_model`. To use different
|
|
524
|
-
transformations, pass a list of transforms. See examples for more details.
|
|
525
|
-
|
|
526
523
|
By default, CAREamics only supports types defined in
|
|
527
524
|
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
528
525
|
`data_type` to `custom` and provide a function that returns a numpy array from a
|
|
@@ -627,12 +624,12 @@ def create_train_datamodule(
|
|
|
627
624
|
transforms:
|
|
628
625
|
>>> import numpy as np
|
|
629
626
|
>>> from careamics.lightning import create_train_datamodule
|
|
627
|
+
>>> from careamics.config.transformations import XYFlipModel, N2VManipulateModel
|
|
630
628
|
>>> from careamics.config.support import SupportedTransform
|
|
631
629
|
>>> my_array = np.arange(256).reshape(16, 16)
|
|
632
630
|
>>> my_transforms = [
|
|
633
|
-
...
|
|
634
|
-
...
|
|
635
|
-
... }
|
|
631
|
+
... XYFlipModel(flip_y=False),
|
|
632
|
+
... N2VManipulateModel()
|
|
636
633
|
... ]
|
|
637
634
|
>>> data_module = create_train_datamodule(
|
|
638
635
|
... train_data=my_array,
|
|
@@ -659,21 +656,15 @@ def create_train_datamodule(
|
|
|
659
656
|
if transforms is not None:
|
|
660
657
|
data_dict["transforms"] = transforms
|
|
661
658
|
|
|
662
|
-
#
|
|
663
|
-
|
|
659
|
+
# TODO not compatible with HDN, consider adding an argument for n2v/hdn
|
|
660
|
+
if train_target_data is None:
|
|
661
|
+
data_config: GeneralDataConfig = N2VDataConfig(**data_dict)
|
|
662
|
+
assert isinstance(data_config, N2VDataConfig)
|
|
664
663
|
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
data_config.set_N2V2(use_n2v2)
|
|
670
|
-
data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
|
|
671
|
-
else:
|
|
672
|
-
raise ValueError(
|
|
673
|
-
"Cannot have both supervised training (target data) and "
|
|
674
|
-
"N2V manipulation in the transforms. Pass a list of transforms "
|
|
675
|
-
"that is compatible with your supervised training."
|
|
676
|
-
)
|
|
664
|
+
data_config.set_n2v2(use_n2v2)
|
|
665
|
+
data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
|
|
666
|
+
else:
|
|
667
|
+
data_config = DataConfig(**data_dict)
|
|
677
668
|
|
|
678
669
|
# sanity check on the dataloader parameters
|
|
679
670
|
if "batch_size" in dataloader_params:
|
careamics/losses/__init__.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
"""Losses module."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
+
"denoisplit_loss",
|
|
5
|
+
"denoisplit_musplit_loss",
|
|
4
6
|
"loss_factory",
|
|
5
7
|
"mae_loss",
|
|
6
8
|
"mse_loss",
|
|
7
|
-
"n2v_loss",
|
|
8
|
-
"denoisplit_loss",
|
|
9
9
|
"musplit_loss",
|
|
10
|
-
"
|
|
10
|
+
"n2v_loss",
|
|
11
11
|
]
|
|
12
12
|
|
|
13
13
|
from .fcn.losses import mae_loss, mse_loss, n2v_loss
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Union
|
|
1
|
+
from typing import Union, Optional
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
@@ -34,9 +34,6 @@ class Calibration:
|
|
|
34
34
|
self._bins = num_bins
|
|
35
35
|
self._bin_boundaries = None
|
|
36
36
|
|
|
37
|
-
def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray:
|
|
38
|
-
return np.exp(logvar / 2)
|
|
39
|
-
|
|
40
37
|
def compute_bin_boundaries(self, predict_std: np.ndarray) -> np.ndarray:
|
|
41
38
|
"""Compute the bin boundaries for `num_bins` bins and predicted std values."""
|
|
42
39
|
min_std = np.min(predict_std)
|
|
@@ -104,65 +101,75 @@ class Calibration:
|
|
|
104
101
|
)
|
|
105
102
|
rmse_stderr = np.sqrt(stderr) if stderr is not None else None
|
|
106
103
|
|
|
107
|
-
bin_var = np.mean(
|
|
104
|
+
bin_var = np.mean(std_ch[bin_mask] ** 2)
|
|
108
105
|
stats_dict[ch_idx]["rmse"].append(bin_error)
|
|
109
106
|
stats_dict[ch_idx]["rmse_err"].append(rmse_stderr)
|
|
110
107
|
stats_dict[ch_idx]["rmv"].append(np.sqrt(bin_var))
|
|
111
108
|
stats_dict[ch_idx]["bin_count"].append(bin_size)
|
|
109
|
+
self.stats_dict = stats_dict
|
|
112
110
|
return stats_dict
|
|
113
111
|
|
|
112
|
+
def get_calibrated_factor_for_stdev(
|
|
113
|
+
self,
|
|
114
|
+
pred: Optional[np.ndarray] = None,
|
|
115
|
+
pred_std: Optional[np.ndarray] = None,
|
|
116
|
+
target: Optional[np.ndarray] = None,
|
|
117
|
+
q_s: float = 0.00001,
|
|
118
|
+
q_e: float = 0.99999,
|
|
119
|
+
) -> dict[str, float]:
|
|
120
|
+
"""Calibrate the uncertainty by multiplying the predicted std with a scalar.
|
|
114
121
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
stats_dict : dict[int, dict[str, Union[np.ndarray, list]]]
|
|
125
|
+
Dictionary containing the stats for each channel.
|
|
126
|
+
q_s : float, optional
|
|
127
|
+
Start quantile, by default 0.00001.
|
|
128
|
+
q_e : float, optional
|
|
129
|
+
End quantile, by default 0.99999.
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
dict[str, float]
|
|
134
|
+
Calibrated factor for each channel (slope + intercept).
|
|
135
|
+
"""
|
|
136
|
+
if not hasattr(self, "stats_dict"):
|
|
137
|
+
print("No stats found. Computing stats...")
|
|
138
|
+
if any(v is None for v in [pred, pred_std, target]):
|
|
139
|
+
raise ValueError("pred, pred_std, and target must be provided.")
|
|
140
|
+
self.stats_dict = self.compute_stats(
|
|
141
|
+
pred=pred, pred_std=pred_std, target=target
|
|
142
|
+
)
|
|
143
|
+
outputs = {}
|
|
144
|
+
for ch_idx in self.stats_dict.keys():
|
|
145
|
+
y = self.stats_dict[ch_idx]["rmse"]
|
|
146
|
+
x = self.stats_dict[ch_idx]["rmv"]
|
|
147
|
+
count = self.stats_dict[ch_idx]["bin_count"]
|
|
148
|
+
|
|
149
|
+
first_idx = get_first_index(count, q_s)
|
|
150
|
+
last_idx = get_last_index(count, q_e)
|
|
151
|
+
x = x[first_idx:-last_idx]
|
|
152
|
+
y = y[first_idx:-last_idx]
|
|
153
|
+
slope, intercept, *_ = stats.linregress(x, y)
|
|
154
|
+
output = {"scalar": slope, "offset": intercept}
|
|
155
|
+
outputs[ch_idx] = output
|
|
156
|
+
factors = self.get_factors_array(factors_dict=outputs)
|
|
157
|
+
return outputs, factors
|
|
158
|
+
|
|
159
|
+
def get_factors_array(self, factors_dict: list[dict]):
|
|
160
|
+
"""Get the calibration factors as a numpy array."""
|
|
161
|
+
calib_scalar = [factors_dict[i]["scalar"] for i in range(len(factors_dict))]
|
|
162
|
+
calib_scalar = np.array(calib_scalar).reshape(1, 1, 1, -1)
|
|
163
|
+
calib_offset = [
|
|
164
|
+
factors_dict[i].get("offset", 0.0) for i in range(len(factors_dict))
|
|
165
|
+
]
|
|
166
|
+
calib_offset = np.array(calib_offset).reshape(1, 1, 1, -1)
|
|
167
|
+
return {"scalar": calib_scalar, "offset": calib_offset}
|
|
161
168
|
|
|
162
169
|
|
|
163
170
|
def plot_calibration(ax, calibration_stats):
|
|
164
|
-
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.
|
|
165
|
-
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.
|
|
171
|
+
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.0001)
|
|
172
|
+
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.9999)
|
|
166
173
|
ax.plot(
|
|
167
174
|
calibration_stats[0]["rmv"][first_idx:-last_idx],
|
|
168
175
|
calibration_stats[0]["rmse"][first_idx:-last_idx],
|
|
@@ -170,15 +177,15 @@ def plot_calibration(ax, calibration_stats):
|
|
|
170
177
|
label=r"$\hat{C}_0$: Ch1",
|
|
171
178
|
)
|
|
172
179
|
|
|
173
|
-
first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.
|
|
174
|
-
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.
|
|
180
|
+
first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.0001)
|
|
181
|
+
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.9999)
|
|
175
182
|
ax.plot(
|
|
176
183
|
calibration_stats[1]["rmv"][first_idx:-last_idx],
|
|
177
184
|
calibration_stats[1]["rmse"][first_idx:-last_idx],
|
|
178
185
|
"o",
|
|
179
|
-
label=r"$\hat{C}_1
|
|
186
|
+
label=r"$\hat{C}_1$: Ch2",
|
|
180
187
|
)
|
|
181
|
-
|
|
188
|
+
# TODO add multichannel
|
|
182
189
|
ax.set_xlabel("RMV")
|
|
183
190
|
ax.set_ylabel("RMSE")
|
|
184
191
|
ax.legend()
|
|
@@ -97,7 +97,8 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
97
97
|
]
|
|
98
98
|
|
|
99
99
|
self.N = len(t_list)
|
|
100
|
-
self.
|
|
100
|
+
# TODO where tf is self._img_sz defined?
|
|
101
|
+
self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
|
|
101
102
|
print(
|
|
102
103
|
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
|
|
103
104
|
)
|
|
@@ -359,8 +359,8 @@ class MultiChDloader:
|
|
|
359
359
|
self._noise_data = self._noise_data[
|
|
360
360
|
t_list, h_start:h_end, w_start:w_end, :
|
|
361
361
|
].copy()
|
|
362
|
-
|
|
363
|
-
self.set_img_sz(self._img_sz, self._grid_sz)
|
|
362
|
+
# TODO where tf is self._img_sz defined?
|
|
363
|
+
self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
|
|
364
364
|
print(
|
|
365
365
|
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
|
|
366
366
|
)
|