careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +12 -11
- careamics/config/__init__.py +0 -1
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/callback_model.py +1 -0
- careamics/config/configuration_example.py +0 -2
- careamics/config/configuration_factory.py +112 -42
- careamics/config/configuration_model.py +14 -16
- careamics/config/data_model.py +59 -157
- careamics/config/inference_model.py +19 -20
- careamics/config/references/algorithm_descriptions.py +1 -0
- careamics/config/references/references.py +1 -0
- careamics/config/support/supported_extraction_strategies.py +1 -0
- careamics/config/training_model.py +1 -0
- careamics/config/transformations/n2v_manipulate_model.py +1 -0
- careamics/config/transformations/nd_flip_model.py +6 -11
- careamics/config/transformations/normalize_model.py +1 -0
- careamics/config/transformations/transform_model.py +1 -0
- careamics/config/transformations/xy_random_rotate90_model.py +6 -8
- careamics/config/validators/validator_utils.py +1 -0
- careamics/conftest.py +1 -0
- careamics/dataset/dataset_utils/__init__.py +0 -1
- careamics/dataset/dataset_utils/dataset_utils.py +1 -0
- careamics/dataset/in_memory_dataset.py +14 -45
- careamics/dataset/iterable_dataset.py +13 -68
- careamics/dataset/patching/__init__.py +0 -7
- careamics/dataset/patching/patching.py +1 -0
- careamics/dataset/patching/sequential_patching.py +6 -6
- careamics/dataset/patching/tiled_patching.py +10 -6
- careamics/lightning_datamodule.py +20 -24
- careamics/lightning_module.py +1 -1
- careamics/lightning_prediction_datamodule.py +15 -10
- careamics/losses/__init__.py +0 -1
- careamics/losses/loss_factory.py +1 -0
- careamics/model_io/__init__.py +0 -1
- careamics/model_io/bioimage/_readme_factory.py +2 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -0
- careamics/model_io/bioimage/model_description.py +1 -0
- careamics/model_io/bmz_io.py +2 -1
- careamics/models/layers.py +1 -0
- careamics/models/model_factory.py +1 -0
- careamics/models/unet.py +91 -17
- careamics/prediction/stitch_prediction.py +1 -0
- careamics/transforms/__init__.py +2 -23
- careamics/transforms/compose.py +98 -0
- careamics/transforms/n2v_manipulate.py +18 -23
- careamics/transforms/nd_flip.py +38 -64
- careamics/transforms/normalize.py +45 -34
- careamics/transforms/pixel_manipulation.py +2 -2
- careamics/transforms/transform.py +33 -0
- careamics/transforms/tta.py +2 -2
- careamics/transforms/xy_random_rotate90.py +41 -68
- careamics/utils/__init__.py +0 -1
- careamics/utils/context.py +1 -0
- careamics/utils/logging.py +1 -0
- careamics/utils/metrics.py +1 -0
- careamics/utils/torch_utils.py +1 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/METADATA +16 -61
- careamics-0.1.0rc5.dist-info/RECORD +111 -0
- careamics/dataset/patching/patch_transform.py +0 -44
- careamics-0.1.0rc4.dist-info/RECORD +0 -110
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/licenses/LICENSE +0 -0
careamics/config/data_model.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""Data configuration."""
|
|
2
|
+
|
|
2
3
|
from __future__ import annotations
|
|
3
4
|
|
|
4
5
|
from pprint import pformat
|
|
5
6
|
from typing import Any, List, Literal, Optional, Union
|
|
6
7
|
|
|
7
|
-
from albumentations import Compose
|
|
8
8
|
from pydantic import (
|
|
9
9
|
BaseModel,
|
|
10
10
|
ConfigDict,
|
|
@@ -71,8 +71,6 @@ class DataConfig(BaseModel):
|
|
|
71
71
|
... },
|
|
72
72
|
... {
|
|
73
73
|
... "name": "NDFlip",
|
|
74
|
-
... "is_3D": True,
|
|
75
|
-
... "flip_z": True,
|
|
76
74
|
... }
|
|
77
75
|
... ]
|
|
78
76
|
... )
|
|
@@ -81,7 +79,6 @@ class DataConfig(BaseModel):
|
|
|
81
79
|
# Pydantic class configuration
|
|
82
80
|
model_config = ConfigDict(
|
|
83
81
|
validate_assignment=True,
|
|
84
|
-
arbitrary_types_allowed=True, # Allow Compose declaration
|
|
85
82
|
)
|
|
86
83
|
|
|
87
84
|
# Dataset configuration
|
|
@@ -94,7 +91,7 @@ class DataConfig(BaseModel):
|
|
|
94
91
|
mean: Optional[float] = None
|
|
95
92
|
std: Optional[float] = None
|
|
96
93
|
|
|
97
|
-
transforms:
|
|
94
|
+
transforms: List[TRANSFORMS_UNION] = Field(
|
|
98
95
|
default=[
|
|
99
96
|
{
|
|
100
97
|
"name": SupportedTransform.NORMALIZE.value,
|
|
@@ -181,19 +178,19 @@ class DataConfig(BaseModel):
|
|
|
181
178
|
@field_validator("transforms")
|
|
182
179
|
@classmethod
|
|
183
180
|
def validate_prediction_transforms(
|
|
184
|
-
cls, transforms:
|
|
185
|
-
) ->
|
|
181
|
+
cls, transforms: List[TRANSFORMS_UNION]
|
|
182
|
+
) -> List[TRANSFORMS_UNION]:
|
|
186
183
|
"""
|
|
187
184
|
Validate N2VManipulate transform position in the transform list.
|
|
188
185
|
|
|
189
186
|
Parameters
|
|
190
187
|
----------
|
|
191
|
-
transforms :
|
|
188
|
+
transforms : List[Transformations_Union]
|
|
192
189
|
Transforms.
|
|
193
190
|
|
|
194
191
|
Returns
|
|
195
192
|
-------
|
|
196
|
-
|
|
193
|
+
List[TRANSFORMS_UNION]
|
|
197
194
|
Validated transforms.
|
|
198
195
|
|
|
199
196
|
Raises
|
|
@@ -201,23 +198,22 @@ class DataConfig(BaseModel):
|
|
|
201
198
|
ValueError
|
|
202
199
|
If multiple instances of N2VManipulate are found.
|
|
203
200
|
"""
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
transforms.append(transform)
|
|
201
|
+
transform_list = [t.name for t in transforms]
|
|
202
|
+
|
|
203
|
+
if SupportedTransform.N2V_MANIPULATE in transform_list:
|
|
204
|
+
# multiple N2V_MANIPULATE
|
|
205
|
+
if transform_list.count(SupportedTransform.N2V_MANIPULATE) > 1:
|
|
206
|
+
raise ValueError(
|
|
207
|
+
f"Multiple instances of "
|
|
208
|
+
f"{SupportedTransform.N2V_MANIPULATE} transforms "
|
|
209
|
+
f"are not allowed."
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# N2V_MANIPULATE not the last transform
|
|
213
|
+
elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
|
|
214
|
+
index = transform_list.index(SupportedTransform.N2V_MANIPULATE)
|
|
215
|
+
transform = transforms.pop(index)
|
|
216
|
+
transforms.append(transform)
|
|
221
217
|
|
|
222
218
|
return transforms
|
|
223
219
|
|
|
@@ -256,11 +252,10 @@ class DataConfig(BaseModel):
|
|
|
256
252
|
"""
|
|
257
253
|
if self.mean is not None or self.std is not None:
|
|
258
254
|
# search in the transforms for Normalize and update parameters
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
transform.std = self.std
|
|
255
|
+
for transform in self.transforms:
|
|
256
|
+
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
257
|
+
transform.mean = self.mean
|
|
258
|
+
transform.std = self.std
|
|
264
259
|
|
|
265
260
|
return self
|
|
266
261
|
|
|
@@ -286,13 +281,6 @@ class DataConfig(BaseModel):
|
|
|
286
281
|
f"({self.axes})."
|
|
287
282
|
)
|
|
288
283
|
|
|
289
|
-
if self.has_transform_list():
|
|
290
|
-
for transform in self.transforms:
|
|
291
|
-
if transform.name == SupportedTransform.NDFLIP:
|
|
292
|
-
transform.is_3D = True
|
|
293
|
-
elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90:
|
|
294
|
-
transform.is_3D = True
|
|
295
|
-
|
|
296
284
|
else:
|
|
297
285
|
if len(self.patch_size) != 2:
|
|
298
286
|
raise ValueError(
|
|
@@ -300,13 +288,6 @@ class DataConfig(BaseModel):
|
|
|
300
288
|
f"({self.axes})."
|
|
301
289
|
)
|
|
302
290
|
|
|
303
|
-
if self.has_transform_list():
|
|
304
|
-
for transform in self.transforms:
|
|
305
|
-
if transform.name == SupportedTransform.NDFLIP:
|
|
306
|
-
transform.is_3D = False
|
|
307
|
-
elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90:
|
|
308
|
-
transform.is_3D = False
|
|
309
|
-
|
|
310
291
|
return self
|
|
311
292
|
|
|
312
293
|
def __str__(self) -> str:
|
|
@@ -332,84 +313,31 @@ class DataConfig(BaseModel):
|
|
|
332
313
|
self.__dict__.update(kwargs)
|
|
333
314
|
self.__class__.model_validate(self.__dict__)
|
|
334
315
|
|
|
335
|
-
def has_transform_list(self) -> bool:
|
|
336
|
-
"""
|
|
337
|
-
Check if the transforms are a list, as opposed to a Compose object.
|
|
338
|
-
|
|
339
|
-
Returns
|
|
340
|
-
-------
|
|
341
|
-
bool
|
|
342
|
-
True if the transforms are a list, False otherwise.
|
|
343
|
-
"""
|
|
344
|
-
return isinstance(self.transforms, list)
|
|
345
|
-
|
|
346
316
|
def has_n2v_manipulate(self) -> bool:
|
|
347
317
|
"""
|
|
348
318
|
Check if the transforms contain N2VManipulate.
|
|
349
319
|
|
|
350
|
-
Use `has_transform_list` to check if the transforms are a list.
|
|
351
|
-
|
|
352
320
|
Returns
|
|
353
321
|
-------
|
|
354
322
|
bool
|
|
355
323
|
True if the transforms contain N2VManipulate, False otherwise.
|
|
356
|
-
|
|
357
|
-
Raises
|
|
358
|
-
------
|
|
359
|
-
ValueError
|
|
360
|
-
If the transforms are a Compose object.
|
|
361
324
|
"""
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
)
|
|
367
|
-
else:
|
|
368
|
-
raise ValueError(
|
|
369
|
-
"Checking for N2VManipulate with Compose transforms is not allowed. "
|
|
370
|
-
"Check directly in the Compose."
|
|
371
|
-
)
|
|
325
|
+
return any(
|
|
326
|
+
transform.name == SupportedTransform.N2V_MANIPULATE.value
|
|
327
|
+
for transform in self.transforms
|
|
328
|
+
)
|
|
372
329
|
|
|
373
330
|
def add_n2v_manipulate(self) -> None:
|
|
374
|
-
"""
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
Raises
|
|
380
|
-
------
|
|
381
|
-
ValueError
|
|
382
|
-
If the transforms are a Compose object.
|
|
383
|
-
"""
|
|
384
|
-
if self.has_transform_list():
|
|
385
|
-
if not self.has_n2v_manipulate():
|
|
386
|
-
self.transforms.append(
|
|
387
|
-
N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value)
|
|
388
|
-
)
|
|
389
|
-
else:
|
|
390
|
-
raise ValueError(
|
|
391
|
-
"Adding N2VManipulate with Compose transforms is not allowed. Add "
|
|
392
|
-
"N2VManipulate directly to the transform in the Compose."
|
|
331
|
+
"""Add N2VManipulate to the transforms."""
|
|
332
|
+
if not self.has_n2v_manipulate():
|
|
333
|
+
self.transforms.append(
|
|
334
|
+
N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value)
|
|
393
335
|
)
|
|
394
336
|
|
|
395
337
|
def remove_n2v_manipulate(self) -> None:
|
|
396
|
-
"""
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
Use `has_transform_list` to check if the transforms are a list.
|
|
400
|
-
|
|
401
|
-
Raises
|
|
402
|
-
------
|
|
403
|
-
ValueError
|
|
404
|
-
If the transforms are a Compose object.
|
|
405
|
-
"""
|
|
406
|
-
if self.has_transform_list() and self.has_n2v_manipulate():
|
|
338
|
+
"""Remove N2VManipulate from the transforms."""
|
|
339
|
+
if self.has_n2v_manipulate():
|
|
407
340
|
self.transforms.pop(-1)
|
|
408
|
-
else:
|
|
409
|
-
raise ValueError(
|
|
410
|
-
"Removing N2VManipulate with Compose transforms is not allowed. Remove "
|
|
411
|
-
"N2VManipulate directly from the transform in the Compose."
|
|
412
|
-
)
|
|
413
341
|
|
|
414
342
|
def set_mean_and_std(self, mean: float, std: float) -> None:
|
|
415
343
|
"""
|
|
@@ -428,16 +356,10 @@ class DataConfig(BaseModel):
|
|
|
428
356
|
self._update(mean=mean, std=std)
|
|
429
357
|
|
|
430
358
|
# search in the transforms for Normalize and update parameters
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
transform.std = std
|
|
436
|
-
else:
|
|
437
|
-
raise ValueError(
|
|
438
|
-
"Setting mean and std with Compose transforms is not allowed. Add "
|
|
439
|
-
"mean and std parameters directly to the transform in the Compose."
|
|
440
|
-
)
|
|
359
|
+
for transform in self.transforms:
|
|
360
|
+
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
361
|
+
transform.mean = mean
|
|
362
|
+
transform.std = std
|
|
441
363
|
|
|
442
364
|
def set_3D(self, axes: str, patch_size: List[int]) -> None:
|
|
443
365
|
"""
|
|
@@ -465,8 +387,6 @@ class DataConfig(BaseModel):
|
|
|
465
387
|
------
|
|
466
388
|
ValueError
|
|
467
389
|
If the N2V pixel manipulate transform is not found in the transforms.
|
|
468
|
-
ValueError
|
|
469
|
-
If the transforms are a Compose object.
|
|
470
390
|
"""
|
|
471
391
|
if use_n2v2:
|
|
472
392
|
self.set_N2V2_strategy("median")
|
|
@@ -486,28 +406,19 @@ class DataConfig(BaseModel):
|
|
|
486
406
|
------
|
|
487
407
|
ValueError
|
|
488
408
|
If the N2V pixel manipulate transform is not found in the transforms.
|
|
489
|
-
ValueError
|
|
490
|
-
If the transforms are a Compose object.
|
|
491
409
|
"""
|
|
492
|
-
|
|
493
|
-
found_n2v = False
|
|
410
|
+
found_n2v = False
|
|
494
411
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
if not found_n2v:
|
|
501
|
-
transforms = [t.name for t in self.transforms]
|
|
502
|
-
raise ValueError(
|
|
503
|
-
f"N2V_Manipulate transform not found in the transforms list "
|
|
504
|
-
f"({transforms})."
|
|
505
|
-
)
|
|
412
|
+
for transform in self.transforms:
|
|
413
|
+
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
414
|
+
transform.strategy = strategy
|
|
415
|
+
found_n2v = True
|
|
506
416
|
|
|
507
|
-
|
|
417
|
+
if not found_n2v:
|
|
418
|
+
transforms = [t.name for t in self.transforms]
|
|
508
419
|
raise ValueError(
|
|
509
|
-
"
|
|
510
|
-
"
|
|
420
|
+
f"N2V_Manipulate transform not found in the transforms list "
|
|
421
|
+
f"({transforms})."
|
|
511
422
|
)
|
|
512
423
|
|
|
513
424
|
def set_structN2V_mask(
|
|
@@ -529,27 +440,18 @@ class DataConfig(BaseModel):
|
|
|
529
440
|
------
|
|
530
441
|
ValueError
|
|
531
442
|
If the N2V pixel manipulate transform is not found in the transforms.
|
|
532
|
-
ValueError
|
|
533
|
-
If the transforms are a Compose object.
|
|
534
443
|
"""
|
|
535
|
-
|
|
536
|
-
found_n2v = False
|
|
444
|
+
found_n2v = False
|
|
537
445
|
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
446
|
+
for transform in self.transforms:
|
|
447
|
+
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
448
|
+
transform.struct_mask_axis = mask_axis
|
|
449
|
+
transform.struct_mask_span = mask_span
|
|
450
|
+
found_n2v = True
|
|
543
451
|
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
raise ValueError(
|
|
547
|
-
f"N2V pixel manipulate transform not found in the transforms "
|
|
548
|
-
f"({transforms})."
|
|
549
|
-
)
|
|
550
|
-
|
|
551
|
-
else:
|
|
452
|
+
if not found_n2v:
|
|
453
|
+
transforms = [t.name for t in self.transforms]
|
|
552
454
|
raise ValueError(
|
|
553
|
-
"
|
|
554
|
-
"
|
|
455
|
+
f"N2V pixel manipulate transform not found in the transforms "
|
|
456
|
+
f"({transforms})."
|
|
555
457
|
)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
"""Pydantic model representing CAREamics prediction configuration."""
|
|
2
|
+
|
|
2
3
|
from __future__ import annotations
|
|
3
4
|
|
|
4
5
|
from typing import Any, List, Literal, Optional, Union
|
|
5
6
|
|
|
6
|
-
from albumentations import Compose
|
|
7
7
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
8
8
|
from typing_extensions import Self
|
|
9
9
|
|
|
@@ -33,7 +33,7 @@ class InferenceConfig(BaseModel):
|
|
|
33
33
|
mean: float
|
|
34
34
|
std: float = Field(..., ge=0.0)
|
|
35
35
|
|
|
36
|
-
transforms:
|
|
36
|
+
transforms: List[TRANSFORMS_UNION] = Field(
|
|
37
37
|
default=[
|
|
38
38
|
{
|
|
39
39
|
"name": SupportedTransform.NORMALIZE.value,
|
|
@@ -51,22 +51,22 @@ class InferenceConfig(BaseModel):
|
|
|
51
51
|
@field_validator("tile_overlap")
|
|
52
52
|
@classmethod
|
|
53
53
|
def all_elements_non_zero_even(
|
|
54
|
-
cls,
|
|
54
|
+
cls, tile_overlap: Optional[Union[List[int]]]
|
|
55
55
|
) -> Optional[Union[List[int]]]:
|
|
56
56
|
"""
|
|
57
|
-
Validate
|
|
57
|
+
Validate tile overlap.
|
|
58
58
|
|
|
59
|
-
|
|
59
|
+
Overlaps must be non-zero, positive and even.
|
|
60
60
|
|
|
61
61
|
Parameters
|
|
62
62
|
----------
|
|
63
|
-
|
|
63
|
+
tile_overlap : Optional[Union[List[int]]]
|
|
64
64
|
Patch size.
|
|
65
65
|
|
|
66
66
|
Returns
|
|
67
67
|
-------
|
|
68
68
|
Optional[Union[List[int]]]
|
|
69
|
-
Validated
|
|
69
|
+
Validated tile overlap.
|
|
70
70
|
|
|
71
71
|
Raises
|
|
72
72
|
------
|
|
@@ -75,8 +75,8 @@ class InferenceConfig(BaseModel):
|
|
|
75
75
|
ValueError
|
|
76
76
|
If the patch size is not even.
|
|
77
77
|
"""
|
|
78
|
-
if
|
|
79
|
-
for dim in
|
|
78
|
+
if tile_overlap is not None:
|
|
79
|
+
for dim in tile_overlap:
|
|
80
80
|
if dim < 1:
|
|
81
81
|
raise ValueError(
|
|
82
82
|
f"Patch size must be non-zero positive (got {dim})."
|
|
@@ -85,7 +85,7 @@ class InferenceConfig(BaseModel):
|
|
|
85
85
|
if dim % 2 != 0:
|
|
86
86
|
raise ValueError(f"Patch size must be even (got {dim}).")
|
|
87
87
|
|
|
88
|
-
return
|
|
88
|
+
return tile_overlap
|
|
89
89
|
|
|
90
90
|
@field_validator("tile_size")
|
|
91
91
|
@classmethod
|
|
@@ -152,19 +152,19 @@ class InferenceConfig(BaseModel):
|
|
|
152
152
|
@field_validator("transforms")
|
|
153
153
|
@classmethod
|
|
154
154
|
def validate_transforms(
|
|
155
|
-
cls, transforms:
|
|
156
|
-
) ->
|
|
155
|
+
cls, transforms: List[TRANSFORMS_UNION]
|
|
156
|
+
) -> List[TRANSFORMS_UNION]:
|
|
157
157
|
"""
|
|
158
158
|
Validate that transforms do not have N2V pixel manipulate transforms.
|
|
159
159
|
|
|
160
160
|
Parameters
|
|
161
161
|
----------
|
|
162
|
-
transforms :
|
|
162
|
+
transforms : List[TRANSFORMS_UNION]
|
|
163
163
|
Transforms.
|
|
164
164
|
|
|
165
165
|
Returns
|
|
166
166
|
-------
|
|
167
|
-
|
|
167
|
+
List[TRANSFORMS_UNION]
|
|
168
168
|
Validated transforms.
|
|
169
169
|
|
|
170
170
|
Raises
|
|
@@ -172,7 +172,7 @@ class InferenceConfig(BaseModel):
|
|
|
172
172
|
ValueError
|
|
173
173
|
If transforms contain N2V pixel manipulate transforms.
|
|
174
174
|
"""
|
|
175
|
-
if
|
|
175
|
+
if transforms is not None:
|
|
176
176
|
for transform in transforms:
|
|
177
177
|
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
178
178
|
raise ValueError(
|
|
@@ -247,11 +247,10 @@ class InferenceConfig(BaseModel):
|
|
|
247
247
|
"""
|
|
248
248
|
if self.mean is not None or self.std is not None:
|
|
249
249
|
# search in the transforms for Normalize and update parameters
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
transform.std = self.std
|
|
250
|
+
for transform in self.transforms:
|
|
251
|
+
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
252
|
+
transform.mean = self.mean
|
|
253
|
+
transform.std = self.std
|
|
255
254
|
|
|
256
255
|
return self
|
|
257
256
|
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Pydantic model for the NDFlip transform."""
|
|
2
|
-
from typing import Literal
|
|
3
2
|
|
|
4
|
-
from
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict
|
|
5
6
|
|
|
6
7
|
from .transform_model import TransformModel
|
|
7
8
|
|
|
@@ -14,12 +15,8 @@ class NDFlipModel(TransformModel):
|
|
|
14
15
|
----------
|
|
15
16
|
name : Literal["NDFlip"]
|
|
16
17
|
Name of the transformation.
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
is_3D : bool
|
|
20
|
-
Whether the transformation should be applied in 3D, by default False.
|
|
21
|
-
flip_z : bool
|
|
22
|
-
Whether to flip the z axis, by default True.
|
|
18
|
+
seed : Optional[int]
|
|
19
|
+
Seed for the random number generator.
|
|
23
20
|
"""
|
|
24
21
|
|
|
25
22
|
model_config = ConfigDict(
|
|
@@ -27,6 +24,4 @@ class NDFlipModel(TransformModel):
|
|
|
27
24
|
)
|
|
28
25
|
|
|
29
26
|
name: Literal["NDFlip"] = "NDFlip"
|
|
30
|
-
|
|
31
|
-
is_3D: bool = Field(default=False)
|
|
32
|
-
flip_z: bool = Field(default=True)
|
|
27
|
+
seed: Optional[int] = None
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Pydantic model for the XYRandomRotate90 transform."""
|
|
2
|
-
from typing import Literal
|
|
3
2
|
|
|
4
|
-
from
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict
|
|
5
6
|
|
|
6
7
|
from .transform_model import TransformModel
|
|
7
8
|
|
|
@@ -14,10 +15,8 @@ class XYRandomRotate90Model(TransformModel):
|
|
|
14
15
|
----------
|
|
15
16
|
name : Literal["XYRandomRotate90"]
|
|
16
17
|
Name of the transformation.
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
is_3D : bool
|
|
20
|
-
Whether the transformation should be applied in 3D, by default False.
|
|
18
|
+
seed : Optional[int]
|
|
19
|
+
Seed for the random number generator.
|
|
21
20
|
"""
|
|
22
21
|
|
|
23
22
|
model_config = ConfigDict(
|
|
@@ -25,5 +24,4 @@ class XYRandomRotate90Model(TransformModel):
|
|
|
25
24
|
)
|
|
26
25
|
|
|
27
26
|
name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
|
|
28
|
-
|
|
29
|
-
is_3D: bool = Field(default=False)
|
|
27
|
+
seed: Optional[int] = None
|
careamics/conftest.py
CHANGED