careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,618 @@
1
+ """
2
+ Utility functions needed by dataloader & co.
3
+ """
4
+
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ from skimage.io import imread, imsave
9
+
10
+ from careamics.models.lvae.utils import Enum
11
+
12
+
13
+ class DataType(Enum):
14
+ MNIST = 0
15
+ Places365 = 1
16
+ NotMNIST = 2
17
+ OptiMEM100_014 = 3
18
+ CustomSinosoid = 4
19
+ Prevedel_EMBL = 5
20
+ AllenCellMito = 6
21
+ SeparateTiffData = 7
22
+ CustomSinosoidThreeCurve = 8
23
+ SemiSupBloodVesselsEMBL = 9
24
+ Pavia2 = 10
25
+ Pavia2VanillaSplitting = 11
26
+ ExpansionMicroscopyMitoTub = 12
27
+ ShroffMitoEr = 13
28
+ HTIba1Ki67 = 14
29
+ BSD68 = 15
30
+ BioSR_MRC = 16
31
+ TavernaSox2Golgi = 17
32
+ Dao3Channel = 18
33
+ ExpMicroscopyV2 = 19
34
+ Dao3ChannelWithInput = 20
35
+ TavernaSox2GolgiV2 = 21
36
+ TwoDset = 22
37
+ PredictedTiffData = 23
38
+ Pavia3SeqData = 24
39
+ # Here, we have 16 splitting tasks.
40
+ NicolaData = 25
41
+
42
+
43
+ class DataSplitType(Enum):
44
+ All = 0
45
+ Train = 1
46
+ Val = 2
47
+ Test = 3
48
+
49
+
50
+ class GridAlignement(Enum):
51
+ """
52
+ A patch is formed by padding the grid with content. If the grids are 'Center' aligned, then padding is to done equally on all 4 sides.
53
+ On the other hand, if grids are 'LeftTop' aligned, padding is to be done on the right and bottom end of the grid.
54
+ In the former case, one needs (patch_size - grid_size)//2 amount of content on the right end of the frame.
55
+ In the latter case, one needs patch_size - grid_size amount of content on the right end of the frame.
56
+ """
57
+
58
+ LeftTop = 0
59
+ Center = 1
60
+
61
+
62
+ def load_tiff(path):
63
+ """
64
+ Returns a 4d numpy array: num_imgs*h*w*num_channels
65
+ """
66
+ data = imread(path, plugin="tifffile")
67
+ return data
68
+
69
+
70
+ def save_tiff(path, data):
71
+ imsave(path, data, plugin="tifffile")
72
+
73
+
74
+ def load_tiffs(paths):
75
+ data = [load_tiff(path) for path in paths]
76
+ return np.concatenate(data, axis=0)
77
+
78
+
79
+ def split_in_half(s, e):
80
+ n = e - s
81
+ s1 = list(np.arange(n // 2))
82
+ s2 = list(np.arange(n // 2, n))
83
+ return [x + s for x in s1], [x + s for x in s2]
84
+
85
+
86
+ def adjust_for_imbalance_in_fraction_value(
87
+ val: List[int],
88
+ test: List[int],
89
+ val_fraction: float,
90
+ test_fraction: float,
91
+ total_size: int,
92
+ ):
93
+ """
94
+ here, val and test are divided almost equally. Here, we need to take into account their respective fractions
95
+ and pick elements rendomly from one array and put in the other array.
96
+ """
97
+ if val_fraction == 0:
98
+ test += val
99
+ val = []
100
+ elif test_fraction == 0:
101
+ val += test
102
+ test = []
103
+ else:
104
+ diff_fraction = test_fraction - val_fraction
105
+ if diff_fraction > 0:
106
+ imb_count = int(diff_fraction * total_size / 2)
107
+ val = list(np.random.RandomState(seed=955).permutation(val))
108
+ test += val[:imb_count]
109
+ val = val[imb_count:]
110
+ elif diff_fraction < 0:
111
+ imb_count = int(-1 * diff_fraction * total_size / 2)
112
+ test = list(np.random.RandomState(seed=955).permutation(test))
113
+ val += test[:imb_count]
114
+ test = test[imb_count:]
115
+ return val, test
116
+
117
+
118
+ def get_datasplit_tuples(
119
+ val_fraction: float,
120
+ test_fraction: float,
121
+ total_size: int,
122
+ starting_test: bool = False,
123
+ ):
124
+ if starting_test:
125
+ # test => val => train
126
+ test = list(range(0, int(total_size * test_fraction)))
127
+ val = list(range(test[-1] + 1, test[-1] + 1 + int(total_size * val_fraction)))
128
+ train = list(range(val[-1] + 1, total_size))
129
+ else:
130
+ # {test,val}=> train
131
+ test_val_size = int((val_fraction + test_fraction) * total_size)
132
+ train = list(range(test_val_size, total_size))
133
+
134
+ if test_val_size == 0:
135
+ test = []
136
+ val = []
137
+ return train, val, test
138
+
139
+ # Split the test and validation in chunks.
140
+ chunksize = max(1, min(3, test_val_size // 2))
141
+
142
+ nchunks = test_val_size // chunksize
143
+
144
+ test = []
145
+ val = []
146
+ s = 0
147
+ for i in range(nchunks):
148
+ if i % 2 == 0:
149
+ val += list(np.arange(s, s + chunksize))
150
+ else:
151
+ test += list(np.arange(s, s + chunksize))
152
+ s += chunksize
153
+
154
+ if i % 2 == 0:
155
+ test += list(np.arange(s, test_val_size))
156
+ else:
157
+ p1, p2 = split_in_half(s, test_val_size)
158
+ test += p1
159
+ val += p2
160
+
161
+ val, test = adjust_for_imbalance_in_fraction_value(
162
+ val, test, val_fraction, test_fraction, total_size
163
+ )
164
+
165
+ return train, val, test
166
+
167
+
168
+ def get_mrc_data(fpath):
169
+ # HXWXN
170
+ _, data = read_mrc(fpath)
171
+ data = data[None]
172
+ data = np.swapaxes(data, 0, 3)
173
+ return data[..., 0]
174
+
175
+
176
+ class GridIndexManager:
177
+
178
+ def __init__(self, data_shape, grid_size, patch_size, grid_alignement) -> None:
179
+ self._data_shape = data_shape
180
+ self._default_grid_size = grid_size
181
+ self.patch_size = patch_size
182
+ self.N = self._data_shape[0]
183
+ self._align = grid_alignement
184
+
185
+ def get_data_shape(self):
186
+ return self._data_shape
187
+
188
+ def use_default_grid(self, grid_size):
189
+ return grid_size is None or grid_size < 0
190
+
191
+ def grid_rows(self, grid_size):
192
+ if self._align == GridAlignement.LeftTop:
193
+ extra_pixels = self.patch_size - grid_size
194
+ elif self._align == GridAlignement.Center:
195
+ # Center is exclusively used during evaluation. In this case, we use the padding to handle edge cases.
196
+ # So, here, we will ideally like to cover all pixels and so extra_pixels is set to 0.
197
+ # If there was no padding, then it should be set to (self.patch_size - grid_size) // 2
198
+ extra_pixels = 0
199
+
200
+ return (self._data_shape[-3] - extra_pixels) // grid_size
201
+
202
+ def grid_cols(self, grid_size):
203
+ if self._align == GridAlignement.LeftTop:
204
+ extra_pixels = self.patch_size - grid_size
205
+ elif self._align == GridAlignement.Center:
206
+ extra_pixels = 0
207
+
208
+ return (self._data_shape[-2] - extra_pixels) // grid_size
209
+
210
+ def grid_count(self, grid_size=None):
211
+ if self.use_default_grid(grid_size):
212
+ grid_size = self._default_grid_size
213
+
214
+ return self.N * self.grid_rows(grid_size) * self.grid_cols(grid_size)
215
+
216
+ def hwt_from_idx(self, index, grid_size=None):
217
+ t = self.get_t(index)
218
+ return (*self.get_deterministic_hw(index, grid_size=grid_size), t)
219
+
220
+ def idx_from_hwt(self, h_start, w_start, t, grid_size=None):
221
+ """
222
+ Given h,w,t (where h,w constitutes the top left corner of the patch), it returns the corresponding index.
223
+ """
224
+ if grid_size is None:
225
+ grid_size = self._default_grid_size
226
+
227
+ nth_row = h_start // grid_size
228
+ nth_col = w_start // grid_size
229
+
230
+ index = self.grid_cols(grid_size) * nth_row + nth_col
231
+ return index * self._data_shape[0] + t
232
+
233
+ def get_t(self, index):
234
+ return index % self.N
235
+
236
+ def get_top_nbr_idx(self, index, grid_size=None):
237
+ if self.use_default_grid(grid_size):
238
+ grid_size = self._default_grid_size
239
+
240
+ ncols = self.grid_cols(grid_size)
241
+ index -= ncols * self.N
242
+ if index < 0:
243
+ return None
244
+
245
+ return index
246
+
247
+ def get_bottom_nbr_idx(self, index, grid_size=None):
248
+ if self.use_default_grid(grid_size):
249
+ grid_size = self._default_grid_size
250
+
251
+ ncols = self.grid_cols(grid_size)
252
+ index += ncols * self.N
253
+ if index > self.grid_count(grid_size=grid_size):
254
+ return None
255
+
256
+ return index
257
+
258
+ def get_left_nbr_idx(self, index, grid_size=None):
259
+ if self.on_left_boundary(index, grid_size=grid_size):
260
+ return None
261
+
262
+ index -= self.N
263
+ return index
264
+
265
+ def get_right_nbr_idx(self, index, grid_size=None):
266
+ if self.on_right_boundary(index, grid_size=grid_size):
267
+ return None
268
+ index += self.N
269
+ return index
270
+
271
+ def on_left_boundary(self, index, grid_size=None):
272
+ if self.use_default_grid(grid_size):
273
+ grid_size = self._default_grid_size
274
+
275
+ factor = index // self.N
276
+ ncols = self.grid_cols(grid_size)
277
+
278
+ left_boundary = (factor // ncols) != (factor - 1) // ncols
279
+ return left_boundary
280
+
281
+ def on_right_boundary(self, index, grid_size=None):
282
+ if self.use_default_grid(grid_size):
283
+ grid_size = self._default_grid_size
284
+
285
+ factor = index // self.N
286
+ ncols = self.grid_cols(grid_size)
287
+
288
+ right_boundary = (factor // ncols) != (factor + 1) // ncols
289
+ return right_boundary
290
+
291
+ def on_top_boundary(self, index, grid_size=None):
292
+ if self.use_default_grid(grid_size):
293
+ grid_size = self._default_grid_size
294
+
295
+ ncols = self.grid_cols(grid_size)
296
+ return index < self.N * ncols
297
+
298
+ def on_bottom_boundary(self, index, grid_size=None):
299
+ if self.use_default_grid(grid_size):
300
+ grid_size = self._default_grid_size
301
+
302
+ ncols = self.grid_cols(grid_size)
303
+ return index + self.N * ncols > self.grid_count(grid_size=grid_size)
304
+
305
+ def on_boundary(self, idx, grid_size=None):
306
+ if self.on_left_boundary(idx, grid_size=grid_size):
307
+ return True
308
+
309
+ if self.on_right_boundary(idx, grid_size=grid_size):
310
+ return True
311
+
312
+ if self.on_top_boundary(idx, grid_size=grid_size):
313
+ return True
314
+
315
+ if self.on_bottom_boundary(idx, grid_size=grid_size):
316
+ return True
317
+ return False
318
+
319
+ def get_deterministic_hw(self, index: int, grid_size=None):
320
+ """
321
+ Fixed starting position for the crop for the img with index `index`.
322
+ """
323
+ if self.use_default_grid(grid_size):
324
+ grid_size = self._default_grid_size
325
+
326
+ # _, h, w, _ = self._data_shape
327
+ # assert h == w
328
+ factor = index // self.N
329
+ ncols = self.grid_cols(grid_size)
330
+
331
+ ith_row = factor // ncols
332
+ jth_col = factor % ncols
333
+ h_start = ith_row * grid_size
334
+ w_start = jth_col * grid_size
335
+ return h_start, w_start
336
+
337
+
338
+ class IndexSwitcher:
339
+ """
340
+ The idea is to switch from valid indices for target to invalid indices for target.
341
+ If index in invalid for the target, then we return all zero vector as target.
342
+ This combines both logic:
343
+ 1. Using less amount of total data.
344
+ 2. Using less amount of target data but using full data.
345
+ """
346
+
347
+ def __init__(self, idx_manager, data_config, patch_size) -> None:
348
+ self.idx_manager = idx_manager
349
+ self._data_shape = self.idx_manager.get_data_shape()
350
+ self._training_validtarget_fraction = data_config.get(
351
+ "training_validtarget_fraction", 1.0
352
+ )
353
+ self._validtarget_ceilT = int(
354
+ np.ceil(self._data_shape[0] * self._training_validtarget_fraction)
355
+ )
356
+ self._patch_size = patch_size
357
+ assert (
358
+ data_config.deterministic_grid is True
359
+ ), "This only works when the dataset has deterministic grid. Needed randomness comes from this class."
360
+ assert (
361
+ "grid_size" in data_config and data_config.grid_size == 1
362
+ ), "We need a one to one mapping between index and h, w, t"
363
+
364
+ self._h_validmax, self._w_validmax = self.get_reduced_frame_size(
365
+ self._data_shape[:3], self._training_validtarget_fraction
366
+ )
367
+ if self._h_validmax < self._patch_size or self._w_validmax < self._patch_size:
368
+ print(
369
+ "WARNING: The valid target size is smaller than the patch size. This will result in all zero target. so, we are ignoring this frame for target."
370
+ )
371
+ self._h_validmax = 0
372
+ self._w_validmax = 0
373
+
374
+ print(
375
+ f"[{self.__class__.__name__}] Target Indices: [0,{self._validtarget_ceilT-1}]. Index={self._validtarget_ceilT-1} has shape [:{self._h_validmax},:{self._w_validmax}]. Available data: {self._data_shape[0]}"
376
+ )
377
+
378
+ def get_valid_target_index(self):
379
+ """
380
+ Returns an index which corresponds to a frame which is expected to have a target.
381
+ """
382
+ _, h, w, _ = self._data_shape
383
+ framepixelcount = h * w
384
+ targetpixels = np.array(
385
+ [framepixelcount] * (self._validtarget_ceilT - 1)
386
+ + [self._h_validmax * self._w_validmax]
387
+ )
388
+ targetpixels = targetpixels / np.sum(targetpixels)
389
+ t = np.random.choice(self._validtarget_ceilT, p=targetpixels)
390
+ # t = np.random.randint(0, self._validtarget_ceilT) if self._validtarget_ceilT >= 1 else 0
391
+ h, w = self.get_valid_target_hw(t)
392
+ index = self.idx_manager.idx_from_hwt(h, w, t)
393
+ # print('Valid', index, h,w,t)
394
+ return index
395
+
396
+ def get_invalid_target_index(self):
397
+ # if self._validtarget_ceilT == 0:
398
+ # TODO: There may not be enough data for this to work. The better way is to skip using 0 for invalid target.
399
+ # t = np.random.randint(1, self._data_shape[0])
400
+ # elif self._validtarget_ceilT < self._data_shape[0]:
401
+ # t = np.random.randint(self._validtarget_ceilT, self._data_shape[0])
402
+ # else:
403
+ # t = self._validtarget_ceilT - 1
404
+ # 5
405
+ # 1.2 => 2
406
+ total_t, h, w, _ = self._data_shape
407
+ framepixelcount = h * w
408
+ available_h = h - self._h_validmax
409
+ if available_h < self._patch_size:
410
+ available_h = 0
411
+ available_w = w - self._w_validmax
412
+ if available_w < self._patch_size:
413
+ available_w = 0
414
+
415
+ targetpixels = np.array(
416
+ [available_h * available_w]
417
+ + [framepixelcount] * (total_t - self._validtarget_ceilT)
418
+ )
419
+ t_probab = targetpixels / np.sum(targetpixels)
420
+ t = np.random.choice(
421
+ np.arange(self._validtarget_ceilT - 1, total_t), p=t_probab
422
+ )
423
+
424
+ h, w = self.get_invalid_target_hw(t)
425
+ index = self.idx_manager.idx_from_hwt(h, w, t)
426
+ # print('Invalid', index, h,w,t)
427
+ return index
428
+
429
+ def get_valid_target_hw(self, t):
430
+ """
431
+ This is the opposite of get_invalid_target_hw. It returns a h,w which is valid for target.
432
+ This is only valid for single frame setup.
433
+ """
434
+ if t == self._validtarget_ceilT - 1:
435
+ h = np.random.randint(0, self._h_validmax - self._patch_size)
436
+ w = np.random.randint(0, self._w_validmax - self._patch_size)
437
+ else:
438
+ h = np.random.randint(0, self._data_shape[1] - self._patch_size)
439
+ w = np.random.randint(0, self._data_shape[2] - self._patch_size)
440
+ return h, w
441
+
442
+ def get_invalid_target_hw(self, t):
443
+ """
444
+ This is the opposite of get_valid_target_hw. It returns a h,w which is not valid for target.
445
+ This is only valid for single frame setup.
446
+ """
447
+ if t == self._validtarget_ceilT - 1:
448
+ h = np.random.randint(
449
+ self._h_validmax, self._data_shape[1] - self._patch_size
450
+ )
451
+ w = np.random.randint(
452
+ self._w_validmax, self._data_shape[2] - self._patch_size
453
+ )
454
+ else:
455
+ h = np.random.randint(0, self._data_shape[1] - self._patch_size)
456
+ w = np.random.randint(0, self._data_shape[2] - self._patch_size)
457
+ return h, w
458
+
459
+ def _get_tidx(self, index):
460
+ if isinstance(index, int) or isinstance(index, np.int64):
461
+ idx = index
462
+ else:
463
+ idx = index[0]
464
+ return self.idx_manager.get_t(idx)
465
+
466
+ def index_should_have_target(self, index):
467
+ tidx = self._get_tidx(index)
468
+ if tidx < self._validtarget_ceilT - 1:
469
+ return True
470
+ elif tidx > self._validtarget_ceilT - 1:
471
+ return False
472
+ else:
473
+ h, w, _ = self.idx_manager.hwt_from_idx(index)
474
+ return (
475
+ h + self._patch_size < self._h_validmax
476
+ and w + self._patch_size < self._w_validmax
477
+ )
478
+
479
+ @staticmethod
480
+ def get_reduced_frame_size(data_shape_nhw, fraction):
481
+ n, h, w = data_shape_nhw
482
+
483
+ framepixelcount = h * w
484
+ targetpixelcount = int(n * framepixelcount * fraction)
485
+
486
+ # We are currently supporting this only when there is just one frame.
487
+ # if np.ceil(pixelcount / framepixelcount) > 1:
488
+ # return None, None
489
+
490
+ lastframepixelcount = targetpixelcount % framepixelcount
491
+ assert data_shape_nhw[1] == data_shape_nhw[2]
492
+ if lastframepixelcount > 0:
493
+ new_size = int(np.sqrt(lastframepixelcount))
494
+ return new_size, new_size
495
+ else:
496
+ assert (
497
+ targetpixelcount / framepixelcount >= 1
498
+ ), "This is not possible in euclidean space :D (so this is a bug)"
499
+ return h, w
500
+
501
+
502
+ rec_header_dtd = [
503
+ ("nx", "i4"), # Number of columns
504
+ ("ny", "i4"), # Number of rows
505
+ ("nz", "i4"), # Number of sections
506
+ ("mode", "i4"), # Types of pixels in the image. Values used by IMOD:
507
+ # 0 = unsigned or signed bytes depending on flag in imodFlags
508
+ # 1 = signed short integers (16 bits)
509
+ # 2 = float (32 bits)
510
+ # 3 = short * 2, (used for complex data)
511
+ # 4 = float * 2, (used for complex data)
512
+ # 6 = unsigned 16-bit integers (non-standard)
513
+ # 16 = unsigned char * 3 (for rgb data, non-standard)
514
+ ("nxstart", "i4"), # Starting point of sub-image (not used in IMOD)
515
+ ("nystart", "i4"),
516
+ ("nzstart", "i4"),
517
+ ("mx", "i4"), # Grid size in X, Y and Z
518
+ ("my", "i4"),
519
+ ("mz", "i4"),
520
+ ("xlen", "f4"), # Cell size; pixel spacing = xlen/mx, ylen/my, zlen/mz
521
+ ("ylen", "f4"),
522
+ ("zlen", "f4"),
523
+ ("alpha", "f4"), # Cell angles - ignored by IMOD
524
+ ("beta", "f4"),
525
+ ("gamma", "f4"),
526
+ # These need to be set to 1, 2, and 3 for pixel spacing to be interpreted correctly
527
+ ("mapc", "i4"), # map column 1=x,2=y,3=z.
528
+ ("mapr", "i4"), # map row 1=x,2=y,3=z.
529
+ ("maps", "i4"), # map section 1=x,2=y,3=z.
530
+ # These need to be set for proper scaling of data
531
+ ("amin", "f4"), # Minimum pixel value
532
+ ("amax", "f4"), # Maximum pixel value
533
+ ("amean", "f4"), # Mean pixel value
534
+ ("ispg", "i4"), # space group number (ignored by IMOD)
535
+ (
536
+ "next",
537
+ "i4",
538
+ ), # number of bytes in extended header (called nsymbt in MRC standard)
539
+ ("creatid", "i2"), # used to be an ID number, is 0 as of IMOD 4.2.23
540
+ ("extra_data", "V30"), # (not used, first two bytes should be 0)
541
+ # These two values specify the structure of data in the extended header; their meaning depend on whether the
542
+ # extended header has the Agard format, a series of 4-byte integers then real numbers, or has data
543
+ # produced by SerialEM, a series of short integers. SerialEM stores a float as two shorts, s1 and s2, by:
544
+ # value = (sign of s1)*(|s1|*256 + (|s2| modulo 256)) * 2**((sign of s2) * (|s2|/256))
545
+ ("nint", "i2"),
546
+ # Number of integers per section (Agard format) or number of bytes per section (SerialEM format)
547
+ ("nreal", "i2"), # Number of reals per section (Agard format) or bit
548
+ # Number of reals per section (Agard format) or bit
549
+ # flags for which types of short data (SerialEM format):
550
+ # 1 = tilt angle * 100 (2 bytes)
551
+ # 2 = piece coordinates for montage (6 bytes)
552
+ # 4 = Stage position * 25 (4 bytes)
553
+ # 8 = Magnification / 100 (2 bytes)
554
+ # 16 = Intensity * 25000 (2 bytes)
555
+ # 32 = Exposure dose in e-/A2, a float in 4 bytes
556
+ # 128, 512: Reserved for 4-byte items
557
+ # 64, 256, 1024: Reserved for 2-byte items
558
+ # If the number of bytes implied by these flags does
559
+ # not add up to the value in nint, then nint and nreal
560
+ # are interpreted as ints and reals per section
561
+ ("extra_data2", "V20"), # extra data (not used)
562
+ ("imodStamp", "i4"), # 1146047817 indicates that file was created by IMOD
563
+ ("imodFlags", "i4"), # Bit flags: 1 = bytes are stored as signed
564
+ # Explanation of type of data
565
+ ("idtype", "i2"), # ( 0 = mono, 1 = tilt, 2 = tilts, 3 = lina, 4 = lins)
566
+ ("lens", "i2"),
567
+ # ("nd1", "i2"), # for idtype = 1, nd1 = axis (1, 2, or 3)
568
+ # ("nd2", "i2"),
569
+ ("nphase", "i4"),
570
+ ("vd1", "i2"), # vd1 = 100. * tilt increment
571
+ ("vd2", "i2"), # vd2 = 100. * starting angle
572
+ # Current angles are used to rotate a model to match a new rotated image. The three values in each set are
573
+ # rotations about X, Y, and Z axes, applied in the order Z, Y, X.
574
+ ("triangles", "f4", 6), # 0,1,2 = original: 3,4,5 = current
575
+ ("xorg", "f4"), # Origin of image
576
+ ("yorg", "f4"),
577
+ ("zorg", "f4"),
578
+ ("cmap", "S4"), # Contains "MAP "
579
+ (
580
+ "stamp",
581
+ "u1",
582
+ 4,
583
+ ), # First two bytes have 17 and 17 for big-endian or 68 and 65 for little-endian
584
+ ("rms", "f4"), # RMS deviation of densities from mean density
585
+ ("nlabl", "i4"), # Number of labels with useful data
586
+ ("labels", "S80", 10), # 10 labels of 80 charactors
587
+ ]
588
+
589
+
590
+ def read_mrc(filename, filetype="image"):
591
+
592
+ fd = open(filename, "rb")
593
+ header = np.fromfile(fd, dtype=rec_header_dtd, count=1)
594
+
595
+ nx, ny, nz = header["nx"][0], header["ny"][0], header["nz"][0]
596
+
597
+ if header[0][3] == 1:
598
+ data_type = "int16"
599
+ elif header[0][3] == 2:
600
+ data_type = "float32"
601
+ elif header[0][3] == 4:
602
+ data_type = "single"
603
+ nx = nx * 2
604
+ elif header[0][3] == 6:
605
+ data_type = "uint16"
606
+
607
+ data = np.ndarray(shape=(nx, ny, nz))
608
+ imgrawdata = np.fromfile(fd, data_type)
609
+ fd.close()
610
+
611
+ if filetype == "image":
612
+ for iz in range(nz):
613
+ data_2d = imgrawdata[nx * ny * iz : nx * ny * (iz + 1)]
614
+ data[:, :, iz] = data_2d.reshape(nx, ny, order="F")
615
+ else:
616
+ data = imgrawdata
617
+
618
+ return header, data