careamics 0.0.8__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (63) hide show
  1. careamics/__init__.py +0 -4
  2. careamics/careamist.py +0 -1
  3. careamics/config/__init__.py +1 -13
  4. careamics/config/algorithms/care_algorithm_model.py +84 -0
  5. careamics/config/algorithms/n2n_algorithm_model.py +85 -0
  6. careamics/config/algorithms/n2v_algorithm_model.py +269 -1
  7. careamics/config/configuration.py +21 -13
  8. careamics/config/configuration_factories.py +179 -187
  9. careamics/config/configuration_io.py +2 -2
  10. careamics/config/data/__init__.py +1 -4
  11. careamics/config/data/data_model.py +46 -62
  12. careamics/config/support/supported_transforms.py +1 -1
  13. careamics/config/transformations/__init__.py +0 -2
  14. careamics/config/transformations/n2v_manipulate_model.py +15 -0
  15. careamics/config/transformations/transform_unions.py +0 -13
  16. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  17. careamics/dataset/in_memory_dataset.py +3 -10
  18. careamics/dataset/in_memory_pred_dataset.py +3 -5
  19. careamics/dataset/in_memory_tiled_pred_dataset.py +2 -2
  20. careamics/dataset/iterable_dataset.py +2 -2
  21. careamics/dataset/iterable_pred_dataset.py +3 -5
  22. careamics/dataset/iterable_tiled_pred_dataset.py +3 -3
  23. careamics/dataset_ng/dataset/__init__.py +3 -0
  24. careamics/dataset_ng/dataset/dataset.py +184 -0
  25. careamics/dataset_ng/demo_dataset.ipynb +271 -0
  26. careamics/dataset_ng/demo_patch_extractor.py +53 -0
  27. careamics/dataset_ng/demo_patch_extractor_factory.py +37 -0
  28. careamics/dataset_ng/patch_extractor/__init__.py +10 -0
  29. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +111 -0
  30. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +9 -0
  31. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +53 -0
  32. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +55 -0
  33. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +163 -0
  34. careamics/dataset_ng/patch_extractor/image_stack_loader.py +140 -0
  35. careamics/dataset_ng/patch_extractor/patch_extractor.py +29 -0
  36. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +208 -0
  37. careamics/dataset_ng/patching_strategies/__init__.py +11 -0
  38. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +82 -0
  39. careamics/dataset_ng/patching_strategies/random_patching.py +338 -0
  40. careamics/dataset_ng/patching_strategies/sequential_patching.py +75 -0
  41. careamics/lightning/lightning_module.py +78 -27
  42. careamics/lightning/train_data_module.py +8 -39
  43. careamics/losses/fcn/losses.py +17 -10
  44. careamics/lvae_training/eval_utils.py +21 -8
  45. careamics/model_io/bioimage/bioimage_utils.py +5 -3
  46. careamics/model_io/bioimage/model_description.py +3 -3
  47. careamics/model_io/bmz_io.py +2 -2
  48. careamics/model_io/model_io_utils.py +2 -2
  49. careamics/transforms/__init__.py +2 -1
  50. careamics/transforms/compose.py +5 -15
  51. careamics/transforms/n2v_manipulate_torch.py +143 -0
  52. careamics/transforms/pixel_manipulation.py +1 -0
  53. careamics/transforms/pixel_manipulation_torch.py +418 -0
  54. careamics/utils/version.py +38 -0
  55. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/METADATA +7 -8
  56. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/RECORD +59 -42
  57. careamics/config/care_configuration.py +0 -100
  58. careamics/config/data/n2v_data_model.py +0 -193
  59. careamics/config/n2n_configuration.py +0 -101
  60. careamics/config/n2v_configuration.py +0 -266
  61. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/WHEEL +0 -0
  62. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/entry_points.txt +0 -0
  63. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,338 @@
1
+ """A module for random patching strategies."""
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+
8
+ from .patching_strategy_protocol import PatchSpecs
9
+
10
+
11
+ class RandomPatchingStrategy:
12
+ """
13
+ A patching strategy for sampling random patches, it implements the
14
+ `PatchingStrategy` `Protocol`.
15
+
16
+ The output of `get_patch_spec` will be random, i.e. if the same index is given
17
+ twice the two outputs can be different.
18
+
19
+ However the strategy still ensures that there will be a known number of patches for
20
+ each sample in each image stack. This is achieved through defining a set of bins
21
+ that map to each sample in each image stack. Whichever bin an `index` passed to
22
+ `get_patch_spec` falls into, determines the `"data_idx"` and `"sample_idx"` in
23
+ the returned `PatchSpecs`, but the `"coords"` will be random.
24
+
25
+ The number of patches in each sample is based on the number of patches that would
26
+ fit if they were sampled sequentially, non-overlapping, and covering the entire
27
+ array.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ data_shapes: Sequence[Sequence[int]],
33
+ patch_size: Sequence[int],
34
+ seed: Optional[int] = None,
35
+ ):
36
+ """
37
+ A patching strategy for sampling random patches.
38
+
39
+ Parameters
40
+ ----------
41
+ data_shapes : sequence of (sequence of int)
42
+ The shapes of the underlying data. Each element is the dimension of the
43
+ axes SC(Z)YX.
44
+ patch_size : sequence of int
45
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
46
+ data respectively.
47
+ seed : int, optional
48
+ An optional seed to ensure the reproducibility of the random patches.
49
+ """
50
+ self.rng = np.random.default_rng(seed=seed)
51
+ self.patch_size = patch_size
52
+ self.data_shapes = data_shapes
53
+
54
+ # these bins will determine which image stack and sample a patch comes from
55
+ # the image_stack_cumulative_patches map a patch index to each image stack
56
+ # the sample_cumulative_patches map a patch index to each sample
57
+ # the image_stack_cumulative_samples map a sample index to each image stack
58
+ (
59
+ self.image_stack_cumulative_patches,
60
+ self.sample_cumulative_patches,
61
+ self.image_stack_cumulative_samples,
62
+ ) = self._calc_bins(self.data_shapes, self.patch_size)
63
+
64
+ @property
65
+ def n_patches(self) -> int:
66
+ """
67
+ The number of patches that this patching strategy will return.
68
+
69
+ It also determines the maximum index that can be given to `get_patch_spec`.
70
+ """
71
+ # last bin boundary will be total patches
72
+ return self.image_stack_cumulative_patches[-1]
73
+
74
+ def get_patch_spec(self, index: int) -> PatchSpecs:
75
+ """Return the patch specs for a given index.
76
+
77
+ Parameters
78
+ ----------
79
+ index : int
80
+ A patch index.
81
+
82
+ Returns
83
+ -------
84
+ PatchSpecs
85
+ A dictionary that specifies a single patch in a series of `ImageStacks`.
86
+ """
87
+ # TODO: break into smaller testable functions?
88
+ if index >= self.n_patches:
89
+ raise IndexError(
90
+ f"Index {index} out of bounds for RandomPatchingStrategy with number "
91
+ f"of patches {self.n_patches}"
92
+ )
93
+ # digitize returns the bin that `index` belongs to
94
+ data_index = np.digitize(index, bins=self.image_stack_cumulative_patches).item()
95
+ # maps to a particular sample within the whole series of image stacks
96
+ # (not just a single image stack)
97
+ total_samples_index = np.digitize(
98
+ index, bins=self.sample_cumulative_patches
99
+ ).item()
100
+
101
+ data_shape = self.data_shapes[data_index]
102
+ spatial_shape = data_shape[2:]
103
+
104
+ # calculate sample index relative to image stack:
105
+ # subtract the total number of samples in the previous image stacks
106
+ if data_index == 0:
107
+ n_previous_samples = 0
108
+ else:
109
+ n_previous_samples = self.image_stack_cumulative_samples[data_index - 1]
110
+ sample_index = total_samples_index - n_previous_samples
111
+ coords = _generate_random_coords(spatial_shape, self.patch_size, self.rng)
112
+ return {
113
+ "data_idx": data_index,
114
+ "sample_idx": sample_index,
115
+ "coords": coords,
116
+ "patch_size": self.patch_size,
117
+ }
118
+
119
+ @staticmethod
120
+ def _calc_bins(
121
+ data_shapes: Sequence[Sequence[int]], patch_size: Sequence[int]
122
+ ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
123
+ """Calculate bins used to map an index to an image_stack and a sample.
124
+
125
+ The number of patches in each sample is based on the number of patches that
126
+ would fit if they were sampled sequentially.
127
+
128
+ Parameters
129
+ ----------
130
+ data_shapes : sequence of (sequence of int)
131
+ The shapes of the underlying data. Each element is the dimension of the
132
+ axes SC(Z)YX.
133
+ patch_size : sequence of int
134
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
135
+ data respectively.
136
+
137
+ Returns
138
+ -------
139
+ image_stack_cumulative_patches: tuple of int
140
+ The bins that map a patch index to an image stack. E.g. if a patch index
141
+ falls below the first bin boundary it belongs to the first image stack, if
142
+ a patch index falls between the first bin boundary and the second bin
143
+ boundary it belongs to the second image stack, and so on.
144
+ sample_cumulative_patches: tuple of int
145
+ The bins that map a patch index to a sample. E.g. if a patch index
146
+ falls below the first bin boundary it belongs to the first sample, if
147
+ a patch index falls between the first bin boundary and the second bin
148
+ boundary it belongs to the second sample, and so on.
149
+ image_stack_cumulative_samples: tuple of int
150
+ The bins that map a sample index to an image stack. E.g. if a sample index
151
+ falls below the first bin boundary it belongs to the first image stack, if
152
+ a patch index falls between the first bin boundary and the second bin
153
+ boundary it belongs to the second image stack, and so on.
154
+ """
155
+ patches_per_image_stack: list[int] = []
156
+ patches_per_sample: list[int] = []
157
+ samples_per_image_stack: list[int] = []
158
+ for data_shape in data_shapes:
159
+ spatial_shape = data_shape[2:]
160
+ n_single_sample_patches = _calc_n_patches(spatial_shape, patch_size)
161
+ # multiply by number of samples in image_stack
162
+ patches_per_image_stack.append(n_single_sample_patches * data_shape[0])
163
+ # list of length `sample` filled with `n_single_sample_patches`
164
+ patches_per_sample.extend([n_single_sample_patches] * data_shape[0])
165
+ # number of samples in each image stack
166
+ samples_per_image_stack.append(data_shape[0])
167
+
168
+ # cumulative sum creates the bins
169
+ image_stack_cumulative_patches = np.cumsum(patches_per_image_stack)
170
+ sample_cumulative_patches = np.cumsum(patches_per_sample)
171
+ image_stack_cumulative_samples = np.cumsum(samples_per_image_stack)
172
+ return (
173
+ tuple(image_stack_cumulative_patches),
174
+ tuple(sample_cumulative_patches),
175
+ tuple(image_stack_cumulative_samples),
176
+ )
177
+
178
+
179
+ class FixedRandomPatchingStrategy:
180
+ """
181
+ A patching strategy for sampling random patches it implements the `PatchingStrategy`
182
+ `Protocol`.
183
+
184
+ The output of `get_patch_spec` will be deterministic, i.e. if the same index is
185
+ given twice the two outputs will be the same.
186
+
187
+ The number of patches in each sample is based on the number of patches that would
188
+ fit if they were sampled sequentially, non-overlapping, and covering the entire
189
+ array.
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ data_shapes: Sequence[Sequence[int]],
195
+ patch_size: Sequence[int],
196
+ seed: Optional[int] = None,
197
+ ):
198
+ """A patching strategy for sampling random patches.
199
+
200
+ Parameters
201
+ ----------
202
+ data_shapes : sequence of (sequence of int)
203
+ The shapes of the underlying data. Each element is the dimension of the
204
+ axes SC(Z)YX.
205
+ patch_size : sequence of int
206
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
207
+ data respectively.
208
+ seed : int, optional
209
+ An optional seed to ensure the reproducibility of the random patches.
210
+ """
211
+ self.rng = np.random.default_rng(seed=seed)
212
+ self.patch_size = patch_size
213
+ self.data_shapes = data_shapes
214
+
215
+ # simply generate all the patches at initialisation, so they will be fixed
216
+ self.fixed_patch_specs: list[PatchSpecs] = []
217
+ for data_idx, data_shape in enumerate(self.data_shapes):
218
+ spatial_shape = data_shape[2:]
219
+ n_patches = _calc_n_patches(spatial_shape, self.patch_size)
220
+ for sample_idx in range(data_shape[0]):
221
+ for _ in range(n_patches):
222
+ random_coords = _generate_random_coords(
223
+ spatial_shape, self.patch_size, self.rng
224
+ )
225
+ patch_specs: PatchSpecs = {
226
+ "data_idx": data_idx,
227
+ "sample_idx": sample_idx,
228
+ "coords": random_coords,
229
+ "patch_size": self.patch_size,
230
+ }
231
+ self.fixed_patch_specs.append(patch_specs)
232
+
233
+ @property
234
+ def n_patches(self):
235
+ """
236
+ The number of patches that this patching strategy will return.
237
+
238
+ It also determines the maximum index that can be given to `get_patch_spec`.
239
+ """
240
+ return len(self.fixed_patch_specs)
241
+
242
+ def get_patch_spec(self, index: int) -> PatchSpecs:
243
+ """Return the patch specs for a given index.
244
+
245
+ Parameters
246
+ ----------
247
+ index : int
248
+ A patch index.
249
+
250
+ Returns
251
+ -------
252
+ PatchSpecs
253
+ A dictionary that specifies a single patch in a series of `ImageStacks`.
254
+ """
255
+ if index >= self.n_patches:
256
+ raise IndexError(
257
+ f"Index {index} out of bounds for FixedRandomPatchingStrategy with "
258
+ f"number of patches, {self.n_patches}"
259
+ )
260
+ # simply index the pre-generated patches to get the correct patch
261
+ return self.fixed_patch_specs[index]
262
+
263
+
264
+ def _generate_random_coords(
265
+ spatial_shape: Sequence[int], patch_size: Sequence[int], rng: np.random.Generator
266
+ ) -> tuple[int, ...]:
267
+ """Generate random patch coordinates for a given `spatial_shape` and `patch_size`.
268
+
269
+ The coords are the top-left (and first z-slice for 3D data) of a patch. The
270
+ sequence will have length 2 or 3, for 2D and 3D data respectively.
271
+
272
+ Parameters
273
+ ----------
274
+ spatial_shape : sequence of int
275
+ The dimension of the axes (Z)YX, a sequence of length 2 or 3, for 2D and 3D
276
+ data respectively.
277
+ patch_size : sequence of int
278
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
279
+ data respectively.
280
+ rng : numpy.random.Generator
281
+ A numpy generator to ensure the reproducibility of the random patches.
282
+
283
+ Returns
284
+ -------
285
+ coords: tuple of int
286
+ The top-left (and first z-slice for 3D data) coords of a patch. The tuple will
287
+ have length 2 or 3, for 2D and 3D data respectively.
288
+
289
+ Raises
290
+ ------
291
+ ValueError
292
+ Raises if the number of spatial dimensions do not match the number of patch
293
+ dimensions.
294
+ """
295
+ if len(patch_size) != len(spatial_shape):
296
+ raise ValueError(
297
+ f"Number of patch dimension {len(patch_size)}, do not match the number of "
298
+ f"spatial dimensions {len(spatial_shape)}, for `patch_size={patch_size}` "
299
+ f"and `spatial_shape={spatial_shape}`."
300
+ )
301
+ return tuple(
302
+ rng.integers(
303
+ np.zeros(len(patch_size), dtype=int),
304
+ np.array(spatial_shape) - np.array(patch_size),
305
+ endpoint=False,
306
+ dtype=int,
307
+ ).tolist()
308
+ )
309
+
310
+
311
+ def _calc_n_patches(spatial_shape: Sequence[int], patch_size: Sequence[int]) -> int:
312
+ """
313
+ Calculates the number of patches for a given `spatial_shape` and `patch_size`.
314
+
315
+ This is based on the number of patches that would fit if they were sampled
316
+ sequentially.
317
+
318
+ Parameters
319
+ ----------
320
+ spatial_shape : sequence of int
321
+ The dimension of the axes (Z)YX, a sequence of length 2 or 3, for 2D and 3D
322
+ data respectively.
323
+ patch_size : sequence of int
324
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
325
+ data respectively.
326
+
327
+ Returns
328
+ -------
329
+ int
330
+ The number of patches.
331
+ """
332
+ if len(patch_size) != len(spatial_shape):
333
+ raise ValueError(
334
+ f"Number of patch dimension {len(patch_size)}, do not match the number of "
335
+ f"spatial dimensions {len(spatial_shape)}, for `patch_size={patch_size}` "
336
+ f"and `spatial_shape={spatial_shape}`."
337
+ )
338
+ return int(np.ceil(np.prod(spatial_shape) / np.prod(patch_size)))
@@ -0,0 +1,75 @@
1
+ import itertools
2
+ from collections.abc import Sequence
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ from typing_extensions import ParamSpec
7
+
8
+ from .patching_strategy_protocol import PatchSpecs
9
+
10
+ P = ParamSpec("P")
11
+
12
+
13
+ # TODO: this is an unfinished prototype based on current tiling implementation
14
+ # not guaranteed to work!
15
+ class SequentialPatchingStrategy:
16
+ # TODO: docs
17
+ def __init__(
18
+ self,
19
+ data_shapes: Sequence[Sequence[int]],
20
+ patch_size: Sequence[int],
21
+ overlap: Optional[Sequence[int]] = None,
22
+ ):
23
+ self.data_shapes = data_shapes
24
+ self.patch_size = patch_size
25
+ if overlap is None:
26
+ overlap = [0] * len(patch_size)
27
+ self.overlap = np.asarray(overlap)
28
+
29
+ self.patch_specs: list[PatchSpecs] = self._initialize_patch_specs()
30
+
31
+ @property
32
+ def n_patches(self) -> int:
33
+ return len(self.patch_specs)
34
+
35
+ def get_patch_spec(self, index: int) -> PatchSpecs:
36
+ return self.patch_specs[index]
37
+
38
+ def _compute_coords_1d(
39
+ self, patch_size: int, spatial_shape: int, overlap: int
40
+ ) -> list[tuple[int, int]]:
41
+ step = patch_size - overlap
42
+ crop_coords = []
43
+
44
+ current_pos = 0
45
+ while current_pos <= spatial_shape - patch_size:
46
+ crop_coords.append((current_pos, current_pos + patch_size))
47
+ current_pos += step
48
+
49
+ if crop_coords[-1][1] < spatial_shape:
50
+ crop_coords.append((spatial_shape - patch_size, spatial_shape))
51
+
52
+ return crop_coords
53
+
54
+ def _initialize_patch_specs(self) -> list[PatchSpecs]:
55
+ patch_specs: list[PatchSpecs] = []
56
+ for data_idx, data_shape in enumerate(self.data_shapes):
57
+
58
+ data_spatial_shape = data_shape[-len(self.patch_size) :]
59
+ coords_list = [
60
+ self._compute_coords_1d(
61
+ self.patch_size[i], data_spatial_shape[i], self.overlap[i]
62
+ )
63
+ for i in range(len(self.patch_size))
64
+ ]
65
+ for sample_idx in range(data_shape[0]):
66
+ for crop_coord in itertools.product(*coords_list):
67
+ patch_specs.append(
68
+ PatchSpecs(
69
+ data_idx=data_idx,
70
+ sample_idx=sample_idx,
71
+ coords=tuple(coord[0] for coord in crop_coord),
72
+ patch_size=self.patch_size,
73
+ )
74
+ )
75
+ return patch_specs
@@ -1,12 +1,17 @@
1
1
  """CAREamics Lightning module."""
2
2
 
3
- from typing import Any, Callable, Optional, Union
3
+ from typing import Any, Callable, Literal, Optional, Union
4
4
 
5
5
  import numpy as np
6
6
  import pytorch_lightning as L
7
7
  from torch import Tensor, nn
8
8
 
9
- from careamics.config import UNetBasedAlgorithm, VAEBasedAlgorithm
9
+ from careamics.config import (
10
+ N2VAlgorithm,
11
+ UNetBasedAlgorithm,
12
+ VAEBasedAlgorithm,
13
+ algorithm_factory,
14
+ )
10
15
  from careamics.config.support import (
11
16
  SupportedAlgorithm,
12
17
  SupportedArchitecture,
@@ -27,7 +32,11 @@ from careamics.models.lvae.noise_models import (
27
32
  noise_model_factory,
28
33
  )
29
34
  from careamics.models.model_factory import model_factory
30
- from careamics.transforms import Denormalize, ImageRestorationTTA
35
+ from careamics.transforms import (
36
+ Denormalize,
37
+ ImageRestorationTTA,
38
+ N2VManipulateTorch,
39
+ )
31
40
  from careamics.utils.metrics import RunningPSNR, scale_invariant_psnr
32
41
  from careamics.utils.torch_utils import get_optimizer, get_scheduler
33
42
 
@@ -73,13 +82,21 @@ class FCNModule(L.LightningModule):
73
82
  Algorithm configuration.
74
83
  """
75
84
  super().__init__()
76
- # if loading from a checkpoint, AlgorithmModel needs to be instantiated
85
+
77
86
  if isinstance(algorithm_config, dict):
78
- algorithm_config = UNetBasedAlgorithm(
79
- **algorithm_config
80
- ) # TODO this needs to be updated using the algorithm-specific class
87
+ algorithm_config = algorithm_factory(algorithm_config)
88
+
89
+ # create preprocessing, model and loss function
90
+ if isinstance(algorithm_config, N2VAlgorithm):
91
+ self.use_n2v = True
92
+ self.n2v_preprocess: Optional[N2VManipulateTorch] = N2VManipulateTorch(
93
+ n2v_manipulate_config=algorithm_config.n2v_config
94
+ )
95
+ else:
96
+ self.use_n2v = False
97
+ self.n2v_preprocess = None
81
98
 
82
- # create model and loss function
99
+ self.algorithm = algorithm_config.algorithm
83
100
  self.model: nn.Module = model_factory(algorithm_config.model)
84
101
  self.loss_func = loss_factory(algorithm_config.loss)
85
102
 
@@ -119,10 +136,15 @@ class FCNModule(L.LightningModule):
119
136
  Any
120
137
  Loss value.
121
138
  """
122
- # TODO can N2V be simplified by returning mask*original_patch
123
- x, *aux = batch
124
- out = self.model(x)
125
- loss = self.loss_func(out, *aux)
139
+ x, *targets = batch
140
+ if self.use_n2v and self.n2v_preprocess is not None:
141
+ x_preprocessed, *aux = self.n2v_preprocess(x)
142
+ else:
143
+ x_preprocessed = x
144
+ aux = []
145
+
146
+ out = self.model(x_preprocessed)
147
+ loss = self.loss_func(out, *aux, *targets)
126
148
  self.log(
127
149
  "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
128
150
  )
@@ -138,9 +160,15 @@ class FCNModule(L.LightningModule):
138
160
  batch_idx : Any
139
161
  Batch index.
140
162
  """
141
- x, *aux = batch
142
- out = self.model(x)
143
- val_loss = self.loss_func(out, *aux)
163
+ x, *targets = batch
164
+ if self.use_n2v and self.n2v_preprocess is not None:
165
+ x_preprocessed, *aux = self.n2v_preprocess(x)
166
+ else:
167
+ x_preprocessed = x
168
+ aux = []
169
+
170
+ out = self.model(x_preprocessed)
171
+ val_loss = self.loss_func(out, *aux, *targets)
144
172
 
145
173
  # log validation loss
146
174
  self.log(
@@ -177,10 +205,16 @@ class FCNModule(L.LightningModule):
177
205
  and isinstance(batch[1][0], TileInformation)
178
206
  )
179
207
 
208
+ # TODO add explanations for what is happening here
180
209
  if is_tiled:
181
210
  x, *aux = batch
211
+ if type(x) in [list, tuple]:
212
+ x = x[0]
182
213
  else:
183
- x = batch
214
+ if type(batch) in [list, tuple]:
215
+ x = batch[0] # TODO change, ugly way to deal with n2v refac
216
+ else:
217
+ x = batch
184
218
  aux = []
185
219
 
186
220
  # apply test-time augmentation if available
@@ -593,6 +627,9 @@ def create_careamics_module(
593
627
  algorithm: Union[SupportedAlgorithm, str],
594
628
  loss: Union[SupportedLoss, str],
595
629
  architecture: Union[SupportedArchitecture, str],
630
+ use_n2v2: bool = False,
631
+ struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
632
+ struct_n2v_span: int = 5,
596
633
  model_parameters: Optional[dict] = None,
597
634
  optimizer: Union[SupportedOptimizer, str] = "Adam",
598
635
  optimizer_parameters: Optional[dict] = None,
@@ -612,6 +649,12 @@ def create_careamics_module(
612
649
  Loss function to use for training (see SupportedLoss).
613
650
  architecture : SupportedArchitecture or str
614
651
  Model architecture to use for training (see SupportedArchitecture).
652
+ use_n2v2 : bool, default=False
653
+ Whether to use N2V2 or Noise2Void.
654
+ struct_n2v_axis : "horizontal", "vertical", or "none", default="none"
655
+ Axis of the StructN2V mask.
656
+ struct_n2v_span : int, default=5
657
+ Span of the StructN2V mask.
615
658
  model_parameters : dict, optional
616
659
  Model parameters to use for training, by default {}. Model parameters are
617
660
  defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
@@ -633,14 +676,15 @@ def create_careamics_module(
633
676
  CAREamicsModule
634
677
  CAREamics Lightning module.
635
678
  """
636
- # create a AlgorithmModel compatible dictionary
679
+ # TODO should use the same functions are in configuration_factory.py
680
+ # create an AlgorithmModel compatible dictionary
637
681
  if lr_scheduler_parameters is None:
638
682
  lr_scheduler_parameters = {}
639
683
  if optimizer_parameters is None:
640
684
  optimizer_parameters = {}
641
685
  if model_parameters is None:
642
686
  model_parameters = {}
643
- algorithm_configuration: dict[str, Any] = {
687
+ algorithm_dict: dict[str, Any] = {
644
688
  "algorithm": algorithm,
645
689
  "loss": loss,
646
690
  "optimizer": {
@@ -652,18 +696,25 @@ def create_careamics_module(
652
696
  "parameters": lr_scheduler_parameters,
653
697
  },
654
698
  }
655
- model_configuration = {"architecture": architecture}
656
- model_configuration.update(model_parameters)
699
+
700
+ model_dict = {"architecture": architecture}
701
+ model_dict.update(model_parameters)
657
702
 
658
703
  # add model parameters to algorithm configuration
659
- algorithm_configuration["model"] = model_configuration
704
+ algorithm_dict["model"] = model_dict
705
+
706
+ which_algo = algorithm_dict["algorithm"]
707
+ if which_algo in UNetBasedAlgorithm.get_compatible_algorithms():
708
+ algorithm_cfg = algorithm_factory(algorithm_dict)
709
+
710
+ # if use N2V
711
+ if isinstance(algorithm_cfg, N2VAlgorithm):
712
+ algorithm_cfg.n2v_config.struct_mask_axis = struct_n2v_axis
713
+ algorithm_cfg.n2v_config.struct_mask_span = struct_n2v_span
714
+ algorithm_cfg.set_n2v2(use_n2v2)
660
715
 
661
- # call the parent init using an AlgorithmModel instance
662
- # TODO broken by new configutations!
663
- algorithm_str = algorithm_configuration["algorithm"]
664
- if algorithm_str in UNetBasedAlgorithm.get_compatible_algorithms():
665
- return FCNModule(UNetBasedAlgorithm(**algorithm_configuration))
716
+ return FCNModule(algorithm_cfg)
666
717
  else:
667
718
  raise NotImplementedError(
668
- f"Model {algorithm_str} is not implemented or unknown."
719
+ f"Algorithm {which_algo} is not implemented or unknown."
669
720
  )