careamics 0.0.4.2__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (43) hide show
  1. careamics/careamist.py +235 -25
  2. careamics/cli/conf.py +19 -30
  3. careamics/cli/main.py +111 -10
  4. careamics/cli/utils.py +29 -0
  5. careamics/config/__init__.py +2 -0
  6. careamics/config/architectures/lvae_model.py +104 -21
  7. careamics/config/configuration_factory.py +49 -45
  8. careamics/config/configuration_model.py +2 -2
  9. careamics/config/likelihood_model.py +7 -6
  10. careamics/config/loss_model.py +56 -0
  11. careamics/config/nm_model.py +24 -24
  12. careamics/config/vae_algorithm_model.py +14 -13
  13. careamics/dataset/dataset_utils/running_stats.py +22 -23
  14. careamics/lightning/lightning_module.py +58 -27
  15. careamics/lightning/train_data_module.py +15 -1
  16. careamics/losses/loss_factory.py +1 -85
  17. careamics/losses/lvae/losses.py +223 -164
  18. careamics/lvae_training/calibration.py +184 -0
  19. careamics/lvae_training/dataset/config.py +2 -2
  20. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  21. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  22. careamics/lvae_training/dataset/types.py +15 -26
  23. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  24. careamics/lvae_training/eval_utils.py +125 -213
  25. careamics/model_io/bioimage/_readme_factory.py +25 -33
  26. careamics/model_io/bioimage/cover_factory.py +171 -0
  27. careamics/model_io/bioimage/model_description.py +39 -17
  28. careamics/model_io/bmz_io.py +36 -25
  29. careamics/models/layers.py +6 -4
  30. careamics/models/lvae/layers.py +348 -975
  31. careamics/models/lvae/likelihoods.py +10 -8
  32. careamics/models/lvae/lvae.py +214 -272
  33. careamics/models/lvae/noise_models.py +179 -112
  34. careamics/models/lvae/stochastic.py +393 -0
  35. careamics/models/lvae/utils.py +82 -73
  36. careamics/utils/lightning_utils.py +57 -0
  37. careamics/utils/serializers.py +2 -0
  38. careamics/utils/torch_utils.py +1 -1
  39. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/METADATA +12 -9
  40. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/RECORD +43 -37
  41. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/WHEEL +1 -1
  42. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/entry_points.txt +0 -0
  43. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,184 @@
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ from scipy import stats
6
+
7
+
8
+ def get_last_index(bin_count, quantile):
9
+ cumsum = np.cumsum(bin_count)
10
+ normalized_cumsum = cumsum / cumsum[-1]
11
+ for i in range(1, len(normalized_cumsum)):
12
+ if normalized_cumsum[-i] < quantile:
13
+ return i - 1
14
+ return None
15
+
16
+
17
+ def get_first_index(bin_count, quantile):
18
+ cumsum = np.cumsum(bin_count)
19
+ normalized_cumsum = cumsum / cumsum[-1]
20
+ for i in range(len(normalized_cumsum)):
21
+ if normalized_cumsum[i] > quantile:
22
+ return i
23
+ return None
24
+
25
+
26
+ class Calibration:
27
+ """Calibrate the uncertainty computed over samples from LVAE model.
28
+
29
+ Calibration is done by learning a scalar that maps the pixel-wise standard
30
+ deviation of the the predicted samples into the actual prediction error.
31
+ """
32
+
33
+ def __init__(self, num_bins: int = 15):
34
+ self._bins = num_bins
35
+ self._bin_boundaries = None
36
+
37
+ def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray:
38
+ return np.exp(logvar / 2)
39
+
40
+ def compute_bin_boundaries(self, predict_std: np.ndarray) -> np.ndarray:
41
+ """Compute the bin boundaries for `num_bins` bins and predicted std values."""
42
+ min_std = np.min(predict_std)
43
+ max_std = np.max(predict_std)
44
+ return np.linspace(min_std, max_std, self._bins + 1)
45
+
46
+ def compute_stats(
47
+ self, pred: np.ndarray, pred_std: np.ndarray, target: np.ndarray
48
+ ) -> dict[int, dict[str, Union[np.ndarray, list]]]:
49
+ """
50
+ It computes the bin-wise RMSE and RMV for each channel of the predicted image.
51
+
52
+ Recall that:
53
+ - RMSE = np.sqrt((pred - target)**2 / num_pixels)
54
+ - RMV = np.sqrt(np.mean(pred_std**2))
55
+
56
+ ALGORITHM
57
+ - For each channel:
58
+ - Given the bin boundaries, assign pixels of `std_ch` array to a specific bin index.
59
+ - For each bin index:
60
+ - Compute the RMSE, RMV, and number of pixels for that bin.
61
+
62
+ NOTE: each channel of the predicted image/logvar has its own stats.
63
+
64
+ Parameters
65
+ ----------
66
+ pred: np.ndarray
67
+ Predicted patches, shape (n, h, w, c).
68
+ pred_std: np.ndarray
69
+ Std computed over the predicted patches, shape (n, h, w, c).
70
+ target: np.ndarray
71
+ Target GT image, shape (n, h, w, c).
72
+ """
73
+ self._bin_boundaries = {}
74
+ stats_dict = {}
75
+ for ch_idx in range(pred.shape[-1]):
76
+ stats_dict[ch_idx] = {
77
+ "bin_count": [],
78
+ "rmv": [],
79
+ "rmse": [],
80
+ "bin_boundaries": None,
81
+ "bin_matrix": [],
82
+ "rmse_err": [],
83
+ }
84
+ pred_ch = pred[..., ch_idx]
85
+ std_ch = pred_std[..., ch_idx]
86
+ target_ch = target[..., ch_idx]
87
+ boundaries = self.compute_bin_boundaries(std_ch)
88
+ stats_dict[ch_idx]["bin_boundaries"] = boundaries
89
+ bin_matrix = np.digitize(std_ch.reshape(-1), boundaries)
90
+ bin_matrix = bin_matrix.reshape(std_ch.shape)
91
+ stats_dict[ch_idx]["bin_matrix"] = bin_matrix
92
+ error = (pred_ch - target_ch) ** 2
93
+ for bin_idx in range(1, 1 + self._bins):
94
+ bin_mask = bin_matrix == bin_idx
95
+ bin_error = error[bin_mask]
96
+ bin_size = np.sum(bin_mask)
97
+ bin_error = (
98
+ np.sqrt(np.sum(bin_error) / bin_size) if bin_size > 0 else None
99
+ )
100
+ stderr = (
101
+ np.std(error[bin_mask]) / np.sqrt(bin_size)
102
+ if bin_size > 0
103
+ else None
104
+ )
105
+ rmse_stderr = np.sqrt(stderr) if stderr is not None else None
106
+
107
+ bin_var = np.mean((std_ch[bin_mask] ** 2))
108
+ stats_dict[ch_idx]["rmse"].append(bin_error)
109
+ stats_dict[ch_idx]["rmse_err"].append(rmse_stderr)
110
+ stats_dict[ch_idx]["rmv"].append(np.sqrt(bin_var))
111
+ stats_dict[ch_idx]["bin_count"].append(bin_size)
112
+ return stats_dict
113
+
114
+
115
+ def get_calibrated_factor_for_stdev(
116
+ pred: Union[np.ndarray, torch.Tensor],
117
+ pred_std: Union[np.ndarray, torch.Tensor],
118
+ target: Union[np.ndarray, torch.Tensor],
119
+ q_s: float = 0.00001,
120
+ q_e: float = 0.99999,
121
+ num_bins: int = 30,
122
+ ) -> dict[str, float]:
123
+ """Calibrate the uncertainty by multiplying the predicted std with a scalar.
124
+
125
+ Parameters
126
+ ----------
127
+ pred : Union[np.ndarray, torch.Tensor]
128
+ Predicted image, shape (n, h, w, c).
129
+ pred_std : Union[np.ndarray, torch.Tensor]
130
+ Predicted std, shape (n, h, w, c).
131
+ target : Union[np.ndarray, torch.Tensor]
132
+ Target image, shape (n, h, w, c).
133
+ q_s : float, optional
134
+ Start quantile, by default 0.00001.
135
+ q_e : float, optional
136
+ End quantile, by default 0.99999.
137
+ num_bins : int, optional
138
+ Number of bins to use for calibration, by default 30.
139
+
140
+ Returns
141
+ -------
142
+ dict[str, float]
143
+ Calibrated factor for each channel (slope + intercept).
144
+ """
145
+ calib = Calibration(num_bins=num_bins)
146
+ stats_dict = calib.compute_stats(pred, pred_std, target)
147
+ outputs = {}
148
+ for ch_idx in stats_dict.keys():
149
+ y = stats_dict[ch_idx]["rmse"]
150
+ x = stats_dict[ch_idx]["rmv"]
151
+ count = stats_dict[ch_idx]["bin_count"]
152
+
153
+ first_idx = get_first_index(count, q_s)
154
+ last_idx = get_last_index(count, q_e)
155
+ x = x[first_idx:-last_idx]
156
+ y = y[first_idx:-last_idx]
157
+ slope, intercept, *_ = stats.linregress(x, y)
158
+ output = {"scalar": slope, "offset": intercept}
159
+ outputs[ch_idx] = output
160
+ return outputs
161
+
162
+
163
+ def plot_calibration(ax, calibration_stats):
164
+ first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
165
+ last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
166
+ ax.plot(
167
+ calibration_stats[0]["rmv"][first_idx:-last_idx],
168
+ calibration_stats[0]["rmse"][first_idx:-last_idx],
169
+ "o",
170
+ label=r"$\hat{C}_0$: Ch1",
171
+ )
172
+
173
+ first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001)
174
+ last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999)
175
+ ax.plot(
176
+ calibration_stats[1]["rmv"][first_idx:-last_idx],
177
+ calibration_stats[1]["rmse"][first_idx:-last_idx],
178
+ "o",
179
+ label=r"$\hat{C}_1: : Ch2$",
180
+ )
181
+
182
+ ax.set_xlabel("RMV")
183
+ ax.set_ylabel("RMSE")
184
+ ax.legend()
@@ -2,7 +2,7 @@ from typing import Any, Optional
2
2
 
3
3
  from pydantic import BaseModel, ConfigDict
4
4
 
5
- from .types import DataType, DataSplitType, TilingMode
5
+ from .types import DataSplitType, DataType, TilingMode
6
6
 
7
7
 
8
8
  # TODO: check if any bool logic can be removed
@@ -40,7 +40,7 @@ class DatasetConfig(BaseModel):
40
40
  start_alpha: Optional[Any] = None
41
41
  end_alpha: Optional[Any] = None
42
42
 
43
- image_size: int
43
+ image_size: tuple # TODO: revisit, new model_config uses tuple
44
44
  """Size of one patch of data"""
45
45
 
46
46
  grid_size: Optional[int] = None
@@ -91,18 +91,18 @@ class MultiChDloader:
91
91
  self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = None
92
92
 
93
93
  self._img_sz = self._grid_sz = self._repeat_factor = self.idx_manager = None
94
+
95
+ # changed set_img_sz because "grid_size" in data_config returns false
96
+ try:
97
+ grid_size = data_config.grid_size
98
+ except AttributeError:
99
+ grid_size = data_config.image_size
100
+
94
101
  if self._is_train:
95
102
  self._start_alpha_arr = data_config.start_alpha
96
103
  self._end_alpha_arr = data_config.end_alpha
97
104
 
98
- self.set_img_sz(
99
- data_config.image_size,
100
- (
101
- data_config.grid_size
102
- if "grid_size" in data_config
103
- else data_config.image_size
104
- ),
105
- )
105
+ self.set_img_sz(data_config.image_size, grid_size)
106
106
 
107
107
  if self._validtarget_rand_fract is not None:
108
108
  self._train_index_switcher = IndexSwitcher(
@@ -110,15 +110,7 @@ class MultiChDloader:
110
110
  )
111
111
 
112
112
  else:
113
-
114
- self.set_img_sz(
115
- data_config.image_size,
116
- (
117
- data_config.grid_size
118
- if "grid_size" in data_config
119
- else data_config.image_size
120
- ),
121
- )
113
+ self.set_img_sz(data_config.image_size, grid_size)
122
114
 
123
115
  self._return_alpha = False
124
116
  self._return_index = False
@@ -401,8 +393,8 @@ class MultiChDloader:
401
393
  image_size: size of one patch
402
394
  grid_size: frame is divided into square grids of this size. A patch centered on a grid having size `image_size` is returned.
403
395
  """
404
-
405
- self._img_sz = image_size
396
+ # hacky way to deal with image shape from new conf
397
+ self._img_sz = image_size[-1] # TODO revisit!
406
398
  self._grid_sz = grid_size
407
399
  shape = self._data.shape
408
400
 
@@ -1,12 +1,13 @@
1
- from typing import Union, Callable, Sequence
1
+ from collections.abc import Sequence
2
+ from typing import Callable, Union
2
3
 
3
4
  import numpy as np
4
5
  from numpy.typing import NDArray
5
6
 
6
7
  from .config import DatasetConfig
8
+ from .lc_dataset import LCMultiChDloader
7
9
  from .multich_dataset import MultiChDloader
8
10
  from .types import DataSplitType
9
- from .lc_dataset import LCMultiChDloader
10
11
 
11
12
 
12
13
  class TwoChannelData(Sequence):
@@ -2,32 +2,21 @@ from enum import Enum
2
2
 
3
3
 
4
4
  class DataType(Enum):
5
- MNIST = 0
6
- Places365 = 1
7
- NotMNIST = 2
8
- OptiMEM100_014 = 3
9
- CustomSinosoid = 4
10
- Prevedel_EMBL = 5
11
- AllenCellMito = 6
12
- SeparateTiffData = 7
13
- CustomSinosoidThreeCurve = 8
14
- SemiSupBloodVesselsEMBL = 9
15
- Pavia2 = 10
16
- Pavia2VanillaSplitting = 11
17
- ExpansionMicroscopyMitoTub = 12
18
- ShroffMitoEr = 13
19
- HTIba1Ki67 = 14
20
- BSD68 = 15
21
- BioSR_MRC = 16
22
- TavernaSox2Golgi = 17
23
- Dao3Channel = 18
24
- ExpMicroscopyV2 = 19
25
- Dao3ChannelWithInput = 20
26
- TavernaSox2GolgiV2 = 21
27
- TwoDset = 22
28
- PredictedTiffData = 23
29
- Pavia3SeqData = 24
30
- NicolaData = 25
5
+ Elisa3DData = 0
6
+ NicolaData = 1
7
+ Pavia3SeqData = 2
8
+ TavernaSox2GolgiV2 = 3
9
+ Dao3ChannelWithInput = 4
10
+ ExpMicroscopyV1 = 5
11
+ ExpMicroscopyV2 = 6
12
+ Dao3Channel = 7
13
+ TavernaSox2Golgi = 8
14
+ HTIba1Ki67 = 9
15
+ OptiMEM100_014 = 10
16
+ SeparateTiffData = 11
17
+ BioSR_MRC = 12
18
+ PunctaRemoval = 13 # for the case when we have a set of differently sized crops for each channel.
19
+ Care3D = 14
31
20
 
32
21
 
33
22
  class DataSplitType(Enum):
@@ -151,10 +151,10 @@ class GridIndexManager:
151
151
  self.data_shape
152
152
  ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
153
153
  assert dim >= 0, "Dimension must be greater than or equal to 0"
154
- assert dim_index < self.get_individual_dim_grid_count(
155
- dim
156
- ), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"
157
-
154
+ # assert dim_index < self.get_individual_dim_grid_count(
155
+ # dim
156
+ # ), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"
157
+ # TODO comented out this shit cuz I have no interest to dig why it's failing at this point !
158
158
  if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
159
159
  return dim_index
160
160
  elif self.tiling_mode == TilingMode.PadBoundary: