careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (64) hide show
  1. careamics/careamist.py +14 -11
  2. careamics/config/__init__.py +7 -3
  3. careamics/config/architectures/__init__.py +2 -2
  4. careamics/config/architectures/architecture_model.py +1 -1
  5. careamics/config/architectures/custom_model.py +11 -8
  6. careamics/config/architectures/lvae_model.py +174 -0
  7. careamics/config/configuration_factory.py +11 -3
  8. careamics/config/configuration_model.py +7 -3
  9. careamics/config/data_model.py +33 -8
  10. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
  11. careamics/config/likelihood_model.py +43 -0
  12. careamics/config/nm_model.py +101 -0
  13. careamics/config/support/supported_activations.py +1 -0
  14. careamics/config/support/supported_algorithms.py +17 -4
  15. careamics/config/support/supported_architectures.py +8 -11
  16. careamics/config/support/supported_losses.py +3 -1
  17. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  18. careamics/config/vae_algorithm_model.py +171 -0
  19. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  20. careamics/file_io/read/tiff.py +1 -1
  21. careamics/lightning/__init__.py +3 -2
  22. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  23. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  24. careamics/lightning/lightning_module.py +365 -9
  25. careamics/lightning/predict_data_module.py +2 -2
  26. careamics/lightning/train_data_module.py +2 -2
  27. careamics/losses/__init__.py +11 -1
  28. careamics/losses/fcn/__init__.py +1 -0
  29. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  30. careamics/losses/loss_factory.py +112 -6
  31. careamics/losses/lvae/__init__.py +1 -0
  32. careamics/losses/lvae/loss_utils.py +83 -0
  33. careamics/losses/lvae/losses.py +445 -0
  34. careamics/lvae_training/dataset/__init__.py +0 -0
  35. careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
  36. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  37. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  38. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  39. careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
  40. careamics/lvae_training/get_config.py +1 -1
  41. careamics/lvae_training/train_lvae.py +6 -3
  42. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  43. careamics/model_io/bioimage/model_description.py +2 -2
  44. careamics/model_io/bmz_io.py +19 -6
  45. careamics/model_io/model_io_utils.py +16 -4
  46. careamics/models/__init__.py +1 -3
  47. careamics/models/activation.py +2 -0
  48. careamics/models/lvae/__init__.py +3 -0
  49. careamics/models/lvae/layers.py +21 -21
  50. careamics/models/lvae/likelihoods.py +180 -128
  51. careamics/models/lvae/lvae.py +52 -136
  52. careamics/models/lvae/noise_models.py +318 -186
  53. careamics/models/lvae/utils.py +2 -2
  54. careamics/models/model_factory.py +22 -7
  55. careamics/prediction_utils/lvae_prediction.py +158 -0
  56. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  57. careamics/prediction_utils/stitch_prediction.py +16 -2
  58. careamics/transforms/pixel_manipulation.py +1 -1
  59. careamics/utils/metrics.py +74 -1
  60. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
  61. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
  62. careamics/config/architectures/vae_model.py +0 -42
  63. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
  64. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
@@ -2,61 +2,14 @@
2
2
  Utility functions needed by dataloader & co.
3
3
  """
4
4
 
5
+ import os
6
+ from dataclasses import dataclass
5
7
  from typing import List
6
8
 
7
9
  import numpy as np
8
10
  from skimage.io import imread, imsave
9
11
 
10
- from careamics.models.lvae.utils import Enum
11
-
12
-
13
- class DataType(Enum):
14
- MNIST = 0
15
- Places365 = 1
16
- NotMNIST = 2
17
- OptiMEM100_014 = 3
18
- CustomSinosoid = 4
19
- Prevedel_EMBL = 5
20
- AllenCellMito = 6
21
- SeparateTiffData = 7
22
- CustomSinosoidThreeCurve = 8
23
- SemiSupBloodVesselsEMBL = 9
24
- Pavia2 = 10
25
- Pavia2VanillaSplitting = 11
26
- ExpansionMicroscopyMitoTub = 12
27
- ShroffMitoEr = 13
28
- HTIba1Ki67 = 14
29
- BSD68 = 15
30
- BioSR_MRC = 16
31
- TavernaSox2Golgi = 17
32
- Dao3Channel = 18
33
- ExpMicroscopyV2 = 19
34
- Dao3ChannelWithInput = 20
35
- TavernaSox2GolgiV2 = 21
36
- TwoDset = 22
37
- PredictedTiffData = 23
38
- Pavia3SeqData = 24
39
- # Here, we have 16 splitting tasks.
40
- NicolaData = 25
41
-
42
-
43
- class DataSplitType(Enum):
44
- All = 0
45
- Train = 1
46
- Val = 2
47
- Test = 3
48
-
49
-
50
- class GridAlignement(Enum):
51
- """
52
- A patch is formed by padding the grid with content. If the grids are 'Center' aligned, then padding is to done equally on all 4 sides.
53
- On the other hand, if grids are 'LeftTop' aligned, padding is to be done on the right and bottom end of the grid.
54
- In the former case, one needs (patch_size - grid_size)//2 amount of content on the right end of the frame.
55
- In the latter case, one needs patch_size - grid_size amount of content on the right end of the frame.
56
- """
57
-
58
- LeftTop = 0
59
- Center = 1
12
+ from careamics.lvae_training.dataset.vae_data_config import DataSplitType, DataType
60
13
 
61
14
 
62
15
  def load_tiff(path):
@@ -115,6 +68,104 @@ def adjust_for_imbalance_in_fraction_value(
115
68
  return val, test
116
69
 
117
70
 
71
+ def get_train_val_data(
72
+ data_config,
73
+ fpath,
74
+ datasplit_type: DataSplitType,
75
+ val_fraction=None,
76
+ test_fraction=None,
77
+ allow_generation=False, # TODO: what is this
78
+ ):
79
+ """
80
+ Load the data from the given path and split them in training, validation and test sets.
81
+
82
+ Ensure that the shape of data should be N*H*W*C: N is number of data points. H,W are the image dimensions.
83
+ C is the number of channels.
84
+ """
85
+ if data_config.data_type == DataType.SeparateTiffData:
86
+ fpath1 = os.path.join(fpath, data_config.ch1_fname)
87
+ fpath2 = os.path.join(fpath, data_config.ch2_fname)
88
+ fpaths = [fpath1, fpath2]
89
+ fpath0 = ""
90
+ if "ch_input_fname" in data_config:
91
+ fpath0 = os.path.join(fpath, data_config.ch_input_fname)
92
+ fpaths = [fpath0] + fpaths
93
+
94
+ print(
95
+ f"Loading from {fpath} Channels: "
96
+ f"{fpath1},{fpath2}, inp:{fpath0} Mode:{DataSplitType.name(datasplit_type)}"
97
+ )
98
+
99
+ data = np.concatenate([load_tiff(fpath)[..., None] for fpath in fpaths], axis=3)
100
+ if data_config.data_type == DataType.PredictedTiffData:
101
+ assert len(data.shape) == 5 and data.shape[-1] == 1
102
+ data = data[..., 0].copy()
103
+ # data = data[::3].copy()
104
+ # NOTE: This was not the correct way to do it. It is so because the noise present in the input was directly related
105
+ # to the noise present in the channels and so this is not the way we would get the data.
106
+ # We need to add the noise independently to the input and the target.
107
+
108
+ # if data_config.get('poisson_noise_factor', False):
109
+ # data = np.random.poisson(data)
110
+ # if data_config.get('enable_gaussian_noise', False):
111
+ # synthetic_scale = data_config.get('synthetic_gaussian_scale', 0.1)
112
+ # print('Adding Gaussian noise with scale', synthetic_scale)
113
+ # noise = np.random.normal(0, synthetic_scale, data.shape)
114
+ # data = data + noise
115
+
116
+ if datasplit_type == DataSplitType.All:
117
+ return data.astype(np.float32)
118
+
119
+ train_idx, val_idx, test_idx = get_datasplit_tuples(
120
+ val_fraction, test_fraction, len(data), starting_test=True
121
+ )
122
+ if datasplit_type == DataSplitType.Train:
123
+ return data[train_idx].astype(np.float32)
124
+ elif datasplit_type == DataSplitType.Val:
125
+ return data[val_idx].astype(np.float32)
126
+ elif datasplit_type == DataSplitType.Test:
127
+ return data[test_idx].astype(np.float32)
128
+
129
+ elif data_config.data_type == DataType.BioSR_MRC:
130
+ num_channels = data_config.num_channels
131
+ fpaths = []
132
+ data_list = []
133
+ for i in range(num_channels):
134
+ fpath1 = os.path.join(fpath, getattr(data_config, f"ch{i + 1}_fname"))
135
+ fpaths.append(fpath1)
136
+ data = get_mrc_data(fpath1)[..., None]
137
+ data_list.append(data)
138
+
139
+ dirname = os.path.dirname(os.path.dirname(fpaths[0])) + "/"
140
+
141
+ msg = ",".join([x[len(dirname) :] for x in fpaths])
142
+ print(
143
+ f"Loaded from {dirname} Channels:{len(fpaths)} {msg} Mode:{datasplit_type}"
144
+ )
145
+ N = data_list[0].shape[0]
146
+ for data in data_list:
147
+ N = min(N, data.shape[0])
148
+
149
+ cropped_data = []
150
+ for data in data_list:
151
+ cropped_data.append(data[:N])
152
+
153
+ data = np.concatenate(cropped_data, axis=3)
154
+
155
+ if datasplit_type == DataSplitType.All:
156
+ return data.astype(np.float32)
157
+
158
+ train_idx, val_idx, test_idx = get_datasplit_tuples(
159
+ val_fraction, test_fraction, len(data), starting_test=True
160
+ )
161
+ if datasplit_type == DataSplitType.Train:
162
+ return data[train_idx].astype(np.float32)
163
+ elif datasplit_type == DataSplitType.Val:
164
+ return data[val_idx].astype(np.float32)
165
+ elif datasplit_type == DataSplitType.Test:
166
+ return data[test_idx].astype(np.float32)
167
+
168
+
118
169
  def get_datasplit_tuples(
119
170
  val_fraction: float,
120
171
  test_fraction: float,
@@ -173,166 +224,199 @@ def get_mrc_data(fpath):
173
224
  return data[..., 0]
174
225
 
175
226
 
227
+ @dataclass
176
228
  class GridIndexManager:
177
-
178
- def __init__(self, data_shape, grid_size, patch_size, grid_alignement) -> None:
179
- self._data_shape = data_shape
180
- self._default_grid_size = grid_size
181
- self.patch_size = patch_size
182
- self.N = self._data_shape[0]
183
- self._align = grid_alignement
184
-
185
- def get_data_shape(self):
186
- return self._data_shape
187
-
188
- def use_default_grid(self, grid_size):
189
- return grid_size is None or grid_size < 0
190
-
191
- def grid_rows(self, grid_size):
192
- if self._align == GridAlignement.LeftTop:
193
- extra_pixels = self.patch_size - grid_size
194
- elif self._align == GridAlignement.Center:
195
- # Center is exclusively used during evaluation. In this case, we use the padding to handle edge cases.
196
- # So, here, we will ideally like to cover all pixels and so extra_pixels is set to 0.
197
- # If there was no padding, then it should be set to (self.patch_size - grid_size) // 2
198
- extra_pixels = 0
199
-
200
- return (self._data_shape[-3] - extra_pixels) // grid_size
201
-
202
- def grid_cols(self, grid_size):
203
- if self._align == GridAlignement.LeftTop:
204
- extra_pixels = self.patch_size - grid_size
205
- elif self._align == GridAlignement.Center:
206
- extra_pixels = 0
207
-
208
- return (self._data_shape[-2] - extra_pixels) // grid_size
209
-
210
- def grid_count(self, grid_size=None):
211
- if self.use_default_grid(grid_size):
212
- grid_size = self._default_grid_size
213
-
214
- return self.N * self.grid_rows(grid_size) * self.grid_cols(grid_size)
215
-
216
- def hwt_from_idx(self, index, grid_size=None):
217
- t = self.get_t(index)
218
- return (*self.get_deterministic_hw(index, grid_size=grid_size), t)
219
-
220
- def idx_from_hwt(self, h_start, w_start, t, grid_size=None):
229
+ data_shape: tuple
230
+ grid_shape: tuple
231
+ patch_shape: tuple
232
+ trim_boundary: bool
233
+
234
+ # Vera: patch is centered on index in the grid, grid size not used in training,
235
+ # used only during val / test, grid size controls the overlap of the patches
236
+ # in training you only get random patches every time
237
+ # For borders - just cropped the data, so it perfectly divisible
238
+
239
+ def __post_init__(self):
240
+ assert len(self.data_shape) == len(
241
+ self.grid_shape
242
+ ), f"Data shape:{self.data_shape} and grid size:{self.grid_shape} must have the same dimension"
243
+ assert len(self.data_shape) == len(
244
+ self.patch_shape
245
+ ), f"Data shape:{self.data_shape} and patch shape:{self.patch_shape} must have the same dimension"
246
+ innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
247
+ for dim, pad in enumerate(innerpad):
248
+ if pad < 0:
249
+ raise ValueError(
250
+ f"Patch shape:{self.patch_shape} must be greater than or equal to grid shape:{self.grid_shape} in dimension {dim}"
251
+ )
252
+ if pad % 2 != 0:
253
+ raise ValueError(
254
+ f"Patch shape:{self.patch_shape} must have even padding in dimension {dim}"
255
+ )
256
+
257
+ def patch_offset(self):
258
+ return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
259
+
260
+ def get_individual_dim_grid_count(self, dim: int):
221
261
  """
222
- Given h,w,t (where h,w constitutes the top left corner of the patch), it returns the corresponding index.
262
+ Returns the number of the grid in the specified dimension, ignoring all other dimensions.
223
263
  """
224
- if grid_size is None:
225
- grid_size = self._default_grid_size
264
+ assert dim < len(
265
+ self.data_shape
266
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
267
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
268
+
269
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
270
+ return self.data_shape[dim]
271
+ elif self.trim_boundary is False:
272
+ return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
273
+ else:
274
+ excess_size = self.patch_shape[dim] - self.grid_shape[dim]
275
+ return int(
276
+ np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
277
+ )
226
278
 
227
- nth_row = h_start // grid_size
228
- nth_col = w_start // grid_size
279
+ def total_grid_count(self):
280
+ """
281
+ Returns the total number of grids in the dataset.
282
+ """
283
+ return self.grid_count(0) * self.get_individual_dim_grid_count(0)
229
284
 
230
- index = self.grid_cols(grid_size) * nth_row + nth_col
231
- return index * self._data_shape[0] + t
285
+ def grid_count(self, dim: int):
286
+ """
287
+ Returns the total number of grids for one value in the specified dimension.
288
+ """
289
+ assert dim < len(
290
+ self.data_shape
291
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
292
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
293
+ if dim == len(self.data_shape) - 1:
294
+ return 1
232
295
 
233
- def get_t(self, index):
234
- return index % self.N
296
+ return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1)
235
297
 
236
- def get_top_nbr_idx(self, index, grid_size=None):
237
- if self.use_default_grid(grid_size):
238
- grid_size = self._default_grid_size
298
+ def get_grid_index(self, dim: int, coordinate: int):
299
+ """
300
+ Returns the index of the grid in the specified dimension.
301
+ """
302
+ assert dim < len(
303
+ self.data_shape
304
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
305
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
306
+ assert (
307
+ coordinate < self.data_shape[dim]
308
+ ), f"Coordinate {coordinate} is out of bounds for data shape {self.data_shape}"
239
309
 
240
- ncols = self.grid_cols(grid_size)
241
- index -= ncols * self.N
242
- if index < 0:
243
- return None
310
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
311
+ return coordinate
312
+ elif self.trim_boundary is False:
313
+ return np.floor(coordinate / self.grid_shape[dim])
314
+ else:
315
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
316
+ # can be <0 if coordinate is in [0,grid_shape[dim]]
317
+ return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
244
318
 
319
+ def dataset_idx_from_grid_idx(self, grid_idx: tuple):
320
+ """
321
+ Returns the index of the grid in the dataset.
322
+ """
323
+ assert len(grid_idx) == len(
324
+ self.data_shape
325
+ ), f"Dimension indices {grid_idx} must have the same dimension as data shape {self.data_shape}"
326
+ index = 0
327
+ for dim in range(len(grid_idx)):
328
+ index += grid_idx[dim] * self.grid_count(dim)
245
329
  return index
246
330
 
247
- def get_bottom_nbr_idx(self, index, grid_size=None):
248
- if self.use_default_grid(grid_size):
249
- grid_size = self._default_grid_size
250
-
251
- ncols = self.grid_cols(grid_size)
252
- index += ncols * self.N
253
- if index > self.grid_count(grid_size=grid_size):
254
- return None
255
-
256
- return index
331
+ def get_patch_location_from_dataset_idx(self, dataset_idx: int):
332
+ """
333
+ Returns the patch location of the grid in the dataset.
334
+ """
335
+ location = self.get_location_from_dataset_idx(dataset_idx)
336
+ offset = self.patch_offset()
337
+ return tuple(np.array(location) - np.array(offset))
338
+
339
+ def get_dataset_idx_from_grid_location(self, location: tuple):
340
+ assert len(location) == len(
341
+ self.data_shape
342
+ ), f"Location {location} must have the same dimension as data shape {self.data_shape}"
343
+ grid_idx = [
344
+ self.get_grid_index(dim, location[dim]) for dim in range(len(location))
345
+ ]
346
+ return self.dataset_idx_from_grid_idx(tuple(grid_idx))
347
+
348
+ def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int):
349
+ """
350
+ Returns the grid-start coordinate of the grid in the specified dimension.
351
+ """
352
+ assert dim < len(
353
+ self.data_shape
354
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
355
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
356
+ assert dim_index < self.get_individual_dim_grid_count(
357
+ dim
358
+ ), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"
359
+
360
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
361
+ return dim_index
362
+ elif self.trim_boundary is False:
363
+ return dim_index * self.grid_shape[dim]
364
+ else:
365
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
366
+ return dim_index * self.grid_shape[dim] + excess_size
367
+
368
+ def get_location_from_dataset_idx(self, dataset_idx: int):
369
+ grid_idx = []
370
+ for dim in range(len(self.data_shape)):
371
+ grid_idx.append(dataset_idx // self.grid_count(dim))
372
+ dataset_idx = dataset_idx % self.grid_count(dim)
373
+ location = [
374
+ self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
375
+ for dim in range(len(self.data_shape))
376
+ ]
377
+ return tuple(location)
378
+
379
+ def on_boundary(self, dataset_idx: int, dim: int):
380
+ """
381
+ Returns True if the grid is on the boundary in the specified dimension.
382
+ """
383
+ assert dim < len(
384
+ self.data_shape
385
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
386
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
257
387
 
258
- def get_left_nbr_idx(self, index, grid_size=None):
259
- if self.on_left_boundary(index, grid_size=grid_size):
260
- return None
388
+ if dim > 0:
389
+ dataset_idx = dataset_idx % self.grid_count(dim - 1)
261
390
 
262
- index -= self.N
263
- return index
391
+ dim_index = dataset_idx // self.grid_count(dim)
392
+ return (
393
+ dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
394
+ )
264
395
 
265
- def get_right_nbr_idx(self, index, grid_size=None):
266
- if self.on_right_boundary(index, grid_size=grid_size):
396
+ def next_grid_along_dim(self, dataset_idx: int, dim: int):
397
+ """
398
+ Returns the index of the grid in the specified dimension in the specified direction.
399
+ """
400
+ assert dim < len(
401
+ self.data_shape
402
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
403
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
404
+ new_idx = dataset_idx + self.grid_count(dim)
405
+ if new_idx >= self.total_grid_count():
267
406
  return None
268
- index += self.N
269
- return index
270
-
271
- def on_left_boundary(self, index, grid_size=None):
272
- if self.use_default_grid(grid_size):
273
- grid_size = self._default_grid_size
407
+ return new_idx
274
408
 
275
- factor = index // self.N
276
- ncols = self.grid_cols(grid_size)
277
-
278
- left_boundary = (factor // ncols) != (factor - 1) // ncols
279
- return left_boundary
280
-
281
- def on_right_boundary(self, index, grid_size=None):
282
- if self.use_default_grid(grid_size):
283
- grid_size = self._default_grid_size
284
-
285
- factor = index // self.N
286
- ncols = self.grid_cols(grid_size)
287
-
288
- right_boundary = (factor // ncols) != (factor + 1) // ncols
289
- return right_boundary
290
-
291
- def on_top_boundary(self, index, grid_size=None):
292
- if self.use_default_grid(grid_size):
293
- grid_size = self._default_grid_size
294
-
295
- ncols = self.grid_cols(grid_size)
296
- return index < self.N * ncols
297
-
298
- def on_bottom_boundary(self, index, grid_size=None):
299
- if self.use_default_grid(grid_size):
300
- grid_size = self._default_grid_size
301
-
302
- ncols = self.grid_cols(grid_size)
303
- return index + self.N * ncols > self.grid_count(grid_size=grid_size)
304
-
305
- def on_boundary(self, idx, grid_size=None):
306
- if self.on_left_boundary(idx, grid_size=grid_size):
307
- return True
308
-
309
- if self.on_right_boundary(idx, grid_size=grid_size):
310
- return True
311
-
312
- if self.on_top_boundary(idx, grid_size=grid_size):
313
- return True
314
-
315
- if self.on_bottom_boundary(idx, grid_size=grid_size):
316
- return True
317
- return False
318
-
319
- def get_deterministic_hw(self, index: int, grid_size=None):
409
+ def prev_grid_along_dim(self, dataset_idx: int, dim: int):
320
410
  """
321
- Fixed starting position for the crop for the img with index `index`.
411
+ Returns the index of the grid in the specified dimension in the specified direction.
322
412
  """
323
- if self.use_default_grid(grid_size):
324
- grid_size = self._default_grid_size
325
-
326
- # _, h, w, _ = self._data_shape
327
- # assert h == w
328
- factor = index // self.N
329
- ncols = self.grid_cols(grid_size)
330
-
331
- ith_row = factor // ncols
332
- jth_col = factor % ncols
333
- h_start = ith_row * grid_size
334
- w_start = jth_col * grid_size
335
- return h_start, w_start
413
+ assert dim < len(
414
+ self.data_shape
415
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
416
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
417
+ new_idx = dataset_idx - self.grid_count(dim)
418
+ if new_idx < 0:
419
+ return None
336
420
 
337
421
 
338
422
  class IndexSwitcher:
@@ -372,7 +456,7 @@ class IndexSwitcher:
372
456
  self._w_validmax = 0
373
457
 
374
458
  print(
375
- f"[{self.__class__.__name__}] Target Indices: [0,{self._validtarget_ceilT-1}]. Index={self._validtarget_ceilT-1} has shape [:{self._h_validmax},:{self._w_validmax}]. Available data: {self._data_shape[0]}"
459
+ f"[{self.__class__.__name__}] Target Indices: [0,{self._validtarget_ceilT - 1}]. Index={self._validtarget_ceilT - 1} has shape [:{self._h_validmax},:{self._w_validmax}]. Available data: {self._data_shape[0]}"
376
460
  )
377
461
 
378
462
  def get_valid_target_index(self):
@@ -588,7 +672,6 @@ rec_header_dtd = [
588
672
 
589
673
 
590
674
  def read_mrc(filename, filetype="image"):
591
-
592
675
  fd = open(filename, "rb")
593
676
  header = np.fromfile(fd, dtype=rec_header_dtd, count=1)
594
677