careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (133) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +323 -134
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -14
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -221
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -12
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +112 -75
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -182
  111. careamics/bioimage/rdf.py +0 -105
  112. careamics/config/algorithm.py +0 -231
  113. careamics/config/config.py +0 -297
  114. careamics/config/config_filter.py +0 -44
  115. careamics/config/data.py +0 -194
  116. careamics/config/torch_optim.py +0 -118
  117. careamics/config/training.py +0 -534
  118. careamics/dataset/dataset_utils.py +0 -111
  119. careamics/dataset/patching.py +0 -492
  120. careamics/dataset/prepare_dataset.py +0 -175
  121. careamics/dataset/tiff_dataset.py +0 -212
  122. careamics/engine.py +0 -1014
  123. careamics/manipulation/__init__.py +0 -4
  124. careamics/manipulation/pixel_manipulation.py +0 -158
  125. careamics/prediction/prediction_utils.py +0 -106
  126. careamics/utils/ascii_logo.txt +0 -9
  127. careamics/utils/augment.py +0 -65
  128. careamics/utils/normalization.py +0 -55
  129. careamics/utils/validators.py +0 -170
  130. careamics/utils/wandb.py +0 -121
  131. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  132. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  133. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,555 @@
1
+ """Data configuration."""
2
+ from __future__ import annotations
3
+
4
+ from pprint import pformat
5
+ from typing import Any, List, Literal, Optional, Union
6
+
7
+ from albumentations import Compose
8
+ from pydantic import (
9
+ BaseModel,
10
+ ConfigDict,
11
+ Discriminator,
12
+ Field,
13
+ field_validator,
14
+ model_validator,
15
+ )
16
+ from typing_extensions import Annotated, Self
17
+
18
+ from .support import SupportedTransform
19
+ from .transformations.n2v_manipulate_model import N2VManipulateModel
20
+ from .transformations.nd_flip_model import NDFlipModel
21
+ from .transformations.normalize_model import NormalizeModel
22
+ from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
23
+ from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
24
+
25
+ TRANSFORMS_UNION = Annotated[
26
+ Union[
27
+ NDFlipModel,
28
+ XYRandomRotate90Model,
29
+ NormalizeModel,
30
+ N2VManipulateModel,
31
+ ],
32
+ Discriminator("name"), # used to tell the different transform models apart
33
+ ]
34
+
35
+
36
+ class DataModel(BaseModel):
37
+ """
38
+ Data configuration.
39
+
40
+ If std is specified, mean must be specified as well. Note that setting the std first
41
+ and then the mean (if they were both `None` before) will raise a validation error.
42
+ Prefer instead `set_mean_and_std` to set both at once.
43
+
44
+ Examples
45
+ --------
46
+ Minimum example:
47
+
48
+ >>> data = DataModel(
49
+ ... data_type="array", # defined in SupportedData
50
+ ... patch_size=[128, 128],
51
+ ... batch_size=4,
52
+ ... axes="YX"
53
+ ... )
54
+
55
+ To change the mean and std of the data:
56
+ >>> data.set_mean_and_std(mean=214.3, std=84.5)
57
+
58
+ One can pass also a list of transformations, by keyword, using the
59
+ SupportedTransform or the name of an Albumentation transform:
60
+ >>> from careamics.config.support import SupportedTransform
61
+ >>> data = DataModel(
62
+ ... data_type="tiff",
63
+ ... patch_size=[128, 128],
64
+ ... batch_size=4,
65
+ ... axes="YX",
66
+ ... transforms=[
67
+ ... {
68
+ ... "name": SupportedTransform.NORMALIZE.value,
69
+ ... "mean": 167.6,
70
+ ... "std": 47.2,
71
+ ... },
72
+ ... {
73
+ ... "name": "NDFlip",
74
+ ... "is_3D": True,
75
+ ... "flip_z": True,
76
+ ... }
77
+ ... ]
78
+ ... )
79
+ """
80
+
81
+ # Pydantic class configuration
82
+ model_config = ConfigDict(
83
+ validate_assignment=True,
84
+ arbitrary_types_allowed=True, # Allow Compose declaration
85
+ )
86
+
87
+ # Dataset configuration
88
+ data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
89
+ patch_size: Union[List[int]] = Field(..., min_length=2, max_length=3)
90
+ batch_size: int = Field(default=1, ge=1, validate_default=True)
91
+ axes: str
92
+
93
+ # Optional fields
94
+ mean: Optional[float] = None
95
+ std: Optional[float] = None
96
+
97
+ transforms: Union[List[TRANSFORMS_UNION], Compose] = Field(
98
+ default=[
99
+ {
100
+ "name": SupportedTransform.NORMALIZE.value,
101
+ },
102
+ {
103
+ "name": SupportedTransform.NDFLIP.value,
104
+ },
105
+ {
106
+ "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
107
+ },
108
+ {
109
+ "name": SupportedTransform.N2V_MANIPULATE.value,
110
+ },
111
+ ],
112
+ validate_default=True,
113
+ )
114
+
115
+ dataloader_params: Optional[dict] = None
116
+
117
+ @field_validator("patch_size")
118
+ @classmethod
119
+ def all_elements_power_of_2_minimum_8(
120
+ cls, patch_list: Union[List[int]]
121
+ ) -> Union[List[int]]:
122
+ """
123
+ Validate patch size.
124
+
125
+ Patch size must be powers of 2 and minimum 8.
126
+
127
+ Parameters
128
+ ----------
129
+ patch_list : Union[List[int]]
130
+ Patch size.
131
+
132
+ Returns
133
+ -------
134
+ Union[List[int]]
135
+ Validated patch size.
136
+
137
+ Raises
138
+ ------
139
+ ValueError
140
+ If the patch size is smaller than 8.
141
+ ValueError
142
+ If the patch size is not a power of 2.
143
+ """
144
+ patch_size_ge_than_8_power_of_2(patch_list)
145
+
146
+ return patch_list
147
+
148
+ @field_validator("axes")
149
+ @classmethod
150
+ def axes_valid(cls, axes: str) -> str:
151
+ """
152
+ Validate axes.
153
+
154
+ Axes must:
155
+ - be a combination of 'STCZYX'
156
+ - not contain duplicates
157
+ - contain at least 2 contiguous axes: X and Y
158
+ - contain at most 4 axes
159
+ - not contain both S and T axes
160
+
161
+ Parameters
162
+ ----------
163
+ axes : str
164
+ Axes to validate.
165
+
166
+ Returns
167
+ -------
168
+ str
169
+ Validated axes.
170
+
171
+ Raises
172
+ ------
173
+ ValueError
174
+ If axes are not valid.
175
+ """
176
+ # Validate axes
177
+ check_axes_validity(axes)
178
+
179
+ return axes
180
+
181
+ @field_validator("transforms")
182
+ @classmethod
183
+ def validate_prediction_transforms(
184
+ cls, transforms: Union[List[TRANSFORMS_UNION], Compose]
185
+ ) -> Union[List[TRANSFORMS_UNION], Compose]:
186
+ """
187
+ Validate N2VManipulate transform position in the transform list.
188
+
189
+ Parameters
190
+ ----------
191
+ transforms : Union[List[Transformations_Union], Compose]
192
+ Transforms.
193
+
194
+ Returns
195
+ -------
196
+ Union[List[Transformations_Union], Compose]
197
+ Validated transforms.
198
+
199
+ Raises
200
+ ------
201
+ ValueError
202
+ If multiple instances of N2VManipulate are found.
203
+ """
204
+ if not isinstance(transforms, Compose):
205
+ transform_list = [t.name for t in transforms]
206
+
207
+ if SupportedTransform.N2V_MANIPULATE in transform_list:
208
+ # multiple N2V_MANIPULATE
209
+ if transform_list.count(SupportedTransform.N2V_MANIPULATE) > 1:
210
+ raise ValueError(
211
+ f"Multiple instances of "
212
+ f"{SupportedTransform.N2V_MANIPULATE} transforms "
213
+ f"are not allowed."
214
+ )
215
+
216
+ # N2V_MANIPULATE not the last transform
217
+ elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
218
+ index = transform_list.index(SupportedTransform.N2V_MANIPULATE)
219
+ transform = transforms.pop(index)
220
+ transforms.append(transform)
221
+
222
+ return transforms
223
+
224
+ @model_validator(mode="after")
225
+ def std_only_with_mean(self: Self) -> Self:
226
+ """
227
+ Check that mean and std are either both None, or both specified.
228
+
229
+ Returns
230
+ -------
231
+ Self
232
+ Validated data model.
233
+
234
+ Raises
235
+ ------
236
+ ValueError
237
+ If std is not None and mean is None.
238
+ """
239
+ # check that mean and std are either both None, or both specified
240
+ if (self.mean is None) != (self.std is None):
241
+ raise ValueError(
242
+ "Mean and std must be either both None, or both specified."
243
+ )
244
+
245
+ return self
246
+
247
+ @model_validator(mode="after")
248
+ def add_std_and_mean_to_normalize(self: Self) -> Self:
249
+ """
250
+ Add mean and std to the Normalize transform if it is present.
251
+
252
+ Returns
253
+ -------
254
+ Self
255
+ Data model with mean and std added to the Normalize transform.
256
+ """
257
+ if self.mean is not None or self.std is not None:
258
+ # search in the transforms for Normalize and update parameters
259
+ if self.has_transform_list():
260
+ for transform in self.transforms:
261
+ if transform.name == SupportedTransform.NORMALIZE.value:
262
+ transform.mean = self.mean
263
+ transform.std = self.std
264
+
265
+ return self
266
+
267
+ @model_validator(mode="after")
268
+ def validate_dimensions(self: Self) -> Self:
269
+ """
270
+ Validate 2D/3D dimensions between axes, patch size and transforms.
271
+
272
+ Returns
273
+ -------
274
+ Self
275
+ Validated data model.
276
+
277
+ Raises
278
+ ------
279
+ ValueError
280
+ If the transforms are not valid.
281
+ """
282
+ if "Z" in self.axes:
283
+ if len(self.patch_size) != 3:
284
+ raise ValueError(
285
+ f"Patch size must have 3 dimensions if the data is 3D "
286
+ f"({self.axes})."
287
+ )
288
+
289
+ if self.has_transform_list():
290
+ for transform in self.transforms:
291
+ if transform.name == SupportedTransform.NDFLIP:
292
+ transform.is_3D = True
293
+ elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90:
294
+ transform.is_3D = True
295
+
296
+ else:
297
+ if len(self.patch_size) != 2:
298
+ raise ValueError(
299
+ f"Patch size must have 3 dimensions if the data is 3D "
300
+ f"({self.axes})."
301
+ )
302
+
303
+ if self.has_transform_list():
304
+ for transform in self.transforms:
305
+ if transform.name == SupportedTransform.NDFLIP:
306
+ transform.is_3D = False
307
+ elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90:
308
+ transform.is_3D = False
309
+
310
+ return self
311
+
312
+ def __str__(self) -> str:
313
+ """
314
+ Pretty string reprensenting the configuration.
315
+
316
+ Returns
317
+ -------
318
+ str
319
+ Pretty string.
320
+ """
321
+ return pformat(self.model_dump())
322
+
323
+ def _update(self, **kwargs: Any) -> None:
324
+ """
325
+ Update multiple arguments at once.
326
+
327
+ Parameters
328
+ ----------
329
+ **kwargs : Any
330
+ Keyword arguments to update.
331
+ """
332
+ self.__dict__.update(kwargs)
333
+ self.__class__.model_validate(self.__dict__)
334
+
335
+ def has_transform_list(self) -> bool:
336
+ """
337
+ Check if the transforms are a list, as opposed to a Compose object.
338
+
339
+ Returns
340
+ -------
341
+ bool
342
+ True if the transforms are a list, False otherwise.
343
+ """
344
+ return isinstance(self.transforms, list)
345
+
346
+ def has_n2v_manipulate(self) -> bool:
347
+ """
348
+ Check if the transforms contain N2VManipulate.
349
+
350
+ Use `has_transform_list` to check if the transforms are a list.
351
+
352
+ Returns
353
+ -------
354
+ bool
355
+ True if the transforms contain N2VManipulate, False otherwise.
356
+
357
+ Raises
358
+ ------
359
+ ValueError
360
+ If the transforms are a Compose object.
361
+ """
362
+ if self.has_transform_list():
363
+ return any(
364
+ transform.name == SupportedTransform.N2V_MANIPULATE.value
365
+ for transform in self.transforms
366
+ )
367
+ else:
368
+ raise ValueError(
369
+ "Checking for N2VManipulate with Compose transforms is not allowed. "
370
+ "Check directly in the Compose."
371
+ )
372
+
373
+ def add_n2v_manipulate(self) -> None:
374
+ """
375
+ Add N2VManipulate to the transforms.
376
+
377
+ Use `has_transform_list` to check if the transforms are a list.
378
+
379
+ Raises
380
+ ------
381
+ ValueError
382
+ If the transforms are a Compose object.
383
+ """
384
+ if self.has_transform_list():
385
+ if not self.has_n2v_manipulate():
386
+ self.transforms.append(
387
+ N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value)
388
+ )
389
+ else:
390
+ raise ValueError(
391
+ "Adding N2VManipulate with Compose transforms is not allowed. Add "
392
+ "N2VManipulate directly to the transform in the Compose."
393
+ )
394
+
395
+ def remove_n2v_manipulate(self) -> None:
396
+ """
397
+ Remove N2VManipulate from the transforms.
398
+
399
+ Use `has_transform_list` to check if the transforms are a list.
400
+
401
+ Raises
402
+ ------
403
+ ValueError
404
+ If the transforms are a Compose object.
405
+ """
406
+ if self.has_transform_list() and self.has_n2v_manipulate():
407
+ self.transforms.pop(-1)
408
+ else:
409
+ raise ValueError(
410
+ "Removing N2VManipulate with Compose transforms is not allowed. Remove "
411
+ "N2VManipulate directly from the transform in the Compose."
412
+ )
413
+
414
+ def set_mean_and_std(self, mean: float, std: float) -> None:
415
+ """
416
+ Set mean and standard deviation of the data.
417
+
418
+ This method should be used instead setting the fields directly, as it would
419
+ otherwise trigger a validation error.
420
+
421
+ Parameters
422
+ ----------
423
+ mean : float
424
+ Mean of the data.
425
+ std : float
426
+ Standard deviation of the data.
427
+ """
428
+ self._update(mean=mean, std=std)
429
+
430
+ # search in the transforms for Normalize and update parameters
431
+ if self.has_transform_list():
432
+ for transform in self.transforms:
433
+ if transform.name == SupportedTransform.NORMALIZE.value:
434
+ transform.mean = mean
435
+ transform.std = std
436
+ else:
437
+ raise ValueError(
438
+ "Setting mean and std with Compose transforms is not allowed. Add "
439
+ "mean and std parameters directly to the transform in the Compose."
440
+ )
441
+
442
+ def set_3D(self, axes: str, patch_size: List[int]) -> None:
443
+ """
444
+ Set 3D parameters.
445
+
446
+ Parameters
447
+ ----------
448
+ axes : str
449
+ Axes.
450
+ patch_size : List[int]
451
+ Patch size.
452
+ """
453
+ self._update(axes=axes, patch_size=patch_size)
454
+
455
+ def set_N2V2(self, use_n2v2: bool) -> None:
456
+ """
457
+ Set N2V2.
458
+
459
+ Parameters
460
+ ----------
461
+ use_n2v2 : bool
462
+ Whether to use N2V2.
463
+
464
+ Raises
465
+ ------
466
+ ValueError
467
+ If the N2V pixel manipulate transform is not found in the transforms.
468
+ ValueError
469
+ If the transforms are a Compose object.
470
+ """
471
+ if use_n2v2:
472
+ self.set_N2V2_strategy("median")
473
+ else:
474
+ self.set_N2V2_strategy("uniform")
475
+
476
+ def set_N2V2_strategy(self, strategy: Literal["uniform", "median"]) -> None:
477
+ """
478
+ Set N2V2 strategy.
479
+
480
+ Parameters
481
+ ----------
482
+ strategy : Literal["uniform", "median"]
483
+ Strategy to use for N2V2.
484
+
485
+ Raises
486
+ ------
487
+ ValueError
488
+ If the N2V pixel manipulate transform is not found in the transforms.
489
+ ValueError
490
+ If the transforms are a Compose object.
491
+ """
492
+ if isinstance(self.transforms, list):
493
+ found_n2v = False
494
+
495
+ for transform in self.transforms:
496
+ if transform.name == SupportedTransform.N2V_MANIPULATE.value:
497
+ transform.strategy = strategy
498
+ found_n2v = True
499
+
500
+ if not found_n2v:
501
+ transforms = [t.name for t in self.transforms]
502
+ raise ValueError(
503
+ f"N2V_Manipulate transform not found in the transforms list "
504
+ f"({transforms})."
505
+ )
506
+
507
+ else:
508
+ raise ValueError(
509
+ "Setting N2V2 strategy with Compose transforms is not allowed. Add "
510
+ "N2V2 strategy parameters directly to the transform in the Compose."
511
+ )
512
+
513
+ def set_structN2V_mask(
514
+ self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
515
+ ) -> None:
516
+ """
517
+ Set structN2V mask parameters.
518
+
519
+ Setting `mask_axis` to `none` will disable structN2V.
520
+
521
+ Parameters
522
+ ----------
523
+ mask_axis : Literal["horizontal", "vertical", "none"]
524
+ Axis along which to apply the mask. `none` will disable structN2V.
525
+ mask_span : int
526
+ Total span of the mask in pixels.
527
+
528
+ Raises
529
+ ------
530
+ ValueError
531
+ If the N2V pixel manipulate transform is not found in the transforms.
532
+ ValueError
533
+ If the transforms are a Compose object.
534
+ """
535
+ if isinstance(self.transforms, list):
536
+ found_n2v = False
537
+
538
+ for transform in self.transforms:
539
+ if transform.name == SupportedTransform.N2V_MANIPULATE.value:
540
+ transform.struct_mask_axis = mask_axis
541
+ transform.struct_mask_span = mask_span
542
+ found_n2v = True
543
+
544
+ if not found_n2v:
545
+ transforms = [t.name for t in self.transforms]
546
+ raise ValueError(
547
+ f"N2V pixel manipulate transform not found in the transforms "
548
+ f"({transforms})."
549
+ )
550
+
551
+ else:
552
+ raise ValueError(
553
+ "Setting structN2VMask with Compose transforms is not allowed. Add "
554
+ "structN2VMask parameters directly to the transform in the Compose."
555
+ )