careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (87) hide show
  1. careamics/careamist.py +39 -28
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/__init__.py +7 -3
  6. careamics/config/architectures/__init__.py +2 -2
  7. careamics/config/architectures/architecture_model.py +1 -1
  8. careamics/config/architectures/custom_model.py +11 -8
  9. careamics/config/architectures/lvae_model.py +170 -0
  10. careamics/config/configuration_factory.py +481 -170
  11. careamics/config/configuration_model.py +6 -3
  12. careamics/config/data_model.py +31 -20
  13. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
  14. careamics/config/likelihood_model.py +60 -0
  15. careamics/config/nm_model.py +127 -0
  16. careamics/config/optimizer_models.py +3 -1
  17. careamics/config/support/supported_activations.py +1 -0
  18. careamics/config/support/supported_algorithms.py +17 -4
  19. careamics/config/support/supported_architectures.py +8 -11
  20. careamics/config/support/supported_losses.py +3 -1
  21. careamics/config/support/supported_optimizers.py +1 -1
  22. careamics/config/support/supported_transforms.py +1 -0
  23. careamics/config/training_model.py +35 -6
  24. careamics/config/transformations/__init__.py +4 -1
  25. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  26. careamics/config/transformations/transform_union.py +20 -0
  27. careamics/config/vae_algorithm_model.py +137 -0
  28. careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
  29. careamics/file_io/read/tiff.py +1 -1
  30. careamics/lightning/__init__.py +3 -2
  31. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  32. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  33. careamics/lightning/lightning_module.py +367 -9
  34. careamics/lightning/predict_data_module.py +2 -2
  35. careamics/lightning/train_data_module.py +4 -4
  36. careamics/losses/__init__.py +11 -1
  37. careamics/losses/fcn/__init__.py +1 -0
  38. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  39. careamics/losses/loss_factory.py +112 -6
  40. careamics/losses/lvae/__init__.py +1 -0
  41. careamics/losses/lvae/loss_utils.py +83 -0
  42. careamics/losses/lvae/losses.py +445 -0
  43. careamics/lvae_training/dataset/__init__.py +15 -0
  44. careamics/lvae_training/dataset/config.py +123 -0
  45. careamics/lvae_training/dataset/lc_dataset.py +267 -0
  46. careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
  47. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  48. careamics/lvae_training/dataset/types.py +43 -0
  49. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  50. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  51. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  52. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  53. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  54. careamics/lvae_training/eval_utils.py +109 -64
  55. careamics/lvae_training/get_config.py +1 -1
  56. careamics/lvae_training/train_lvae.py +6 -3
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +2 -2
  59. careamics/model_io/bmz_io.py +20 -7
  60. careamics/model_io/model_io_utils.py +16 -4
  61. careamics/models/__init__.py +1 -3
  62. careamics/models/activation.py +2 -0
  63. careamics/models/lvae/__init__.py +3 -0
  64. careamics/models/lvae/layers.py +21 -21
  65. careamics/models/lvae/likelihoods.py +190 -129
  66. careamics/models/lvae/lvae.py +60 -148
  67. careamics/models/lvae/noise_models.py +318 -186
  68. careamics/models/lvae/utils.py +2 -2
  69. careamics/models/model_factory.py +22 -7
  70. careamics/prediction_utils/lvae_prediction.py +158 -0
  71. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  72. careamics/prediction_utils/stitch_prediction.py +16 -2
  73. careamics/transforms/compose.py +90 -15
  74. careamics/transforms/n2v_manipulate.py +6 -2
  75. careamics/transforms/normalize.py +14 -3
  76. careamics/transforms/pixel_manipulation.py +1 -1
  77. careamics/transforms/xy_flip.py +16 -6
  78. careamics/transforms/xy_random_rotate90.py +16 -7
  79. careamics/utils/metrics.py +277 -24
  80. careamics/utils/serializers.py +60 -0
  81. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
  82. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
  83. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  84. careamics/config/architectures/vae_model.py +0 -42
  85. careamics/lvae_training/data_utils.py +0 -618
  86. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  87. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,267 @@
1
+ """
2
+ A place for Datasets and Dataloaders.
3
+ """
4
+
5
+ from typing import Tuple, Union, Callable
6
+
7
+ import numpy as np
8
+ from skimage.transform import resize
9
+
10
+ from .config import DatasetConfig
11
+ from .multich_dataset import MultiChDloader
12
+
13
+
14
+ class LCMultiChDloader(MultiChDloader):
15
+ def __init__(
16
+ self,
17
+ data_config: DatasetConfig,
18
+ fpath: str,
19
+ load_data_fn: Callable,
20
+ val_fraction=None,
21
+ test_fraction=None,
22
+ ):
23
+ self._padding_kwargs = (
24
+ data_config.padding_kwargs # mode=padding_mode, constant_values=constant_value
25
+ )
26
+ self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
27
+
28
+ super().__init__(
29
+ data_config,
30
+ fpath,
31
+ load_data_fn=load_data_fn,
32
+ val_fraction=val_fraction,
33
+ test_fraction=test_fraction,
34
+ )
35
+
36
+ if data_config.overlapping_padding_kwargs is not None:
37
+ assert (
38
+ self._padding_kwargs == data_config.overlapping_padding_kwargs
39
+ ), "During evaluation, overlapping_padding_kwargs should be same as padding_args. \
40
+ It should be so since we just use overlapping_padding_kwargs when it is not None"
41
+
42
+ else:
43
+ self._overlapping_padding_kwargs = data_config.padding_kwargs
44
+
45
+ self.multiscale_lowres_count = data_config.multiscale_lowres_count
46
+ assert self.multiscale_lowres_count is not None
47
+ self._scaled_data = [self._data]
48
+ self._scaled_noise_data = [self._noise_data]
49
+
50
+ assert (
51
+ isinstance(self.multiscale_lowres_count, int)
52
+ and self.multiscale_lowres_count >= 1
53
+ )
54
+ assert isinstance(self._padding_kwargs, dict)
55
+ assert "mode" in self._padding_kwargs
56
+
57
+ for _ in range(1, self.multiscale_lowres_count):
58
+ shape = self._scaled_data[-1].shape
59
+ assert len(shape) == 4
60
+ new_shape = (shape[0], shape[1] // 2, shape[2] // 2, shape[3])
61
+ ds_data = resize(
62
+ self._scaled_data[-1].astype(np.float32), new_shape
63
+ ).astype(self._scaled_data[-1].dtype)
64
+ # NOTE: These asserts are important. the resize method expects np.float32. otherwise, one gets weird results.
65
+ assert (
66
+ ds_data.max() / self._scaled_data[-1].max() < 5
67
+ ), "Downsampled image should not have very different values"
68
+ assert (
69
+ ds_data.max() / self._scaled_data[-1].max() > 0.2
70
+ ), "Downsampled image should not have very different values"
71
+
72
+ self._scaled_data.append(ds_data)
73
+ # do the same for noise
74
+ if self._noise_data is not None:
75
+ noise_data = resize(self._scaled_noise_data[-1], new_shape)
76
+ self._scaled_noise_data.append(noise_data)
77
+
78
+ def reduce_data(
79
+ self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
80
+ ):
81
+ assert t_list is not None
82
+ assert h_start is None
83
+ assert h_end is None
84
+ assert w_start is None
85
+ assert w_end is None
86
+
87
+ self._data = self._data[t_list].copy()
88
+ self._scaled_data = [
89
+ self._scaled_data[i][t_list].copy() for i in range(len(self._scaled_data))
90
+ ]
91
+
92
+ if self._noise_data is not None:
93
+ self._noise_data = self._noise_data[t_list].copy()
94
+ self._scaled_noise_data = [
95
+ self._scaled_noise_data[i][t_list].copy()
96
+ for i in range(len(self._scaled_noise_data))
97
+ ]
98
+
99
+ self.N = len(t_list)
100
+ self.set_img_sz(self._img_sz, self._grid_sz)
101
+ print(
102
+ f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
103
+ )
104
+
105
+ def _init_msg(self):
106
+ msg = super()._init_msg()
107
+ msg += f" Pad:{self._padding_kwargs}"
108
+ if self._uncorrelated_channels:
109
+ msg += f" UncorrChProbab:{self._uncorrelated_channel_probab}"
110
+ return msg
111
+
112
+ def _load_scaled_img(
113
+ self, scaled_index, index: Union[int, Tuple[int, int]]
114
+ ) -> Tuple[np.ndarray, np.ndarray]:
115
+ if isinstance(index, int):
116
+ idx = index
117
+ else:
118
+ idx, _ = index
119
+
120
+ # tidx = self.idx_manager.get_t(idx)
121
+ patch_loc_list = self.idx_manager.get_patch_location_from_dataset_idx(idx)
122
+ nidx = patch_loc_list[0]
123
+
124
+ imgs = self._scaled_data[scaled_index][nidx]
125
+ imgs = tuple([imgs[None, ..., i] for i in range(imgs.shape[-1])])
126
+ if self._noise_data is not None:
127
+ noisedata = self._scaled_noise_data[scaled_index][nidx]
128
+ noise = tuple([noisedata[None, ..., i] for i in range(noisedata.shape[-1])])
129
+ factor = np.sqrt(2) if self._input_is_sum else 1.0
130
+ imgs = tuple([img + noise[0] * factor for img in imgs])
131
+ return imgs
132
+
133
+ def _crop_img(self, img: np.ndarray, patch_start_loc: Tuple):
134
+ """
135
+ Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
136
+ the cropped image will be smaller than self._img_sz * self._img_sz
137
+ """
138
+ max_len_vals = list(self.idx_manager.data_shape[1:-1])
139
+ max_len_vals[-2:] = img.shape[-2:]
140
+ return self._crop_img_with_padding(
141
+ img, patch_start_loc, max_len_vals=max_len_vals
142
+ )
143
+
144
+ def _get_img(self, index: int):
145
+ """
146
+ Returns the primary patch along with low resolution patches centered on the primary patch.
147
+ """
148
+ # Noise_tuples is populated when there is synthetic noise in training
149
+ # Should have similar type of noise with the noise model
150
+ # Starting with microsplit, dump the noise, use it instead as an augmentation if nessesary
151
+ img_tuples, noise_tuples = self._load_img(index)
152
+ assert self._img_sz is not None
153
+ h, w = img_tuples[0].shape[-2:]
154
+ if self._enable_random_cropping:
155
+ patch_start_loc = self._get_random_hw(h, w)
156
+ if self._5Ddata:
157
+ patch_start_loc = (
158
+ np.random.choice(img_tuples[0].shape[-3] - self._depth3D),
159
+ ) + patch_start_loc
160
+ else:
161
+ patch_start_loc = self._get_deterministic_loc(index)
162
+
163
+ # LC logic is located here, the function crops the image of the highest resolution
164
+ cropped_img_tuples = [
165
+ self._crop_flip_img(img, patch_start_loc, False, False)
166
+ for img in img_tuples
167
+ ]
168
+ cropped_noise_tuples = [
169
+ self._crop_flip_img(noise, patch_start_loc, False, False)
170
+ for noise in noise_tuples
171
+ ]
172
+ patch_start_loc = list(patch_start_loc)
173
+ h_start, w_start = patch_start_loc[-2], patch_start_loc[-1]
174
+ h_center = h_start + self._img_sz // 2
175
+ w_center = w_start + self._img_sz // 2
176
+ allres_versions = {
177
+ i: [cropped_img_tuples[i]] for i in range(len(cropped_img_tuples))
178
+ }
179
+ for scale_idx in range(1, self.multiscale_lowres_count):
180
+ # Returning the image of the lower resolution
181
+ scaled_img_tuples = self._load_scaled_img(scale_idx, index)
182
+
183
+ h_center = h_center // 2
184
+ w_center = w_center // 2
185
+
186
+ h_start = h_center - self._img_sz // 2
187
+ w_start = w_center - self._img_sz // 2
188
+ patch_start_loc[-2:] = [h_start, w_start]
189
+ scaled_cropped_img_tuples = [
190
+ self._crop_flip_img(img, patch_start_loc, False, False)
191
+ for img in scaled_img_tuples
192
+ ]
193
+ for ch_idx in range(len(img_tuples)):
194
+ allres_versions[ch_idx].append(scaled_cropped_img_tuples[ch_idx])
195
+
196
+ output_img_tuples = tuple(
197
+ [
198
+ np.concatenate(allres_versions[ch_idx])
199
+ for ch_idx in range(len(img_tuples))
200
+ ]
201
+ )
202
+ return output_img_tuples, cropped_noise_tuples
203
+
204
+ def __getitem__(self, index: Union[int, Tuple[int, int]]):
205
+ img_tuples, noise_tuples = self._get_img(index)
206
+ if self._uncorrelated_channels:
207
+ assert (
208
+ self._input_idx is None
209
+ ), "Uncorrelated channels is not implemented when there is a separate input channel."
210
+ if np.random.rand() < self._uncorrelated_channel_probab:
211
+ img_tuples_new = [None] * len(img_tuples)
212
+ img_tuples_new[0] = img_tuples[0]
213
+ for i in range(1, len(img_tuples)):
214
+ new_index = np.random.randint(len(self))
215
+ img_tuples_tmp, _ = self._get_img(new_index)
216
+ img_tuples_new[i] = img_tuples_tmp[i]
217
+ img_tuples = img_tuples_new
218
+
219
+ if self._is_train:
220
+ if self._empty_patch_replacement_enabled:
221
+ if np.random.rand() < self._empty_patch_replacement_probab:
222
+ img_tuples = self.replace_with_empty_patch(img_tuples)
223
+
224
+ if self._enable_rotation:
225
+ img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
226
+
227
+ # add noise to input, if noise is present combine it with the image
228
+ # factor is for the compute input not to have too much noise because the average of two gaussians
229
+ if len(noise_tuples) > 0:
230
+ factor = np.sqrt(2) if self._input_is_sum else 1.0
231
+ input_tuples = []
232
+ for x in img_tuples:
233
+ x = (
234
+ x.copy()
235
+ ) # to avoid changing the original image since it is later used for target
236
+ # NOTE: other LC levels already have noise added. So, we just need to add noise to the highest resolution.
237
+ x[0] = x[0] + noise_tuples[0] * factor
238
+ input_tuples.append(x)
239
+ else:
240
+ input_tuples = img_tuples
241
+
242
+ # Compute the input by sum / average the channels
243
+ # Alpha is an amount of weight which is applied to the channels when combining them
244
+ # How to sample alpha is still under research
245
+ inp, alpha = self._compute_input(input_tuples)
246
+ target_tuples = [img[:1] for img in img_tuples]
247
+ # add noise to target.
248
+ if len(noise_tuples) >= 1:
249
+ target_tuples = [
250
+ x + noise for x, noise in zip(target_tuples, noise_tuples[1:])
251
+ ]
252
+
253
+ target = self._compute_target(target_tuples, alpha)
254
+
255
+ norm_target = self.normalize_target(target)
256
+
257
+ output = [inp, norm_target]
258
+
259
+ if self._return_alpha:
260
+ output.append(alpha)
261
+
262
+ if isinstance(index, int):
263
+ return tuple(output)
264
+
265
+ _, grid_size = index
266
+ output.append(grid_size)
267
+ return tuple(output)