careamics 0.0.3__py3-none-any.whl → 0.0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (56) hide show
  1. careamics/careamist.py +25 -17
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/architectures/lvae_model.py +0 -4
  6. careamics/config/configuration_factory.py +480 -177
  7. careamics/config/configuration_model.py +1 -2
  8. careamics/config/data_model.py +1 -15
  9. careamics/config/fcn_algorithm_model.py +14 -9
  10. careamics/config/likelihood_model.py +21 -4
  11. careamics/config/nm_model.py +31 -5
  12. careamics/config/optimizer_models.py +3 -1
  13. careamics/config/support/supported_optimizers.py +1 -1
  14. careamics/config/support/supported_transforms.py +1 -0
  15. careamics/config/training_model.py +35 -6
  16. careamics/config/transformations/__init__.py +4 -1
  17. careamics/config/transformations/transform_union.py +20 -0
  18. careamics/config/vae_algorithm_model.py +2 -36
  19. careamics/dataset/tiling/lvae_tiled_patching.py +90 -8
  20. careamics/lightning/lightning_module.py +10 -8
  21. careamics/lightning/train_data_module.py +2 -2
  22. careamics/losses/loss_factory.py +3 -3
  23. careamics/losses/lvae/losses.py +2 -2
  24. careamics/lvae_training/dataset/__init__.py +15 -0
  25. careamics/lvae_training/dataset/{vae_data_config.py → config.py} +25 -81
  26. careamics/lvae_training/dataset/lc_dataset.py +28 -20
  27. careamics/lvae_training/dataset/{vae_dataset.py → multich_dataset.py} +91 -51
  28. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  29. careamics/lvae_training/dataset/types.py +43 -0
  30. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  31. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  32. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  33. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  34. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  35. careamics/lvae_training/eval_utils.py +109 -64
  36. careamics/lvae_training/get_config.py +1 -1
  37. careamics/lvae_training/train_lvae.py +1 -1
  38. careamics/model_io/bioimage/bioimage_utils.py +4 -2
  39. careamics/model_io/bmz_io.py +6 -5
  40. careamics/models/lvae/likelihoods.py +18 -9
  41. careamics/models/lvae/lvae.py +12 -16
  42. careamics/models/lvae/noise_models.py +1 -1
  43. careamics/transforms/compose.py +90 -15
  44. careamics/transforms/n2v_manipulate.py +6 -2
  45. careamics/transforms/normalize.py +14 -3
  46. careamics/transforms/xy_flip.py +16 -6
  47. careamics/transforms/xy_random_rotate90.py +16 -7
  48. careamics/utils/metrics.py +204 -24
  49. careamics/utils/serializers.py +60 -0
  50. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/METADATA +4 -3
  51. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/RECORD +54 -43
  52. careamics-0.0.4.1.dist-info/entry_points.txt +2 -0
  53. careamics/lvae_training/dataset/data_utils.py +0 -701
  54. careamics/lvae_training/dataset/lc_dataset_config.py +0 -13
  55. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/WHEEL +0 -0
  56. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,232 @@
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+
5
+ from careamics.lvae_training.dataset.types import TilingMode
6
+
7
+
8
+ @dataclass
9
+ class GridIndexManager:
10
+ data_shape: tuple
11
+ grid_shape: tuple
12
+ patch_shape: tuple
13
+ tiling_mode: TilingMode
14
+
15
+ # Patch is centered on index in the grid, grid size not used in training,
16
+ # used only during val / test, grid size controls the overlap of the patches
17
+ # in training you only get random patches every time
18
+ # For borders - just cropped the data, so it perfectly divisible
19
+
20
+ def __post_init__(self):
21
+ assert len(self.data_shape) == len(
22
+ self.grid_shape
23
+ ), f"Data shape:{self.data_shape} and grid size:{self.grid_shape} must have the same dimension"
24
+ assert len(self.data_shape) == len(
25
+ self.patch_shape
26
+ ), f"Data shape:{self.data_shape} and patch shape:{self.patch_shape} must have the same dimension"
27
+ innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
28
+ for dim, pad in enumerate(innerpad):
29
+ if pad < 0:
30
+ raise ValueError(
31
+ f"Patch shape:{self.patch_shape} must be greater than or equal to grid shape:{self.grid_shape} in dimension {dim}"
32
+ )
33
+ if pad % 2 != 0:
34
+ raise ValueError(
35
+ f"Patch shape:{self.patch_shape} must have even padding in dimension {dim}"
36
+ )
37
+
38
+ def patch_offset(self):
39
+ return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
40
+
41
+ def get_individual_dim_grid_count(self, dim: int):
42
+ """
43
+ Returns the number of the grid in the specified dimension, ignoring all other dimensions.
44
+ """
45
+ assert dim < len(
46
+ self.data_shape
47
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
48
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
49
+
50
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
51
+ return self.data_shape[dim]
52
+ elif self.tiling_mode == TilingMode.PadBoundary:
53
+ return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
54
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
55
+ excess_size = self.patch_shape[dim] - self.grid_shape[dim]
56
+ return int(
57
+ np.ceil((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
58
+ )
59
+ else:
60
+ excess_size = self.patch_shape[dim] - self.grid_shape[dim]
61
+ return int(
62
+ np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
63
+ )
64
+
65
+ def total_grid_count(self):
66
+ """
67
+ Returns the total number of grids in the dataset.
68
+ """
69
+ return self.grid_count(0) * self.get_individual_dim_grid_count(0)
70
+
71
+ def grid_count(self, dim: int):
72
+ """
73
+ Returns the total number of grids for one value in the specified dimension.
74
+ """
75
+ assert dim < len(
76
+ self.data_shape
77
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
78
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
79
+ if dim == len(self.data_shape) - 1:
80
+ return 1
81
+
82
+ return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1)
83
+
84
+ def get_grid_index(self, dim: int, coordinate: int):
85
+ """
86
+ Returns the index of the grid in the specified dimension.
87
+ """
88
+ assert dim < len(
89
+ self.data_shape
90
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
91
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
92
+ assert (
93
+ coordinate < self.data_shape[dim]
94
+ ), f"Coordinate {coordinate} is out of bounds for data shape {self.data_shape}"
95
+
96
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
97
+ return coordinate
98
+ elif self.tiling_mode == TilingMode.PadBoundary: # self.trim_boundary is False:
99
+ return np.floor(coordinate / self.grid_shape[dim])
100
+ elif self.tiling_mode == TilingMode.TrimBoundary:
101
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
102
+ # can be <0 if coordinate is in [0,grid_shape[dim]]
103
+ return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
104
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
105
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
106
+ if coordinate + self.grid_shape[dim] + excess_size == self.data_shape[dim]:
107
+ return self.get_individual_dim_grid_count(dim) - 1
108
+ else:
109
+ # can be <0 if coordinate is in [0,grid_shape[dim]]
110
+ return max(
111
+ 0, np.floor((coordinate - excess_size) / self.grid_shape[dim])
112
+ )
113
+
114
+ else:
115
+ raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
116
+
117
+ def dataset_idx_from_grid_idx(self, grid_idx: tuple):
118
+ """
119
+ Returns the index of the grid in the dataset.
120
+ """
121
+ assert len(grid_idx) == len(
122
+ self.data_shape
123
+ ), f"Dimension indices {grid_idx} must have the same dimension as data shape {self.data_shape}"
124
+ index = 0
125
+ for dim in range(len(grid_idx)):
126
+ index += grid_idx[dim] * self.grid_count(dim)
127
+ return index
128
+
129
+ def get_patch_location_from_dataset_idx(self, dataset_idx: int):
130
+ """
131
+ Returns the patch location of the grid in the dataset.
132
+ """
133
+ grid_location = self.get_location_from_dataset_idx(dataset_idx)
134
+ offset = self.patch_offset()
135
+ return tuple(np.array(grid_location) - np.array(offset))
136
+
137
+ def get_dataset_idx_from_grid_location(self, location: tuple):
138
+ assert len(location) == len(
139
+ self.data_shape
140
+ ), f"Location {location} must have the same dimension as data shape {self.data_shape}"
141
+ grid_idx = [
142
+ self.get_grid_index(dim, location[dim]) for dim in range(len(location))
143
+ ]
144
+ return self.dataset_idx_from_grid_idx(tuple(grid_idx))
145
+
146
+ def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int):
147
+ """
148
+ Returns the grid-start coordinate of the grid in the specified dimension.
149
+ """
150
+ assert dim < len(
151
+ self.data_shape
152
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
153
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
154
+ assert dim_index < self.get_individual_dim_grid_count(
155
+ dim
156
+ ), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"
157
+
158
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
159
+ return dim_index
160
+ elif self.tiling_mode == TilingMode.PadBoundary:
161
+ return dim_index * self.grid_shape[dim]
162
+ elif self.tiling_mode == TilingMode.TrimBoundary:
163
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
164
+ return dim_index * self.grid_shape[dim] + excess_size
165
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
166
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
167
+ if dim_index < self.get_individual_dim_grid_count(dim) - 1:
168
+ return dim_index * self.grid_shape[dim] + excess_size
169
+ else:
170
+ # on boundary. grid should be placed such that the patch covers the entire data.
171
+ return self.data_shape[dim] - self.grid_shape[dim] - excess_size
172
+ else:
173
+ raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
174
+
175
+ def get_location_from_dataset_idx(self, dataset_idx: int):
176
+ """
177
+ Returns the start location of the grid in the dataset.
178
+ """
179
+ grid_idx = []
180
+ for dim in range(len(self.data_shape)):
181
+ grid_idx.append(dataset_idx // self.grid_count(dim))
182
+ dataset_idx = dataset_idx % self.grid_count(dim)
183
+ location = [
184
+ self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
185
+ for dim in range(len(self.data_shape))
186
+ ]
187
+ return tuple(location)
188
+
189
+ def on_boundary(self, dataset_idx: int, dim: int, only_end: bool = False):
190
+ """
191
+ Returns True if the grid is on the boundary in the specified dimension.
192
+ """
193
+ assert dim < len(
194
+ self.data_shape
195
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
196
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
197
+
198
+ if dim > 0:
199
+ dataset_idx = dataset_idx % self.grid_count(dim - 1)
200
+
201
+ dim_index = dataset_idx // self.grid_count(dim)
202
+ if only_end:
203
+ return dim_index == self.get_individual_dim_grid_count(dim) - 1
204
+
205
+ return (
206
+ dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
207
+ )
208
+
209
+ def next_grid_along_dim(self, dataset_idx: int, dim: int):
210
+ """
211
+ Returns the index of the grid in the specified dimension in the specified direction.
212
+ """
213
+ assert dim < len(
214
+ self.data_shape
215
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
216
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
217
+ new_idx = dataset_idx + self.grid_count(dim)
218
+ if new_idx >= self.total_grid_count():
219
+ return None
220
+ return new_idx
221
+
222
+ def prev_grid_along_dim(self, dataset_idx: int, dim: int):
223
+ """
224
+ Returns the index of the grid in the specified dimension in the specified direction.
225
+ """
226
+ assert dim < len(
227
+ self.data_shape
228
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
229
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
230
+ new_idx = dataset_idx - self.grid_count(dim)
231
+ if new_idx < 0:
232
+ return None
@@ -0,0 +1,165 @@
1
+ import numpy as np
2
+
3
+
4
+ class IndexSwitcher:
5
+ """
6
+ The idea is to switch from valid indices for target to invalid indices for target.
7
+ If index in invalid for the target, then we return all zero vector as target.
8
+ This combines both logic:
9
+ 1. Using less amount of total data.
10
+ 2. Using less amount of target data but using full data.
11
+ """
12
+
13
+ def __init__(self, idx_manager, data_config, patch_size) -> None:
14
+ self.idx_manager = idx_manager
15
+ self._data_shape = self.idx_manager.get_data_shape()
16
+ self._training_validtarget_fraction = data_config.get(
17
+ "training_validtarget_fraction", 1.0
18
+ )
19
+ self._validtarget_ceilT = int(
20
+ np.ceil(self._data_shape[0] * self._training_validtarget_fraction)
21
+ )
22
+ self._patch_size = patch_size
23
+ assert (
24
+ data_config.deterministic_grid is True
25
+ ), "This only works when the dataset has deterministic grid. Needed randomness comes from this class."
26
+ assert (
27
+ "grid_size" in data_config and data_config.grid_size == 1
28
+ ), "We need a one to one mapping between index and h, w, t"
29
+
30
+ self._h_validmax, self._w_validmax = self.get_reduced_frame_size(
31
+ self._data_shape[:3], self._training_validtarget_fraction
32
+ )
33
+ if self._h_validmax < self._patch_size or self._w_validmax < self._patch_size:
34
+ print(
35
+ "WARNING: The valid target size is smaller than the patch size. This will result in all zero target. so, we are ignoring this frame for target."
36
+ )
37
+ self._h_validmax = 0
38
+ self._w_validmax = 0
39
+
40
+ print(
41
+ f"[{self.__class__.__name__}] Target Indices: [0,{self._validtarget_ceilT - 1}]. Index={self._validtarget_ceilT - 1} has shape [:{self._h_validmax},:{self._w_validmax}]. Available data: {self._data_shape[0]}"
42
+ )
43
+
44
+ def get_valid_target_index(self):
45
+ """
46
+ Returns an index which corresponds to a frame which is expected to have a target.
47
+ """
48
+ _, h, w, _ = self._data_shape
49
+ framepixelcount = h * w
50
+ targetpixels = np.array(
51
+ [framepixelcount] * (self._validtarget_ceilT - 1)
52
+ + [self._h_validmax * self._w_validmax]
53
+ )
54
+ targetpixels = targetpixels / np.sum(targetpixels)
55
+ t = np.random.choice(self._validtarget_ceilT, p=targetpixels)
56
+ # t = np.random.randint(0, self._validtarget_ceilT) if self._validtarget_ceilT >= 1 else 0
57
+ h, w = self.get_valid_target_hw(t)
58
+ index = self.idx_manager.idx_from_hwt(h, w, t)
59
+ # print('Valid', index, h,w,t)
60
+ return index
61
+
62
+ def get_invalid_target_index(self):
63
+ # if self._validtarget_ceilT == 0:
64
+ # TODO: There may not be enough data for this to work. The better way is to skip using 0 for invalid target.
65
+ # t = np.random.randint(1, self._data_shape[0])
66
+ # elif self._validtarget_ceilT < self._data_shape[0]:
67
+ # t = np.random.randint(self._validtarget_ceilT, self._data_shape[0])
68
+ # else:
69
+ # t = self._validtarget_ceilT - 1
70
+ # 5
71
+ # 1.2 => 2
72
+ total_t, h, w, _ = self._data_shape
73
+ framepixelcount = h * w
74
+ available_h = h - self._h_validmax
75
+ if available_h < self._patch_size:
76
+ available_h = 0
77
+ available_w = w - self._w_validmax
78
+ if available_w < self._patch_size:
79
+ available_w = 0
80
+
81
+ targetpixels = np.array(
82
+ [available_h * available_w]
83
+ + [framepixelcount] * (total_t - self._validtarget_ceilT)
84
+ )
85
+ t_probab = targetpixels / np.sum(targetpixels)
86
+ t = np.random.choice(
87
+ np.arange(self._validtarget_ceilT - 1, total_t), p=t_probab
88
+ )
89
+
90
+ h, w = self.get_invalid_target_hw(t)
91
+ index = self.idx_manager.idx_from_hwt(h, w, t)
92
+ # print('Invalid', index, h,w,t)
93
+ return index
94
+
95
+ def get_valid_target_hw(self, t):
96
+ """
97
+ This is the opposite of get_invalid_target_hw. It returns a h,w which is valid for target.
98
+ This is only valid for single frame setup.
99
+ """
100
+ if t == self._validtarget_ceilT - 1:
101
+ h = np.random.randint(0, self._h_validmax - self._patch_size)
102
+ w = np.random.randint(0, self._w_validmax - self._patch_size)
103
+ else:
104
+ h = np.random.randint(0, self._data_shape[1] - self._patch_size)
105
+ w = np.random.randint(0, self._data_shape[2] - self._patch_size)
106
+ return h, w
107
+
108
+ def get_invalid_target_hw(self, t):
109
+ """
110
+ This is the opposite of get_valid_target_hw. It returns a h,w which is not valid for target.
111
+ This is only valid for single frame setup.
112
+ """
113
+ if t == self._validtarget_ceilT - 1:
114
+ h = np.random.randint(
115
+ self._h_validmax, self._data_shape[1] - self._patch_size
116
+ )
117
+ w = np.random.randint(
118
+ self._w_validmax, self._data_shape[2] - self._patch_size
119
+ )
120
+ else:
121
+ h = np.random.randint(0, self._data_shape[1] - self._patch_size)
122
+ w = np.random.randint(0, self._data_shape[2] - self._patch_size)
123
+ return h, w
124
+
125
+ def _get_tidx(self, index):
126
+ if isinstance(index, int) or isinstance(index, np.int64):
127
+ idx = index
128
+ else:
129
+ idx = index[0]
130
+ return self.idx_manager.get_t(idx)
131
+
132
+ def index_should_have_target(self, index):
133
+ tidx = self._get_tidx(index)
134
+ if tidx < self._validtarget_ceilT - 1:
135
+ return True
136
+ elif tidx > self._validtarget_ceilT - 1:
137
+ return False
138
+ else:
139
+ h, w, _ = self.idx_manager.hwt_from_idx(index)
140
+ return (
141
+ h + self._patch_size < self._h_validmax
142
+ and w + self._patch_size < self._w_validmax
143
+ )
144
+
145
+ @staticmethod
146
+ def get_reduced_frame_size(data_shape_nhw, fraction):
147
+ n, h, w = data_shape_nhw
148
+
149
+ framepixelcount = h * w
150
+ targetpixelcount = int(n * framepixelcount * fraction)
151
+
152
+ # We are currently supporting this only when there is just one frame.
153
+ # if np.ceil(pixelcount / framepixelcount) > 1:
154
+ # return None, None
155
+
156
+ lastframepixelcount = targetpixelcount % framepixelcount
157
+ assert data_shape_nhw[1] == data_shape_nhw[2]
158
+ if lastframepixelcount > 0:
159
+ new_size = int(np.sqrt(lastframepixelcount))
160
+ return new_size, new_size
161
+ else:
162
+ assert (
163
+ targetpixelcount / framepixelcount >= 1
164
+ ), "This is not possible in euclidean space :D (so this is a bug)"
165
+ return h, w
@@ -14,13 +14,19 @@ import matplotlib
14
14
  import matplotlib.pyplot as plt
15
15
  import numpy as np
16
16
  import torch
17
+ from torch import nn
18
+ from torch.utils.data import Dataset
17
19
  from matplotlib.gridspec import GridSpec
18
20
  from torch.utils.data import DataLoader
19
21
  from tqdm import tqdm
20
22
 
23
+ from careamics.lightning import VAEModule
24
+ from careamics.losses.lvae.losses import (
25
+ get_reconstruction_loss,
26
+ reconstruction_loss_musplit_denoisplit,
27
+ )
21
28
  from careamics.models.lvae.utils import ModelType
22
-
23
- from .metrics import RangeInvariantPsnr, RunningPSNR
29
+ from careamics.utils.metrics import scale_invariant_psnr, RunningPSNR
24
30
 
25
31
 
26
32
  # ------------------------------------------------------------------------------------------------
@@ -40,28 +46,26 @@ def clean_ax(ax):
40
46
  ax.tick_params(left=False, right=False, top=False, bottom=False)
41
47
 
42
48
 
43
- def get_plots_output_dir(
49
+ def get_eval_output_dir(
44
50
  saveplotsdir: str, patch_size: int, mmse_count: int = 50
45
51
  ) -> str:
46
52
  """
47
53
  Given the path to a root directory to save plots, patch size, and mmse count,
48
54
  it returns the specific directory to save the plots.
49
55
  """
50
- plotsrootdir = os.path.join(
51
- saveplotsdir, f"plots/patch_{patch_size}_mmse_{mmse_count}"
56
+ eval_out_dir = os.path.join(
57
+ saveplotsdir, f"eval_outputs/patch_{patch_size}_mmse_{mmse_count}"
52
58
  )
53
- os.makedirs(plotsrootdir, exist_ok=True)
54
- print(plotsrootdir)
55
- return plotsrootdir
59
+ os.makedirs(eval_out_dir, exist_ok=True)
60
+ print(eval_out_dir)
61
+ return eval_out_dir
56
62
 
57
63
 
58
64
  def get_psnr_str(tar_hsnr, pred, col_idx):
59
65
  """
60
66
  Compute PSNR between the ground truth (`tar_hsnr`) and the predicted image (`pred`).
61
67
  """
62
- return (
63
- f"{RangeInvariantPsnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item():.1f}"
64
- )
68
+ return f"{scale_invariant_psnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item():.1f}"
65
69
 
66
70
 
67
71
  def add_psnr_str(ax_, psnr):
@@ -499,20 +503,40 @@ def get_predictions(idx, val_dset, model, mmse_count=50, patch_size=256):
499
503
 
500
504
 
501
505
  def get_dset_predictions(
502
- model,
503
- dset,
506
+ model: VAEModule,
507
+ dset: Dataset,
504
508
  batch_size: int,
505
- model_type: ModelType = None,
509
+ loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"],
506
510
  mmse_count: int = 1,
507
511
  num_workers: int = 4,
508
- ):
509
- """
510
- Get predictions from a model for the entire dataset.
512
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[float]]:
513
+ """Get patch-wise predictions from a model for the entire dataset.
511
514
 
512
515
  Parameters
513
516
  ----------
514
- mmse_count : int
515
- Number of samples to generate for each input and then to average over for MMSE estimation.
517
+ model : VAEModule
518
+ Lightning model used for prediction.
519
+ dset : Dataset
520
+ Dataset to predict on.
521
+ batch_size : int
522
+ Batch size to use for prediction.
523
+ loss_type :
524
+ Type of reconstruction loss used by the model, by default `None`.
525
+ mmse_count : int, optional
526
+ Number of samples to generate for each input and then to average over for
527
+ MMSE estimation, by default 1.
528
+ num_workers : int, optional
529
+ Number of workers to use for DataLoader, by default 4.
530
+
531
+ Returns
532
+ -------
533
+ tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[float]]
534
+ Tuple containing:
535
+ - predictions: Predicted images for the dataset.
536
+ - predictions_std: Standard deviation of the predicted images.
537
+ - logvar_arr: Log variance of the predicted images.
538
+ - losses: Reconstruction losses for the predictions.
539
+ - psnr: PSNR values for the predictions.
516
540
  """
517
541
  dloader = DataLoader(
518
542
  dset,
@@ -521,69 +545,90 @@ def get_dset_predictions(
521
545
  shuffle=False,
522
546
  batch_size=batch_size,
523
547
  )
524
- likelihood = model.model.likelihood
548
+
549
+ gauss_likelihood = model.gaussian_likelihood
550
+ nm_likelihood = model.noise_model_likelihood
551
+
525
552
  predictions = []
526
553
  predictions_std = []
527
554
  losses = []
528
555
  logvar_arr = []
529
- patch_psnr_channels = [RunningPSNR() for _ in range(dset[0][1].shape[0])]
556
+ num_channels = dset[0][1].shape[0]
557
+ patch_psnr_channels = [RunningPSNR() for _ in range(num_channels)]
530
558
  with torch.no_grad():
531
- for batch in tqdm(dloader):
532
- inp, tar = batch[:2]
559
+ for batch in tqdm(dloader, desc="Predicting patches"):
560
+ inp, tar = batch
533
561
  inp = inp.cuda()
534
562
  tar = tar.cuda()
535
563
 
536
- recon_img_list = []
564
+ rec_img_list = []
537
565
  for mmse_idx in range(mmse_count):
538
- if model_type == ModelType.Denoiser:
539
- assert model.denoise_channel in [
540
- "Ch1",
541
- "Ch2",
542
- "input",
543
- ], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"'
544
-
545
- x_normalized_new, tar_new = model.get_new_input_target(
546
- (inp, tar, *batch[2:])
566
+
567
+ # TODO: case of HDN left for future refactoring
568
+ # if model_type == ModelType.Denoiser:
569
+ # assert model.denoise_channel in [
570
+ # "Ch1",
571
+ # "Ch2",
572
+ # "input",
573
+ # ], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"'
574
+
575
+ # x_normalized_new, tar_new = model.get_new_input_target(
576
+ # (inp, tar, *batch[2:])
577
+ # )
578
+ # rec, _ = model(x_normalized_new)
579
+ # rec_loss, imgs = model.get_reconstruction_loss(
580
+ # rec,
581
+ # tar,
582
+ # x_normalized_new,
583
+ # return_predicted_img=True,
584
+ # )
585
+
586
+ # get model output
587
+ rec, _ = model(inp)
588
+
589
+ # get reconstructed img
590
+ if model.model.predict_logvar is None:
591
+ rec_img = rec
592
+ logvar = torch.tensor([-1])
593
+ else:
594
+ rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
595
+ rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
596
+ logvar_arr.append(logvar.cpu().numpy())
597
+
598
+ # compute reconstruction loss
599
+ if loss_type == "musplit":
600
+ rec_loss = get_reconstruction_loss(
601
+ reconstruction=rec, target=tar, likelihood_obj=gauss_likelihood
547
602
  )
548
- tar_normalized = model.normalize_target(tar_new)
549
- recon_normalized, _ = model(x_normalized_new)
550
- rec_loss, imgs = model.get_reconstruction_loss(
551
- recon_normalized,
552
- tar_normalized,
553
- x_normalized_new,
554
- return_predicted_img=True,
603
+ elif loss_type == "denoisplit":
604
+ rec_loss = get_reconstruction_loss(
605
+ reconstruction=rec, target=tar, likelihood_obj=nm_likelihood
555
606
  )
556
- else:
557
- x_normalized = model.normalize_input(inp)
558
- tar_normalized = model.normalize_target(tar)
559
- recon_normalized, _ = model(x_normalized)
560
- rec_loss, imgs = model.get_reconstruction_loss(
561
- recon_normalized, tar_normalized, inp, return_predicted_img=True
607
+ elif loss_type == "denoisplit_musplit":
608
+ rec_loss = reconstruction_loss_musplit_denoisplit(
609
+ predictions=rec,
610
+ targets=tar,
611
+ gaussian_likelihood=gauss_likelihood,
612
+ nm_likelihood=nm_likelihood,
613
+ nm_weight=model.loss_parameters.denoisplit_weight,
614
+ gaussian_weight=model.loss_parameters.musplit_weight,
562
615
  )
616
+ rec_loss = {"loss": rec_loss} # hacky, but ok for now
563
617
 
618
+ # store rec loss values for first pred
564
619
  if mmse_idx == 0:
565
- q_dic = (
566
- likelihood.distr_params(recon_normalized)
567
- if likelihood is not None
568
- else {"logvar": None}
569
- )
570
- if q_dic["logvar"] is not None:
571
- logvar_arr.append(q_dic["logvar"].cpu().numpy())
572
- else:
573
- logvar_arr.append(np.array([-1]))
574
-
575
620
  try:
576
621
  losses.append(rec_loss["loss"].cpu().numpy())
577
622
  except:
578
623
  losses.append(rec_loss["loss"])
579
624
 
580
- for i in range(imgs.shape[1]):
581
- patch_psnr_channels[i].update(imgs[:, i], tar_normalized[:, i])
625
+ # update running PSNR
626
+ for i in range(num_channels):
627
+ patch_psnr_channels[i].update(rec_img[:, i], tar[:, i])
582
628
 
583
- recon_img_list.append(imgs.cpu()[None])
584
-
585
- samples = torch.cat(recon_img_list, dim=0)
586
- mmse_imgs = torch.mean(samples, dim=0)
629
+ # aggregate results
630
+ samples = torch.cat(rec_img_list, dim=0)
631
+ mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
587
632
  mmse_std = torch.std(samples, dim=0)
588
633
  predictions.append(mmse_imgs.cpu().numpy())
589
634
  predictions_std.append(mmse_std.cpu().numpy())
@@ -591,10 +636,10 @@ def get_dset_predictions(
591
636
  psnr = [x.get() for x in patch_psnr_channels]
592
637
  return (
593
638
  np.concatenate(predictions, axis=0),
594
- np.array(losses),
639
+ np.concatenate(predictions_std, axis=0),
595
640
  np.concatenate(logvar_arr),
641
+ np.array(losses),
596
642
  psnr,
597
- np.concatenate(predictions_std, axis=0),
598
643
  )
599
644
 
600
645
 
@@ -6,7 +6,7 @@ import os
6
6
 
7
7
  import ml_collections
8
8
 
9
- from careamics.lvae_training.dataset.data_utils import DataType
9
+ from careamics.lvae_training.dataset.utils.data_utils import DataType
10
10
  from careamics.models.lvae.utils import LossType
11
11
 
12
12
 
@@ -24,7 +24,7 @@ from careamics.lvae_training.dataset.data_modules import (
24
24
  LCMultiChDloader,
25
25
  MultiChDloader,
26
26
  )
27
- from careamics.lvae_training.dataset.data_utils import DataSplitType
27
+ from careamics.lvae_training.dataset.utils.data_utils import DataSplitType
28
28
  from careamics.lvae_training.lightning_module import LadderVAELight
29
29
  from careamics.lvae_training.train_utils import *
30
30