careamics 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (111) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +4 -3
  3. careamics/cli/conf.py +1 -2
  4. careamics/cli/main.py +1 -2
  5. careamics/cli/utils.py +3 -3
  6. careamics/config/__init__.py +47 -25
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +38 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +30 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +29 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +6 -1
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +185 -57
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +91 -186
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +1 -2
  25. careamics/config/n2n_configuration.py +101 -0
  26. careamics/config/n2v_configuration.py +266 -0
  27. careamics/config/nm_model.py +1 -2
  28. careamics/config/support/__init__.py +7 -7
  29. careamics/config/support/supported_algorithms.py +5 -4
  30. careamics/config/support/supported_architectures.py +0 -4
  31. careamics/config/transformations/__init__.py +10 -4
  32. careamics/config/transformations/transform_model.py +3 -3
  33. careamics/config/transformations/transform_unions.py +42 -0
  34. careamics/config/validators/__init__.py +12 -1
  35. careamics/config/validators/model_validators.py +84 -0
  36. careamics/config/validators/validator_utils.py +3 -3
  37. careamics/dataset/__init__.py +2 -2
  38. careamics/dataset/dataset_utils/__init__.py +3 -3
  39. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  40. careamics/dataset/dataset_utils/file_utils.py +9 -9
  41. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +3 -3
  63. careamics/lightning/lightning_module.py +11 -7
  64. careamics/lightning/train_data_module.py +36 -45
  65. careamics/losses/__init__.py +3 -3
  66. careamics/lvae_training/calibration.py +64 -57
  67. careamics/lvae_training/dataset/lc_dataset.py +2 -1
  68. careamics/lvae_training/dataset/multich_dataset.py +2 -2
  69. careamics/lvae_training/dataset/types.py +1 -1
  70. careamics/lvae_training/eval_utils.py +123 -128
  71. careamics/model_io/__init__.py +1 -1
  72. careamics/model_io/bioimage/__init__.py +1 -1
  73. careamics/model_io/bioimage/_readme_factory.py +1 -1
  74. careamics/model_io/bioimage/model_description.py +17 -17
  75. careamics/model_io/bmz_io.py +6 -17
  76. careamics/model_io/model_io_utils.py +9 -9
  77. careamics/models/layers.py +16 -16
  78. careamics/models/lvae/likelihoods.py +2 -0
  79. careamics/models/lvae/lvae.py +13 -4
  80. careamics/models/lvae/noise_models.py +280 -217
  81. careamics/models/lvae/stochastic.py +1 -0
  82. careamics/models/model_factory.py +2 -15
  83. careamics/models/unet.py +8 -8
  84. careamics/prediction_utils/__init__.py +1 -1
  85. careamics/prediction_utils/prediction_outputs.py +15 -15
  86. careamics/prediction_utils/stitch_prediction.py +6 -6
  87. careamics/transforms/__init__.py +5 -5
  88. careamics/transforms/compose.py +13 -13
  89. careamics/transforms/n2v_manipulate.py +3 -3
  90. careamics/transforms/pixel_manipulation.py +9 -9
  91. careamics/transforms/xy_random_rotate90.py +4 -4
  92. careamics/utils/__init__.py +5 -5
  93. careamics/utils/context.py +2 -1
  94. careamics/utils/logging.py +11 -10
  95. careamics/utils/metrics.py +25 -0
  96. careamics/utils/plotting.py +78 -0
  97. careamics/utils/torch_utils.py +7 -7
  98. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/METADATA +13 -11
  99. careamics-0.0.7.dist-info/RECORD +178 -0
  100. careamics/config/architectures/custom_model.py +0 -162
  101. careamics/config/architectures/register_model.py +0 -103
  102. careamics/config/configuration_model.py +0 -603
  103. careamics/config/fcn_algorithm_model.py +0 -152
  104. careamics/config/references/__init__.py +0 -45
  105. careamics/config/references/algorithm_descriptions.py +0 -132
  106. careamics/config/references/references.py +0 -39
  107. careamics/config/transformations/transform_union.py +0 -20
  108. careamics-0.0.5.dist-info/RECORD +0 -171
  109. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/WHEEL +0 -0
  110. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/entry_points.txt +0 -0
  111. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/licenses/LICENSE +0 -0
@@ -2,8 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Any, Optional, Sequence, Union
7
+ from typing import Any, Optional, Union
7
8
 
8
9
  from pytorch_lightning import LightningModule, Trainer
9
10
  from pytorch_lightning.callbacks import BasePredictionWriter
@@ -1,7 +1,8 @@
1
1
  """Module containing different strategies for writing predictions."""
2
2
 
3
+ from collections.abc import Sequence
3
4
  from pathlib import Path
4
- from typing import Any, Optional, Protocol, Sequence, Union
5
+ from typing import Any, Optional, Protocol, Union
5
6
 
6
7
  import numpy as np
7
8
  from numpy.typing import NDArray
@@ -1,11 +1,11 @@
1
1
  """Progressbar callback."""
2
2
 
3
3
  import sys
4
- from typing import Dict, Union
4
+ from typing import Union
5
5
 
6
6
  from pytorch_lightning import LightningModule, Trainer
7
7
  from pytorch_lightning.callbacks import TQDMProgressBar
8
- from tqdm import tqdm
8
+ from tqdm.auto import tqdm
9
9
 
10
10
 
11
11
  class ProgressBarCallback(TQDMProgressBar):
@@ -71,7 +71,7 @@ class ProgressBarCallback(TQDMProgressBar):
71
71
 
72
72
  def get_metrics(
73
73
  self, trainer: Trainer, pl_module: LightningModule
74
- ) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
74
+ ) -> dict[str, Union[int, str, float, dict[str, float]]]:
75
75
  """Override this to customize the metrics displayed in the progress bar.
76
76
 
77
77
  Parameters
@@ -6,7 +6,7 @@ import numpy as np
6
6
  import pytorch_lightning as L
7
7
  from torch import Tensor, nn
8
8
 
9
- from careamics.config import FCNAlgorithmConfig, VAEAlgorithmConfig
9
+ from careamics.config import UNetBasedAlgorithm, VAEBasedAlgorithm
10
10
  from careamics.config.support import (
11
11
  SupportedAlgorithm,
12
12
  SupportedArchitecture,
@@ -34,6 +34,7 @@ from careamics.utils.torch_utils import get_optimizer, get_scheduler
34
34
  NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
35
35
 
36
36
 
37
+ # TODO rename to UNetModule
37
38
  class FCNModule(L.LightningModule):
38
39
  """
39
40
  CAREamics Lightning module.
@@ -60,7 +61,7 @@ class FCNModule(L.LightningModule):
60
61
  Learning rate scheduler name.
61
62
  """
62
63
 
63
- def __init__(self, algorithm_config: Union[FCNAlgorithmConfig, dict]) -> None:
64
+ def __init__(self, algorithm_config: Union[UNetBasedAlgorithm, dict]) -> None:
64
65
  """Lightning module for CAREamics.
65
66
 
66
67
  This class encapsulates the a PyTorch model along with the training, validation,
@@ -74,7 +75,9 @@ class FCNModule(L.LightningModule):
74
75
  super().__init__()
75
76
  # if loading from a checkpoint, AlgorithmModel needs to be instantiated
76
77
  if isinstance(algorithm_config, dict):
77
- algorithm_config = FCNAlgorithmConfig(**algorithm_config)
78
+ algorithm_config = UNetBasedAlgorithm(
79
+ **algorithm_config
80
+ ) # TODO this needs to be updated using the algorithm-specific class
78
81
 
79
82
  # create model and loss function
80
83
  self.model: nn.Module = model_factory(algorithm_config.model)
@@ -266,7 +269,7 @@ class VAEModule(L.LightningModule):
266
269
  Learning rate scheduler name.
267
270
  """
268
271
 
269
- def __init__(self, algorithm_config: Union[VAEAlgorithmConfig, dict]) -> None:
272
+ def __init__(self, algorithm_config: Union[VAEBasedAlgorithm, dict]) -> None:
270
273
  """Lightning module for CAREamics.
271
274
 
272
275
  This class encapsulates the a PyTorch model along with the training, validation,
@@ -280,7 +283,7 @@ class VAEModule(L.LightningModule):
280
283
  super().__init__()
281
284
  # if loading from a checkpoint, AlgorithmModel needs to be instantiated
282
285
  self.algorithm_config = (
283
- VAEAlgorithmConfig(**algorithm_config)
286
+ VAEBasedAlgorithm(**algorithm_config)
284
287
  if isinstance(algorithm_config, dict)
285
288
  else algorithm_config
286
289
  )
@@ -656,9 +659,10 @@ def create_careamics_module(
656
659
  algorithm_configuration["model"] = model_configuration
657
660
 
658
661
  # call the parent init using an AlgorithmModel instance
662
+ # TODO broken by new configutations!
659
663
  algorithm_str = algorithm_configuration["algorithm"]
660
- if algorithm_str in FCNAlgorithmConfig.get_compatible_algorithms():
661
- return FCNModule(FCNAlgorithmConfig(**algorithm_configuration))
664
+ if algorithm_str in UNetBasedAlgorithm.get_compatible_algorithms():
665
+ return FCNModule(UNetBasedAlgorithm(**algorithm_configuration))
662
666
  else:
663
667
  raise NotImplementedError(
664
668
  f"Model {algorithm_str} is not implemented or unknown."
@@ -2,14 +2,13 @@
2
2
 
3
3
  from pathlib import Path
4
4
  from typing import Any, Callable, Literal, Optional, Union
5
- from warnings import warn
6
5
 
7
6
  import numpy as np
8
7
  import pytorch_lightning as L
9
8
  from numpy.typing import NDArray
10
9
  from torch.utils.data import DataLoader, IterableDataset
11
10
 
12
- from careamics.config import DataConfig
11
+ from careamics.config.data import DataConfig, GeneralDataConfig, N2VDataConfig
13
12
  from careamics.config.support import SupportedData
14
13
  from careamics.config.transformations import TransformModel
15
14
  from careamics.dataset.dataset_utils import (
@@ -119,7 +118,7 @@ class TrainDataModule(L.LightningDataModule):
119
118
 
120
119
  def __init__(
121
120
  self,
122
- data_config: DataConfig,
121
+ data_config: GeneralDataConfig,
123
122
  train_data: Union[Path, str, NDArray],
124
123
  val_data: Optional[Union[Path, str, NDArray]] = None,
125
124
  train_data_target: Optional[Union[Path, str, NDArray]] = None,
@@ -219,7 +218,7 @@ class TrainDataModule(L.LightningDataModule):
219
218
  )
220
219
 
221
220
  # configuration
222
- self.data_config: DataConfig = data_config
221
+ self.data_config: GeneralDataConfig = data_config
223
222
  self.data_type: str = data_config.data_type
224
223
  self.batch_size: int = data_config.batch_size
225
224
  self.use_in_memory: bool = use_in_memory
@@ -261,11 +260,6 @@ class TrainDataModule(L.LightningDataModule):
261
260
 
262
261
  self.extension_filter: str = extension_filter
263
262
 
264
- # Pytorch dataloader parameters
265
- self.dataloader_params: dict[str, Any] = (
266
- data_config.dataloader_params if data_config.dataloader_params else {}
267
- )
268
-
269
263
  def prepare_data(self) -> None:
270
264
  """
271
265
  Hook used to prepare the data before calling `setup`.
@@ -447,21 +441,17 @@ class TrainDataModule(L.LightningDataModule):
447
441
  Any
448
442
  Training dataloader.
449
443
  """
450
- # check because iterable dataset cannot be shuffled
451
- if not isinstance(self.train_dataset, IterableDataset):
452
- if ("shuffle" in self.dataloader_params) and (
453
- not self.dataloader_params["shuffle"]
454
- ):
455
- warn(
456
- "Dataloader parameters include `shuffle=False`, this will be "
457
- "passed to the training dataloader and may result in bad results.",
458
- stacklevel=1,
459
- )
460
- else:
461
- self.dataloader_params["shuffle"] = True
444
+ train_dataloader_params = self.data_config.train_dataloader_params.copy()
445
+
446
+ # NOTE: When next-gen datasets are completed this can be removed
447
+ # iterable dataset cannot be shuffled
448
+ if isinstance(self.train_dataset, IterableDataset):
449
+ del train_dataloader_params["shuffle"]
462
450
 
463
451
  return DataLoader(
464
- self.train_dataset, batch_size=self.batch_size, **self.dataloader_params
452
+ self.train_dataset,
453
+ batch_size=self.batch_size,
454
+ **train_dataloader_params,
465
455
  )
466
456
 
467
457
  def val_dataloader(self) -> Any:
@@ -476,6 +466,7 @@ class TrainDataModule(L.LightningDataModule):
476
466
  return DataLoader(
477
467
  self.val_dataset,
478
468
  batch_size=self.batch_size,
469
+ **self.data_config.val_dataloader_params,
479
470
  )
480
471
 
481
472
 
@@ -502,12 +493,23 @@ def create_train_datamodule(
502
493
  """Create a TrainDataModule.
503
494
 
504
495
  This function is used to explicitly pass the parameters usually contained in a
505
- `data_model` configuration to a TrainDataModule.
496
+ `GenericDataConfig` to a TrainDataModule.
506
497
 
507
498
  Since the lightning datamodule has no access to the model, make sure that the
508
499
  parameters passed to the datamodule are consistent with the model's requirements and
509
500
  are coherent.
510
501
 
502
+ By default, the train DataModule will be set for Noise2Void if no target data is
503
+ provided. That means that it will add a `N2VManipulateModel` transformation to the
504
+ list of augmentations. The default augmentations are XY flip, XY rotation, and N2V
505
+ pixel manipulation. If you pass a training target data, the default behaviour is to
506
+ train a supervised model. It will use the default XY flip and rotation
507
+ augmentations.
508
+
509
+ To use a different set of transformations, you can pass a list of transforms to
510
+ `transforms`. Note that if you intend to use Noise2Void, you should add
511
+ `N2VManipulateModel` as the last transform in the list of transformations.
512
+
511
513
  The data module can be used with Path, str or numpy arrays. In the case of
512
514
  numpy arrays, it loads and computes all the patches in memory. For Path and str
513
515
  inputs, it calculates the total file size and estimate whether it can fit in
@@ -518,11 +520,6 @@ def create_train_datamodule(
518
520
  To use array data, set `data_type` to `array` and pass a numpy array to
519
521
  `train_data`.
520
522
 
521
- In particular, N2V requires a specific transformation (N2V manipulates), which is
522
- not compatible with supervised training. The default transformations applied to the
523
- training patches are defined in `careamics.config.data_model`. To use different
524
- transformations, pass a list of transforms. See examples for more details.
525
-
526
523
  By default, CAREamics only supports types defined in
527
524
  `careamics.config.support.SupportedData`. To read custom data types, you can set
528
525
  `data_type` to `custom` and provide a function that returns a numpy array from a
@@ -627,12 +624,12 @@ def create_train_datamodule(
627
624
  transforms:
628
625
  >>> import numpy as np
629
626
  >>> from careamics.lightning import create_train_datamodule
627
+ >>> from careamics.config.transformations import XYFlipModel, N2VManipulateModel
630
628
  >>> from careamics.config.support import SupportedTransform
631
629
  >>> my_array = np.arange(256).reshape(16, 16)
632
630
  >>> my_transforms = [
633
- ... {
634
- ... "name": SupportedTransform.XY_FLIP.value,
635
- ... }
631
+ ... XYFlipModel(flip_y=False),
632
+ ... N2VManipulateModel()
636
633
  ... ]
637
634
  >>> data_module = create_train_datamodule(
638
635
  ... train_data=my_array,
@@ -659,21 +656,15 @@ def create_train_datamodule(
659
656
  if transforms is not None:
660
657
  data_dict["transforms"] = transforms
661
658
 
662
- # validate configuration
663
- data_config = DataConfig(**data_dict)
659
+ # TODO not compatible with HDN, consider adding an argument for n2v/hdn
660
+ if train_target_data is None:
661
+ data_config: GeneralDataConfig = N2VDataConfig(**data_dict)
662
+ assert isinstance(data_config, N2VDataConfig)
664
663
 
665
- # N2V specific checks, N2V, structN2V, and transforms
666
- if data_config.has_n2v_manipulate():
667
- # there is not target, n2v2 and structN2V can be changed
668
- if train_target_data is None:
669
- data_config.set_N2V2(use_n2v2)
670
- data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
671
- else:
672
- raise ValueError(
673
- "Cannot have both supervised training (target data) and "
674
- "N2V manipulation in the transforms. Pass a list of transforms "
675
- "that is compatible with your supervised training."
676
- )
664
+ data_config.set_n2v2(use_n2v2)
665
+ data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
666
+ else:
667
+ data_config = DataConfig(**data_dict)
677
668
 
678
669
  # sanity check on the dataloader parameters
679
670
  if "batch_size" in dataloader_params:
@@ -1,13 +1,13 @@
1
1
  """Losses module."""
2
2
 
3
3
  __all__ = [
4
+ "denoisplit_loss",
5
+ "denoisplit_musplit_loss",
4
6
  "loss_factory",
5
7
  "mae_loss",
6
8
  "mse_loss",
7
- "n2v_loss",
8
- "denoisplit_loss",
9
9
  "musplit_loss",
10
- "denoisplit_musplit_loss",
10
+ "n2v_loss",
11
11
  ]
12
12
 
13
13
  from .fcn.losses import mae_loss, mse_loss, n2v_loss
@@ -1,4 +1,4 @@
1
- from typing import Union
1
+ from typing import Union, Optional
2
2
 
3
3
  import numpy as np
4
4
  import torch
@@ -34,9 +34,6 @@ class Calibration:
34
34
  self._bins = num_bins
35
35
  self._bin_boundaries = None
36
36
 
37
- def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray:
38
- return np.exp(logvar / 2)
39
-
40
37
  def compute_bin_boundaries(self, predict_std: np.ndarray) -> np.ndarray:
41
38
  """Compute the bin boundaries for `num_bins` bins and predicted std values."""
42
39
  min_std = np.min(predict_std)
@@ -104,65 +101,75 @@ class Calibration:
104
101
  )
105
102
  rmse_stderr = np.sqrt(stderr) if stderr is not None else None
106
103
 
107
- bin_var = np.mean((std_ch[bin_mask] ** 2))
104
+ bin_var = np.mean(std_ch[bin_mask] ** 2)
108
105
  stats_dict[ch_idx]["rmse"].append(bin_error)
109
106
  stats_dict[ch_idx]["rmse_err"].append(rmse_stderr)
110
107
  stats_dict[ch_idx]["rmv"].append(np.sqrt(bin_var))
111
108
  stats_dict[ch_idx]["bin_count"].append(bin_size)
109
+ self.stats_dict = stats_dict
112
110
  return stats_dict
113
111
 
112
+ def get_calibrated_factor_for_stdev(
113
+ self,
114
+ pred: Optional[np.ndarray] = None,
115
+ pred_std: Optional[np.ndarray] = None,
116
+ target: Optional[np.ndarray] = None,
117
+ q_s: float = 0.00001,
118
+ q_e: float = 0.99999,
119
+ ) -> dict[str, float]:
120
+ """Calibrate the uncertainty by multiplying the predicted std with a scalar.
114
121
 
115
- def get_calibrated_factor_for_stdev(
116
- pred: Union[np.ndarray, torch.Tensor],
117
- pred_std: Union[np.ndarray, torch.Tensor],
118
- target: Union[np.ndarray, torch.Tensor],
119
- q_s: float = 0.00001,
120
- q_e: float = 0.99999,
121
- num_bins: int = 30,
122
- ) -> dict[str, float]:
123
- """Calibrate the uncertainty by multiplying the predicted std with a scalar.
124
-
125
- Parameters
126
- ----------
127
- pred : Union[np.ndarray, torch.Tensor]
128
- Predicted image, shape (n, h, w, c).
129
- pred_std : Union[np.ndarray, torch.Tensor]
130
- Predicted std, shape (n, h, w, c).
131
- target : Union[np.ndarray, torch.Tensor]
132
- Target image, shape (n, h, w, c).
133
- q_s : float, optional
134
- Start quantile, by default 0.00001.
135
- q_e : float, optional
136
- End quantile, by default 0.99999.
137
- num_bins : int, optional
138
- Number of bins to use for calibration, by default 30.
139
-
140
- Returns
141
- -------
142
- dict[str, float]
143
- Calibrated factor for each channel (slope + intercept).
144
- """
145
- calib = Calibration(num_bins=num_bins)
146
- stats_dict = calib.compute_stats(pred, pred_std, target)
147
- outputs = {}
148
- for ch_idx in stats_dict.keys():
149
- y = stats_dict[ch_idx]["rmse"]
150
- x = stats_dict[ch_idx]["rmv"]
151
- count = stats_dict[ch_idx]["bin_count"]
152
-
153
- first_idx = get_first_index(count, q_s)
154
- last_idx = get_last_index(count, q_e)
155
- x = x[first_idx:-last_idx]
156
- y = y[first_idx:-last_idx]
157
- slope, intercept, *_ = stats.linregress(x, y)
158
- output = {"scalar": slope, "offset": intercept}
159
- outputs[ch_idx] = output
160
- return outputs
122
+ Parameters
123
+ ----------
124
+ stats_dict : dict[int, dict[str, Union[np.ndarray, list]]]
125
+ Dictionary containing the stats for each channel.
126
+ q_s : float, optional
127
+ Start quantile, by default 0.00001.
128
+ q_e : float, optional
129
+ End quantile, by default 0.99999.
130
+
131
+ Returns
132
+ -------
133
+ dict[str, float]
134
+ Calibrated factor for each channel (slope + intercept).
135
+ """
136
+ if not hasattr(self, "stats_dict"):
137
+ print("No stats found. Computing stats...")
138
+ if any(v is None for v in [pred, pred_std, target]):
139
+ raise ValueError("pred, pred_std, and target must be provided.")
140
+ self.stats_dict = self.compute_stats(
141
+ pred=pred, pred_std=pred_std, target=target
142
+ )
143
+ outputs = {}
144
+ for ch_idx in self.stats_dict.keys():
145
+ y = self.stats_dict[ch_idx]["rmse"]
146
+ x = self.stats_dict[ch_idx]["rmv"]
147
+ count = self.stats_dict[ch_idx]["bin_count"]
148
+
149
+ first_idx = get_first_index(count, q_s)
150
+ last_idx = get_last_index(count, q_e)
151
+ x = x[first_idx:-last_idx]
152
+ y = y[first_idx:-last_idx]
153
+ slope, intercept, *_ = stats.linregress(x, y)
154
+ output = {"scalar": slope, "offset": intercept}
155
+ outputs[ch_idx] = output
156
+ factors = self.get_factors_array(factors_dict=outputs)
157
+ return outputs, factors
158
+
159
+ def get_factors_array(self, factors_dict: list[dict]):
160
+ """Get the calibration factors as a numpy array."""
161
+ calib_scalar = [factors_dict[i]["scalar"] for i in range(len(factors_dict))]
162
+ calib_scalar = np.array(calib_scalar).reshape(1, 1, 1, -1)
163
+ calib_offset = [
164
+ factors_dict[i].get("offset", 0.0) for i in range(len(factors_dict))
165
+ ]
166
+ calib_offset = np.array(calib_offset).reshape(1, 1, 1, -1)
167
+ return {"scalar": calib_scalar, "offset": calib_offset}
161
168
 
162
169
 
163
170
  def plot_calibration(ax, calibration_stats):
164
- first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
165
- last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
171
+ first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.0001)
172
+ last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.9999)
166
173
  ax.plot(
167
174
  calibration_stats[0]["rmv"][first_idx:-last_idx],
168
175
  calibration_stats[0]["rmse"][first_idx:-last_idx],
@@ -170,15 +177,15 @@ def plot_calibration(ax, calibration_stats):
170
177
  label=r"$\hat{C}_0$: Ch1",
171
178
  )
172
179
 
173
- first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001)
174
- last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999)
180
+ first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.0001)
181
+ last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.9999)
175
182
  ax.plot(
176
183
  calibration_stats[1]["rmv"][first_idx:-last_idx],
177
184
  calibration_stats[1]["rmse"][first_idx:-last_idx],
178
185
  "o",
179
- label=r"$\hat{C}_1: : Ch2$",
186
+ label=r"$\hat{C}_1$: Ch2",
180
187
  )
181
-
188
+ # TODO add multichannel
182
189
  ax.set_xlabel("RMV")
183
190
  ax.set_ylabel("RMSE")
184
191
  ax.legend()
@@ -97,7 +97,8 @@ class LCMultiChDloader(MultiChDloader):
97
97
  ]
98
98
 
99
99
  self.N = len(t_list)
100
- self.set_img_sz(self._img_sz, self._grid_sz)
100
+ # TODO where tf is self._img_sz defined?
101
+ self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
101
102
  print(
102
103
  f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
103
104
  )
@@ -359,8 +359,8 @@ class MultiChDloader:
359
359
  self._noise_data = self._noise_data[
360
360
  t_list, h_start:h_end, w_start:w_end, :
361
361
  ].copy()
362
-
363
- self.set_img_sz(self._img_sz, self._grid_sz)
362
+ # TODO where tf is self._img_sz defined?
363
+ self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
364
364
  print(
365
365
  f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
366
366
  )
@@ -3,7 +3,7 @@ from enum import Enum
3
3
 
4
4
  class DataType(Enum):
5
5
  Elisa3DData = 0
6
- NicolaData = 1
6
+ HTLIF24Data = 1
7
7
  Pavia3SeqData = 2
8
8
  TavernaSox2GolgiV2 = 3
9
9
  Dao3ChannelWithInput = 4