careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,701 @@
1
+ """
2
+ Utility functions needed by dataloader & co.
3
+ """
4
+
5
+ import os
6
+ from dataclasses import dataclass
7
+ from typing import List
8
+
9
+ import numpy as np
10
+ from skimage.io import imread, imsave
11
+
12
+ from careamics.lvae_training.dataset.vae_data_config import DataSplitType, DataType
13
+
14
+
15
+ def load_tiff(path):
16
+ """
17
+ Returns a 4d numpy array: num_imgs*h*w*num_channels
18
+ """
19
+ data = imread(path, plugin="tifffile")
20
+ return data
21
+
22
+
23
+ def save_tiff(path, data):
24
+ imsave(path, data, plugin="tifffile")
25
+
26
+
27
+ def load_tiffs(paths):
28
+ data = [load_tiff(path) for path in paths]
29
+ return np.concatenate(data, axis=0)
30
+
31
+
32
+ def split_in_half(s, e):
33
+ n = e - s
34
+ s1 = list(np.arange(n // 2))
35
+ s2 = list(np.arange(n // 2, n))
36
+ return [x + s for x in s1], [x + s for x in s2]
37
+
38
+
39
+ def adjust_for_imbalance_in_fraction_value(
40
+ val: List[int],
41
+ test: List[int],
42
+ val_fraction: float,
43
+ test_fraction: float,
44
+ total_size: int,
45
+ ):
46
+ """
47
+ here, val and test are divided almost equally. Here, we need to take into account their respective fractions
48
+ and pick elements rendomly from one array and put in the other array.
49
+ """
50
+ if val_fraction == 0:
51
+ test += val
52
+ val = []
53
+ elif test_fraction == 0:
54
+ val += test
55
+ test = []
56
+ else:
57
+ diff_fraction = test_fraction - val_fraction
58
+ if diff_fraction > 0:
59
+ imb_count = int(diff_fraction * total_size / 2)
60
+ val = list(np.random.RandomState(seed=955).permutation(val))
61
+ test += val[:imb_count]
62
+ val = val[imb_count:]
63
+ elif diff_fraction < 0:
64
+ imb_count = int(-1 * diff_fraction * total_size / 2)
65
+ test = list(np.random.RandomState(seed=955).permutation(test))
66
+ val += test[:imb_count]
67
+ test = test[imb_count:]
68
+ return val, test
69
+
70
+
71
+ def get_train_val_data(
72
+ data_config,
73
+ fpath,
74
+ datasplit_type: DataSplitType,
75
+ val_fraction=None,
76
+ test_fraction=None,
77
+ allow_generation=False, # TODO: what is this
78
+ ):
79
+ """
80
+ Load the data from the given path and split them in training, validation and test sets.
81
+
82
+ Ensure that the shape of data should be N*H*W*C: N is number of data points. H,W are the image dimensions.
83
+ C is the number of channels.
84
+ """
85
+ if data_config.data_type == DataType.SeparateTiffData:
86
+ fpath1 = os.path.join(fpath, data_config.ch1_fname)
87
+ fpath2 = os.path.join(fpath, data_config.ch2_fname)
88
+ fpaths = [fpath1, fpath2]
89
+ fpath0 = ""
90
+ if "ch_input_fname" in data_config:
91
+ fpath0 = os.path.join(fpath, data_config.ch_input_fname)
92
+ fpaths = [fpath0] + fpaths
93
+
94
+ print(
95
+ f"Loading from {fpath} Channels: "
96
+ f"{fpath1},{fpath2}, inp:{fpath0} Mode:{DataSplitType.name(datasplit_type)}"
97
+ )
98
+
99
+ data = np.concatenate([load_tiff(fpath)[..., None] for fpath in fpaths], axis=3)
100
+ if data_config.data_type == DataType.PredictedTiffData:
101
+ assert len(data.shape) == 5 and data.shape[-1] == 1
102
+ data = data[..., 0].copy()
103
+ # data = data[::3].copy()
104
+ # NOTE: This was not the correct way to do it. It is so because the noise present in the input was directly related
105
+ # to the noise present in the channels and so this is not the way we would get the data.
106
+ # We need to add the noise independently to the input and the target.
107
+
108
+ # if data_config.get('poisson_noise_factor', False):
109
+ # data = np.random.poisson(data)
110
+ # if data_config.get('enable_gaussian_noise', False):
111
+ # synthetic_scale = data_config.get('synthetic_gaussian_scale', 0.1)
112
+ # print('Adding Gaussian noise with scale', synthetic_scale)
113
+ # noise = np.random.normal(0, synthetic_scale, data.shape)
114
+ # data = data + noise
115
+
116
+ if datasplit_type == DataSplitType.All:
117
+ return data.astype(np.float32)
118
+
119
+ train_idx, val_idx, test_idx = get_datasplit_tuples(
120
+ val_fraction, test_fraction, len(data), starting_test=True
121
+ )
122
+ if datasplit_type == DataSplitType.Train:
123
+ return data[train_idx].astype(np.float32)
124
+ elif datasplit_type == DataSplitType.Val:
125
+ return data[val_idx].astype(np.float32)
126
+ elif datasplit_type == DataSplitType.Test:
127
+ return data[test_idx].astype(np.float32)
128
+
129
+ elif data_config.data_type == DataType.BioSR_MRC:
130
+ num_channels = data_config.num_channels
131
+ fpaths = []
132
+ data_list = []
133
+ for i in range(num_channels):
134
+ fpath1 = os.path.join(fpath, getattr(data_config, f"ch{i + 1}_fname"))
135
+ fpaths.append(fpath1)
136
+ data = get_mrc_data(fpath1)[..., None]
137
+ data_list.append(data)
138
+
139
+ dirname = os.path.dirname(os.path.dirname(fpaths[0])) + "/"
140
+
141
+ msg = ",".join([x[len(dirname) :] for x in fpaths])
142
+ print(
143
+ f"Loaded from {dirname} Channels:{len(fpaths)} {msg} Mode:{datasplit_type}"
144
+ )
145
+ N = data_list[0].shape[0]
146
+ for data in data_list:
147
+ N = min(N, data.shape[0])
148
+
149
+ cropped_data = []
150
+ for data in data_list:
151
+ cropped_data.append(data[:N])
152
+
153
+ data = np.concatenate(cropped_data, axis=3)
154
+
155
+ if datasplit_type == DataSplitType.All:
156
+ return data.astype(np.float32)
157
+
158
+ train_idx, val_idx, test_idx = get_datasplit_tuples(
159
+ val_fraction, test_fraction, len(data), starting_test=True
160
+ )
161
+ if datasplit_type == DataSplitType.Train:
162
+ return data[train_idx].astype(np.float32)
163
+ elif datasplit_type == DataSplitType.Val:
164
+ return data[val_idx].astype(np.float32)
165
+ elif datasplit_type == DataSplitType.Test:
166
+ return data[test_idx].astype(np.float32)
167
+
168
+
169
+ def get_datasplit_tuples(
170
+ val_fraction: float,
171
+ test_fraction: float,
172
+ total_size: int,
173
+ starting_test: bool = False,
174
+ ):
175
+ if starting_test:
176
+ # test => val => train
177
+ test = list(range(0, int(total_size * test_fraction)))
178
+ val = list(range(test[-1] + 1, test[-1] + 1 + int(total_size * val_fraction)))
179
+ train = list(range(val[-1] + 1, total_size))
180
+ else:
181
+ # {test,val}=> train
182
+ test_val_size = int((val_fraction + test_fraction) * total_size)
183
+ train = list(range(test_val_size, total_size))
184
+
185
+ if test_val_size == 0:
186
+ test = []
187
+ val = []
188
+ return train, val, test
189
+
190
+ # Split the test and validation in chunks.
191
+ chunksize = max(1, min(3, test_val_size // 2))
192
+
193
+ nchunks = test_val_size // chunksize
194
+
195
+ test = []
196
+ val = []
197
+ s = 0
198
+ for i in range(nchunks):
199
+ if i % 2 == 0:
200
+ val += list(np.arange(s, s + chunksize))
201
+ else:
202
+ test += list(np.arange(s, s + chunksize))
203
+ s += chunksize
204
+
205
+ if i % 2 == 0:
206
+ test += list(np.arange(s, test_val_size))
207
+ else:
208
+ p1, p2 = split_in_half(s, test_val_size)
209
+ test += p1
210
+ val += p2
211
+
212
+ val, test = adjust_for_imbalance_in_fraction_value(
213
+ val, test, val_fraction, test_fraction, total_size
214
+ )
215
+
216
+ return train, val, test
217
+
218
+
219
+ def get_mrc_data(fpath):
220
+ # HXWXN
221
+ _, data = read_mrc(fpath)
222
+ data = data[None]
223
+ data = np.swapaxes(data, 0, 3)
224
+ return data[..., 0]
225
+
226
+
227
+ @dataclass
228
+ class GridIndexManager:
229
+ data_shape: tuple
230
+ grid_shape: tuple
231
+ patch_shape: tuple
232
+ trim_boundary: bool
233
+
234
+ # Vera: patch is centered on index in the grid, grid size not used in training,
235
+ # used only during val / test, grid size controls the overlap of the patches
236
+ # in training you only get random patches every time
237
+ # For borders - just cropped the data, so it perfectly divisible
238
+
239
+ def __post_init__(self):
240
+ assert len(self.data_shape) == len(
241
+ self.grid_shape
242
+ ), f"Data shape:{self.data_shape} and grid size:{self.grid_shape} must have the same dimension"
243
+ assert len(self.data_shape) == len(
244
+ self.patch_shape
245
+ ), f"Data shape:{self.data_shape} and patch shape:{self.patch_shape} must have the same dimension"
246
+ innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
247
+ for dim, pad in enumerate(innerpad):
248
+ if pad < 0:
249
+ raise ValueError(
250
+ f"Patch shape:{self.patch_shape} must be greater than or equal to grid shape:{self.grid_shape} in dimension {dim}"
251
+ )
252
+ if pad % 2 != 0:
253
+ raise ValueError(
254
+ f"Patch shape:{self.patch_shape} must have even padding in dimension {dim}"
255
+ )
256
+
257
+ def patch_offset(self):
258
+ return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
259
+
260
+ def get_individual_dim_grid_count(self, dim: int):
261
+ """
262
+ Returns the number of the grid in the specified dimension, ignoring all other dimensions.
263
+ """
264
+ assert dim < len(
265
+ self.data_shape
266
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
267
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
268
+
269
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
270
+ return self.data_shape[dim]
271
+ elif self.trim_boundary is False:
272
+ return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
273
+ else:
274
+ excess_size = self.patch_shape[dim] - self.grid_shape[dim]
275
+ return int(
276
+ np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
277
+ )
278
+
279
+ def total_grid_count(self):
280
+ """
281
+ Returns the total number of grids in the dataset.
282
+ """
283
+ return self.grid_count(0) * self.get_individual_dim_grid_count(0)
284
+
285
+ def grid_count(self, dim: int):
286
+ """
287
+ Returns the total number of grids for one value in the specified dimension.
288
+ """
289
+ assert dim < len(
290
+ self.data_shape
291
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
292
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
293
+ if dim == len(self.data_shape) - 1:
294
+ return 1
295
+
296
+ return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1)
297
+
298
+ def get_grid_index(self, dim: int, coordinate: int):
299
+ """
300
+ Returns the index of the grid in the specified dimension.
301
+ """
302
+ assert dim < len(
303
+ self.data_shape
304
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
305
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
306
+ assert (
307
+ coordinate < self.data_shape[dim]
308
+ ), f"Coordinate {coordinate} is out of bounds for data shape {self.data_shape}"
309
+
310
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
311
+ return coordinate
312
+ elif self.trim_boundary is False:
313
+ return np.floor(coordinate / self.grid_shape[dim])
314
+ else:
315
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
316
+ # can be <0 if coordinate is in [0,grid_shape[dim]]
317
+ return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
318
+
319
+ def dataset_idx_from_grid_idx(self, grid_idx: tuple):
320
+ """
321
+ Returns the index of the grid in the dataset.
322
+ """
323
+ assert len(grid_idx) == len(
324
+ self.data_shape
325
+ ), f"Dimension indices {grid_idx} must have the same dimension as data shape {self.data_shape}"
326
+ index = 0
327
+ for dim in range(len(grid_idx)):
328
+ index += grid_idx[dim] * self.grid_count(dim)
329
+ return index
330
+
331
+ def get_patch_location_from_dataset_idx(self, dataset_idx: int):
332
+ """
333
+ Returns the patch location of the grid in the dataset.
334
+ """
335
+ location = self.get_location_from_dataset_idx(dataset_idx)
336
+ offset = self.patch_offset()
337
+ return tuple(np.array(location) - np.array(offset))
338
+
339
+ def get_dataset_idx_from_grid_location(self, location: tuple):
340
+ assert len(location) == len(
341
+ self.data_shape
342
+ ), f"Location {location} must have the same dimension as data shape {self.data_shape}"
343
+ grid_idx = [
344
+ self.get_grid_index(dim, location[dim]) for dim in range(len(location))
345
+ ]
346
+ return self.dataset_idx_from_grid_idx(tuple(grid_idx))
347
+
348
+ def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int):
349
+ """
350
+ Returns the grid-start coordinate of the grid in the specified dimension.
351
+ """
352
+ assert dim < len(
353
+ self.data_shape
354
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
355
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
356
+ assert dim_index < self.get_individual_dim_grid_count(
357
+ dim
358
+ ), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"
359
+
360
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
361
+ return dim_index
362
+ elif self.trim_boundary is False:
363
+ return dim_index * self.grid_shape[dim]
364
+ else:
365
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
366
+ return dim_index * self.grid_shape[dim] + excess_size
367
+
368
+ def get_location_from_dataset_idx(self, dataset_idx: int):
369
+ grid_idx = []
370
+ for dim in range(len(self.data_shape)):
371
+ grid_idx.append(dataset_idx // self.grid_count(dim))
372
+ dataset_idx = dataset_idx % self.grid_count(dim)
373
+ location = [
374
+ self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
375
+ for dim in range(len(self.data_shape))
376
+ ]
377
+ return tuple(location)
378
+
379
+ def on_boundary(self, dataset_idx: int, dim: int):
380
+ """
381
+ Returns True if the grid is on the boundary in the specified dimension.
382
+ """
383
+ assert dim < len(
384
+ self.data_shape
385
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
386
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
387
+
388
+ if dim > 0:
389
+ dataset_idx = dataset_idx % self.grid_count(dim - 1)
390
+
391
+ dim_index = dataset_idx // self.grid_count(dim)
392
+ return (
393
+ dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
394
+ )
395
+
396
+ def next_grid_along_dim(self, dataset_idx: int, dim: int):
397
+ """
398
+ Returns the index of the grid in the specified dimension in the specified direction.
399
+ """
400
+ assert dim < len(
401
+ self.data_shape
402
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
403
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
404
+ new_idx = dataset_idx + self.grid_count(dim)
405
+ if new_idx >= self.total_grid_count():
406
+ return None
407
+ return new_idx
408
+
409
+ def prev_grid_along_dim(self, dataset_idx: int, dim: int):
410
+ """
411
+ Returns the index of the grid in the specified dimension in the specified direction.
412
+ """
413
+ assert dim < len(
414
+ self.data_shape
415
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
416
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
417
+ new_idx = dataset_idx - self.grid_count(dim)
418
+ if new_idx < 0:
419
+ return None
420
+
421
+
422
+ class IndexSwitcher:
423
+ """
424
+ The idea is to switch from valid indices for target to invalid indices for target.
425
+ If index in invalid for the target, then we return all zero vector as target.
426
+ This combines both logic:
427
+ 1. Using less amount of total data.
428
+ 2. Using less amount of target data but using full data.
429
+ """
430
+
431
+ def __init__(self, idx_manager, data_config, patch_size) -> None:
432
+ self.idx_manager = idx_manager
433
+ self._data_shape = self.idx_manager.get_data_shape()
434
+ self._training_validtarget_fraction = data_config.get(
435
+ "training_validtarget_fraction", 1.0
436
+ )
437
+ self._validtarget_ceilT = int(
438
+ np.ceil(self._data_shape[0] * self._training_validtarget_fraction)
439
+ )
440
+ self._patch_size = patch_size
441
+ assert (
442
+ data_config.deterministic_grid is True
443
+ ), "This only works when the dataset has deterministic grid. Needed randomness comes from this class."
444
+ assert (
445
+ "grid_size" in data_config and data_config.grid_size == 1
446
+ ), "We need a one to one mapping between index and h, w, t"
447
+
448
+ self._h_validmax, self._w_validmax = self.get_reduced_frame_size(
449
+ self._data_shape[:3], self._training_validtarget_fraction
450
+ )
451
+ if self._h_validmax < self._patch_size or self._w_validmax < self._patch_size:
452
+ print(
453
+ "WARNING: The valid target size is smaller than the patch size. This will result in all zero target. so, we are ignoring this frame for target."
454
+ )
455
+ self._h_validmax = 0
456
+ self._w_validmax = 0
457
+
458
+ print(
459
+ f"[{self.__class__.__name__}] Target Indices: [0,{self._validtarget_ceilT - 1}]. Index={self._validtarget_ceilT - 1} has shape [:{self._h_validmax},:{self._w_validmax}]. Available data: {self._data_shape[0]}"
460
+ )
461
+
462
+ def get_valid_target_index(self):
463
+ """
464
+ Returns an index which corresponds to a frame which is expected to have a target.
465
+ """
466
+ _, h, w, _ = self._data_shape
467
+ framepixelcount = h * w
468
+ targetpixels = np.array(
469
+ [framepixelcount] * (self._validtarget_ceilT - 1)
470
+ + [self._h_validmax * self._w_validmax]
471
+ )
472
+ targetpixels = targetpixels / np.sum(targetpixels)
473
+ t = np.random.choice(self._validtarget_ceilT, p=targetpixels)
474
+ # t = np.random.randint(0, self._validtarget_ceilT) if self._validtarget_ceilT >= 1 else 0
475
+ h, w = self.get_valid_target_hw(t)
476
+ index = self.idx_manager.idx_from_hwt(h, w, t)
477
+ # print('Valid', index, h,w,t)
478
+ return index
479
+
480
+ def get_invalid_target_index(self):
481
+ # if self._validtarget_ceilT == 0:
482
+ # TODO: There may not be enough data for this to work. The better way is to skip using 0 for invalid target.
483
+ # t = np.random.randint(1, self._data_shape[0])
484
+ # elif self._validtarget_ceilT < self._data_shape[0]:
485
+ # t = np.random.randint(self._validtarget_ceilT, self._data_shape[0])
486
+ # else:
487
+ # t = self._validtarget_ceilT - 1
488
+ # 5
489
+ # 1.2 => 2
490
+ total_t, h, w, _ = self._data_shape
491
+ framepixelcount = h * w
492
+ available_h = h - self._h_validmax
493
+ if available_h < self._patch_size:
494
+ available_h = 0
495
+ available_w = w - self._w_validmax
496
+ if available_w < self._patch_size:
497
+ available_w = 0
498
+
499
+ targetpixels = np.array(
500
+ [available_h * available_w]
501
+ + [framepixelcount] * (total_t - self._validtarget_ceilT)
502
+ )
503
+ t_probab = targetpixels / np.sum(targetpixels)
504
+ t = np.random.choice(
505
+ np.arange(self._validtarget_ceilT - 1, total_t), p=t_probab
506
+ )
507
+
508
+ h, w = self.get_invalid_target_hw(t)
509
+ index = self.idx_manager.idx_from_hwt(h, w, t)
510
+ # print('Invalid', index, h,w,t)
511
+ return index
512
+
513
+ def get_valid_target_hw(self, t):
514
+ """
515
+ This is the opposite of get_invalid_target_hw. It returns a h,w which is valid for target.
516
+ This is only valid for single frame setup.
517
+ """
518
+ if t == self._validtarget_ceilT - 1:
519
+ h = np.random.randint(0, self._h_validmax - self._patch_size)
520
+ w = np.random.randint(0, self._w_validmax - self._patch_size)
521
+ else:
522
+ h = np.random.randint(0, self._data_shape[1] - self._patch_size)
523
+ w = np.random.randint(0, self._data_shape[2] - self._patch_size)
524
+ return h, w
525
+
526
+ def get_invalid_target_hw(self, t):
527
+ """
528
+ This is the opposite of get_valid_target_hw. It returns a h,w which is not valid for target.
529
+ This is only valid for single frame setup.
530
+ """
531
+ if t == self._validtarget_ceilT - 1:
532
+ h = np.random.randint(
533
+ self._h_validmax, self._data_shape[1] - self._patch_size
534
+ )
535
+ w = np.random.randint(
536
+ self._w_validmax, self._data_shape[2] - self._patch_size
537
+ )
538
+ else:
539
+ h = np.random.randint(0, self._data_shape[1] - self._patch_size)
540
+ w = np.random.randint(0, self._data_shape[2] - self._patch_size)
541
+ return h, w
542
+
543
+ def _get_tidx(self, index):
544
+ if isinstance(index, int) or isinstance(index, np.int64):
545
+ idx = index
546
+ else:
547
+ idx = index[0]
548
+ return self.idx_manager.get_t(idx)
549
+
550
+ def index_should_have_target(self, index):
551
+ tidx = self._get_tidx(index)
552
+ if tidx < self._validtarget_ceilT - 1:
553
+ return True
554
+ elif tidx > self._validtarget_ceilT - 1:
555
+ return False
556
+ else:
557
+ h, w, _ = self.idx_manager.hwt_from_idx(index)
558
+ return (
559
+ h + self._patch_size < self._h_validmax
560
+ and w + self._patch_size < self._w_validmax
561
+ )
562
+
563
+ @staticmethod
564
+ def get_reduced_frame_size(data_shape_nhw, fraction):
565
+ n, h, w = data_shape_nhw
566
+
567
+ framepixelcount = h * w
568
+ targetpixelcount = int(n * framepixelcount * fraction)
569
+
570
+ # We are currently supporting this only when there is just one frame.
571
+ # if np.ceil(pixelcount / framepixelcount) > 1:
572
+ # return None, None
573
+
574
+ lastframepixelcount = targetpixelcount % framepixelcount
575
+ assert data_shape_nhw[1] == data_shape_nhw[2]
576
+ if lastframepixelcount > 0:
577
+ new_size = int(np.sqrt(lastframepixelcount))
578
+ return new_size, new_size
579
+ else:
580
+ assert (
581
+ targetpixelcount / framepixelcount >= 1
582
+ ), "This is not possible in euclidean space :D (so this is a bug)"
583
+ return h, w
584
+
585
+
586
+ rec_header_dtd = [
587
+ ("nx", "i4"), # Number of columns
588
+ ("ny", "i4"), # Number of rows
589
+ ("nz", "i4"), # Number of sections
590
+ ("mode", "i4"), # Types of pixels in the image. Values used by IMOD:
591
+ # 0 = unsigned or signed bytes depending on flag in imodFlags
592
+ # 1 = signed short integers (16 bits)
593
+ # 2 = float (32 bits)
594
+ # 3 = short * 2, (used for complex data)
595
+ # 4 = float * 2, (used for complex data)
596
+ # 6 = unsigned 16-bit integers (non-standard)
597
+ # 16 = unsigned char * 3 (for rgb data, non-standard)
598
+ ("nxstart", "i4"), # Starting point of sub-image (not used in IMOD)
599
+ ("nystart", "i4"),
600
+ ("nzstart", "i4"),
601
+ ("mx", "i4"), # Grid size in X, Y and Z
602
+ ("my", "i4"),
603
+ ("mz", "i4"),
604
+ ("xlen", "f4"), # Cell size; pixel spacing = xlen/mx, ylen/my, zlen/mz
605
+ ("ylen", "f4"),
606
+ ("zlen", "f4"),
607
+ ("alpha", "f4"), # Cell angles - ignored by IMOD
608
+ ("beta", "f4"),
609
+ ("gamma", "f4"),
610
+ # These need to be set to 1, 2, and 3 for pixel spacing to be interpreted correctly
611
+ ("mapc", "i4"), # map column 1=x,2=y,3=z.
612
+ ("mapr", "i4"), # map row 1=x,2=y,3=z.
613
+ ("maps", "i4"), # map section 1=x,2=y,3=z.
614
+ # These need to be set for proper scaling of data
615
+ ("amin", "f4"), # Minimum pixel value
616
+ ("amax", "f4"), # Maximum pixel value
617
+ ("amean", "f4"), # Mean pixel value
618
+ ("ispg", "i4"), # space group number (ignored by IMOD)
619
+ (
620
+ "next",
621
+ "i4",
622
+ ), # number of bytes in extended header (called nsymbt in MRC standard)
623
+ ("creatid", "i2"), # used to be an ID number, is 0 as of IMOD 4.2.23
624
+ ("extra_data", "V30"), # (not used, first two bytes should be 0)
625
+ # These two values specify the structure of data in the extended header; their meaning depend on whether the
626
+ # extended header has the Agard format, a series of 4-byte integers then real numbers, or has data
627
+ # produced by SerialEM, a series of short integers. SerialEM stores a float as two shorts, s1 and s2, by:
628
+ # value = (sign of s1)*(|s1|*256 + (|s2| modulo 256)) * 2**((sign of s2) * (|s2|/256))
629
+ ("nint", "i2"),
630
+ # Number of integers per section (Agard format) or number of bytes per section (SerialEM format)
631
+ ("nreal", "i2"), # Number of reals per section (Agard format) or bit
632
+ # Number of reals per section (Agard format) or bit
633
+ # flags for which types of short data (SerialEM format):
634
+ # 1 = tilt angle * 100 (2 bytes)
635
+ # 2 = piece coordinates for montage (6 bytes)
636
+ # 4 = Stage position * 25 (4 bytes)
637
+ # 8 = Magnification / 100 (2 bytes)
638
+ # 16 = Intensity * 25000 (2 bytes)
639
+ # 32 = Exposure dose in e-/A2, a float in 4 bytes
640
+ # 128, 512: Reserved for 4-byte items
641
+ # 64, 256, 1024: Reserved for 2-byte items
642
+ # If the number of bytes implied by these flags does
643
+ # not add up to the value in nint, then nint and nreal
644
+ # are interpreted as ints and reals per section
645
+ ("extra_data2", "V20"), # extra data (not used)
646
+ ("imodStamp", "i4"), # 1146047817 indicates that file was created by IMOD
647
+ ("imodFlags", "i4"), # Bit flags: 1 = bytes are stored as signed
648
+ # Explanation of type of data
649
+ ("idtype", "i2"), # ( 0 = mono, 1 = tilt, 2 = tilts, 3 = lina, 4 = lins)
650
+ ("lens", "i2"),
651
+ # ("nd1", "i2"), # for idtype = 1, nd1 = axis (1, 2, or 3)
652
+ # ("nd2", "i2"),
653
+ ("nphase", "i4"),
654
+ ("vd1", "i2"), # vd1 = 100. * tilt increment
655
+ ("vd2", "i2"), # vd2 = 100. * starting angle
656
+ # Current angles are used to rotate a model to match a new rotated image. The three values in each set are
657
+ # rotations about X, Y, and Z axes, applied in the order Z, Y, X.
658
+ ("triangles", "f4", 6), # 0,1,2 = original: 3,4,5 = current
659
+ ("xorg", "f4"), # Origin of image
660
+ ("yorg", "f4"),
661
+ ("zorg", "f4"),
662
+ ("cmap", "S4"), # Contains "MAP "
663
+ (
664
+ "stamp",
665
+ "u1",
666
+ 4,
667
+ ), # First two bytes have 17 and 17 for big-endian or 68 and 65 for little-endian
668
+ ("rms", "f4"), # RMS deviation of densities from mean density
669
+ ("nlabl", "i4"), # Number of labels with useful data
670
+ ("labels", "S80", 10), # 10 labels of 80 charactors
671
+ ]
672
+
673
+
674
+ def read_mrc(filename, filetype="image"):
675
+ fd = open(filename, "rb")
676
+ header = np.fromfile(fd, dtype=rec_header_dtd, count=1)
677
+
678
+ nx, ny, nz = header["nx"][0], header["ny"][0], header["nz"][0]
679
+
680
+ if header[0][3] == 1:
681
+ data_type = "int16"
682
+ elif header[0][3] == 2:
683
+ data_type = "float32"
684
+ elif header[0][3] == 4:
685
+ data_type = "single"
686
+ nx = nx * 2
687
+ elif header[0][3] == 6:
688
+ data_type = "uint16"
689
+
690
+ data = np.ndarray(shape=(nx, ny, nz))
691
+ imgrawdata = np.fromfile(fd, data_type)
692
+ fd.close()
693
+
694
+ if filetype == "image":
695
+ for iz in range(nz):
696
+ data_2d = imgrawdata[nx * ny * iz : nx * ny * (iz + 1)]
697
+ data[:, :, iz] = data_2d.reshape(nx, ny, order="F")
698
+ else:
699
+ data = imgrawdata
700
+
701
+ return header, data