careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +14 -11
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/configuration_factory.py +11 -3
- careamics/config/configuration_model.py +7 -3
- careamics/config/data_model.py +33 -8
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +365 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +2 -2
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +19 -6
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +180 -128
- careamics/models/lvae/lvae.py +52 -136
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/utils/metrics.py +74 -1
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
- careamics/config/architectures/vae_model.py +0 -42
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,61 +2,14 @@
|
|
|
2
2
|
Utility functions needed by dataloader & co.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
import os
|
|
6
|
+
from dataclasses import dataclass
|
|
5
7
|
from typing import List
|
|
6
8
|
|
|
7
9
|
import numpy as np
|
|
8
10
|
from skimage.io import imread, imsave
|
|
9
11
|
|
|
10
|
-
from careamics.
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class DataType(Enum):
|
|
14
|
-
MNIST = 0
|
|
15
|
-
Places365 = 1
|
|
16
|
-
NotMNIST = 2
|
|
17
|
-
OptiMEM100_014 = 3
|
|
18
|
-
CustomSinosoid = 4
|
|
19
|
-
Prevedel_EMBL = 5
|
|
20
|
-
AllenCellMito = 6
|
|
21
|
-
SeparateTiffData = 7
|
|
22
|
-
CustomSinosoidThreeCurve = 8
|
|
23
|
-
SemiSupBloodVesselsEMBL = 9
|
|
24
|
-
Pavia2 = 10
|
|
25
|
-
Pavia2VanillaSplitting = 11
|
|
26
|
-
ExpansionMicroscopyMitoTub = 12
|
|
27
|
-
ShroffMitoEr = 13
|
|
28
|
-
HTIba1Ki67 = 14
|
|
29
|
-
BSD68 = 15
|
|
30
|
-
BioSR_MRC = 16
|
|
31
|
-
TavernaSox2Golgi = 17
|
|
32
|
-
Dao3Channel = 18
|
|
33
|
-
ExpMicroscopyV2 = 19
|
|
34
|
-
Dao3ChannelWithInput = 20
|
|
35
|
-
TavernaSox2GolgiV2 = 21
|
|
36
|
-
TwoDset = 22
|
|
37
|
-
PredictedTiffData = 23
|
|
38
|
-
Pavia3SeqData = 24
|
|
39
|
-
# Here, we have 16 splitting tasks.
|
|
40
|
-
NicolaData = 25
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
class DataSplitType(Enum):
|
|
44
|
-
All = 0
|
|
45
|
-
Train = 1
|
|
46
|
-
Val = 2
|
|
47
|
-
Test = 3
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class GridAlignement(Enum):
|
|
51
|
-
"""
|
|
52
|
-
A patch is formed by padding the grid with content. If the grids are 'Center' aligned, then padding is to done equally on all 4 sides.
|
|
53
|
-
On the other hand, if grids are 'LeftTop' aligned, padding is to be done on the right and bottom end of the grid.
|
|
54
|
-
In the former case, one needs (patch_size - grid_size)//2 amount of content on the right end of the frame.
|
|
55
|
-
In the latter case, one needs patch_size - grid_size amount of content on the right end of the frame.
|
|
56
|
-
"""
|
|
57
|
-
|
|
58
|
-
LeftTop = 0
|
|
59
|
-
Center = 1
|
|
12
|
+
from careamics.lvae_training.dataset.vae_data_config import DataSplitType, DataType
|
|
60
13
|
|
|
61
14
|
|
|
62
15
|
def load_tiff(path):
|
|
@@ -115,6 +68,104 @@ def adjust_for_imbalance_in_fraction_value(
|
|
|
115
68
|
return val, test
|
|
116
69
|
|
|
117
70
|
|
|
71
|
+
def get_train_val_data(
|
|
72
|
+
data_config,
|
|
73
|
+
fpath,
|
|
74
|
+
datasplit_type: DataSplitType,
|
|
75
|
+
val_fraction=None,
|
|
76
|
+
test_fraction=None,
|
|
77
|
+
allow_generation=False, # TODO: what is this
|
|
78
|
+
):
|
|
79
|
+
"""
|
|
80
|
+
Load the data from the given path and split them in training, validation and test sets.
|
|
81
|
+
|
|
82
|
+
Ensure that the shape of data should be N*H*W*C: N is number of data points. H,W are the image dimensions.
|
|
83
|
+
C is the number of channels.
|
|
84
|
+
"""
|
|
85
|
+
if data_config.data_type == DataType.SeparateTiffData:
|
|
86
|
+
fpath1 = os.path.join(fpath, data_config.ch1_fname)
|
|
87
|
+
fpath2 = os.path.join(fpath, data_config.ch2_fname)
|
|
88
|
+
fpaths = [fpath1, fpath2]
|
|
89
|
+
fpath0 = ""
|
|
90
|
+
if "ch_input_fname" in data_config:
|
|
91
|
+
fpath0 = os.path.join(fpath, data_config.ch_input_fname)
|
|
92
|
+
fpaths = [fpath0] + fpaths
|
|
93
|
+
|
|
94
|
+
print(
|
|
95
|
+
f"Loading from {fpath} Channels: "
|
|
96
|
+
f"{fpath1},{fpath2}, inp:{fpath0} Mode:{DataSplitType.name(datasplit_type)}"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
data = np.concatenate([load_tiff(fpath)[..., None] for fpath in fpaths], axis=3)
|
|
100
|
+
if data_config.data_type == DataType.PredictedTiffData:
|
|
101
|
+
assert len(data.shape) == 5 and data.shape[-1] == 1
|
|
102
|
+
data = data[..., 0].copy()
|
|
103
|
+
# data = data[::3].copy()
|
|
104
|
+
# NOTE: This was not the correct way to do it. It is so because the noise present in the input was directly related
|
|
105
|
+
# to the noise present in the channels and so this is not the way we would get the data.
|
|
106
|
+
# We need to add the noise independently to the input and the target.
|
|
107
|
+
|
|
108
|
+
# if data_config.get('poisson_noise_factor', False):
|
|
109
|
+
# data = np.random.poisson(data)
|
|
110
|
+
# if data_config.get('enable_gaussian_noise', False):
|
|
111
|
+
# synthetic_scale = data_config.get('synthetic_gaussian_scale', 0.1)
|
|
112
|
+
# print('Adding Gaussian noise with scale', synthetic_scale)
|
|
113
|
+
# noise = np.random.normal(0, synthetic_scale, data.shape)
|
|
114
|
+
# data = data + noise
|
|
115
|
+
|
|
116
|
+
if datasplit_type == DataSplitType.All:
|
|
117
|
+
return data.astype(np.float32)
|
|
118
|
+
|
|
119
|
+
train_idx, val_idx, test_idx = get_datasplit_tuples(
|
|
120
|
+
val_fraction, test_fraction, len(data), starting_test=True
|
|
121
|
+
)
|
|
122
|
+
if datasplit_type == DataSplitType.Train:
|
|
123
|
+
return data[train_idx].astype(np.float32)
|
|
124
|
+
elif datasplit_type == DataSplitType.Val:
|
|
125
|
+
return data[val_idx].astype(np.float32)
|
|
126
|
+
elif datasplit_type == DataSplitType.Test:
|
|
127
|
+
return data[test_idx].astype(np.float32)
|
|
128
|
+
|
|
129
|
+
elif data_config.data_type == DataType.BioSR_MRC:
|
|
130
|
+
num_channels = data_config.num_channels
|
|
131
|
+
fpaths = []
|
|
132
|
+
data_list = []
|
|
133
|
+
for i in range(num_channels):
|
|
134
|
+
fpath1 = os.path.join(fpath, getattr(data_config, f"ch{i + 1}_fname"))
|
|
135
|
+
fpaths.append(fpath1)
|
|
136
|
+
data = get_mrc_data(fpath1)[..., None]
|
|
137
|
+
data_list.append(data)
|
|
138
|
+
|
|
139
|
+
dirname = os.path.dirname(os.path.dirname(fpaths[0])) + "/"
|
|
140
|
+
|
|
141
|
+
msg = ",".join([x[len(dirname) :] for x in fpaths])
|
|
142
|
+
print(
|
|
143
|
+
f"Loaded from {dirname} Channels:{len(fpaths)} {msg} Mode:{datasplit_type}"
|
|
144
|
+
)
|
|
145
|
+
N = data_list[0].shape[0]
|
|
146
|
+
for data in data_list:
|
|
147
|
+
N = min(N, data.shape[0])
|
|
148
|
+
|
|
149
|
+
cropped_data = []
|
|
150
|
+
for data in data_list:
|
|
151
|
+
cropped_data.append(data[:N])
|
|
152
|
+
|
|
153
|
+
data = np.concatenate(cropped_data, axis=3)
|
|
154
|
+
|
|
155
|
+
if datasplit_type == DataSplitType.All:
|
|
156
|
+
return data.astype(np.float32)
|
|
157
|
+
|
|
158
|
+
train_idx, val_idx, test_idx = get_datasplit_tuples(
|
|
159
|
+
val_fraction, test_fraction, len(data), starting_test=True
|
|
160
|
+
)
|
|
161
|
+
if datasplit_type == DataSplitType.Train:
|
|
162
|
+
return data[train_idx].astype(np.float32)
|
|
163
|
+
elif datasplit_type == DataSplitType.Val:
|
|
164
|
+
return data[val_idx].astype(np.float32)
|
|
165
|
+
elif datasplit_type == DataSplitType.Test:
|
|
166
|
+
return data[test_idx].astype(np.float32)
|
|
167
|
+
|
|
168
|
+
|
|
118
169
|
def get_datasplit_tuples(
|
|
119
170
|
val_fraction: float,
|
|
120
171
|
test_fraction: float,
|
|
@@ -173,166 +224,199 @@ def get_mrc_data(fpath):
|
|
|
173
224
|
return data[..., 0]
|
|
174
225
|
|
|
175
226
|
|
|
227
|
+
@dataclass
|
|
176
228
|
class GridIndexManager:
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
def grid_count(self, grid_size=None):
|
|
211
|
-
if self.use_default_grid(grid_size):
|
|
212
|
-
grid_size = self._default_grid_size
|
|
213
|
-
|
|
214
|
-
return self.N * self.grid_rows(grid_size) * self.grid_cols(grid_size)
|
|
215
|
-
|
|
216
|
-
def hwt_from_idx(self, index, grid_size=None):
|
|
217
|
-
t = self.get_t(index)
|
|
218
|
-
return (*self.get_deterministic_hw(index, grid_size=grid_size), t)
|
|
219
|
-
|
|
220
|
-
def idx_from_hwt(self, h_start, w_start, t, grid_size=None):
|
|
229
|
+
data_shape: tuple
|
|
230
|
+
grid_shape: tuple
|
|
231
|
+
patch_shape: tuple
|
|
232
|
+
trim_boundary: bool
|
|
233
|
+
|
|
234
|
+
# Vera: patch is centered on index in the grid, grid size not used in training,
|
|
235
|
+
# used only during val / test, grid size controls the overlap of the patches
|
|
236
|
+
# in training you only get random patches every time
|
|
237
|
+
# For borders - just cropped the data, so it perfectly divisible
|
|
238
|
+
|
|
239
|
+
def __post_init__(self):
|
|
240
|
+
assert len(self.data_shape) == len(
|
|
241
|
+
self.grid_shape
|
|
242
|
+
), f"Data shape:{self.data_shape} and grid size:{self.grid_shape} must have the same dimension"
|
|
243
|
+
assert len(self.data_shape) == len(
|
|
244
|
+
self.patch_shape
|
|
245
|
+
), f"Data shape:{self.data_shape} and patch shape:{self.patch_shape} must have the same dimension"
|
|
246
|
+
innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
|
|
247
|
+
for dim, pad in enumerate(innerpad):
|
|
248
|
+
if pad < 0:
|
|
249
|
+
raise ValueError(
|
|
250
|
+
f"Patch shape:{self.patch_shape} must be greater than or equal to grid shape:{self.grid_shape} in dimension {dim}"
|
|
251
|
+
)
|
|
252
|
+
if pad % 2 != 0:
|
|
253
|
+
raise ValueError(
|
|
254
|
+
f"Patch shape:{self.patch_shape} must have even padding in dimension {dim}"
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def patch_offset(self):
|
|
258
|
+
return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
|
|
259
|
+
|
|
260
|
+
def get_individual_dim_grid_count(self, dim: int):
|
|
221
261
|
"""
|
|
222
|
-
|
|
262
|
+
Returns the number of the grid in the specified dimension, ignoring all other dimensions.
|
|
223
263
|
"""
|
|
224
|
-
|
|
225
|
-
|
|
264
|
+
assert dim < len(
|
|
265
|
+
self.data_shape
|
|
266
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
267
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
268
|
+
|
|
269
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
270
|
+
return self.data_shape[dim]
|
|
271
|
+
elif self.trim_boundary is False:
|
|
272
|
+
return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
|
|
273
|
+
else:
|
|
274
|
+
excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
|
275
|
+
return int(
|
|
276
|
+
np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
|
|
277
|
+
)
|
|
226
278
|
|
|
227
|
-
|
|
228
|
-
|
|
279
|
+
def total_grid_count(self):
|
|
280
|
+
"""
|
|
281
|
+
Returns the total number of grids in the dataset.
|
|
282
|
+
"""
|
|
283
|
+
return self.grid_count(0) * self.get_individual_dim_grid_count(0)
|
|
229
284
|
|
|
230
|
-
|
|
231
|
-
|
|
285
|
+
def grid_count(self, dim: int):
|
|
286
|
+
"""
|
|
287
|
+
Returns the total number of grids for one value in the specified dimension.
|
|
288
|
+
"""
|
|
289
|
+
assert dim < len(
|
|
290
|
+
self.data_shape
|
|
291
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
292
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
293
|
+
if dim == len(self.data_shape) - 1:
|
|
294
|
+
return 1
|
|
232
295
|
|
|
233
|
-
|
|
234
|
-
return index % self.N
|
|
296
|
+
return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1)
|
|
235
297
|
|
|
236
|
-
def
|
|
237
|
-
|
|
238
|
-
|
|
298
|
+
def get_grid_index(self, dim: int, coordinate: int):
|
|
299
|
+
"""
|
|
300
|
+
Returns the index of the grid in the specified dimension.
|
|
301
|
+
"""
|
|
302
|
+
assert dim < len(
|
|
303
|
+
self.data_shape
|
|
304
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
305
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
306
|
+
assert (
|
|
307
|
+
coordinate < self.data_shape[dim]
|
|
308
|
+
), f"Coordinate {coordinate} is out of bounds for data shape {self.data_shape}"
|
|
239
309
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
return
|
|
310
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
311
|
+
return coordinate
|
|
312
|
+
elif self.trim_boundary is False:
|
|
313
|
+
return np.floor(coordinate / self.grid_shape[dim])
|
|
314
|
+
else:
|
|
315
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
316
|
+
# can be <0 if coordinate is in [0,grid_shape[dim]]
|
|
317
|
+
return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
|
|
244
318
|
|
|
319
|
+
def dataset_idx_from_grid_idx(self, grid_idx: tuple):
|
|
320
|
+
"""
|
|
321
|
+
Returns the index of the grid in the dataset.
|
|
322
|
+
"""
|
|
323
|
+
assert len(grid_idx) == len(
|
|
324
|
+
self.data_shape
|
|
325
|
+
), f"Dimension indices {grid_idx} must have the same dimension as data shape {self.data_shape}"
|
|
326
|
+
index = 0
|
|
327
|
+
for dim in range(len(grid_idx)):
|
|
328
|
+
index += grid_idx[dim] * self.grid_count(dim)
|
|
245
329
|
return index
|
|
246
330
|
|
|
247
|
-
def
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
331
|
+
def get_patch_location_from_dataset_idx(self, dataset_idx: int):
|
|
332
|
+
"""
|
|
333
|
+
Returns the patch location of the grid in the dataset.
|
|
334
|
+
"""
|
|
335
|
+
location = self.get_location_from_dataset_idx(dataset_idx)
|
|
336
|
+
offset = self.patch_offset()
|
|
337
|
+
return tuple(np.array(location) - np.array(offset))
|
|
338
|
+
|
|
339
|
+
def get_dataset_idx_from_grid_location(self, location: tuple):
|
|
340
|
+
assert len(location) == len(
|
|
341
|
+
self.data_shape
|
|
342
|
+
), f"Location {location} must have the same dimension as data shape {self.data_shape}"
|
|
343
|
+
grid_idx = [
|
|
344
|
+
self.get_grid_index(dim, location[dim]) for dim in range(len(location))
|
|
345
|
+
]
|
|
346
|
+
return self.dataset_idx_from_grid_idx(tuple(grid_idx))
|
|
347
|
+
|
|
348
|
+
def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int):
|
|
349
|
+
"""
|
|
350
|
+
Returns the grid-start coordinate of the grid in the specified dimension.
|
|
351
|
+
"""
|
|
352
|
+
assert dim < len(
|
|
353
|
+
self.data_shape
|
|
354
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
355
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
356
|
+
assert dim_index < self.get_individual_dim_grid_count(
|
|
357
|
+
dim
|
|
358
|
+
), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"
|
|
359
|
+
|
|
360
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
361
|
+
return dim_index
|
|
362
|
+
elif self.trim_boundary is False:
|
|
363
|
+
return dim_index * self.grid_shape[dim]
|
|
364
|
+
else:
|
|
365
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
366
|
+
return dim_index * self.grid_shape[dim] + excess_size
|
|
367
|
+
|
|
368
|
+
def get_location_from_dataset_idx(self, dataset_idx: int):
|
|
369
|
+
grid_idx = []
|
|
370
|
+
for dim in range(len(self.data_shape)):
|
|
371
|
+
grid_idx.append(dataset_idx // self.grid_count(dim))
|
|
372
|
+
dataset_idx = dataset_idx % self.grid_count(dim)
|
|
373
|
+
location = [
|
|
374
|
+
self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
|
|
375
|
+
for dim in range(len(self.data_shape))
|
|
376
|
+
]
|
|
377
|
+
return tuple(location)
|
|
378
|
+
|
|
379
|
+
def on_boundary(self, dataset_idx: int, dim: int):
|
|
380
|
+
"""
|
|
381
|
+
Returns True if the grid is on the boundary in the specified dimension.
|
|
382
|
+
"""
|
|
383
|
+
assert dim < len(
|
|
384
|
+
self.data_shape
|
|
385
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
386
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
257
387
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
return None
|
|
388
|
+
if dim > 0:
|
|
389
|
+
dataset_idx = dataset_idx % self.grid_count(dim - 1)
|
|
261
390
|
|
|
262
|
-
|
|
263
|
-
return
|
|
391
|
+
dim_index = dataset_idx // self.grid_count(dim)
|
|
392
|
+
return (
|
|
393
|
+
dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
|
|
394
|
+
)
|
|
264
395
|
|
|
265
|
-
def
|
|
266
|
-
|
|
396
|
+
def next_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
397
|
+
"""
|
|
398
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
|
399
|
+
"""
|
|
400
|
+
assert dim < len(
|
|
401
|
+
self.data_shape
|
|
402
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
403
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
404
|
+
new_idx = dataset_idx + self.grid_count(dim)
|
|
405
|
+
if new_idx >= self.total_grid_count():
|
|
267
406
|
return None
|
|
268
|
-
|
|
269
|
-
return index
|
|
270
|
-
|
|
271
|
-
def on_left_boundary(self, index, grid_size=None):
|
|
272
|
-
if self.use_default_grid(grid_size):
|
|
273
|
-
grid_size = self._default_grid_size
|
|
407
|
+
return new_idx
|
|
274
408
|
|
|
275
|
-
|
|
276
|
-
ncols = self.grid_cols(grid_size)
|
|
277
|
-
|
|
278
|
-
left_boundary = (factor // ncols) != (factor - 1) // ncols
|
|
279
|
-
return left_boundary
|
|
280
|
-
|
|
281
|
-
def on_right_boundary(self, index, grid_size=None):
|
|
282
|
-
if self.use_default_grid(grid_size):
|
|
283
|
-
grid_size = self._default_grid_size
|
|
284
|
-
|
|
285
|
-
factor = index // self.N
|
|
286
|
-
ncols = self.grid_cols(grid_size)
|
|
287
|
-
|
|
288
|
-
right_boundary = (factor // ncols) != (factor + 1) // ncols
|
|
289
|
-
return right_boundary
|
|
290
|
-
|
|
291
|
-
def on_top_boundary(self, index, grid_size=None):
|
|
292
|
-
if self.use_default_grid(grid_size):
|
|
293
|
-
grid_size = self._default_grid_size
|
|
294
|
-
|
|
295
|
-
ncols = self.grid_cols(grid_size)
|
|
296
|
-
return index < self.N * ncols
|
|
297
|
-
|
|
298
|
-
def on_bottom_boundary(self, index, grid_size=None):
|
|
299
|
-
if self.use_default_grid(grid_size):
|
|
300
|
-
grid_size = self._default_grid_size
|
|
301
|
-
|
|
302
|
-
ncols = self.grid_cols(grid_size)
|
|
303
|
-
return index + self.N * ncols > self.grid_count(grid_size=grid_size)
|
|
304
|
-
|
|
305
|
-
def on_boundary(self, idx, grid_size=None):
|
|
306
|
-
if self.on_left_boundary(idx, grid_size=grid_size):
|
|
307
|
-
return True
|
|
308
|
-
|
|
309
|
-
if self.on_right_boundary(idx, grid_size=grid_size):
|
|
310
|
-
return True
|
|
311
|
-
|
|
312
|
-
if self.on_top_boundary(idx, grid_size=grid_size):
|
|
313
|
-
return True
|
|
314
|
-
|
|
315
|
-
if self.on_bottom_boundary(idx, grid_size=grid_size):
|
|
316
|
-
return True
|
|
317
|
-
return False
|
|
318
|
-
|
|
319
|
-
def get_deterministic_hw(self, index: int, grid_size=None):
|
|
409
|
+
def prev_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
320
410
|
"""
|
|
321
|
-
|
|
411
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
|
322
412
|
"""
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
ith_row = factor // ncols
|
|
332
|
-
jth_col = factor % ncols
|
|
333
|
-
h_start = ith_row * grid_size
|
|
334
|
-
w_start = jth_col * grid_size
|
|
335
|
-
return h_start, w_start
|
|
413
|
+
assert dim < len(
|
|
414
|
+
self.data_shape
|
|
415
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
416
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
417
|
+
new_idx = dataset_idx - self.grid_count(dim)
|
|
418
|
+
if new_idx < 0:
|
|
419
|
+
return None
|
|
336
420
|
|
|
337
421
|
|
|
338
422
|
class IndexSwitcher:
|
|
@@ -372,7 +456,7 @@ class IndexSwitcher:
|
|
|
372
456
|
self._w_validmax = 0
|
|
373
457
|
|
|
374
458
|
print(
|
|
375
|
-
f"[{self.__class__.__name__}] Target Indices: [0,{self._validtarget_ceilT-1}]. Index={self._validtarget_ceilT-1} has shape [:{self._h_validmax},:{self._w_validmax}]. Available data: {self._data_shape[0]}"
|
|
459
|
+
f"[{self.__class__.__name__}] Target Indices: [0,{self._validtarget_ceilT - 1}]. Index={self._validtarget_ceilT - 1} has shape [:{self._h_validmax},:{self._w_validmax}]. Available data: {self._data_shape[0]}"
|
|
376
460
|
)
|
|
377
461
|
|
|
378
462
|
def get_valid_target_index(self):
|
|
@@ -588,7 +672,6 @@ rec_header_dtd = [
|
|
|
588
672
|
|
|
589
673
|
|
|
590
674
|
def read_mrc(filename, filetype="image"):
|
|
591
|
-
|
|
592
675
|
fd = open(filename, "rb")
|
|
593
676
|
header = np.fromfile(fd, dtype=rec_header_dtd, count=1)
|
|
594
677
|
|