careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +14 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +27 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_factory.py +460 -0
- careamics/config/configuration_model.py +596 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +665 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +390 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc3.dist-info/METADATA +122 -0
- careamics-0.1.0rc3.dist-info/RECORD +109 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,555 @@
|
|
|
1
|
+
"""Data configuration."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from pprint import pformat
|
|
5
|
+
from typing import Any, List, Literal, Optional, Union
|
|
6
|
+
|
|
7
|
+
from albumentations import Compose
|
|
8
|
+
from pydantic import (
|
|
9
|
+
BaseModel,
|
|
10
|
+
ConfigDict,
|
|
11
|
+
Discriminator,
|
|
12
|
+
Field,
|
|
13
|
+
field_validator,
|
|
14
|
+
model_validator,
|
|
15
|
+
)
|
|
16
|
+
from typing_extensions import Annotated, Self
|
|
17
|
+
|
|
18
|
+
from .support import SupportedTransform
|
|
19
|
+
from .transformations.n2v_manipulate_model import N2VManipulateModel
|
|
20
|
+
from .transformations.nd_flip_model import NDFlipModel
|
|
21
|
+
from .transformations.normalize_model import NormalizeModel
|
|
22
|
+
from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
|
|
23
|
+
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
24
|
+
|
|
25
|
+
TRANSFORMS_UNION = Annotated[
|
|
26
|
+
Union[
|
|
27
|
+
NDFlipModel,
|
|
28
|
+
XYRandomRotate90Model,
|
|
29
|
+
NormalizeModel,
|
|
30
|
+
N2VManipulateModel,
|
|
31
|
+
],
|
|
32
|
+
Discriminator("name"), # used to tell the different transform models apart
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DataModel(BaseModel):
|
|
37
|
+
"""
|
|
38
|
+
Data configuration.
|
|
39
|
+
|
|
40
|
+
If std is specified, mean must be specified as well. Note that setting the std first
|
|
41
|
+
and then the mean (if they were both `None` before) will raise a validation error.
|
|
42
|
+
Prefer instead `set_mean_and_std` to set both at once.
|
|
43
|
+
|
|
44
|
+
Examples
|
|
45
|
+
--------
|
|
46
|
+
Minimum example:
|
|
47
|
+
|
|
48
|
+
>>> data = DataModel(
|
|
49
|
+
... data_type="array", # defined in SupportedData
|
|
50
|
+
... patch_size=[128, 128],
|
|
51
|
+
... batch_size=4,
|
|
52
|
+
... axes="YX"
|
|
53
|
+
... )
|
|
54
|
+
|
|
55
|
+
To change the mean and std of the data:
|
|
56
|
+
>>> data.set_mean_and_std(mean=214.3, std=84.5)
|
|
57
|
+
|
|
58
|
+
One can pass also a list of transformations, by keyword, using the
|
|
59
|
+
SupportedTransform or the name of an Albumentation transform:
|
|
60
|
+
>>> from careamics.config.support import SupportedTransform
|
|
61
|
+
>>> data = DataModel(
|
|
62
|
+
... data_type="tiff",
|
|
63
|
+
... patch_size=[128, 128],
|
|
64
|
+
... batch_size=4,
|
|
65
|
+
... axes="YX",
|
|
66
|
+
... transforms=[
|
|
67
|
+
... {
|
|
68
|
+
... "name": SupportedTransform.NORMALIZE.value,
|
|
69
|
+
... "mean": 167.6,
|
|
70
|
+
... "std": 47.2,
|
|
71
|
+
... },
|
|
72
|
+
... {
|
|
73
|
+
... "name": "NDFlip",
|
|
74
|
+
... "is_3D": True,
|
|
75
|
+
... "flip_z": True,
|
|
76
|
+
... }
|
|
77
|
+
... ]
|
|
78
|
+
... )
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
# Pydantic class configuration
|
|
82
|
+
model_config = ConfigDict(
|
|
83
|
+
validate_assignment=True,
|
|
84
|
+
arbitrary_types_allowed=True, # Allow Compose declaration
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Dataset configuration
|
|
88
|
+
data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
|
|
89
|
+
patch_size: Union[List[int]] = Field(..., min_length=2, max_length=3)
|
|
90
|
+
batch_size: int = Field(default=1, ge=1, validate_default=True)
|
|
91
|
+
axes: str
|
|
92
|
+
|
|
93
|
+
# Optional fields
|
|
94
|
+
mean: Optional[float] = None
|
|
95
|
+
std: Optional[float] = None
|
|
96
|
+
|
|
97
|
+
transforms: Union[List[TRANSFORMS_UNION], Compose] = Field(
|
|
98
|
+
default=[
|
|
99
|
+
{
|
|
100
|
+
"name": SupportedTransform.NORMALIZE.value,
|
|
101
|
+
},
|
|
102
|
+
{
|
|
103
|
+
"name": SupportedTransform.NDFLIP.value,
|
|
104
|
+
},
|
|
105
|
+
{
|
|
106
|
+
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
107
|
+
},
|
|
108
|
+
{
|
|
109
|
+
"name": SupportedTransform.N2V_MANIPULATE.value,
|
|
110
|
+
},
|
|
111
|
+
],
|
|
112
|
+
validate_default=True,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
dataloader_params: Optional[dict] = None
|
|
116
|
+
|
|
117
|
+
@field_validator("patch_size")
|
|
118
|
+
@classmethod
|
|
119
|
+
def all_elements_power_of_2_minimum_8(
|
|
120
|
+
cls, patch_list: Union[List[int]]
|
|
121
|
+
) -> Union[List[int]]:
|
|
122
|
+
"""
|
|
123
|
+
Validate patch size.
|
|
124
|
+
|
|
125
|
+
Patch size must be powers of 2 and minimum 8.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
patch_list : Union[List[int]]
|
|
130
|
+
Patch size.
|
|
131
|
+
|
|
132
|
+
Returns
|
|
133
|
+
-------
|
|
134
|
+
Union[List[int]]
|
|
135
|
+
Validated patch size.
|
|
136
|
+
|
|
137
|
+
Raises
|
|
138
|
+
------
|
|
139
|
+
ValueError
|
|
140
|
+
If the patch size is smaller than 8.
|
|
141
|
+
ValueError
|
|
142
|
+
If the patch size is not a power of 2.
|
|
143
|
+
"""
|
|
144
|
+
patch_size_ge_than_8_power_of_2(patch_list)
|
|
145
|
+
|
|
146
|
+
return patch_list
|
|
147
|
+
|
|
148
|
+
@field_validator("axes")
|
|
149
|
+
@classmethod
|
|
150
|
+
def axes_valid(cls, axes: str) -> str:
|
|
151
|
+
"""
|
|
152
|
+
Validate axes.
|
|
153
|
+
|
|
154
|
+
Axes must:
|
|
155
|
+
- be a combination of 'STCZYX'
|
|
156
|
+
- not contain duplicates
|
|
157
|
+
- contain at least 2 contiguous axes: X and Y
|
|
158
|
+
- contain at most 4 axes
|
|
159
|
+
- not contain both S and T axes
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
axes : str
|
|
164
|
+
Axes to validate.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
str
|
|
169
|
+
Validated axes.
|
|
170
|
+
|
|
171
|
+
Raises
|
|
172
|
+
------
|
|
173
|
+
ValueError
|
|
174
|
+
If axes are not valid.
|
|
175
|
+
"""
|
|
176
|
+
# Validate axes
|
|
177
|
+
check_axes_validity(axes)
|
|
178
|
+
|
|
179
|
+
return axes
|
|
180
|
+
|
|
181
|
+
@field_validator("transforms")
|
|
182
|
+
@classmethod
|
|
183
|
+
def validate_prediction_transforms(
|
|
184
|
+
cls, transforms: Union[List[TRANSFORMS_UNION], Compose]
|
|
185
|
+
) -> Union[List[TRANSFORMS_UNION], Compose]:
|
|
186
|
+
"""
|
|
187
|
+
Validate N2VManipulate transform position in the transform list.
|
|
188
|
+
|
|
189
|
+
Parameters
|
|
190
|
+
----------
|
|
191
|
+
transforms : Union[List[Transformations_Union], Compose]
|
|
192
|
+
Transforms.
|
|
193
|
+
|
|
194
|
+
Returns
|
|
195
|
+
-------
|
|
196
|
+
Union[List[Transformations_Union], Compose]
|
|
197
|
+
Validated transforms.
|
|
198
|
+
|
|
199
|
+
Raises
|
|
200
|
+
------
|
|
201
|
+
ValueError
|
|
202
|
+
If multiple instances of N2VManipulate are found.
|
|
203
|
+
"""
|
|
204
|
+
if not isinstance(transforms, Compose):
|
|
205
|
+
transform_list = [t.name for t in transforms]
|
|
206
|
+
|
|
207
|
+
if SupportedTransform.N2V_MANIPULATE in transform_list:
|
|
208
|
+
# multiple N2V_MANIPULATE
|
|
209
|
+
if transform_list.count(SupportedTransform.N2V_MANIPULATE) > 1:
|
|
210
|
+
raise ValueError(
|
|
211
|
+
f"Multiple instances of "
|
|
212
|
+
f"{SupportedTransform.N2V_MANIPULATE} transforms "
|
|
213
|
+
f"are not allowed."
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# N2V_MANIPULATE not the last transform
|
|
217
|
+
elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
|
|
218
|
+
index = transform_list.index(SupportedTransform.N2V_MANIPULATE)
|
|
219
|
+
transform = transforms.pop(index)
|
|
220
|
+
transforms.append(transform)
|
|
221
|
+
|
|
222
|
+
return transforms
|
|
223
|
+
|
|
224
|
+
@model_validator(mode="after")
|
|
225
|
+
def std_only_with_mean(self: Self) -> Self:
|
|
226
|
+
"""
|
|
227
|
+
Check that mean and std are either both None, or both specified.
|
|
228
|
+
|
|
229
|
+
Returns
|
|
230
|
+
-------
|
|
231
|
+
Self
|
|
232
|
+
Validated data model.
|
|
233
|
+
|
|
234
|
+
Raises
|
|
235
|
+
------
|
|
236
|
+
ValueError
|
|
237
|
+
If std is not None and mean is None.
|
|
238
|
+
"""
|
|
239
|
+
# check that mean and std are either both None, or both specified
|
|
240
|
+
if (self.mean is None) != (self.std is None):
|
|
241
|
+
raise ValueError(
|
|
242
|
+
"Mean and std must be either both None, or both specified."
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
return self
|
|
246
|
+
|
|
247
|
+
@model_validator(mode="after")
|
|
248
|
+
def add_std_and_mean_to_normalize(self: Self) -> Self:
|
|
249
|
+
"""
|
|
250
|
+
Add mean and std to the Normalize transform if it is present.
|
|
251
|
+
|
|
252
|
+
Returns
|
|
253
|
+
-------
|
|
254
|
+
Self
|
|
255
|
+
Data model with mean and std added to the Normalize transform.
|
|
256
|
+
"""
|
|
257
|
+
if self.mean is not None or self.std is not None:
|
|
258
|
+
# search in the transforms for Normalize and update parameters
|
|
259
|
+
if self.has_transform_list():
|
|
260
|
+
for transform in self.transforms:
|
|
261
|
+
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
262
|
+
transform.mean = self.mean
|
|
263
|
+
transform.std = self.std
|
|
264
|
+
|
|
265
|
+
return self
|
|
266
|
+
|
|
267
|
+
@model_validator(mode="after")
|
|
268
|
+
def validate_dimensions(self: Self) -> Self:
|
|
269
|
+
"""
|
|
270
|
+
Validate 2D/3D dimensions between axes, patch size and transforms.
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
Self
|
|
275
|
+
Validated data model.
|
|
276
|
+
|
|
277
|
+
Raises
|
|
278
|
+
------
|
|
279
|
+
ValueError
|
|
280
|
+
If the transforms are not valid.
|
|
281
|
+
"""
|
|
282
|
+
if "Z" in self.axes:
|
|
283
|
+
if len(self.patch_size) != 3:
|
|
284
|
+
raise ValueError(
|
|
285
|
+
f"Patch size must have 3 dimensions if the data is 3D "
|
|
286
|
+
f"({self.axes})."
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
if self.has_transform_list():
|
|
290
|
+
for transform in self.transforms:
|
|
291
|
+
if transform.name == SupportedTransform.NDFLIP:
|
|
292
|
+
transform.is_3D = True
|
|
293
|
+
elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90:
|
|
294
|
+
transform.is_3D = True
|
|
295
|
+
|
|
296
|
+
else:
|
|
297
|
+
if len(self.patch_size) != 2:
|
|
298
|
+
raise ValueError(
|
|
299
|
+
f"Patch size must have 3 dimensions if the data is 3D "
|
|
300
|
+
f"({self.axes})."
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
if self.has_transform_list():
|
|
304
|
+
for transform in self.transforms:
|
|
305
|
+
if transform.name == SupportedTransform.NDFLIP:
|
|
306
|
+
transform.is_3D = False
|
|
307
|
+
elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90:
|
|
308
|
+
transform.is_3D = False
|
|
309
|
+
|
|
310
|
+
return self
|
|
311
|
+
|
|
312
|
+
def __str__(self) -> str:
|
|
313
|
+
"""
|
|
314
|
+
Pretty string reprensenting the configuration.
|
|
315
|
+
|
|
316
|
+
Returns
|
|
317
|
+
-------
|
|
318
|
+
str
|
|
319
|
+
Pretty string.
|
|
320
|
+
"""
|
|
321
|
+
return pformat(self.model_dump())
|
|
322
|
+
|
|
323
|
+
def _update(self, **kwargs: Any) -> None:
|
|
324
|
+
"""
|
|
325
|
+
Update multiple arguments at once.
|
|
326
|
+
|
|
327
|
+
Parameters
|
|
328
|
+
----------
|
|
329
|
+
**kwargs : Any
|
|
330
|
+
Keyword arguments to update.
|
|
331
|
+
"""
|
|
332
|
+
self.__dict__.update(kwargs)
|
|
333
|
+
self.__class__.model_validate(self.__dict__)
|
|
334
|
+
|
|
335
|
+
def has_transform_list(self) -> bool:
|
|
336
|
+
"""
|
|
337
|
+
Check if the transforms are a list, as opposed to a Compose object.
|
|
338
|
+
|
|
339
|
+
Returns
|
|
340
|
+
-------
|
|
341
|
+
bool
|
|
342
|
+
True if the transforms are a list, False otherwise.
|
|
343
|
+
"""
|
|
344
|
+
return isinstance(self.transforms, list)
|
|
345
|
+
|
|
346
|
+
def has_n2v_manipulate(self) -> bool:
|
|
347
|
+
"""
|
|
348
|
+
Check if the transforms contain N2VManipulate.
|
|
349
|
+
|
|
350
|
+
Use `has_transform_list` to check if the transforms are a list.
|
|
351
|
+
|
|
352
|
+
Returns
|
|
353
|
+
-------
|
|
354
|
+
bool
|
|
355
|
+
True if the transforms contain N2VManipulate, False otherwise.
|
|
356
|
+
|
|
357
|
+
Raises
|
|
358
|
+
------
|
|
359
|
+
ValueError
|
|
360
|
+
If the transforms are a Compose object.
|
|
361
|
+
"""
|
|
362
|
+
if self.has_transform_list():
|
|
363
|
+
return any(
|
|
364
|
+
transform.name == SupportedTransform.N2V_MANIPULATE.value
|
|
365
|
+
for transform in self.transforms
|
|
366
|
+
)
|
|
367
|
+
else:
|
|
368
|
+
raise ValueError(
|
|
369
|
+
"Checking for N2VManipulate with Compose transforms is not allowed. "
|
|
370
|
+
"Check directly in the Compose."
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
def add_n2v_manipulate(self) -> None:
|
|
374
|
+
"""
|
|
375
|
+
Add N2VManipulate to the transforms.
|
|
376
|
+
|
|
377
|
+
Use `has_transform_list` to check if the transforms are a list.
|
|
378
|
+
|
|
379
|
+
Raises
|
|
380
|
+
------
|
|
381
|
+
ValueError
|
|
382
|
+
If the transforms are a Compose object.
|
|
383
|
+
"""
|
|
384
|
+
if self.has_transform_list():
|
|
385
|
+
if not self.has_n2v_manipulate():
|
|
386
|
+
self.transforms.append(
|
|
387
|
+
N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value)
|
|
388
|
+
)
|
|
389
|
+
else:
|
|
390
|
+
raise ValueError(
|
|
391
|
+
"Adding N2VManipulate with Compose transforms is not allowed. Add "
|
|
392
|
+
"N2VManipulate directly to the transform in the Compose."
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
def remove_n2v_manipulate(self) -> None:
|
|
396
|
+
"""
|
|
397
|
+
Remove N2VManipulate from the transforms.
|
|
398
|
+
|
|
399
|
+
Use `has_transform_list` to check if the transforms are a list.
|
|
400
|
+
|
|
401
|
+
Raises
|
|
402
|
+
------
|
|
403
|
+
ValueError
|
|
404
|
+
If the transforms are a Compose object.
|
|
405
|
+
"""
|
|
406
|
+
if self.has_transform_list() and self.has_n2v_manipulate():
|
|
407
|
+
self.transforms.pop(-1)
|
|
408
|
+
else:
|
|
409
|
+
raise ValueError(
|
|
410
|
+
"Removing N2VManipulate with Compose transforms is not allowed. Remove "
|
|
411
|
+
"N2VManipulate directly from the transform in the Compose."
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
def set_mean_and_std(self, mean: float, std: float) -> None:
|
|
415
|
+
"""
|
|
416
|
+
Set mean and standard deviation of the data.
|
|
417
|
+
|
|
418
|
+
This method should be used instead setting the fields directly, as it would
|
|
419
|
+
otherwise trigger a validation error.
|
|
420
|
+
|
|
421
|
+
Parameters
|
|
422
|
+
----------
|
|
423
|
+
mean : float
|
|
424
|
+
Mean of the data.
|
|
425
|
+
std : float
|
|
426
|
+
Standard deviation of the data.
|
|
427
|
+
"""
|
|
428
|
+
self._update(mean=mean, std=std)
|
|
429
|
+
|
|
430
|
+
# search in the transforms for Normalize and update parameters
|
|
431
|
+
if self.has_transform_list():
|
|
432
|
+
for transform in self.transforms:
|
|
433
|
+
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
434
|
+
transform.mean = mean
|
|
435
|
+
transform.std = std
|
|
436
|
+
else:
|
|
437
|
+
raise ValueError(
|
|
438
|
+
"Setting mean and std with Compose transforms is not allowed. Add "
|
|
439
|
+
"mean and std parameters directly to the transform in the Compose."
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
def set_3D(self, axes: str, patch_size: List[int]) -> None:
|
|
443
|
+
"""
|
|
444
|
+
Set 3D parameters.
|
|
445
|
+
|
|
446
|
+
Parameters
|
|
447
|
+
----------
|
|
448
|
+
axes : str
|
|
449
|
+
Axes.
|
|
450
|
+
patch_size : List[int]
|
|
451
|
+
Patch size.
|
|
452
|
+
"""
|
|
453
|
+
self._update(axes=axes, patch_size=patch_size)
|
|
454
|
+
|
|
455
|
+
def set_N2V2(self, use_n2v2: bool) -> None:
|
|
456
|
+
"""
|
|
457
|
+
Set N2V2.
|
|
458
|
+
|
|
459
|
+
Parameters
|
|
460
|
+
----------
|
|
461
|
+
use_n2v2 : bool
|
|
462
|
+
Whether to use N2V2.
|
|
463
|
+
|
|
464
|
+
Raises
|
|
465
|
+
------
|
|
466
|
+
ValueError
|
|
467
|
+
If the N2V pixel manipulate transform is not found in the transforms.
|
|
468
|
+
ValueError
|
|
469
|
+
If the transforms are a Compose object.
|
|
470
|
+
"""
|
|
471
|
+
if use_n2v2:
|
|
472
|
+
self.set_N2V2_strategy("median")
|
|
473
|
+
else:
|
|
474
|
+
self.set_N2V2_strategy("uniform")
|
|
475
|
+
|
|
476
|
+
def set_N2V2_strategy(self, strategy: Literal["uniform", "median"]) -> None:
|
|
477
|
+
"""
|
|
478
|
+
Set N2V2 strategy.
|
|
479
|
+
|
|
480
|
+
Parameters
|
|
481
|
+
----------
|
|
482
|
+
strategy : Literal["uniform", "median"]
|
|
483
|
+
Strategy to use for N2V2.
|
|
484
|
+
|
|
485
|
+
Raises
|
|
486
|
+
------
|
|
487
|
+
ValueError
|
|
488
|
+
If the N2V pixel manipulate transform is not found in the transforms.
|
|
489
|
+
ValueError
|
|
490
|
+
If the transforms are a Compose object.
|
|
491
|
+
"""
|
|
492
|
+
if isinstance(self.transforms, list):
|
|
493
|
+
found_n2v = False
|
|
494
|
+
|
|
495
|
+
for transform in self.transforms:
|
|
496
|
+
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
497
|
+
transform.strategy = strategy
|
|
498
|
+
found_n2v = True
|
|
499
|
+
|
|
500
|
+
if not found_n2v:
|
|
501
|
+
transforms = [t.name for t in self.transforms]
|
|
502
|
+
raise ValueError(
|
|
503
|
+
f"N2V_Manipulate transform not found in the transforms list "
|
|
504
|
+
f"({transforms})."
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
else:
|
|
508
|
+
raise ValueError(
|
|
509
|
+
"Setting N2V2 strategy with Compose transforms is not allowed. Add "
|
|
510
|
+
"N2V2 strategy parameters directly to the transform in the Compose."
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
def set_structN2V_mask(
|
|
514
|
+
self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
|
|
515
|
+
) -> None:
|
|
516
|
+
"""
|
|
517
|
+
Set structN2V mask parameters.
|
|
518
|
+
|
|
519
|
+
Setting `mask_axis` to `none` will disable structN2V.
|
|
520
|
+
|
|
521
|
+
Parameters
|
|
522
|
+
----------
|
|
523
|
+
mask_axis : Literal["horizontal", "vertical", "none"]
|
|
524
|
+
Axis along which to apply the mask. `none` will disable structN2V.
|
|
525
|
+
mask_span : int
|
|
526
|
+
Total span of the mask in pixels.
|
|
527
|
+
|
|
528
|
+
Raises
|
|
529
|
+
------
|
|
530
|
+
ValueError
|
|
531
|
+
If the N2V pixel manipulate transform is not found in the transforms.
|
|
532
|
+
ValueError
|
|
533
|
+
If the transforms are a Compose object.
|
|
534
|
+
"""
|
|
535
|
+
if isinstance(self.transforms, list):
|
|
536
|
+
found_n2v = False
|
|
537
|
+
|
|
538
|
+
for transform in self.transforms:
|
|
539
|
+
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
540
|
+
transform.struct_mask_axis = mask_axis
|
|
541
|
+
transform.struct_mask_span = mask_span
|
|
542
|
+
found_n2v = True
|
|
543
|
+
|
|
544
|
+
if not found_n2v:
|
|
545
|
+
transforms = [t.name for t in self.transforms]
|
|
546
|
+
raise ValueError(
|
|
547
|
+
f"N2V pixel manipulate transform not found in the transforms "
|
|
548
|
+
f"({transforms})."
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
else:
|
|
552
|
+
raise ValueError(
|
|
553
|
+
"Setting structN2VMask with Compose transforms is not allowed. Add "
|
|
554
|
+
"structN2VMask parameters directly to the transform in the Compose."
|
|
555
|
+
)
|