careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (103) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +92 -55
  4. careamics/config/__init__.py +0 -1
  5. careamics/config/algorithm_model.py +5 -3
  6. careamics/config/architectures/architecture_model.py +7 -0
  7. careamics/config/architectures/custom_model.py +8 -1
  8. careamics/config/architectures/register_model.py +3 -1
  9. careamics/config/architectures/unet_model.py +3 -0
  10. careamics/config/architectures/vae_model.py +2 -0
  11. careamics/config/callback_model.py +4 -15
  12. careamics/config/configuration_example.py +4 -4
  13. careamics/config/configuration_factory.py +113 -55
  14. careamics/config/configuration_model.py +14 -16
  15. careamics/config/data_model.py +63 -165
  16. careamics/config/inference_model.py +9 -75
  17. careamics/config/optimizer_models.py +4 -4
  18. careamics/config/references/algorithm_descriptions.py +1 -0
  19. careamics/config/references/references.py +1 -0
  20. careamics/config/support/__init__.py +0 -2
  21. careamics/config/support/supported_activations.py +2 -0
  22. careamics/config/support/supported_algorithms.py +3 -1
  23. careamics/config/support/supported_architectures.py +2 -0
  24. careamics/config/support/supported_data.py +2 -0
  25. careamics/config/support/supported_loggers.py +2 -0
  26. careamics/config/support/supported_losses.py +2 -0
  27. careamics/config/support/supported_optimizers.py +2 -0
  28. careamics/config/support/supported_pixel_manipulations.py +3 -3
  29. careamics/config/support/supported_struct_axis.py +2 -0
  30. careamics/config/support/supported_transforms.py +4 -15
  31. careamics/config/tile_information.py +2 -0
  32. careamics/config/training_model.py +1 -0
  33. careamics/config/transformations/__init__.py +3 -2
  34. careamics/config/transformations/n2v_manipulate_model.py +1 -0
  35. careamics/config/transformations/normalize_model.py +1 -0
  36. careamics/config/transformations/transform_model.py +1 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +13 -7
  39. careamics/config/validators/validator_utils.py +1 -0
  40. careamics/conftest.py +13 -0
  41. careamics/dataset/dataset_utils/__init__.py +0 -1
  42. careamics/dataset/dataset_utils/dataset_utils.py +5 -4
  43. careamics/dataset/dataset_utils/file_utils.py +4 -3
  44. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  45. careamics/dataset/dataset_utils/read_utils.py +2 -0
  46. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  47. careamics/dataset/in_memory_dataset.py +84 -76
  48. careamics/dataset/iterable_dataset.py +166 -134
  49. careamics/dataset/patching/__init__.py +0 -7
  50. careamics/dataset/patching/patching.py +56 -14
  51. careamics/dataset/patching/random_patching.py +8 -2
  52. careamics/dataset/patching/sequential_patching.py +20 -14
  53. careamics/dataset/patching/tiled_patching.py +13 -7
  54. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  55. careamics/dataset/zarr_dataset.py +2 -0
  56. careamics/lightning_datamodule.py +63 -41
  57. careamics/lightning_module.py +9 -3
  58. careamics/lightning_prediction_datamodule.py +15 -20
  59. careamics/lightning_prediction_loop.py +8 -6
  60. careamics/losses/__init__.py +1 -3
  61. careamics/losses/loss_factory.py +2 -1
  62. careamics/losses/losses.py +11 -7
  63. careamics/model_io/__init__.py +0 -1
  64. careamics/model_io/bioimage/_readme_factory.py +2 -1
  65. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  66. careamics/model_io/bioimage/model_description.py +1 -0
  67. careamics/model_io/bmz_io.py +4 -3
  68. careamics/models/activation.py +2 -0
  69. careamics/models/layers.py +122 -25
  70. careamics/models/model_factory.py +2 -1
  71. careamics/models/unet.py +114 -19
  72. careamics/prediction/stitch_prediction.py +2 -5
  73. careamics/transforms/__init__.py +4 -25
  74. careamics/transforms/compose.py +124 -0
  75. careamics/transforms/n2v_manipulate.py +65 -34
  76. careamics/transforms/normalize.py +91 -28
  77. careamics/transforms/pixel_manipulation.py +7 -7
  78. careamics/transforms/struct_mask_parameters.py +3 -1
  79. careamics/transforms/transform.py +24 -0
  80. careamics/transforms/tta.py +2 -2
  81. careamics/transforms/xy_flip.py +123 -0
  82. careamics/transforms/xy_random_rotate90.py +66 -60
  83. careamics/utils/__init__.py +0 -1
  84. careamics/utils/base_enum.py +28 -0
  85. careamics/utils/context.py +1 -0
  86. careamics/utils/logging.py +1 -0
  87. careamics/utils/metrics.py +1 -0
  88. careamics/utils/path_utils.py +2 -0
  89. careamics/utils/ram.py +2 -0
  90. careamics/utils/receptive_field.py +93 -87
  91. careamics/utils/torch_utils.py +1 -0
  92. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
  93. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  94. careamics/config/noise_models.py +0 -162
  95. careamics/config/support/supported_extraction_strategies.py +0 -24
  96. careamics/config/transformations/nd_flip_model.py +0 -32
  97. careamics/dataset/patching/patch_transform.py +0 -44
  98. careamics/losses/noise_model_factory.py +0 -40
  99. careamics/losses/noise_models.py +0 -524
  100. careamics/transforms/nd_flip.py +0 -93
  101. careamics-0.1.0rc4.dist-info/RECORD +0 -110
  102. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  103. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
@@ -1,8 +1,6 @@
1
1
  """Convenience functions to create configurations for training and inference."""
2
2
 
3
- from typing import Any, Dict, List, Literal, Optional, Tuple, Union
4
-
5
- from albumentations import Compose
3
+ from typing import Any, Dict, List, Literal, Optional, Tuple
6
4
 
7
5
  from .algorithm_model import AlgorithmConfig
8
6
  from .architectures import UNetModel
@@ -28,8 +26,10 @@ def _create_supervised_configuration(
28
26
  batch_size: int,
29
27
  num_epochs: int,
30
28
  use_augmentations: bool = True,
29
+ independent_channels: bool = False,
31
30
  loss: Literal["mae", "mse"] = "mae",
32
- n_channels: int = -1,
31
+ n_channels_in: int = 1,
32
+ n_channels_out: int = 1,
33
33
  logger: Literal["wandb", "tensorboard", "none"] = "none",
34
34
  model_kwargs: Optional[dict] = None,
35
35
  ) -> Configuration:
@@ -54,10 +54,14 @@ def _create_supervised_configuration(
54
54
  Number of epochs.
55
55
  use_augmentations : bool, optional
56
56
  Whether to use augmentations, by default True.
57
+ independent_channels : bool, optional
58
+ Whether to train all channels independently, by default False.
57
59
  loss : Literal["mae", "mse"], optional
58
60
  Loss function to use, by default "mae".
59
- n_channels : int, optional
60
- Number of channels (in and out), by default -1.
61
+ n_channels_in : int, optional
62
+ Number of channels in, by default 1.
63
+ n_channels_out : int, optional
64
+ Number of channels out, by default 1.
61
65
  logger : Literal["wandb", "tensorboard", "none"], optional
62
66
  Logger to use, by default "none".
63
67
  model_kwargs : dict, optional
@@ -69,23 +73,24 @@ def _create_supervised_configuration(
69
73
  Configuration for training CARE or Noise2Noise.
70
74
  """
71
75
  # if there are channels, we need to specify their number
72
- if "C" in axes and n_channels == 1:
76
+ if "C" in axes and n_channels_in == 1:
73
77
  raise ValueError(
74
- f"Number of channels must be specified when using channels "
75
- f"(got {n_channels} channel)."
78
+ f"Number of channels in must be specified when using channels "
79
+ f"(got {n_channels_in} channel)."
76
80
  )
77
- elif "C" not in axes and n_channels > 1:
81
+ elif "C" not in axes and n_channels_in > 1:
78
82
  raise ValueError(
79
83
  f"C is not present in the axes, but number of channels is specified "
80
- f"(got {n_channels} channel)."
84
+ f"(got {n_channels_in} channels)."
81
85
  )
82
86
 
83
87
  # model
84
88
  if model_kwargs is None:
85
89
  model_kwargs = {}
86
90
  model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
87
- model_kwargs["in_channels"] = n_channels
88
- model_kwargs["num_classes"] = n_channels
91
+ model_kwargs["in_channels"] = n_channels_in
92
+ model_kwargs["num_classes"] = n_channels_out
93
+ model_kwargs["independent_channels"] = independent_channels
89
94
 
90
95
  unet_model = UNetModel(
91
96
  architecture=SupportedArchitecture.UNET.value,
@@ -106,7 +111,7 @@ def _create_supervised_configuration(
106
111
  "name": SupportedTransform.NORMALIZE.value,
107
112
  },
108
113
  {
109
- "name": SupportedTransform.NDFLIP.value,
114
+ "name": SupportedTransform.XY_FLIP.value,
110
115
  },
111
116
  {
112
117
  "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
@@ -154,8 +159,10 @@ def create_care_configuration(
154
159
  batch_size: int,
155
160
  num_epochs: int,
156
161
  use_augmentations: bool = True,
162
+ independent_channels: bool = False,
157
163
  loss: Literal["mae", "mse"] = "mae",
158
- n_channels: int = 1,
164
+ n_channels_in: int = 1,
165
+ n_channels_out: int = -1,
159
166
  logger: Literal["wandb", "tensorboard", "none"] = "none",
160
167
  model_kwargs: Optional[dict] = None,
161
168
  ) -> Configuration:
@@ -165,10 +172,16 @@ def create_care_configuration(
165
172
  If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
166
173
  2.
167
174
 
168
- If "C" is present in `axes`, then you need to set `n_channels` to the number of
175
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
169
176
  channels. Likewise, if you set the number of channels, then "C" must be present in
170
177
  `axes`.
171
178
 
179
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
180
+ not specified, it will be assumed to be equal to `n_channels_in`.
181
+
182
+ By default, all channels are trained together. To train all channels independently,
183
+ set `independent_channels` to True.
184
+
172
185
  By setting `use_augmentations` to False, the only transformation applied will be
173
186
  normalization.
174
187
 
@@ -188,10 +201,14 @@ def create_care_configuration(
188
201
  Number of epochs.
189
202
  use_augmentations : bool, optional
190
203
  Whether to use augmentations, by default True.
204
+ independent_channels : bool, optional
205
+ Whether to train all channels independently, by default False.
191
206
  loss : Literal["mae", "mse"], optional
192
207
  Loss function to use, by default "mae".
193
- n_channels : int, optional
194
- Number of channels (in and out), by default 1.
208
+ n_channels_in : int, optional
209
+ Number of channels in, by default 1.
210
+ n_channels_out : int, optional
211
+ Number of channels out, by default -1.
195
212
  logger : Literal["wandb", "tensorboard", "none"], optional
196
213
  Logger to use, by default "none".
197
214
  model_kwargs : dict, optional
@@ -202,6 +219,9 @@ def create_care_configuration(
202
219
  Configuration
203
220
  Configuration for training CARE.
204
221
  """
222
+ if n_channels_out == -1:
223
+ n_channels_out = n_channels_in
224
+
205
225
  return _create_supervised_configuration(
206
226
  algorithm="care",
207
227
  experiment_name=experiment_name,
@@ -211,9 +231,10 @@ def create_care_configuration(
211
231
  batch_size=batch_size,
212
232
  num_epochs=num_epochs,
213
233
  use_augmentations=use_augmentations,
234
+ independent_channels=independent_channels,
214
235
  loss=loss,
215
- # TODO in the future we might support different in and out channels for CARE
216
- n_channels=n_channels,
236
+ n_channels_in=n_channels_in,
237
+ n_channels_out=n_channels_out,
217
238
  logger=logger,
218
239
  model_kwargs=model_kwargs,
219
240
  )
@@ -227,6 +248,7 @@ def create_n2n_configuration(
227
248
  batch_size: int,
228
249
  num_epochs: int,
229
250
  use_augmentations: bool = True,
251
+ independent_channels: bool = False,
230
252
  loss: Literal["mae", "mse"] = "mae",
231
253
  n_channels: int = 1,
232
254
  logger: Literal["wandb", "tensorboard", "none"] = "none",
@@ -242,6 +264,9 @@ def create_n2n_configuration(
242
264
  channels. Likewise, if you set the number of channels, then "C" must be present in
243
265
  `axes`.
244
266
 
267
+ By default, all channels are trained together. To train all channels independently,
268
+ set `independent_channels` to True.
269
+
245
270
  By setting `use_augmentations` to False, the only transformation applied will be
246
271
  normalization.
247
272
 
@@ -261,6 +286,8 @@ def create_n2n_configuration(
261
286
  Number of epochs.
262
287
  use_augmentations : bool, optional
263
288
  Whether to use augmentations, by default True.
289
+ independent_channels : bool, optional
290
+ Whether to train all channels independently, by default False.
264
291
  loss : Literal["mae", "mse"], optional
265
292
  Loss function to use, by default "mae".
266
293
  n_channels : int, optional
@@ -284,8 +311,10 @@ def create_n2n_configuration(
284
311
  batch_size=batch_size,
285
312
  num_epochs=num_epochs,
286
313
  use_augmentations=use_augmentations,
314
+ independent_channels=independent_channels,
287
315
  loss=loss,
288
- n_channels=n_channels,
316
+ n_channels_in=n_channels,
317
+ n_channels_out=n_channels,
289
318
  logger=logger,
290
319
  model_kwargs=model_kwargs,
291
320
  )
@@ -299,6 +328,7 @@ def create_n2v_configuration(
299
328
  batch_size: int,
300
329
  num_epochs: int,
301
330
  use_augmentations: bool = True,
331
+ independent_channels: bool = True,
302
332
  use_n2v2: bool = False,
303
333
  n_channels: int = 1,
304
334
  roi_size: int = 11,
@@ -320,11 +350,14 @@ def create_n2v_configuration(
320
350
  or horizontal correlations are present in the noise; it applies an additional mask
321
351
  to the manipulated pixel neighbors.
322
352
 
353
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
354
+ 2.
355
+
323
356
  If "C" is present in `axes`, then you need to set `n_channels` to the number of
324
357
  channels.
325
358
 
326
- If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
327
- 2.
359
+ By default, all channels are trained independently. To train all channels together,
360
+ set `independent_channels` to False.
328
361
 
329
362
  By setting `use_augmentations` to False, the only transformations applied will be
330
363
  normalization and N2V manipulation.
@@ -356,6 +389,8 @@ def create_n2v_configuration(
356
389
  Number of epochs.
357
390
  use_augmentations : bool, optional
358
391
  Whether to use augmentations, by default True.
392
+ independent_channels : bool, optional
393
+ Whether to train all channels together, by default True.
359
394
  use_n2v2 : bool, optional
360
395
  Whether to use N2V2, by default False.
361
396
  n_channels : int, optional
@@ -414,8 +449,8 @@ def create_n2v_configuration(
414
449
  ... struct_n2v_span=7
415
450
  ... )
416
451
 
417
- If you are training multiple channels together, then you need to specify the number
418
- of channels:
452
+ If you are training multiple channels independently, then you need to specify the
453
+ number of channels:
419
454
  >>> config = create_n2v_configuration(
420
455
  ... experiment_name="n2v_experiment",
421
456
  ... data_type="array",
@@ -426,6 +461,19 @@ def create_n2v_configuration(
426
461
  ... n_channels=3
427
462
  ... )
428
463
 
464
+ If instead you want to train multiple channels together, you need to turn off the
465
+ `independent_channels` parameter:
466
+ >>> config = create_n2v_configuration(
467
+ ... experiment_name="n2v_experiment",
468
+ ... data_type="array",
469
+ ... axes="YXC",
470
+ ... patch_size=[64, 64],
471
+ ... batch_size=32,
472
+ ... num_epochs=100,
473
+ ... independent_channels=False,
474
+ ... n_channels=3
475
+ ... )
476
+
429
477
  To turn off the augmentations, except normalization and N2V manipulation, use the
430
478
  relevant keyword argument:
431
479
  >>> config = create_n2v_configuration(
@@ -457,6 +505,7 @@ def create_n2v_configuration(
457
505
  model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
458
506
  model_kwargs["in_channels"] = n_channels
459
507
  model_kwargs["num_classes"] = n_channels
508
+ model_kwargs["independent_channels"] = independent_channels
460
509
 
461
510
  unet_model = UNetModel(
462
511
  architecture=SupportedArchitecture.UNET.value,
@@ -477,7 +526,7 @@ def create_n2v_configuration(
477
526
  "name": SupportedTransform.NORMALIZE.value,
478
527
  },
479
528
  {
480
- "name": SupportedTransform.NDFLIP.value,
529
+ "name": SupportedTransform.XY_FLIP.value,
481
530
  },
482
531
  {
483
532
  "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
@@ -493,9 +542,11 @@ def create_n2v_configuration(
493
542
  # n2v2 and structn2v
494
543
  nv2_transform = {
495
544
  "name": SupportedTransform.N2V_MANIPULATE.value,
496
- "strategy": SupportedPixelManipulation.MEDIAN.value
497
- if use_n2v2
498
- else SupportedPixelManipulation.UNIFORM.value,
545
+ "strategy": (
546
+ SupportedPixelManipulation.MEDIAN.value
547
+ if use_n2v2
548
+ else SupportedPixelManipulation.UNIFORM.value
549
+ ),
499
550
  "roi_size": roi_size,
500
551
  "masked_pixel_percentage": masked_pixel_percentage,
501
552
  "struct_mask_axis": struct_n2v_axis,
@@ -530,14 +581,12 @@ def create_n2v_configuration(
530
581
  return configuration
531
582
 
532
583
 
533
- # TODO add tests
534
584
  def create_inference_configuration(
535
- training_configuration: Configuration,
585
+ configuration: Configuration,
536
586
  tile_size: Optional[Tuple[int, ...]] = None,
537
587
  tile_overlap: Optional[Tuple[int, ...]] = None,
538
588
  data_type: Optional[Literal["array", "tiff", "custom"]] = None,
539
589
  axes: Optional[str] = None,
540
- transforms: Optional[Union[List[Dict[str, Any]], Compose]] = None,
541
590
  tta_transforms: bool = True,
542
591
  batch_size: Optional[int] = 1,
543
592
  ) -> InferenceConfig:
@@ -545,12 +594,12 @@ def create_inference_configuration(
545
594
  Create a configuration for inference with N2V.
546
595
 
547
596
  If not provided, `data_type` and `axes` are taken from the training
548
- configuration. If `transforms` are not provided, only normalization is applied.
597
+ configuration.
549
598
 
550
599
  Parameters
551
600
  ----------
552
- training_configuration : Configuration
553
- Configuration used for training.
601
+ configuration : Configuration
602
+ Global configuration.
554
603
  tile_size : Tuple[int, ...], optional
555
604
  Size of the tiles.
556
605
  tile_overlap : Tuple[int, ...], optional
@@ -559,8 +608,6 @@ def create_inference_configuration(
559
608
  Type of the data, by default "tiff".
560
609
  axes : str, optional
561
610
  Axes of the data, by default "YX".
562
- transforms : List[Dict[str, Any]] or Compose, optional
563
- Transformations to apply to the data, by default None.
564
611
  tta_transforms : bool, optional
565
612
  Whether to apply test-time augmentations, by default True.
566
613
  batch_size : int, optional
@@ -569,29 +616,40 @@ def create_inference_configuration(
569
616
  Returns
570
617
  -------
571
618
  InferenceConfiguration
572
- Configuration for inference with N2V.
619
+ Configuration used to configure CAREamicsPredictData.
573
620
  """
574
- if (
575
- training_configuration.data_config.mean is None
576
- or training_configuration.data_config.std is None
577
- ):
578
- raise ValueError("Mean and std must be provided in the training configuration.")
579
-
580
- if transforms is None:
581
- transforms = [
582
- {
583
- "name": SupportedTransform.NORMALIZE.value,
584
- },
585
- ]
621
+ if configuration.data_config.mean is None or configuration.data_config.std is None:
622
+ raise ValueError("Mean and std must be provided in the configuration.")
623
+
624
+ # tile size for UNets
625
+ if tile_size is not None:
626
+ model = configuration.algorithm_config.model
627
+
628
+ if model.architecture == SupportedArchitecture.UNET.value:
629
+ # tile size must be equal to k*2^n, where n is the number of pooling layers
630
+ # (equal to the depth) and k is an integer
631
+ depth = model.depth
632
+ tile_increment = 2**depth
633
+
634
+ for i, t in enumerate(tile_size):
635
+ if t % tile_increment != 0:
636
+ raise ValueError(
637
+ f"Tile size must be divisible by {tile_increment} along all "
638
+ f"axes (got {t} for axis {i}). If your image size is smaller "
639
+ f"along one axis (e.g. Z), consider padding the image."
640
+ )
641
+
642
+ # tile overlaps must be specified
643
+ if tile_overlap is None:
644
+ raise ValueError("Tile overlap must be specified.")
586
645
 
587
646
  return InferenceConfig(
588
- data_type=data_type or training_configuration.data_config.data_type,
647
+ data_type=data_type or configuration.data_config.data_type,
589
648
  tile_size=tile_size,
590
649
  tile_overlap=tile_overlap,
591
- axes=axes or training_configuration.data_config.axes,
592
- mean=training_configuration.data_config.mean,
593
- std=training_configuration.data_config.std,
594
- transforms=transforms,
650
+ axes=axes or configuration.data_config.axes,
651
+ mean=configuration.data_config.mean,
652
+ std=configuration.data_config.std,
595
653
  tta_transforms=tta_transforms,
596
654
  batch_size=batch_size,
597
655
  )
@@ -1,4 +1,5 @@
1
1
  """Pydantic CAREamics configuration."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  import re
@@ -238,25 +239,22 @@ class Configuration(BaseModel):
238
239
  Validated configuration.
239
240
  """
240
241
  if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
241
- # if we have a list of transform (as opposed to Compose)
242
- if self.data_config.has_transform_list():
243
- # missing N2V_MANIPULATE
244
- if not self.data_config.has_n2v_manipulate():
245
- self.data_config.transforms.append(
246
- N2VManipulateModel(
247
- name=SupportedTransform.N2V_MANIPULATE.value,
248
- )
242
+ # missing N2V_MANIPULATE
243
+ if not self.data_config.has_n2v_manipulate():
244
+ self.data_config.transforms.append(
245
+ N2VManipulateModel(
246
+ name=SupportedTransform.N2V_MANIPULATE.value,
249
247
  )
248
+ )
250
249
 
251
- median = SupportedPixelManipulation.MEDIAN.value
252
- uniform = SupportedPixelManipulation.UNIFORM.value
253
- strategy = median if self.algorithm_config.model.n2v2 else uniform
254
- self.data_config.set_N2V2_strategy(strategy)
250
+ median = SupportedPixelManipulation.MEDIAN.value
251
+ uniform = SupportedPixelManipulation.UNIFORM.value
252
+ strategy = median if self.algorithm_config.model.n2v2 else uniform
253
+ self.data_config.set_N2V2_strategy(strategy)
255
254
  else:
256
- # if we have a list of transform, remove N2V manipulate if present
257
- if self.data_config.has_transform_list():
258
- if self.data_config.has_n2v_manipulate():
259
- self.data_config.remove_n2v_manipulate()
255
+ # remove N2V manipulate if present
256
+ if self.data_config.has_n2v_manipulate():
257
+ self.data_config.remove_n2v_manipulate()
260
258
 
261
259
  return self
262
260