careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,618 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions needed by dataloader & co.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from skimage.io import imread, imsave
|
|
9
|
+
|
|
10
|
+
from careamics.models.lvae.utils import Enum
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DataType(Enum):
|
|
14
|
+
MNIST = 0
|
|
15
|
+
Places365 = 1
|
|
16
|
+
NotMNIST = 2
|
|
17
|
+
OptiMEM100_014 = 3
|
|
18
|
+
CustomSinosoid = 4
|
|
19
|
+
Prevedel_EMBL = 5
|
|
20
|
+
AllenCellMito = 6
|
|
21
|
+
SeparateTiffData = 7
|
|
22
|
+
CustomSinosoidThreeCurve = 8
|
|
23
|
+
SemiSupBloodVesselsEMBL = 9
|
|
24
|
+
Pavia2 = 10
|
|
25
|
+
Pavia2VanillaSplitting = 11
|
|
26
|
+
ExpansionMicroscopyMitoTub = 12
|
|
27
|
+
ShroffMitoEr = 13
|
|
28
|
+
HTIba1Ki67 = 14
|
|
29
|
+
BSD68 = 15
|
|
30
|
+
BioSR_MRC = 16
|
|
31
|
+
TavernaSox2Golgi = 17
|
|
32
|
+
Dao3Channel = 18
|
|
33
|
+
ExpMicroscopyV2 = 19
|
|
34
|
+
Dao3ChannelWithInput = 20
|
|
35
|
+
TavernaSox2GolgiV2 = 21
|
|
36
|
+
TwoDset = 22
|
|
37
|
+
PredictedTiffData = 23
|
|
38
|
+
Pavia3SeqData = 24
|
|
39
|
+
# Here, we have 16 splitting tasks.
|
|
40
|
+
NicolaData = 25
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class DataSplitType(Enum):
|
|
44
|
+
All = 0
|
|
45
|
+
Train = 1
|
|
46
|
+
Val = 2
|
|
47
|
+
Test = 3
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class GridAlignement(Enum):
|
|
51
|
+
"""
|
|
52
|
+
A patch is formed by padding the grid with content. If the grids are 'Center' aligned, then padding is to done equally on all 4 sides.
|
|
53
|
+
On the other hand, if grids are 'LeftTop' aligned, padding is to be done on the right and bottom end of the grid.
|
|
54
|
+
In the former case, one needs (patch_size - grid_size)//2 amount of content on the right end of the frame.
|
|
55
|
+
In the latter case, one needs patch_size - grid_size amount of content on the right end of the frame.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
LeftTop = 0
|
|
59
|
+
Center = 1
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def load_tiff(path):
|
|
63
|
+
"""
|
|
64
|
+
Returns a 4d numpy array: num_imgs*h*w*num_channels
|
|
65
|
+
"""
|
|
66
|
+
data = imread(path, plugin="tifffile")
|
|
67
|
+
return data
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def save_tiff(path, data):
|
|
71
|
+
imsave(path, data, plugin="tifffile")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def load_tiffs(paths):
|
|
75
|
+
data = [load_tiff(path) for path in paths]
|
|
76
|
+
return np.concatenate(data, axis=0)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def split_in_half(s, e):
|
|
80
|
+
n = e - s
|
|
81
|
+
s1 = list(np.arange(n // 2))
|
|
82
|
+
s2 = list(np.arange(n // 2, n))
|
|
83
|
+
return [x + s for x in s1], [x + s for x in s2]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def adjust_for_imbalance_in_fraction_value(
|
|
87
|
+
val: List[int],
|
|
88
|
+
test: List[int],
|
|
89
|
+
val_fraction: float,
|
|
90
|
+
test_fraction: float,
|
|
91
|
+
total_size: int,
|
|
92
|
+
):
|
|
93
|
+
"""
|
|
94
|
+
here, val and test are divided almost equally. Here, we need to take into account their respective fractions
|
|
95
|
+
and pick elements rendomly from one array and put in the other array.
|
|
96
|
+
"""
|
|
97
|
+
if val_fraction == 0:
|
|
98
|
+
test += val
|
|
99
|
+
val = []
|
|
100
|
+
elif test_fraction == 0:
|
|
101
|
+
val += test
|
|
102
|
+
test = []
|
|
103
|
+
else:
|
|
104
|
+
diff_fraction = test_fraction - val_fraction
|
|
105
|
+
if diff_fraction > 0:
|
|
106
|
+
imb_count = int(diff_fraction * total_size / 2)
|
|
107
|
+
val = list(np.random.RandomState(seed=955).permutation(val))
|
|
108
|
+
test += val[:imb_count]
|
|
109
|
+
val = val[imb_count:]
|
|
110
|
+
elif diff_fraction < 0:
|
|
111
|
+
imb_count = int(-1 * diff_fraction * total_size / 2)
|
|
112
|
+
test = list(np.random.RandomState(seed=955).permutation(test))
|
|
113
|
+
val += test[:imb_count]
|
|
114
|
+
test = test[imb_count:]
|
|
115
|
+
return val, test
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def get_datasplit_tuples(
|
|
119
|
+
val_fraction: float,
|
|
120
|
+
test_fraction: float,
|
|
121
|
+
total_size: int,
|
|
122
|
+
starting_test: bool = False,
|
|
123
|
+
):
|
|
124
|
+
if starting_test:
|
|
125
|
+
# test => val => train
|
|
126
|
+
test = list(range(0, int(total_size * test_fraction)))
|
|
127
|
+
val = list(range(test[-1] + 1, test[-1] + 1 + int(total_size * val_fraction)))
|
|
128
|
+
train = list(range(val[-1] + 1, total_size))
|
|
129
|
+
else:
|
|
130
|
+
# {test,val}=> train
|
|
131
|
+
test_val_size = int((val_fraction + test_fraction) * total_size)
|
|
132
|
+
train = list(range(test_val_size, total_size))
|
|
133
|
+
|
|
134
|
+
if test_val_size == 0:
|
|
135
|
+
test = []
|
|
136
|
+
val = []
|
|
137
|
+
return train, val, test
|
|
138
|
+
|
|
139
|
+
# Split the test and validation in chunks.
|
|
140
|
+
chunksize = max(1, min(3, test_val_size // 2))
|
|
141
|
+
|
|
142
|
+
nchunks = test_val_size // chunksize
|
|
143
|
+
|
|
144
|
+
test = []
|
|
145
|
+
val = []
|
|
146
|
+
s = 0
|
|
147
|
+
for i in range(nchunks):
|
|
148
|
+
if i % 2 == 0:
|
|
149
|
+
val += list(np.arange(s, s + chunksize))
|
|
150
|
+
else:
|
|
151
|
+
test += list(np.arange(s, s + chunksize))
|
|
152
|
+
s += chunksize
|
|
153
|
+
|
|
154
|
+
if i % 2 == 0:
|
|
155
|
+
test += list(np.arange(s, test_val_size))
|
|
156
|
+
else:
|
|
157
|
+
p1, p2 = split_in_half(s, test_val_size)
|
|
158
|
+
test += p1
|
|
159
|
+
val += p2
|
|
160
|
+
|
|
161
|
+
val, test = adjust_for_imbalance_in_fraction_value(
|
|
162
|
+
val, test, val_fraction, test_fraction, total_size
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
return train, val, test
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def get_mrc_data(fpath):
|
|
169
|
+
# HXWXN
|
|
170
|
+
_, data = read_mrc(fpath)
|
|
171
|
+
data = data[None]
|
|
172
|
+
data = np.swapaxes(data, 0, 3)
|
|
173
|
+
return data[..., 0]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class GridIndexManager:
|
|
177
|
+
|
|
178
|
+
def __init__(self, data_shape, grid_size, patch_size, grid_alignement) -> None:
|
|
179
|
+
self._data_shape = data_shape
|
|
180
|
+
self._default_grid_size = grid_size
|
|
181
|
+
self.patch_size = patch_size
|
|
182
|
+
self.N = self._data_shape[0]
|
|
183
|
+
self._align = grid_alignement
|
|
184
|
+
|
|
185
|
+
def get_data_shape(self):
|
|
186
|
+
return self._data_shape
|
|
187
|
+
|
|
188
|
+
def use_default_grid(self, grid_size):
|
|
189
|
+
return grid_size is None or grid_size < 0
|
|
190
|
+
|
|
191
|
+
def grid_rows(self, grid_size):
|
|
192
|
+
if self._align == GridAlignement.LeftTop:
|
|
193
|
+
extra_pixels = self.patch_size - grid_size
|
|
194
|
+
elif self._align == GridAlignement.Center:
|
|
195
|
+
# Center is exclusively used during evaluation. In this case, we use the padding to handle edge cases.
|
|
196
|
+
# So, here, we will ideally like to cover all pixels and so extra_pixels is set to 0.
|
|
197
|
+
# If there was no padding, then it should be set to (self.patch_size - grid_size) // 2
|
|
198
|
+
extra_pixels = 0
|
|
199
|
+
|
|
200
|
+
return (self._data_shape[-3] - extra_pixels) // grid_size
|
|
201
|
+
|
|
202
|
+
def grid_cols(self, grid_size):
|
|
203
|
+
if self._align == GridAlignement.LeftTop:
|
|
204
|
+
extra_pixels = self.patch_size - grid_size
|
|
205
|
+
elif self._align == GridAlignement.Center:
|
|
206
|
+
extra_pixels = 0
|
|
207
|
+
|
|
208
|
+
return (self._data_shape[-2] - extra_pixels) // grid_size
|
|
209
|
+
|
|
210
|
+
def grid_count(self, grid_size=None):
|
|
211
|
+
if self.use_default_grid(grid_size):
|
|
212
|
+
grid_size = self._default_grid_size
|
|
213
|
+
|
|
214
|
+
return self.N * self.grid_rows(grid_size) * self.grid_cols(grid_size)
|
|
215
|
+
|
|
216
|
+
def hwt_from_idx(self, index, grid_size=None):
|
|
217
|
+
t = self.get_t(index)
|
|
218
|
+
return (*self.get_deterministic_hw(index, grid_size=grid_size), t)
|
|
219
|
+
|
|
220
|
+
def idx_from_hwt(self, h_start, w_start, t, grid_size=None):
|
|
221
|
+
"""
|
|
222
|
+
Given h,w,t (where h,w constitutes the top left corner of the patch), it returns the corresponding index.
|
|
223
|
+
"""
|
|
224
|
+
if grid_size is None:
|
|
225
|
+
grid_size = self._default_grid_size
|
|
226
|
+
|
|
227
|
+
nth_row = h_start // grid_size
|
|
228
|
+
nth_col = w_start // grid_size
|
|
229
|
+
|
|
230
|
+
index = self.grid_cols(grid_size) * nth_row + nth_col
|
|
231
|
+
return index * self._data_shape[0] + t
|
|
232
|
+
|
|
233
|
+
def get_t(self, index):
|
|
234
|
+
return index % self.N
|
|
235
|
+
|
|
236
|
+
def get_top_nbr_idx(self, index, grid_size=None):
|
|
237
|
+
if self.use_default_grid(grid_size):
|
|
238
|
+
grid_size = self._default_grid_size
|
|
239
|
+
|
|
240
|
+
ncols = self.grid_cols(grid_size)
|
|
241
|
+
index -= ncols * self.N
|
|
242
|
+
if index < 0:
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
return index
|
|
246
|
+
|
|
247
|
+
def get_bottom_nbr_idx(self, index, grid_size=None):
|
|
248
|
+
if self.use_default_grid(grid_size):
|
|
249
|
+
grid_size = self._default_grid_size
|
|
250
|
+
|
|
251
|
+
ncols = self.grid_cols(grid_size)
|
|
252
|
+
index += ncols * self.N
|
|
253
|
+
if index > self.grid_count(grid_size=grid_size):
|
|
254
|
+
return None
|
|
255
|
+
|
|
256
|
+
return index
|
|
257
|
+
|
|
258
|
+
def get_left_nbr_idx(self, index, grid_size=None):
|
|
259
|
+
if self.on_left_boundary(index, grid_size=grid_size):
|
|
260
|
+
return None
|
|
261
|
+
|
|
262
|
+
index -= self.N
|
|
263
|
+
return index
|
|
264
|
+
|
|
265
|
+
def get_right_nbr_idx(self, index, grid_size=None):
|
|
266
|
+
if self.on_right_boundary(index, grid_size=grid_size):
|
|
267
|
+
return None
|
|
268
|
+
index += self.N
|
|
269
|
+
return index
|
|
270
|
+
|
|
271
|
+
def on_left_boundary(self, index, grid_size=None):
|
|
272
|
+
if self.use_default_grid(grid_size):
|
|
273
|
+
grid_size = self._default_grid_size
|
|
274
|
+
|
|
275
|
+
factor = index // self.N
|
|
276
|
+
ncols = self.grid_cols(grid_size)
|
|
277
|
+
|
|
278
|
+
left_boundary = (factor // ncols) != (factor - 1) // ncols
|
|
279
|
+
return left_boundary
|
|
280
|
+
|
|
281
|
+
def on_right_boundary(self, index, grid_size=None):
|
|
282
|
+
if self.use_default_grid(grid_size):
|
|
283
|
+
grid_size = self._default_grid_size
|
|
284
|
+
|
|
285
|
+
factor = index // self.N
|
|
286
|
+
ncols = self.grid_cols(grid_size)
|
|
287
|
+
|
|
288
|
+
right_boundary = (factor // ncols) != (factor + 1) // ncols
|
|
289
|
+
return right_boundary
|
|
290
|
+
|
|
291
|
+
def on_top_boundary(self, index, grid_size=None):
|
|
292
|
+
if self.use_default_grid(grid_size):
|
|
293
|
+
grid_size = self._default_grid_size
|
|
294
|
+
|
|
295
|
+
ncols = self.grid_cols(grid_size)
|
|
296
|
+
return index < self.N * ncols
|
|
297
|
+
|
|
298
|
+
def on_bottom_boundary(self, index, grid_size=None):
|
|
299
|
+
if self.use_default_grid(grid_size):
|
|
300
|
+
grid_size = self._default_grid_size
|
|
301
|
+
|
|
302
|
+
ncols = self.grid_cols(grid_size)
|
|
303
|
+
return index + self.N * ncols > self.grid_count(grid_size=grid_size)
|
|
304
|
+
|
|
305
|
+
def on_boundary(self, idx, grid_size=None):
|
|
306
|
+
if self.on_left_boundary(idx, grid_size=grid_size):
|
|
307
|
+
return True
|
|
308
|
+
|
|
309
|
+
if self.on_right_boundary(idx, grid_size=grid_size):
|
|
310
|
+
return True
|
|
311
|
+
|
|
312
|
+
if self.on_top_boundary(idx, grid_size=grid_size):
|
|
313
|
+
return True
|
|
314
|
+
|
|
315
|
+
if self.on_bottom_boundary(idx, grid_size=grid_size):
|
|
316
|
+
return True
|
|
317
|
+
return False
|
|
318
|
+
|
|
319
|
+
def get_deterministic_hw(self, index: int, grid_size=None):
|
|
320
|
+
"""
|
|
321
|
+
Fixed starting position for the crop for the img with index `index`.
|
|
322
|
+
"""
|
|
323
|
+
if self.use_default_grid(grid_size):
|
|
324
|
+
grid_size = self._default_grid_size
|
|
325
|
+
|
|
326
|
+
# _, h, w, _ = self._data_shape
|
|
327
|
+
# assert h == w
|
|
328
|
+
factor = index // self.N
|
|
329
|
+
ncols = self.grid_cols(grid_size)
|
|
330
|
+
|
|
331
|
+
ith_row = factor // ncols
|
|
332
|
+
jth_col = factor % ncols
|
|
333
|
+
h_start = ith_row * grid_size
|
|
334
|
+
w_start = jth_col * grid_size
|
|
335
|
+
return h_start, w_start
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class IndexSwitcher:
|
|
339
|
+
"""
|
|
340
|
+
The idea is to switch from valid indices for target to invalid indices for target.
|
|
341
|
+
If index in invalid for the target, then we return all zero vector as target.
|
|
342
|
+
This combines both logic:
|
|
343
|
+
1. Using less amount of total data.
|
|
344
|
+
2. Using less amount of target data but using full data.
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
def __init__(self, idx_manager, data_config, patch_size) -> None:
|
|
348
|
+
self.idx_manager = idx_manager
|
|
349
|
+
self._data_shape = self.idx_manager.get_data_shape()
|
|
350
|
+
self._training_validtarget_fraction = data_config.get(
|
|
351
|
+
"training_validtarget_fraction", 1.0
|
|
352
|
+
)
|
|
353
|
+
self._validtarget_ceilT = int(
|
|
354
|
+
np.ceil(self._data_shape[0] * self._training_validtarget_fraction)
|
|
355
|
+
)
|
|
356
|
+
self._patch_size = patch_size
|
|
357
|
+
assert (
|
|
358
|
+
data_config.deterministic_grid is True
|
|
359
|
+
), "This only works when the dataset has deterministic grid. Needed randomness comes from this class."
|
|
360
|
+
assert (
|
|
361
|
+
"grid_size" in data_config and data_config.grid_size == 1
|
|
362
|
+
), "We need a one to one mapping between index and h, w, t"
|
|
363
|
+
|
|
364
|
+
self._h_validmax, self._w_validmax = self.get_reduced_frame_size(
|
|
365
|
+
self._data_shape[:3], self._training_validtarget_fraction
|
|
366
|
+
)
|
|
367
|
+
if self._h_validmax < self._patch_size or self._w_validmax < self._patch_size:
|
|
368
|
+
print(
|
|
369
|
+
"WARNING: The valid target size is smaller than the patch size. This will result in all zero target. so, we are ignoring this frame for target."
|
|
370
|
+
)
|
|
371
|
+
self._h_validmax = 0
|
|
372
|
+
self._w_validmax = 0
|
|
373
|
+
|
|
374
|
+
print(
|
|
375
|
+
f"[{self.__class__.__name__}] Target Indices: [0,{self._validtarget_ceilT-1}]. Index={self._validtarget_ceilT-1} has shape [:{self._h_validmax},:{self._w_validmax}]. Available data: {self._data_shape[0]}"
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
def get_valid_target_index(self):
|
|
379
|
+
"""
|
|
380
|
+
Returns an index which corresponds to a frame which is expected to have a target.
|
|
381
|
+
"""
|
|
382
|
+
_, h, w, _ = self._data_shape
|
|
383
|
+
framepixelcount = h * w
|
|
384
|
+
targetpixels = np.array(
|
|
385
|
+
[framepixelcount] * (self._validtarget_ceilT - 1)
|
|
386
|
+
+ [self._h_validmax * self._w_validmax]
|
|
387
|
+
)
|
|
388
|
+
targetpixels = targetpixels / np.sum(targetpixels)
|
|
389
|
+
t = np.random.choice(self._validtarget_ceilT, p=targetpixels)
|
|
390
|
+
# t = np.random.randint(0, self._validtarget_ceilT) if self._validtarget_ceilT >= 1 else 0
|
|
391
|
+
h, w = self.get_valid_target_hw(t)
|
|
392
|
+
index = self.idx_manager.idx_from_hwt(h, w, t)
|
|
393
|
+
# print('Valid', index, h,w,t)
|
|
394
|
+
return index
|
|
395
|
+
|
|
396
|
+
def get_invalid_target_index(self):
|
|
397
|
+
# if self._validtarget_ceilT == 0:
|
|
398
|
+
# TODO: There may not be enough data for this to work. The better way is to skip using 0 for invalid target.
|
|
399
|
+
# t = np.random.randint(1, self._data_shape[0])
|
|
400
|
+
# elif self._validtarget_ceilT < self._data_shape[0]:
|
|
401
|
+
# t = np.random.randint(self._validtarget_ceilT, self._data_shape[0])
|
|
402
|
+
# else:
|
|
403
|
+
# t = self._validtarget_ceilT - 1
|
|
404
|
+
# 5
|
|
405
|
+
# 1.2 => 2
|
|
406
|
+
total_t, h, w, _ = self._data_shape
|
|
407
|
+
framepixelcount = h * w
|
|
408
|
+
available_h = h - self._h_validmax
|
|
409
|
+
if available_h < self._patch_size:
|
|
410
|
+
available_h = 0
|
|
411
|
+
available_w = w - self._w_validmax
|
|
412
|
+
if available_w < self._patch_size:
|
|
413
|
+
available_w = 0
|
|
414
|
+
|
|
415
|
+
targetpixels = np.array(
|
|
416
|
+
[available_h * available_w]
|
|
417
|
+
+ [framepixelcount] * (total_t - self._validtarget_ceilT)
|
|
418
|
+
)
|
|
419
|
+
t_probab = targetpixels / np.sum(targetpixels)
|
|
420
|
+
t = np.random.choice(
|
|
421
|
+
np.arange(self._validtarget_ceilT - 1, total_t), p=t_probab
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
h, w = self.get_invalid_target_hw(t)
|
|
425
|
+
index = self.idx_manager.idx_from_hwt(h, w, t)
|
|
426
|
+
# print('Invalid', index, h,w,t)
|
|
427
|
+
return index
|
|
428
|
+
|
|
429
|
+
def get_valid_target_hw(self, t):
|
|
430
|
+
"""
|
|
431
|
+
This is the opposite of get_invalid_target_hw. It returns a h,w which is valid for target.
|
|
432
|
+
This is only valid for single frame setup.
|
|
433
|
+
"""
|
|
434
|
+
if t == self._validtarget_ceilT - 1:
|
|
435
|
+
h = np.random.randint(0, self._h_validmax - self._patch_size)
|
|
436
|
+
w = np.random.randint(0, self._w_validmax - self._patch_size)
|
|
437
|
+
else:
|
|
438
|
+
h = np.random.randint(0, self._data_shape[1] - self._patch_size)
|
|
439
|
+
w = np.random.randint(0, self._data_shape[2] - self._patch_size)
|
|
440
|
+
return h, w
|
|
441
|
+
|
|
442
|
+
def get_invalid_target_hw(self, t):
|
|
443
|
+
"""
|
|
444
|
+
This is the opposite of get_valid_target_hw. It returns a h,w which is not valid for target.
|
|
445
|
+
This is only valid for single frame setup.
|
|
446
|
+
"""
|
|
447
|
+
if t == self._validtarget_ceilT - 1:
|
|
448
|
+
h = np.random.randint(
|
|
449
|
+
self._h_validmax, self._data_shape[1] - self._patch_size
|
|
450
|
+
)
|
|
451
|
+
w = np.random.randint(
|
|
452
|
+
self._w_validmax, self._data_shape[2] - self._patch_size
|
|
453
|
+
)
|
|
454
|
+
else:
|
|
455
|
+
h = np.random.randint(0, self._data_shape[1] - self._patch_size)
|
|
456
|
+
w = np.random.randint(0, self._data_shape[2] - self._patch_size)
|
|
457
|
+
return h, w
|
|
458
|
+
|
|
459
|
+
def _get_tidx(self, index):
|
|
460
|
+
if isinstance(index, int) or isinstance(index, np.int64):
|
|
461
|
+
idx = index
|
|
462
|
+
else:
|
|
463
|
+
idx = index[0]
|
|
464
|
+
return self.idx_manager.get_t(idx)
|
|
465
|
+
|
|
466
|
+
def index_should_have_target(self, index):
|
|
467
|
+
tidx = self._get_tidx(index)
|
|
468
|
+
if tidx < self._validtarget_ceilT - 1:
|
|
469
|
+
return True
|
|
470
|
+
elif tidx > self._validtarget_ceilT - 1:
|
|
471
|
+
return False
|
|
472
|
+
else:
|
|
473
|
+
h, w, _ = self.idx_manager.hwt_from_idx(index)
|
|
474
|
+
return (
|
|
475
|
+
h + self._patch_size < self._h_validmax
|
|
476
|
+
and w + self._patch_size < self._w_validmax
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
@staticmethod
|
|
480
|
+
def get_reduced_frame_size(data_shape_nhw, fraction):
|
|
481
|
+
n, h, w = data_shape_nhw
|
|
482
|
+
|
|
483
|
+
framepixelcount = h * w
|
|
484
|
+
targetpixelcount = int(n * framepixelcount * fraction)
|
|
485
|
+
|
|
486
|
+
# We are currently supporting this only when there is just one frame.
|
|
487
|
+
# if np.ceil(pixelcount / framepixelcount) > 1:
|
|
488
|
+
# return None, None
|
|
489
|
+
|
|
490
|
+
lastframepixelcount = targetpixelcount % framepixelcount
|
|
491
|
+
assert data_shape_nhw[1] == data_shape_nhw[2]
|
|
492
|
+
if lastframepixelcount > 0:
|
|
493
|
+
new_size = int(np.sqrt(lastframepixelcount))
|
|
494
|
+
return new_size, new_size
|
|
495
|
+
else:
|
|
496
|
+
assert (
|
|
497
|
+
targetpixelcount / framepixelcount >= 1
|
|
498
|
+
), "This is not possible in euclidean space :D (so this is a bug)"
|
|
499
|
+
return h, w
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
rec_header_dtd = [
|
|
503
|
+
("nx", "i4"), # Number of columns
|
|
504
|
+
("ny", "i4"), # Number of rows
|
|
505
|
+
("nz", "i4"), # Number of sections
|
|
506
|
+
("mode", "i4"), # Types of pixels in the image. Values used by IMOD:
|
|
507
|
+
# 0 = unsigned or signed bytes depending on flag in imodFlags
|
|
508
|
+
# 1 = signed short integers (16 bits)
|
|
509
|
+
# 2 = float (32 bits)
|
|
510
|
+
# 3 = short * 2, (used for complex data)
|
|
511
|
+
# 4 = float * 2, (used for complex data)
|
|
512
|
+
# 6 = unsigned 16-bit integers (non-standard)
|
|
513
|
+
# 16 = unsigned char * 3 (for rgb data, non-standard)
|
|
514
|
+
("nxstart", "i4"), # Starting point of sub-image (not used in IMOD)
|
|
515
|
+
("nystart", "i4"),
|
|
516
|
+
("nzstart", "i4"),
|
|
517
|
+
("mx", "i4"), # Grid size in X, Y and Z
|
|
518
|
+
("my", "i4"),
|
|
519
|
+
("mz", "i4"),
|
|
520
|
+
("xlen", "f4"), # Cell size; pixel spacing = xlen/mx, ylen/my, zlen/mz
|
|
521
|
+
("ylen", "f4"),
|
|
522
|
+
("zlen", "f4"),
|
|
523
|
+
("alpha", "f4"), # Cell angles - ignored by IMOD
|
|
524
|
+
("beta", "f4"),
|
|
525
|
+
("gamma", "f4"),
|
|
526
|
+
# These need to be set to 1, 2, and 3 for pixel spacing to be interpreted correctly
|
|
527
|
+
("mapc", "i4"), # map column 1=x,2=y,3=z.
|
|
528
|
+
("mapr", "i4"), # map row 1=x,2=y,3=z.
|
|
529
|
+
("maps", "i4"), # map section 1=x,2=y,3=z.
|
|
530
|
+
# These need to be set for proper scaling of data
|
|
531
|
+
("amin", "f4"), # Minimum pixel value
|
|
532
|
+
("amax", "f4"), # Maximum pixel value
|
|
533
|
+
("amean", "f4"), # Mean pixel value
|
|
534
|
+
("ispg", "i4"), # space group number (ignored by IMOD)
|
|
535
|
+
(
|
|
536
|
+
"next",
|
|
537
|
+
"i4",
|
|
538
|
+
), # number of bytes in extended header (called nsymbt in MRC standard)
|
|
539
|
+
("creatid", "i2"), # used to be an ID number, is 0 as of IMOD 4.2.23
|
|
540
|
+
("extra_data", "V30"), # (not used, first two bytes should be 0)
|
|
541
|
+
# These two values specify the structure of data in the extended header; their meaning depend on whether the
|
|
542
|
+
# extended header has the Agard format, a series of 4-byte integers then real numbers, or has data
|
|
543
|
+
# produced by SerialEM, a series of short integers. SerialEM stores a float as two shorts, s1 and s2, by:
|
|
544
|
+
# value = (sign of s1)*(|s1|*256 + (|s2| modulo 256)) * 2**((sign of s2) * (|s2|/256))
|
|
545
|
+
("nint", "i2"),
|
|
546
|
+
# Number of integers per section (Agard format) or number of bytes per section (SerialEM format)
|
|
547
|
+
("nreal", "i2"), # Number of reals per section (Agard format) or bit
|
|
548
|
+
# Number of reals per section (Agard format) or bit
|
|
549
|
+
# flags for which types of short data (SerialEM format):
|
|
550
|
+
# 1 = tilt angle * 100 (2 bytes)
|
|
551
|
+
# 2 = piece coordinates for montage (6 bytes)
|
|
552
|
+
# 4 = Stage position * 25 (4 bytes)
|
|
553
|
+
# 8 = Magnification / 100 (2 bytes)
|
|
554
|
+
# 16 = Intensity * 25000 (2 bytes)
|
|
555
|
+
# 32 = Exposure dose in e-/A2, a float in 4 bytes
|
|
556
|
+
# 128, 512: Reserved for 4-byte items
|
|
557
|
+
# 64, 256, 1024: Reserved for 2-byte items
|
|
558
|
+
# If the number of bytes implied by these flags does
|
|
559
|
+
# not add up to the value in nint, then nint and nreal
|
|
560
|
+
# are interpreted as ints and reals per section
|
|
561
|
+
("extra_data2", "V20"), # extra data (not used)
|
|
562
|
+
("imodStamp", "i4"), # 1146047817 indicates that file was created by IMOD
|
|
563
|
+
("imodFlags", "i4"), # Bit flags: 1 = bytes are stored as signed
|
|
564
|
+
# Explanation of type of data
|
|
565
|
+
("idtype", "i2"), # ( 0 = mono, 1 = tilt, 2 = tilts, 3 = lina, 4 = lins)
|
|
566
|
+
("lens", "i2"),
|
|
567
|
+
# ("nd1", "i2"), # for idtype = 1, nd1 = axis (1, 2, or 3)
|
|
568
|
+
# ("nd2", "i2"),
|
|
569
|
+
("nphase", "i4"),
|
|
570
|
+
("vd1", "i2"), # vd1 = 100. * tilt increment
|
|
571
|
+
("vd2", "i2"), # vd2 = 100. * starting angle
|
|
572
|
+
# Current angles are used to rotate a model to match a new rotated image. The three values in each set are
|
|
573
|
+
# rotations about X, Y, and Z axes, applied in the order Z, Y, X.
|
|
574
|
+
("triangles", "f4", 6), # 0,1,2 = original: 3,4,5 = current
|
|
575
|
+
("xorg", "f4"), # Origin of image
|
|
576
|
+
("yorg", "f4"),
|
|
577
|
+
("zorg", "f4"),
|
|
578
|
+
("cmap", "S4"), # Contains "MAP "
|
|
579
|
+
(
|
|
580
|
+
"stamp",
|
|
581
|
+
"u1",
|
|
582
|
+
4,
|
|
583
|
+
), # First two bytes have 17 and 17 for big-endian or 68 and 65 for little-endian
|
|
584
|
+
("rms", "f4"), # RMS deviation of densities from mean density
|
|
585
|
+
("nlabl", "i4"), # Number of labels with useful data
|
|
586
|
+
("labels", "S80", 10), # 10 labels of 80 charactors
|
|
587
|
+
]
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def read_mrc(filename, filetype="image"):
|
|
591
|
+
|
|
592
|
+
fd = open(filename, "rb")
|
|
593
|
+
header = np.fromfile(fd, dtype=rec_header_dtd, count=1)
|
|
594
|
+
|
|
595
|
+
nx, ny, nz = header["nx"][0], header["ny"][0], header["nz"][0]
|
|
596
|
+
|
|
597
|
+
if header[0][3] == 1:
|
|
598
|
+
data_type = "int16"
|
|
599
|
+
elif header[0][3] == 2:
|
|
600
|
+
data_type = "float32"
|
|
601
|
+
elif header[0][3] == 4:
|
|
602
|
+
data_type = "single"
|
|
603
|
+
nx = nx * 2
|
|
604
|
+
elif header[0][3] == 6:
|
|
605
|
+
data_type = "uint16"
|
|
606
|
+
|
|
607
|
+
data = np.ndarray(shape=(nx, ny, nz))
|
|
608
|
+
imgrawdata = np.fromfile(fd, data_type)
|
|
609
|
+
fd.close()
|
|
610
|
+
|
|
611
|
+
if filetype == "image":
|
|
612
|
+
for iz in range(nz):
|
|
613
|
+
data_2d = imgrawdata[nx * ny * iz : nx * ny * (iz + 1)]
|
|
614
|
+
data[:, :, iz] = data_2d.reshape(nx, ny, order="F")
|
|
615
|
+
else:
|
|
616
|
+
data = imgrawdata
|
|
617
|
+
|
|
618
|
+
return header, data
|