careamics 0.0.3__py3-none-any.whl → 0.0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (56) hide show
  1. careamics/careamist.py +25 -17
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/architectures/lvae_model.py +0 -4
  6. careamics/config/configuration_factory.py +480 -177
  7. careamics/config/configuration_model.py +1 -2
  8. careamics/config/data_model.py +1 -15
  9. careamics/config/fcn_algorithm_model.py +14 -9
  10. careamics/config/likelihood_model.py +21 -4
  11. careamics/config/nm_model.py +31 -5
  12. careamics/config/optimizer_models.py +3 -1
  13. careamics/config/support/supported_optimizers.py +1 -1
  14. careamics/config/support/supported_transforms.py +1 -0
  15. careamics/config/training_model.py +35 -6
  16. careamics/config/transformations/__init__.py +4 -1
  17. careamics/config/transformations/transform_union.py +20 -0
  18. careamics/config/vae_algorithm_model.py +2 -36
  19. careamics/dataset/tiling/lvae_tiled_patching.py +90 -8
  20. careamics/lightning/lightning_module.py +10 -8
  21. careamics/lightning/train_data_module.py +2 -2
  22. careamics/losses/loss_factory.py +3 -3
  23. careamics/losses/lvae/losses.py +2 -2
  24. careamics/lvae_training/dataset/__init__.py +15 -0
  25. careamics/lvae_training/dataset/{vae_data_config.py → config.py} +25 -81
  26. careamics/lvae_training/dataset/lc_dataset.py +28 -20
  27. careamics/lvae_training/dataset/{vae_dataset.py → multich_dataset.py} +91 -51
  28. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  29. careamics/lvae_training/dataset/types.py +43 -0
  30. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  31. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  32. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  33. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  34. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  35. careamics/lvae_training/eval_utils.py +109 -64
  36. careamics/lvae_training/get_config.py +1 -1
  37. careamics/lvae_training/train_lvae.py +1 -1
  38. careamics/model_io/bioimage/bioimage_utils.py +4 -2
  39. careamics/model_io/bmz_io.py +6 -5
  40. careamics/models/lvae/likelihoods.py +18 -9
  41. careamics/models/lvae/lvae.py +12 -16
  42. careamics/models/lvae/noise_models.py +1 -1
  43. careamics/transforms/compose.py +90 -15
  44. careamics/transforms/n2v_manipulate.py +6 -2
  45. careamics/transforms/normalize.py +14 -3
  46. careamics/transforms/xy_flip.py +16 -6
  47. careamics/transforms/xy_random_rotate90.py +16 -7
  48. careamics/utils/metrics.py +204 -24
  49. careamics/utils/serializers.py +60 -0
  50. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/METADATA +4 -3
  51. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/RECORD +54 -43
  52. careamics-0.0.4.1.dist-info/entry_points.txt +2 -0
  53. careamics/lvae_training/dataset/data_utils.py +0 -701
  54. careamics/lvae_training/dataset/lc_dataset_config.py +0 -13
  55. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/WHEEL +0 -0
  56. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,334 @@
1
+ from typing import Union, Callable, Sequence
2
+
3
+ import numpy as np
4
+ from numpy.typing import NDArray
5
+
6
+ from .config import DatasetConfig
7
+ from .multich_dataset import MultiChDloader
8
+ from .types import DataSplitType
9
+ from .lc_dataset import LCMultiChDloader
10
+
11
+
12
+ class TwoChannelData(Sequence):
13
+ """
14
+ each element in data_arr should be a N*H*W array
15
+ """
16
+
17
+ def __init__(self, data_arr1, data_arr2, paths_data1=None, paths_data2=None):
18
+ assert len(data_arr1) == len(data_arr2)
19
+ self.paths1 = paths_data1
20
+ self.paths2 = paths_data2
21
+
22
+ self._data = []
23
+ for i in range(len(data_arr1)):
24
+ assert data_arr1[i].shape == data_arr2[i].shape
25
+ assert (
26
+ len(data_arr1[i].shape) == 3
27
+ ), f"Each element in data arrays should be a N*H*W, but {data_arr1[i].shape}"
28
+ self._data.append(
29
+ np.concatenate(
30
+ [data_arr1[i][..., None], data_arr2[i][..., None]], axis=-1
31
+ )
32
+ )
33
+
34
+ def __len__(self):
35
+ n = 0
36
+ for x in self._data:
37
+ n += x.shape[0]
38
+ return n
39
+
40
+ def __getitem__(self, idx):
41
+ n = 0
42
+ for dataidx, x in enumerate(self._data):
43
+ if idx < n + x.shape[0]:
44
+ if self.paths1 is None:
45
+ return x[idx - n], None
46
+ else:
47
+ return x[idx - n], (self.paths1[dataidx], self.paths2[dataidx])
48
+ n += x.shape[0]
49
+ raise IndexError("Index out of range")
50
+
51
+
52
+ class MultiChannelData(Sequence):
53
+ """
54
+ each element in data_arr should be a N*H*W array
55
+ """
56
+
57
+ def __init__(self, data_arr, paths=None):
58
+ self.paths = paths
59
+
60
+ self._data = data_arr
61
+
62
+ def __len__(self):
63
+ n = 0
64
+ for x in self._data:
65
+ n += x.shape[0]
66
+ return n
67
+
68
+ def __getitem__(self, idx):
69
+ n = 0
70
+ for dataidx, x in enumerate(self._data):
71
+ if idx < n + x.shape[0]:
72
+ if self.paths is None:
73
+ return x[idx - n], None
74
+ else:
75
+ return x[idx - n], (self.paths[dataidx])
76
+ n += x.shape[0]
77
+ raise IndexError("Index out of range")
78
+
79
+
80
+ class SingleFileLCDset(LCMultiChDloader):
81
+ def __init__(
82
+ self,
83
+ preloaded_data: NDArray,
84
+ data_config: DatasetConfig,
85
+ fpath: str,
86
+ load_data_fn: Callable,
87
+ val_fraction=None,
88
+ test_fraction=None,
89
+ ):
90
+ self._preloaded_data = preloaded_data
91
+ super().__init__(
92
+ data_config,
93
+ fpath,
94
+ load_data_fn=load_data_fn,
95
+ val_fraction=val_fraction,
96
+ test_fraction=test_fraction,
97
+ )
98
+
99
+ @property
100
+ def data_path(self):
101
+ return self._fpath
102
+
103
+ def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
104
+ pass
105
+
106
+ def load_data(
107
+ self,
108
+ data_config: DatasetConfig,
109
+ datasplit_type: DataSplitType,
110
+ load_data_fn: Callable,
111
+ val_fraction=None,
112
+ test_fraction=None,
113
+ allow_generation=None,
114
+ ):
115
+ self._data = self._preloaded_data
116
+ assert "channel_1" not in data_config or isinstance(data_config.channel_1, str)
117
+ assert "channel_2" not in data_config or isinstance(data_config.channel_2, str)
118
+ assert "channel_3" not in data_config or isinstance(data_config.channel_3, str)
119
+ self._loaded_data_preprocessing(data_config)
120
+
121
+
122
+ class SingleFileDset(MultiChDloader):
123
+ def __init__(
124
+ self,
125
+ preloaded_data: NDArray,
126
+ data_config: DatasetConfig,
127
+ fpath: str,
128
+ load_data_fn: Callable,
129
+ val_fraction=None,
130
+ test_fraction=None,
131
+ ):
132
+ self._preloaded_data = preloaded_data
133
+ super().__init__(
134
+ data_config,
135
+ fpath,
136
+ load_data_fn=load_data_fn,
137
+ val_fraction=val_fraction,
138
+ test_fraction=test_fraction,
139
+ )
140
+
141
+ def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
142
+ pass
143
+
144
+ @property
145
+ def data_path(self):
146
+ return self._fpath
147
+
148
+ def load_data(
149
+ self,
150
+ data_config: DatasetConfig,
151
+ datasplit_type: DataSplitType,
152
+ load_data_fn: Callable[..., NDArray],
153
+ val_fraction=None,
154
+ test_fraction=None,
155
+ allow_generation=None,
156
+ ):
157
+ self._data = self._preloaded_data
158
+ assert (
159
+ "channel_1" not in data_config
160
+ ), "Outdated config file. Please remove channel_1, channel_2, channel_3 from the config file."
161
+ assert (
162
+ "channel_2" not in data_config
163
+ ), "Outdated config file. Please remove channel_1, channel_2, channel_3 from the config file."
164
+ assert (
165
+ "channel_3" not in data_config
166
+ ), "Outdated config file. Please remove channel_1, channel_2, channel_3 from the config file."
167
+ self._loaded_data_preprocessing(data_config)
168
+
169
+
170
+ class MultiFileDset:
171
+ """
172
+ Here, we handle dataset having multiple files. Each file can have a different spatial dimension and number of frames (Z stack).
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ data_config: DatasetConfig,
178
+ fpath: str,
179
+ load_data_fn: Callable[..., Union[TwoChannelData, MultiChannelData]],
180
+ val_fraction=None,
181
+ test_fraction=None,
182
+ ):
183
+ self._fpath = fpath
184
+ data: Union[TwoChannelData, MultiChannelData] = load_data_fn(
185
+ data_config,
186
+ self._fpath,
187
+ data_config.datasplit_type,
188
+ val_fraction=val_fraction,
189
+ test_fraction=test_fraction,
190
+ )
191
+ self.dsets = []
192
+
193
+ for i in range(len(data)):
194
+ prefetched_data, fpath_tuple = data[i]
195
+ if (
196
+ data_config.multiscale_lowres_count is not None
197
+ and data_config.multiscale_lowres_count > 1
198
+ ):
199
+
200
+ self.dsets.append(
201
+ SingleFileLCDset(
202
+ prefetched_data[None],
203
+ data_config,
204
+ fpath_tuple,
205
+ load_data_fn,
206
+ val_fraction=val_fraction,
207
+ test_fraction=test_fraction,
208
+ )
209
+ )
210
+
211
+ else:
212
+ self.dsets.append(
213
+ SingleFileDset(
214
+ prefetched_data[None],
215
+ data_config,
216
+ fpath_tuple,
217
+ load_data_fn,
218
+ val_fraction=val_fraction,
219
+ test_fraction=test_fraction,
220
+ )
221
+ )
222
+
223
+ self.rm_bkground_set_max_val_and_upperclip_data(
224
+ data_config.max_val, data_config.datasplit_type
225
+ )
226
+ count = 0
227
+ avg_height = 0
228
+ avg_width = 0
229
+ for dset in self.dsets:
230
+ shape = dset.get_data_shape()
231
+ avg_height += shape[1]
232
+ avg_width += shape[2]
233
+ count += shape[0]
234
+
235
+ avg_height = int(avg_height / len(self.dsets))
236
+ avg_width = int(avg_width / len(self.dsets))
237
+ print(
238
+ f"{self.__class__.__name__} avg height: {avg_height}, avg width: {avg_width}, count: {count}"
239
+ )
240
+
241
+ def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
242
+ self.set_max_val(max_val, datasplit_type)
243
+ self.upperclip_data()
244
+
245
+ def set_mean_std(self, mean_val, std_val):
246
+ for dset in self.dsets:
247
+ dset.set_mean_std(mean_val, std_val)
248
+
249
+ def get_mean_std(self):
250
+ return self.dsets[0].get_mean_std()
251
+
252
+ def compute_max_val(self):
253
+ max_val_arr = []
254
+ for dset in self.dsets:
255
+ max_val_arr.append(dset.compute_max_val())
256
+ return np.max(max_val_arr)
257
+
258
+ def set_max_val(self, max_val, datasplit_type):
259
+ if datasplit_type == DataSplitType.Train:
260
+ assert max_val is None
261
+ max_val = self.compute_max_val()
262
+ for dset in self.dsets:
263
+ dset.set_max_val(max_val, datasplit_type)
264
+
265
+ def upperclip_data(self):
266
+ for dset in self.dsets:
267
+ dset.upperclip_data()
268
+
269
+ def get_max_val(self):
270
+ return self.dsets[0].get_max_val()
271
+
272
+ def get_img_sz(self):
273
+ return self.dsets[0].get_img_sz()
274
+
275
+ def set_img_sz(self, image_size, grid_size):
276
+ for dset in self.dsets:
277
+ dset.set_img_sz(image_size, grid_size)
278
+
279
+ def compute_mean_std(self):
280
+ cur_mean = {"target": 0, "input": 0}
281
+ cur_std = {"target": 0, "input": 0}
282
+ for dset in self.dsets:
283
+ mean, std = dset.compute_mean_std()
284
+ cur_mean["target"] += mean["target"]
285
+ cur_mean["input"] += mean["input"]
286
+
287
+ cur_std["target"] += std["target"]
288
+ cur_std["input"] += std["input"]
289
+
290
+ cur_mean["target"] /= len(self.dsets)
291
+ cur_mean["input"] /= len(self.dsets)
292
+ cur_std["target"] /= len(self.dsets)
293
+ cur_std["input"] /= len(self.dsets)
294
+ return cur_mean, cur_std
295
+
296
+ def compute_individual_mean_std(self):
297
+ cum_mean = 0
298
+ cum_std = 0
299
+ for dset in self.dsets:
300
+ mean, std = dset.compute_individual_mean_std()
301
+ cum_mean += mean
302
+ cum_std += std
303
+ return cum_mean / len(self.dsets), cum_std / len(self.dsets)
304
+
305
+ def get_num_frames(self):
306
+ return len(self.dsets)
307
+
308
+ def reduce_data(
309
+ self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
310
+ ):
311
+ assert h_start is None
312
+ assert h_end is None
313
+ assert w_start is None
314
+ assert w_end is None
315
+ self.dsets = [self.dsets[t] for t in t_list]
316
+ print(
317
+ f"[{self.__class__.__name__}] Data reduced. New data count: {len(self.dsets)}"
318
+ )
319
+
320
+ def __len__(self):
321
+ out = 0
322
+ for dset in self.dsets:
323
+ out += len(dset)
324
+ return out
325
+
326
+ def __getitem__(self, idx):
327
+ cum_len = 0
328
+ for dset in self.dsets:
329
+ cum_len += len(dset)
330
+ if idx < cum_len:
331
+ rel_idx = idx - (cum_len - len(dset))
332
+ return dset[rel_idx]
333
+
334
+ raise IndexError("Index out of range")
@@ -0,0 +1,43 @@
1
+ from enum import Enum
2
+
3
+
4
+ class DataType(Enum):
5
+ MNIST = 0
6
+ Places365 = 1
7
+ NotMNIST = 2
8
+ OptiMEM100_014 = 3
9
+ CustomSinosoid = 4
10
+ Prevedel_EMBL = 5
11
+ AllenCellMito = 6
12
+ SeparateTiffData = 7
13
+ CustomSinosoidThreeCurve = 8
14
+ SemiSupBloodVesselsEMBL = 9
15
+ Pavia2 = 10
16
+ Pavia2VanillaSplitting = 11
17
+ ExpansionMicroscopyMitoTub = 12
18
+ ShroffMitoEr = 13
19
+ HTIba1Ki67 = 14
20
+ BSD68 = 15
21
+ BioSR_MRC = 16
22
+ TavernaSox2Golgi = 17
23
+ Dao3Channel = 18
24
+ ExpMicroscopyV2 = 19
25
+ Dao3ChannelWithInput = 20
26
+ TavernaSox2GolgiV2 = 21
27
+ TwoDset = 22
28
+ PredictedTiffData = 23
29
+ Pavia3SeqData = 24
30
+ NicolaData = 25
31
+
32
+
33
+ class DataSplitType(Enum):
34
+ All = 0
35
+ Train = 1
36
+ Val = 2
37
+ Test = 3
38
+
39
+
40
+ class TilingMode(Enum):
41
+ TrimBoundary = 0
42
+ PadBoundary = 1
43
+ ShiftBoundary = 2
File without changes
@@ -0,0 +1,114 @@
1
+ """
2
+ Utility functions needed by dataloader & co.
3
+ """
4
+
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ from skimage.io import imread, imsave
9
+
10
+
11
+ def load_tiff(path):
12
+ """
13
+ Returns a 4d numpy array: num_imgs*h*w*num_channels
14
+ """
15
+ data = imread(path, plugin="tifffile")
16
+ return data
17
+
18
+
19
+ def save_tiff(path, data):
20
+ imsave(path, data, plugin="tifffile")
21
+
22
+
23
+ def load_tiffs(paths):
24
+ data = [load_tiff(path) for path in paths]
25
+ return np.concatenate(data, axis=0)
26
+
27
+
28
+ def split_in_half(s, e):
29
+ n = e - s
30
+ s1 = list(np.arange(n // 2))
31
+ s2 = list(np.arange(n // 2, n))
32
+ return [x + s for x in s1], [x + s for x in s2]
33
+
34
+
35
+ def adjust_for_imbalance_in_fraction_value(
36
+ val: List[int],
37
+ test: List[int],
38
+ val_fraction: float,
39
+ test_fraction: float,
40
+ total_size: int,
41
+ ):
42
+ """
43
+ here, val and test are divided almost equally. Here, we need to take into account their respective fractions
44
+ and pick elements rendomly from one array and put in the other array.
45
+ """
46
+ if val_fraction == 0:
47
+ test += val
48
+ val = []
49
+ elif test_fraction == 0:
50
+ val += test
51
+ test = []
52
+ else:
53
+ diff_fraction = test_fraction - val_fraction
54
+ if diff_fraction > 0:
55
+ imb_count = int(diff_fraction * total_size / 2)
56
+ val = list(np.random.RandomState(seed=955).permutation(val))
57
+ test += val[:imb_count]
58
+ val = val[imb_count:]
59
+ elif diff_fraction < 0:
60
+ imb_count = int(-1 * diff_fraction * total_size / 2)
61
+ test = list(np.random.RandomState(seed=955).permutation(test))
62
+ val += test[:imb_count]
63
+ test = test[imb_count:]
64
+ return val, test
65
+
66
+
67
+ def get_datasplit_tuples(
68
+ val_fraction: float,
69
+ test_fraction: float,
70
+ total_size: int,
71
+ starting_test: bool = False,
72
+ ):
73
+ if starting_test:
74
+ # test => val => train
75
+ test = list(range(0, int(total_size * test_fraction)))
76
+ val = list(range(test[-1] + 1, test[-1] + 1 + int(total_size * val_fraction)))
77
+ train = list(range(val[-1] + 1, total_size))
78
+ else:
79
+ # {test,val}=> train
80
+ test_val_size = int((val_fraction + test_fraction) * total_size)
81
+ train = list(range(test_val_size, total_size))
82
+
83
+ if test_val_size == 0:
84
+ test = []
85
+ val = []
86
+ return train, val, test
87
+
88
+ # Split the test and validation in chunks.
89
+ chunksize = max(1, min(3, test_val_size // 2))
90
+
91
+ nchunks = test_val_size // chunksize
92
+
93
+ test = []
94
+ val = []
95
+ s = 0
96
+ for i in range(nchunks):
97
+ if i % 2 == 0:
98
+ val += list(np.arange(s, s + chunksize))
99
+ else:
100
+ test += list(np.arange(s, s + chunksize))
101
+ s += chunksize
102
+
103
+ if i % 2 == 0:
104
+ test += list(np.arange(s, test_val_size))
105
+ else:
106
+ p1, p2 = split_in_half(s, test_val_size)
107
+ test += p1
108
+ val += p2
109
+
110
+ val, test = adjust_for_imbalance_in_fraction_value(
111
+ val, test, val_fraction, test_fraction, total_size
112
+ )
113
+
114
+ return train, val, test
@@ -0,0 +1,65 @@
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+
4
+
5
+ class EmptyPatchFetcher:
6
+ """
7
+ The idea is to fetch empty patches so that real content can be replaced with this.
8
+ """
9
+
10
+ def __init__(self, idx_manager, patch_size, data_frames, max_val_threshold=None):
11
+ self._frames = data_frames
12
+ self._idx_manager = idx_manager
13
+ self._max_val_threshold = max_val_threshold
14
+ self._idx_list = []
15
+ self._patch_size = patch_size
16
+ self._grid_size = 1
17
+ self.set_empty_idx()
18
+
19
+ print(f"[{self.__class__.__name__}] MaxVal:{self._max_val_threshold}")
20
+
21
+ def compute_max(self, window):
22
+ """
23
+ Rolling compute.
24
+ """
25
+ N, H, W = self._frames.shape
26
+ randnum = -954321
27
+ assert self._grid_size == 1
28
+ max_data = np.zeros((N, H - window, W - window)) * randnum
29
+
30
+ for h in tqdm(range(H - window)):
31
+ for w in range(W - window):
32
+ max_data[:, h, w] = self._frames[:, h : h + window, w : w + window].max(
33
+ axis=(1, 2)
34
+ )
35
+
36
+ assert (max_data != 954321).any()
37
+ return max_data
38
+
39
+ def set_empty_idx(self):
40
+ max_data = self.compute_max(self._patch_size)
41
+ empty_loc = np.where(
42
+ np.logical_and(max_data >= 0, max_data < self._max_val_threshold)
43
+ )
44
+ # print(max_data.shape, len(empty_loc))
45
+ self._idx_list = []
46
+ for idx in range(len(empty_loc[0])):
47
+ n_idx = empty_loc[0][idx]
48
+ h_start = empty_loc[1][idx]
49
+ w_start = empty_loc[2][idx]
50
+ # print(n_idx,h_start,w_start)
51
+ # channel_idx = self._idx_manager.get_location_from_dataset_idx(0)[-1]
52
+ loc = (n_idx, h_start, w_start, 0)
53
+ idx = self._idx_manager.get_dataset_idx_from_location(loc)
54
+ t, h, w, _ = self._idx_manager.get_location_from_dataset_idx(idx)
55
+ assert h == h_start, f"{h} != {h_start}"
56
+ assert w == w_start, f"{w} != {w_start}"
57
+ assert t == n_idx, f"{t} != {n_idx}"
58
+ self._idx_list.append(idx)
59
+
60
+ self._idx_list = np.array(self._idx_list)
61
+
62
+ assert len(self._idx_list) > 0
63
+
64
+ def sample(self):
65
+ return (np.random.choice(self._idx_list), self._grid_size)