careamics 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (54) hide show
  1. careamics/careamist.py +20 -4
  2. careamics/config/configuration.py +10 -5
  3. careamics/config/data/data_model.py +38 -1
  4. careamics/config/optimizer_models.py +1 -3
  5. careamics/config/training_model.py +0 -2
  6. careamics/dataset/dataset_utils/running_stats.py +7 -3
  7. careamics/dataset_ng/README.md +212 -0
  8. careamics/dataset_ng/dataset.py +233 -0
  9. careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
  10. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  11. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
  12. careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
  13. careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
  14. careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
  15. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
  16. careamics/dataset_ng/factory.py +408 -0
  17. careamics/dataset_ng/legacy_interoperability.py +168 -0
  18. careamics/dataset_ng/patch_extractor/__init__.py +3 -8
  19. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
  20. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
  21. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
  22. careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
  23. careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
  24. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
  25. careamics/dataset_ng/patching_strategies/__init__.py +6 -1
  26. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
  27. careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
  28. careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
  29. careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
  30. careamics/lightning/dataset_ng/data_module.py +488 -0
  31. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  32. careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
  33. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
  34. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
  35. careamics/lightning/lightning_module.py +3 -0
  36. careamics/lvae_training/dataset/__init__.py +8 -3
  37. careamics/lvae_training/dataset/config.py +3 -3
  38. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  39. careamics/lvae_training/dataset/multich_dataset.py +46 -17
  40. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  41. careamics/lvae_training/dataset/types.py +3 -3
  42. careamics/lvae_training/dataset/utils/index_manager.py +259 -0
  43. careamics/lvae_training/eval_utils.py +93 -3
  44. careamics/transforms/compose.py +1 -0
  45. careamics/transforms/normalize.py +18 -7
  46. careamics/utils/lightning_utils.py +25 -11
  47. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
  48. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/RECORD +51 -36
  49. careamics/dataset_ng/dataset/__init__.py +0 -3
  50. careamics/dataset_ng/dataset/dataset.py +0 -184
  51. careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
  52. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
  53. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
  54. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
@@ -340,25 +340,54 @@ class MultiChDloader:
340
340
  return self._data.shape[0]
341
341
 
342
342
  def reduce_data(
343
- self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
343
+ self,
344
+ t_list=None,
345
+ z_start=None,
346
+ z_end=None,
347
+ h_start=None,
348
+ h_end=None,
349
+ w_start=None,
350
+ w_end=None,
344
351
  ):
345
- assert not self._5Ddata, "This function is not supported for 3D data."
346
- if t_list is None:
347
- t_list = list(range(self._data.shape[0]))
348
- if h_start is None:
349
- h_start = 0
350
- if h_end is None:
351
- h_end = self._data.shape[1]
352
- if w_start is None:
353
- w_start = 0
354
- if w_end is None:
355
- w_end = self._data.shape[2]
356
-
357
- self._data = self._data[t_list, h_start:h_end, w_start:w_end, :].copy()
358
- if self._noise_data is not None:
359
- self._noise_data = self._noise_data[
360
- t_list, h_start:h_end, w_start:w_end, :
352
+ if self._5Ddata:
353
+ if t_list is None:
354
+ t_list = list(range(self._data.shape[0]))
355
+ if z_start is None:
356
+ z_start = 0
357
+ if z_end is None:
358
+ z_end = self._data.shape[1]
359
+ if h_start is None:
360
+ h_start = 0
361
+ if h_end is None:
362
+ h_end = self._data.shape[2]
363
+ if w_start is None:
364
+ w_start = 0
365
+ if w_end is None:
366
+ w_end = self._data.shape[3]
367
+ self._data = self._data[
368
+ t_list, z_start:z_end, h_start:h_end, w_start:w_end, :
361
369
  ].copy()
370
+ if self._noise_data is not None:
371
+ self._noise_data = self._noise_data[
372
+ t_list, z_start:z_end, h_start:h_end, w_start:w_end, :
373
+ ].copy()
374
+ else:
375
+ if t_list is None:
376
+ t_list = list(range(self._data.shape[0]))
377
+ if h_start is None:
378
+ h_start = 0
379
+ if h_end is None:
380
+ h_end = self._data.shape[1]
381
+ if w_start is None:
382
+ w_start = 0
383
+ if w_end is None:
384
+ w_end = self._data.shape[2]
385
+
386
+ self._data = self._data[t_list, h_start:h_end, w_start:w_end, :].copy()
387
+ if self._noise_data is not None:
388
+ self._noise_data = self._noise_data[
389
+ t_list, h_start:h_end, w_start:w_end, :
390
+ ].copy()
362
391
  # TODO where tf is self._img_sz defined?
363
392
  self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
364
393
  print(
@@ -0,0 +1,196 @@
1
+ """
2
+ Here, we have multiple folders, each containing images of a single channel.
3
+ """
4
+
5
+ from collections import defaultdict
6
+ from functools import cache
7
+
8
+ import numpy as np
9
+
10
+ from .types import DataSplitType
11
+
12
+
13
+ def l2(x):
14
+ return np.sqrt(np.mean(np.array(x) ** 2))
15
+
16
+
17
+ class MultiCropDset:
18
+ def __init__(
19
+ self,
20
+ data_config,
21
+ fpath: str,
22
+ load_data_fn=None,
23
+ val_fraction=None,
24
+ test_fraction=None,
25
+ ):
26
+
27
+ assert (
28
+ data_config.input_is_sum == True
29
+ ), "This dataset is designed for sum of images"
30
+
31
+ self._img_sz = data_config.image_size
32
+ self._enable_rotation = data_config.enable_rotation_aug
33
+
34
+ self._background_values = data_config.background_values
35
+ self._data = load_data_fn(
36
+ data_config, fpath, data_config.datasplit_type, val_fraction, test_fraction
37
+ )
38
+
39
+ # remove upper quantiles, crucial for removing puncta
40
+ self.max_val = data_config.max_val
41
+ if self.max_val is not None:
42
+ for ch_idx, data in enumerate(self._data):
43
+ if self.max_val[ch_idx] is not None:
44
+ for idx in range(len(data)):
45
+ data[idx][data[idx] > self.max_val[ch_idx]] = self.max_val[
46
+ ch_idx
47
+ ]
48
+
49
+ # remove background values
50
+ if self._background_values is not None:
51
+ final_data_arr = []
52
+ for ch_idx, data in enumerate(self._data):
53
+ data_float = [x.astype(np.float32) for x in data]
54
+ final_data_arr.append(
55
+ [x - self._background_values[ch_idx] for x in data_float]
56
+ )
57
+ self._data = final_data_arr
58
+
59
+ print(
60
+ f"{self.__class__.__name__} N:{len(self)} Rot:{self._enable_rotation} Ch:{len(self._data)} MaxVal:{self.max_val} Bg:{self._background_values}"
61
+ )
62
+
63
+ def get_max_val(self):
64
+ return self.max_val
65
+
66
+ def compute_mean_std(self):
67
+ mean_tar_dict = defaultdict(list)
68
+ std_tar_dict = defaultdict(list)
69
+ mean_inp = []
70
+ std_inp = []
71
+ for _ in range(30000):
72
+ crops = []
73
+ for ch_idx in range(len(self._data)):
74
+ crop = self.sample_crop(ch_idx)
75
+ mean_tar_dict[ch_idx].append(np.mean(crop))
76
+ std_tar_dict[ch_idx].append(np.std(crop))
77
+ crops.append(crop)
78
+
79
+ inp = 0
80
+ for img in crops:
81
+ inp += img
82
+
83
+ mean_inp.append(np.mean(inp))
84
+ std_inp.append(np.std(inp))
85
+
86
+ output_mean = defaultdict(list)
87
+ output_std = defaultdict(list)
88
+ NC = len(self._data)
89
+ for ch_idx in range(NC):
90
+ output_mean["target"].append(np.mean(mean_tar_dict[ch_idx]))
91
+ output_std["target"].append(l2(std_tar_dict[ch_idx]))
92
+
93
+ output_mean["target"] = np.array(output_mean["target"]).reshape(NC, 1, 1)
94
+ output_std["target"] = np.array(output_std["target"]).reshape(NC, 1, 1)
95
+
96
+ output_mean["input"] = np.array([np.mean(mean_inp)]).reshape(1, 1, 1)
97
+ output_std["input"] = np.array([l2(std_inp)]).reshape(1, 1, 1)
98
+ return dict(output_mean), dict(output_std)
99
+
100
+ def set_mean_std(self, mean_dict, std_dict):
101
+ self._data_mean = mean_dict
102
+ self._data_std = std_dict
103
+
104
+ def get_mean_std(self):
105
+ return self._data_mean, self._data_std
106
+
107
+ def get_num_frames(self):
108
+ return len(self._data)
109
+
110
+ @cache
111
+ def crop_probablities(self, ch_idx):
112
+ sizes = np.array([np.prod(x.shape) for x in self._data[ch_idx]])
113
+ return sizes / sizes.sum()
114
+
115
+ def sample_crop(self, ch_idx):
116
+ idx = None
117
+ count = 0
118
+ while idx is None:
119
+ count += 1
120
+ idx = np.random.choice(
121
+ len(self._data[ch_idx]), p=self.crop_probablities(ch_idx)
122
+ )
123
+ data = self._data[ch_idx][idx]
124
+ if data.shape[0] >= self._img_sz[0] and data.shape[1] >= self._img_sz[1]:
125
+ h = np.random.randint(0, data.shape[0] - self._img_sz[0])
126
+ w = np.random.randint(0, data.shape[1] - self._img_sz[1])
127
+ return data[h : h + self._img_sz[0], w : w + self._img_sz[1]]
128
+ elif count > 100:
129
+ raise ValueError("Cannot find a valid crop")
130
+ else:
131
+ idx = None
132
+
133
+ return None
134
+
135
+ def len_per_channel(self, ch_idx):
136
+ return np.sum([np.prod(x.shape) for x in self._data[ch_idx]]) / np.prod(
137
+ self._img_sz
138
+ )
139
+
140
+ def imgs_for_patch(self):
141
+ return [self.sample_crop(ch_idx) for ch_idx in range(len(self._data))]
142
+
143
+ def __len__(self):
144
+ len_per_channel = [
145
+ self.len_per_channel(ch_idx) for ch_idx in range(len(self._data))
146
+ ]
147
+ return int(np.max(len_per_channel))
148
+
149
+ def _rotate(self, img_tuples):
150
+ return self._rotate2D(img_tuples)
151
+
152
+ def _rotate2D(self, img_tuples):
153
+ img_kwargs = {}
154
+ for i, img in enumerate(img_tuples):
155
+ for k in range(len(img)):
156
+ img_kwargs[f"img{i}_{k}"] = img[k]
157
+
158
+ keys = list(img_kwargs.keys())
159
+ self._rotation_transform.add_targets({k: "image" for k in keys})
160
+ rot_dic = self._rotation_transform(image=img_tuples[0][0], **img_kwargs)
161
+
162
+ rotated_img_tuples = []
163
+ for i, img in enumerate(img_tuples):
164
+ if len(img) == 1:
165
+ rotated_img_tuples.append(rot_dic[f"img{i}_0"][None])
166
+ else:
167
+ rotated_img_tuples.append(
168
+ np.concatenate(
169
+ [rot_dic[f"img{i}_{k}"][None] for k in range(len(img))], axis=0
170
+ )
171
+ )
172
+
173
+ return rotated_img_tuples
174
+
175
+ def _compute_input(self, imgs):
176
+ inp = 0
177
+ for img in imgs:
178
+ inp += img
179
+
180
+ inp = inp[None]
181
+ inp = (inp - self._data_mean["input"]) / (self._data_std["input"])
182
+ return inp
183
+
184
+ def _compute_target(self, imgs):
185
+ imgs = np.stack(imgs)
186
+ target = (imgs - self._data_mean["target"]) / (self._data_std["target"])
187
+ return target
188
+
189
+ def __getitem__(self, idx):
190
+ imgs = self.imgs_for_patch()
191
+ if self._enable_rotation:
192
+ imgs = self._rotate(imgs)
193
+
194
+ inp = self._compute_input(imgs)
195
+ target = self._compute_target(imgs)
196
+ return inp, target
@@ -2,9 +2,9 @@ from enum import Enum
2
2
 
3
3
 
4
4
  class DataType(Enum):
5
- Elisa3DData = 0
5
+ HTH24Data = 0
6
6
  HTLIF24Data = 1
7
- Pavia3SeqData = 2
7
+ PaviaP24Data = 2
8
8
  TavernaSox2GolgiV2 = 3
9
9
  Dao3ChannelWithInput = 4
10
10
  ExpMicroscopyV1 = 5
@@ -15,7 +15,7 @@ class DataType(Enum):
15
15
  OptiMEM100_014 = 10
16
16
  SeparateTiffData = 11
17
17
  BioSR_MRC = 12
18
- PunctaRemoval = 13 # for the case when we have a set of differently sized crops for each channel.
18
+ HTH23BData = 13 # puncta, in case we have differently sized crops for each channel.
19
19
  Care3D = 14
20
20
 
21
21
 
@@ -230,3 +230,262 @@ class GridIndexManager:
230
230
  new_idx = dataset_idx - self.grid_count(dim)
231
231
  if new_idx < 0:
232
232
  return None
233
+
234
+
235
+ @dataclass
236
+ class GridIndexManagerRef:
237
+ data_shapes: tuple
238
+ grid_shape: tuple
239
+ patch_shape: tuple
240
+ tiling_mode: TilingMode
241
+
242
+ # This class is used to calculate and store information about patches, and calculate
243
+ # the total length of the dataset in patches.
244
+ # It introduces a concept of a grid, to which input images are split.
245
+ # The grid is defined by the grid_shape and patch_shape, with former controlling the
246
+ # overlap.
247
+ # In this reimplementation it can accept multiple channels with different lengths,
248
+ # and every image can have different shape.
249
+
250
+ def __post_init__(self):
251
+ if len(self.data_shapes) > 1:
252
+ assert {len(ds) for ds in self.data_shapes[0]}.pop() == {
253
+ len(ds) for ds in self.data_shapes[1]
254
+ }.pop(), "Data shape for all channels must be the same" # TODO better way to assert this
255
+ assert {len(ds) for ds in self.data_shapes[0]}.pop() == len(
256
+ self.grid_shape
257
+ ), "Data shape and grid size must have the same dimension"
258
+ assert {len(ds) for ds in self.data_shapes[0]}.pop() == len(
259
+ self.patch_shape
260
+ ), "Data shape and patch shape must have the same dimension"
261
+ innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
262
+ for dim, pad in enumerate(innerpad):
263
+ if pad < 0:
264
+ raise ValueError(
265
+ f"Patch shape must be greater than or equal to grid shape in dimension {dim}"
266
+ )
267
+ if pad % 2 != 0:
268
+ raise ValueError(
269
+ f"Patch shape must have even padding in dimension {dim}"
270
+ )
271
+ self.num_patches_per_channel = self.total_grid_count()[1]
272
+
273
+ def patch_offset(self):
274
+ return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
275
+
276
+ def get_individual_dim_grid_count(self, shape: tuple, dim: int):
277
+ """
278
+ Returns the number of the grid in the specified dimension, ignoring all other dimensions.
279
+ """
280
+ # assert that dim is less than the number of dimensions in data shape
281
+
282
+ # if dim > len()
283
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
284
+ return shape[dim]
285
+ elif self.tiling_mode == TilingMode.PadBoundary:
286
+ return int(np.ceil(shape[dim] / self.grid_shape[dim]))
287
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
288
+ excess_size = self.patch_shape[dim] - self.grid_shape[dim]
289
+ return int(np.ceil((shape[dim] - excess_size) / self.grid_shape[dim]))
290
+ # if dim_index < self.get_individual_dim_grid_count(dim) - 1:
291
+ # return dim_index * self.grid_shape[dim] + excess_size
292
+ # on boundary. grid should be placed such that the patch covers the entire data.
293
+ # return self.data_shape[dim] - self.grid_shape[dim] - excess_size
294
+ else:
295
+ excess_size = self.patch_shape[dim] - self.grid_shape[dim]
296
+ return int(np.floor((shape[dim] - excess_size) / self.grid_shape[dim]))
297
+
298
+ def total_grid_count(self):
299
+ """Returns the total number of patches in the dataset."""
300
+ len_per_channel = []
301
+ num_patches_per_sample = []
302
+ for channel_data in self.data_shapes:
303
+ num_patches = []
304
+ for file_shape in channel_data:
305
+ num_patches.append(np.prod(self.grid_count_per_sample(file_shape)))
306
+ len_per_channel.append(np.sum(num_patches))
307
+ num_patches_per_sample.append(num_patches)
308
+
309
+ return len_per_channel, num_patches_per_sample
310
+
311
+ def grid_count_per_sample(self, shape: tuple):
312
+ """Returns the total number of patches for one dimension."""
313
+ grid_count = []
314
+ for dim in range(len(shape)):
315
+ grid_count.append(self.get_individual_dim_grid_count(shape, dim))
316
+ return grid_count
317
+
318
+ def get_grid_index(self, shape, dim: int, coordinate: int):
319
+ """Returns the index of the patch in the specified dimension."""
320
+ assert dim < len(
321
+ shape
322
+ ), f"Dimension {dim} is out of bounds for data shape {shape}"
323
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
324
+ assert (
325
+ coordinate < shape[dim]
326
+ ), f"Coordinate {coordinate} is out of bounds for data shape {shape}"
327
+
328
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
329
+ return coordinate
330
+ elif self.tiling_mode == TilingMode.PadBoundary: # self.trim_boundary is False:
331
+ return np.floor(coordinate / self.grid_shape[dim])
332
+ elif self.tiling_mode == TilingMode.TrimBoundary:
333
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
334
+ # can be <0 if coordinate is in [0,grid_shape[dim]]
335
+ return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
336
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
337
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
338
+ if coordinate + self.grid_shape[dim] + excess_size == self.data_shapes[dim]:
339
+ return self.get_individual_dim_grid_count(shape, dim) - 1
340
+ else:
341
+ # can be <0 if coordinate is in [0,grid_shape[dim]]
342
+ return max(
343
+ 0, np.floor((coordinate - excess_size) / self.grid_shape[dim])
344
+ )
345
+
346
+ else:
347
+ raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
348
+
349
+ def patch_idx_from_grid_idx(self, shape: tuple, grid_idx: tuple):
350
+ """Returns the index of the patch in the dataset."""
351
+ assert len(grid_idx) == len(
352
+ shape
353
+ ), f"Dimension indices {grid_idx} must have the same dimension as data shape {shape}"
354
+ index = 0
355
+ for dim in range(len(grid_idx)):
356
+ index += grid_idx[dim] * self.grid_count(shape, dim)
357
+ return index
358
+
359
+ def get_patch_location_from_patch_idx(self, ch_idx: int, patch_idx: int):
360
+ """Returns the patch location of the grid in the dataset."""
361
+ grid_location = self.get_location_from_patch_idx(ch_idx, patch_idx)
362
+ offset = self.patch_offset()
363
+ return tuple(np.array(grid_location) - np.concatenate((np.array((0,)), offset)))
364
+
365
+ def get_patch_idx_from_grid_location(self, shape, location: tuple):
366
+ assert len(location) == len(
367
+ shape
368
+ ), f"Location {location} must have the same dimension as data shape {shape}"
369
+ grid_idx = [
370
+ self.get_grid_index(dim, location[dim]) for dim in range(len(location))
371
+ ]
372
+ return self.patch_idx_from_grid_idx(tuple(grid_idx))
373
+
374
+ def get_gridstart_location_from_dim_index(
375
+ self, shape: tuple, dim_idx: int, dim: int
376
+ ):
377
+ """Returns the grid-start coordinate of the grid in the specified dimension.
378
+
379
+ dim_idx: int
380
+ Index of the dimension in the data shape.
381
+ dim: int
382
+ Value of the dimension in the grid (relative to num patches in dimension).
383
+ """
384
+ if self.grid_shape[dim_idx] == 1 and self.patch_shape[dim_idx] == 1:
385
+ return dim_idx
386
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
387
+ excess_size = (self.patch_shape[dim_idx] - self.grid_shape[dim_idx]) // 2
388
+ if dim < self.get_individual_dim_grid_count(shape, dim_idx) - 1:
389
+ return dim * self.grid_shape[dim_idx] + excess_size
390
+ else:
391
+ # on boundary. grid should be placed such that the patch covers the entire data.
392
+ return shape[dim_idx] - self.grid_shape[dim_idx] - excess_size
393
+ else:
394
+ raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
395
+
396
+ def get_location_from_patch_idx(self, channel_idx: int, patch_idx: int):
397
+ """
398
+ Returns the start location of the grid in the dataset. Per channel!.
399
+
400
+ Parameters
401
+ ----------
402
+ patch_idx : int
403
+ The index of the patch in a list of samples within a channel. Channels can
404
+ be different in length.
405
+ """
406
+ # TODO assert patch_idx <= num of patches in the channel
407
+ # create cumulative sum of the grid counts for each channel
408
+ cumulative_indices = np.cumsum(self.total_grid_count()[1][channel_idx])
409
+ # find the channel index
410
+ sample_idx = np.searchsorted(cumulative_indices, patch_idx, side="right")
411
+ sample_shape = self.data_shapes[channel_idx][sample_idx]
412
+ # TODO duplicated runs, revisit
413
+ # ingoring the channel dimension because we index it explicitly
414
+ grid_count = self.grid_count_per_sample(sample_shape)[1:]
415
+
416
+ grid_idx = []
417
+ for i in range(len(grid_count) - 1, -1, -1):
418
+ stride = np.prod(grid_count[:i]) if i > 0 else 1
419
+ grid_idx.insert(0, patch_idx // stride)
420
+ patch_idx %= stride
421
+ # TODO check for 3D !
422
+ # adding channel index
423
+ grid_idx = [channel_idx] + grid_idx
424
+ location = [
425
+ sample_idx,
426
+ ] + [
427
+ self.get_gridstart_location_from_dim_index(
428
+ shape=sample_shape, dim_idx=dim_idx, dim=grid_idx[dim_idx]
429
+ )
430
+ for dim_idx in range(len(grid_idx))
431
+ ]
432
+ return tuple(location)
433
+
434
+ def get_location_from_patch_idx_o(self, dataset_idx: int):
435
+ """
436
+ Returns the start location of the grid in the dataset.
437
+ """
438
+ grid_idx = []
439
+ for dim in range(len(self.data_shape)):
440
+ grid_idx.append(dataset_idx // self.grid_count(dim))
441
+ dataset_idx = dataset_idx % self.grid_count(dim)
442
+ location = [
443
+ self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
444
+ for dim in range(len(self.data_shape))
445
+ ]
446
+ return tuple(location)
447
+
448
+ def on_boundary(self, dataset_idx: int, dim: int, only_end: bool = False):
449
+ """
450
+ Returns True if the grid is on the boundary in the specified dimension.
451
+ """
452
+ assert dim < len(
453
+ self.data_shapes
454
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shapes}"
455
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
456
+
457
+ if dim > 0:
458
+ dataset_idx = dataset_idx % self.grid_count(dim - 1)
459
+
460
+ dim_index = dataset_idx // self.grid_count(dim)
461
+ if only_end:
462
+ return dim_index == self.get_individual_dim_grid_count(dim) - 1
463
+
464
+ return (
465
+ dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
466
+ )
467
+
468
+ def next_grid_along_dim(self, dataset_idx: int, dim: int):
469
+ """
470
+ Returns the index of the grid in the specified dimension in the specified direction.
471
+ """
472
+ assert dim < len(
473
+ self.data_shapes
474
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shapes}"
475
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
476
+ new_idx = dataset_idx + self.grid_count(dim)
477
+ if new_idx >= self.total_grid_count():
478
+ return None
479
+ return new_idx
480
+
481
+ def prev_grid_along_dim(self, dataset_idx: int, dim: int):
482
+ """
483
+ Returns the index of the grid in the specified dimension in the specified direction.
484
+ """
485
+ assert dim < len(
486
+ self.data_shapes
487
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shapes}"
488
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
489
+ new_idx = dataset_idx - self.grid_count(dim)
490
+ if new_idx < 0:
491
+ return None
@@ -14,10 +14,11 @@ import matplotlib.pyplot as plt
14
14
  import numpy as np
15
15
  import torch
16
16
  from matplotlib.gridspec import GridSpec
17
- from torch.utils.data import DataLoader, Dataset, Subset
17
+ from torch.utils.data import DataLoader, Dataset
18
18
  from tqdm import tqdm
19
19
 
20
20
  from careamics.lightning import VAEModule
21
+ from careamics.lvae_training.dataset import MultiChDloaderRef
21
22
  from careamics.utils.metrics import scale_invariant_psnr
22
23
 
23
24
 
@@ -542,7 +543,9 @@ def get_predictions(
542
543
  mmse_count=mmse_count,
543
544
  num_workers=num_workers,
544
545
  )
546
+ # TODO stitching still not working properly for weirdly shaped images
545
547
  # get filename without extension and path
548
+ # TODO in the ref ds this is the name of a folder not file :(
546
549
  filename = dset._fpath.name
547
550
  return (
548
551
  {filename: stitched_predictions},
@@ -656,8 +659,14 @@ def get_single_file_mmse(
656
659
 
657
660
  tiles_arr = np.concatenate(tile_mmse, axis=0)
658
661
  tile_stds = np.concatenate(tile_stds, axis=0)
659
- stitched_predictions = stitch_predictions_new(tiles_arr, dset)
660
- stitched_stds = stitch_predictions_new(tile_stds, dset)
662
+ # TODO temporary hack, because of the stupid jupyter!
663
+ # If a user reruns a cell with class definition, isinstance will return False
664
+ if str(MultiChDloaderRef).split(".")[-1] == str(dset.__class__).split(".")[-1]:
665
+ stitch_func = stitch_predictions_general
666
+ else:
667
+ stitch_func = stitch_predictions_new
668
+ stitched_predictions = stitch_func(tiles_arr, dset)
669
+ stitched_stds = stitch_func(tile_stds, dset)
661
670
  return stitched_predictions, stitched_stds
662
671
 
663
672
 
@@ -873,3 +882,84 @@ def stitch_predictions_new(predictions, dset):
873
882
  raise ValueError(f"Unsupported shape {output.shape}")
874
883
 
875
884
  return output
885
+
886
+
887
+ def stitch_predictions_general(predictions, dset):
888
+ """Stitching for the dataset with multiple files of different shape."""
889
+ mng = dset.idx_manager
890
+
891
+ # TODO assert all shapes are equal len
892
+ # adjust number of channels to match with prediction shape #TODO ugly, refac!
893
+ shapes = []
894
+ for shape in dset.get_data_shapes()[0]:
895
+ shapes.append((predictions.shape[1],) + shape[1:])
896
+
897
+ output = [np.zeros(shape, dtype=predictions.dtype) for shape in shapes]
898
+ # frame_shape = dset.get_data_shape()[:-1]
899
+ for patch_idx in range(predictions.shape[0]):
900
+ # grid start, grid end
901
+ # channel_idx is 0 because during prediction we're only use one channel. # TODO revisit this
902
+ # 0th dimension is sample index in the output list
903
+ grid_coords = np.array(
904
+ mng.get_location_from_patch_idx(channel_idx=0, patch_idx=patch_idx),
905
+ dtype=int,
906
+ )
907
+ sample_idx = grid_coords[0]
908
+ grid_start = grid_coords[1:]
909
+ # from here on, coordinates are relative to the sample(file in the list of inputs)
910
+ grid_end = grid_start + mng.grid_shape
911
+
912
+ # patch start, patch end
913
+ patch_start = grid_start - mng.patch_offset()
914
+ patch_end = patch_start + mng.patch_shape
915
+
916
+ # valid grid start, valid grid end
917
+ valid_grid_start = np.array([max(0, x) for x in grid_start], dtype=int)
918
+ valid_grid_end = np.array(
919
+ [min(x, y) for x, y in zip(grid_end, shapes[sample_idx])], dtype=int
920
+ )
921
+
922
+ if mng.tiling_mode == TilingMode.ShiftBoundary:
923
+ for dim in range(len(valid_grid_start)):
924
+ if patch_start[dim] == 0:
925
+ valid_grid_start[dim] = 0
926
+ if patch_end[dim] == mng.data_shape[dim]:
927
+ valid_grid_end[dim] = mng.data_shape[dim]
928
+
929
+ # relative start, relative end. This will be used on pred_tiled
930
+ relative_start = valid_grid_start - patch_start
931
+ relative_end = relative_start + (valid_grid_end - valid_grid_start)
932
+
933
+ for ch_idx in range(predictions.shape[1]):
934
+ if len(output[sample_idx].shape) == 3:
935
+ # starting from 1 because 0th dimension is channel relative to input
936
+ # channel dimension for stitched output is relative to model output
937
+ output[sample_idx][
938
+ ch_idx,
939
+ valid_grid_start[1] : valid_grid_end[1],
940
+ valid_grid_start[2] : valid_grid_end[2],
941
+ ] = predictions[patch_idx][
942
+ ch_idx,
943
+ relative_start[1] : relative_end[1],
944
+ relative_start[2] : relative_end[2],
945
+ ]
946
+ elif len(output[sample_idx].shape) == 4:
947
+ assert (
948
+ valid_grid_end[0] - valid_grid_start[0] == 1
949
+ ), "Only one frame is supported"
950
+ output[
951
+ ch_idx,
952
+ valid_grid_start[0],
953
+ valid_grid_end[1] : valid_grid_end[1],
954
+ valid_grid_start[2] : valid_grid_end[2],
955
+ valid_grid_start[3] : valid_grid_end[3],
956
+ ] = predictions[patch_idx][
957
+ ch_idx,
958
+ relative_start[1] : relative_end[1],
959
+ relative_start[2] : relative_end[2],
960
+ relative_start[3] : relative_end[3],
961
+ ]
962
+ else:
963
+ raise ValueError(f"Unsupported shape {output.shape}")
964
+
965
+ return output
@@ -86,6 +86,7 @@ class Compose:
86
86
  *params, _ = t(*params) # ignore additional_arrays dict
87
87
 
88
88
  # avoid None values that create problems for collating
89
+ # TODO: removing None should be handled in dataset, not here
89
90
  return tuple(p for p in params if p is not None)
90
91
 
91
92
  def _chain_transforms_additional_arrays(