careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (62) hide show
  1. careamics/careamist.py +12 -11
  2. careamics/config/__init__.py +0 -1
  3. careamics/config/architectures/unet_model.py +1 -0
  4. careamics/config/callback_model.py +1 -0
  5. careamics/config/configuration_example.py +0 -2
  6. careamics/config/configuration_factory.py +112 -42
  7. careamics/config/configuration_model.py +14 -16
  8. careamics/config/data_model.py +59 -157
  9. careamics/config/inference_model.py +19 -20
  10. careamics/config/references/algorithm_descriptions.py +1 -0
  11. careamics/config/references/references.py +1 -0
  12. careamics/config/support/supported_extraction_strategies.py +1 -0
  13. careamics/config/training_model.py +1 -0
  14. careamics/config/transformations/n2v_manipulate_model.py +1 -0
  15. careamics/config/transformations/nd_flip_model.py +6 -11
  16. careamics/config/transformations/normalize_model.py +1 -0
  17. careamics/config/transformations/transform_model.py +1 -0
  18. careamics/config/transformations/xy_random_rotate90_model.py +6 -8
  19. careamics/config/validators/validator_utils.py +1 -0
  20. careamics/conftest.py +1 -0
  21. careamics/dataset/dataset_utils/__init__.py +0 -1
  22. careamics/dataset/dataset_utils/dataset_utils.py +1 -0
  23. careamics/dataset/in_memory_dataset.py +14 -45
  24. careamics/dataset/iterable_dataset.py +13 -68
  25. careamics/dataset/patching/__init__.py +0 -7
  26. careamics/dataset/patching/patching.py +1 -0
  27. careamics/dataset/patching/sequential_patching.py +6 -6
  28. careamics/dataset/patching/tiled_patching.py +10 -6
  29. careamics/lightning_datamodule.py +20 -24
  30. careamics/lightning_module.py +1 -1
  31. careamics/lightning_prediction_datamodule.py +15 -10
  32. careamics/losses/__init__.py +0 -1
  33. careamics/losses/loss_factory.py +1 -0
  34. careamics/model_io/__init__.py +0 -1
  35. careamics/model_io/bioimage/_readme_factory.py +2 -1
  36. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  37. careamics/model_io/bioimage/model_description.py +1 -0
  38. careamics/model_io/bmz_io.py +2 -1
  39. careamics/models/layers.py +1 -0
  40. careamics/models/model_factory.py +1 -0
  41. careamics/models/unet.py +91 -17
  42. careamics/prediction/stitch_prediction.py +1 -0
  43. careamics/transforms/__init__.py +2 -23
  44. careamics/transforms/compose.py +98 -0
  45. careamics/transforms/n2v_manipulate.py +18 -23
  46. careamics/transforms/nd_flip.py +38 -64
  47. careamics/transforms/normalize.py +45 -34
  48. careamics/transforms/pixel_manipulation.py +2 -2
  49. careamics/transforms/transform.py +33 -0
  50. careamics/transforms/tta.py +2 -2
  51. careamics/transforms/xy_random_rotate90.py +41 -68
  52. careamics/utils/__init__.py +0 -1
  53. careamics/utils/context.py +1 -0
  54. careamics/utils/logging.py +1 -0
  55. careamics/utils/metrics.py +1 -0
  56. careamics/utils/torch_utils.py +1 -0
  57. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/METADATA +16 -61
  58. careamics-0.1.0rc5.dist-info/RECORD +111 -0
  59. careamics/dataset/patching/patch_transform.py +0 -44
  60. careamics-0.1.0rc4.dist-info/RECORD +0 -110
  61. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/WHEEL +0 -0
  62. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,10 @@
1
1
  """Data configuration."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  from pprint import pformat
5
6
  from typing import Any, List, Literal, Optional, Union
6
7
 
7
- from albumentations import Compose
8
8
  from pydantic import (
9
9
  BaseModel,
10
10
  ConfigDict,
@@ -71,8 +71,6 @@ class DataConfig(BaseModel):
71
71
  ... },
72
72
  ... {
73
73
  ... "name": "NDFlip",
74
- ... "is_3D": True,
75
- ... "flip_z": True,
76
74
  ... }
77
75
  ... ]
78
76
  ... )
@@ -81,7 +79,6 @@ class DataConfig(BaseModel):
81
79
  # Pydantic class configuration
82
80
  model_config = ConfigDict(
83
81
  validate_assignment=True,
84
- arbitrary_types_allowed=True, # Allow Compose declaration
85
82
  )
86
83
 
87
84
  # Dataset configuration
@@ -94,7 +91,7 @@ class DataConfig(BaseModel):
94
91
  mean: Optional[float] = None
95
92
  std: Optional[float] = None
96
93
 
97
- transforms: Union[List[TRANSFORMS_UNION], Compose] = Field(
94
+ transforms: List[TRANSFORMS_UNION] = Field(
98
95
  default=[
99
96
  {
100
97
  "name": SupportedTransform.NORMALIZE.value,
@@ -181,19 +178,19 @@ class DataConfig(BaseModel):
181
178
  @field_validator("transforms")
182
179
  @classmethod
183
180
  def validate_prediction_transforms(
184
- cls, transforms: Union[List[TRANSFORMS_UNION], Compose]
185
- ) -> Union[List[TRANSFORMS_UNION], Compose]:
181
+ cls, transforms: List[TRANSFORMS_UNION]
182
+ ) -> List[TRANSFORMS_UNION]:
186
183
  """
187
184
  Validate N2VManipulate transform position in the transform list.
188
185
 
189
186
  Parameters
190
187
  ----------
191
- transforms : Union[List[Transformations_Union], Compose]
188
+ transforms : List[Transformations_Union]
192
189
  Transforms.
193
190
 
194
191
  Returns
195
192
  -------
196
- Union[List[Transformations_Union], Compose]
193
+ List[TRANSFORMS_UNION]
197
194
  Validated transforms.
198
195
 
199
196
  Raises
@@ -201,23 +198,22 @@ class DataConfig(BaseModel):
201
198
  ValueError
202
199
  If multiple instances of N2VManipulate are found.
203
200
  """
204
- if not isinstance(transforms, Compose):
205
- transform_list = [t.name for t in transforms]
206
-
207
- if SupportedTransform.N2V_MANIPULATE in transform_list:
208
- # multiple N2V_MANIPULATE
209
- if transform_list.count(SupportedTransform.N2V_MANIPULATE) > 1:
210
- raise ValueError(
211
- f"Multiple instances of "
212
- f"{SupportedTransform.N2V_MANIPULATE} transforms "
213
- f"are not allowed."
214
- )
215
-
216
- # N2V_MANIPULATE not the last transform
217
- elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
218
- index = transform_list.index(SupportedTransform.N2V_MANIPULATE)
219
- transform = transforms.pop(index)
220
- transforms.append(transform)
201
+ transform_list = [t.name for t in transforms]
202
+
203
+ if SupportedTransform.N2V_MANIPULATE in transform_list:
204
+ # multiple N2V_MANIPULATE
205
+ if transform_list.count(SupportedTransform.N2V_MANIPULATE) > 1:
206
+ raise ValueError(
207
+ f"Multiple instances of "
208
+ f"{SupportedTransform.N2V_MANIPULATE} transforms "
209
+ f"are not allowed."
210
+ )
211
+
212
+ # N2V_MANIPULATE not the last transform
213
+ elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
214
+ index = transform_list.index(SupportedTransform.N2V_MANIPULATE)
215
+ transform = transforms.pop(index)
216
+ transforms.append(transform)
221
217
 
222
218
  return transforms
223
219
 
@@ -256,11 +252,10 @@ class DataConfig(BaseModel):
256
252
  """
257
253
  if self.mean is not None or self.std is not None:
258
254
  # search in the transforms for Normalize and update parameters
259
- if self.has_transform_list():
260
- for transform in self.transforms:
261
- if transform.name == SupportedTransform.NORMALIZE.value:
262
- transform.mean = self.mean
263
- transform.std = self.std
255
+ for transform in self.transforms:
256
+ if transform.name == SupportedTransform.NORMALIZE.value:
257
+ transform.mean = self.mean
258
+ transform.std = self.std
264
259
 
265
260
  return self
266
261
 
@@ -286,13 +281,6 @@ class DataConfig(BaseModel):
286
281
  f"({self.axes})."
287
282
  )
288
283
 
289
- if self.has_transform_list():
290
- for transform in self.transforms:
291
- if transform.name == SupportedTransform.NDFLIP:
292
- transform.is_3D = True
293
- elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90:
294
- transform.is_3D = True
295
-
296
284
  else:
297
285
  if len(self.patch_size) != 2:
298
286
  raise ValueError(
@@ -300,13 +288,6 @@ class DataConfig(BaseModel):
300
288
  f"({self.axes})."
301
289
  )
302
290
 
303
- if self.has_transform_list():
304
- for transform in self.transforms:
305
- if transform.name == SupportedTransform.NDFLIP:
306
- transform.is_3D = False
307
- elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90:
308
- transform.is_3D = False
309
-
310
291
  return self
311
292
 
312
293
  def __str__(self) -> str:
@@ -332,84 +313,31 @@ class DataConfig(BaseModel):
332
313
  self.__dict__.update(kwargs)
333
314
  self.__class__.model_validate(self.__dict__)
334
315
 
335
- def has_transform_list(self) -> bool:
336
- """
337
- Check if the transforms are a list, as opposed to a Compose object.
338
-
339
- Returns
340
- -------
341
- bool
342
- True if the transforms are a list, False otherwise.
343
- """
344
- return isinstance(self.transforms, list)
345
-
346
316
  def has_n2v_manipulate(self) -> bool:
347
317
  """
348
318
  Check if the transforms contain N2VManipulate.
349
319
 
350
- Use `has_transform_list` to check if the transforms are a list.
351
-
352
320
  Returns
353
321
  -------
354
322
  bool
355
323
  True if the transforms contain N2VManipulate, False otherwise.
356
-
357
- Raises
358
- ------
359
- ValueError
360
- If the transforms are a Compose object.
361
324
  """
362
- if self.has_transform_list():
363
- return any(
364
- transform.name == SupportedTransform.N2V_MANIPULATE.value
365
- for transform in self.transforms
366
- )
367
- else:
368
- raise ValueError(
369
- "Checking for N2VManipulate with Compose transforms is not allowed. "
370
- "Check directly in the Compose."
371
- )
325
+ return any(
326
+ transform.name == SupportedTransform.N2V_MANIPULATE.value
327
+ for transform in self.transforms
328
+ )
372
329
 
373
330
  def add_n2v_manipulate(self) -> None:
374
- """
375
- Add N2VManipulate to the transforms.
376
-
377
- Use `has_transform_list` to check if the transforms are a list.
378
-
379
- Raises
380
- ------
381
- ValueError
382
- If the transforms are a Compose object.
383
- """
384
- if self.has_transform_list():
385
- if not self.has_n2v_manipulate():
386
- self.transforms.append(
387
- N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value)
388
- )
389
- else:
390
- raise ValueError(
391
- "Adding N2VManipulate with Compose transforms is not allowed. Add "
392
- "N2VManipulate directly to the transform in the Compose."
331
+ """Add N2VManipulate to the transforms."""
332
+ if not self.has_n2v_manipulate():
333
+ self.transforms.append(
334
+ N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value)
393
335
  )
394
336
 
395
337
  def remove_n2v_manipulate(self) -> None:
396
- """
397
- Remove N2VManipulate from the transforms.
398
-
399
- Use `has_transform_list` to check if the transforms are a list.
400
-
401
- Raises
402
- ------
403
- ValueError
404
- If the transforms are a Compose object.
405
- """
406
- if self.has_transform_list() and self.has_n2v_manipulate():
338
+ """Remove N2VManipulate from the transforms."""
339
+ if self.has_n2v_manipulate():
407
340
  self.transforms.pop(-1)
408
- else:
409
- raise ValueError(
410
- "Removing N2VManipulate with Compose transforms is not allowed. Remove "
411
- "N2VManipulate directly from the transform in the Compose."
412
- )
413
341
 
414
342
  def set_mean_and_std(self, mean: float, std: float) -> None:
415
343
  """
@@ -428,16 +356,10 @@ class DataConfig(BaseModel):
428
356
  self._update(mean=mean, std=std)
429
357
 
430
358
  # search in the transforms for Normalize and update parameters
431
- if self.has_transform_list():
432
- for transform in self.transforms:
433
- if transform.name == SupportedTransform.NORMALIZE.value:
434
- transform.mean = mean
435
- transform.std = std
436
- else:
437
- raise ValueError(
438
- "Setting mean and std with Compose transforms is not allowed. Add "
439
- "mean and std parameters directly to the transform in the Compose."
440
- )
359
+ for transform in self.transforms:
360
+ if transform.name == SupportedTransform.NORMALIZE.value:
361
+ transform.mean = mean
362
+ transform.std = std
441
363
 
442
364
  def set_3D(self, axes: str, patch_size: List[int]) -> None:
443
365
  """
@@ -465,8 +387,6 @@ class DataConfig(BaseModel):
465
387
  ------
466
388
  ValueError
467
389
  If the N2V pixel manipulate transform is not found in the transforms.
468
- ValueError
469
- If the transforms are a Compose object.
470
390
  """
471
391
  if use_n2v2:
472
392
  self.set_N2V2_strategy("median")
@@ -486,28 +406,19 @@ class DataConfig(BaseModel):
486
406
  ------
487
407
  ValueError
488
408
  If the N2V pixel manipulate transform is not found in the transforms.
489
- ValueError
490
- If the transforms are a Compose object.
491
409
  """
492
- if isinstance(self.transforms, list):
493
- found_n2v = False
410
+ found_n2v = False
494
411
 
495
- for transform in self.transforms:
496
- if transform.name == SupportedTransform.N2V_MANIPULATE.value:
497
- transform.strategy = strategy
498
- found_n2v = True
499
-
500
- if not found_n2v:
501
- transforms = [t.name for t in self.transforms]
502
- raise ValueError(
503
- f"N2V_Manipulate transform not found in the transforms list "
504
- f"({transforms})."
505
- )
412
+ for transform in self.transforms:
413
+ if transform.name == SupportedTransform.N2V_MANIPULATE.value:
414
+ transform.strategy = strategy
415
+ found_n2v = True
506
416
 
507
- else:
417
+ if not found_n2v:
418
+ transforms = [t.name for t in self.transforms]
508
419
  raise ValueError(
509
- "Setting N2V2 strategy with Compose transforms is not allowed. Add "
510
- "N2V2 strategy parameters directly to the transform in the Compose."
420
+ f"N2V_Manipulate transform not found in the transforms list "
421
+ f"({transforms})."
511
422
  )
512
423
 
513
424
  def set_structN2V_mask(
@@ -529,27 +440,18 @@ class DataConfig(BaseModel):
529
440
  ------
530
441
  ValueError
531
442
  If the N2V pixel manipulate transform is not found in the transforms.
532
- ValueError
533
- If the transforms are a Compose object.
534
443
  """
535
- if isinstance(self.transforms, list):
536
- found_n2v = False
444
+ found_n2v = False
537
445
 
538
- for transform in self.transforms:
539
- if transform.name == SupportedTransform.N2V_MANIPULATE.value:
540
- transform.struct_mask_axis = mask_axis
541
- transform.struct_mask_span = mask_span
542
- found_n2v = True
446
+ for transform in self.transforms:
447
+ if transform.name == SupportedTransform.N2V_MANIPULATE.value:
448
+ transform.struct_mask_axis = mask_axis
449
+ transform.struct_mask_span = mask_span
450
+ found_n2v = True
543
451
 
544
- if not found_n2v:
545
- transforms = [t.name for t in self.transforms]
546
- raise ValueError(
547
- f"N2V pixel manipulate transform not found in the transforms "
548
- f"({transforms})."
549
- )
550
-
551
- else:
452
+ if not found_n2v:
453
+ transforms = [t.name for t in self.transforms]
552
454
  raise ValueError(
553
- "Setting structN2VMask with Compose transforms is not allowed. Add "
554
- "structN2VMask parameters directly to the transform in the Compose."
455
+ f"N2V pixel manipulate transform not found in the transforms "
456
+ f"({transforms})."
555
457
  )
@@ -1,9 +1,9 @@
1
1
  """Pydantic model representing CAREamics prediction configuration."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  from typing import Any, List, Literal, Optional, Union
5
6
 
6
- from albumentations import Compose
7
7
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
8
8
  from typing_extensions import Self
9
9
 
@@ -33,7 +33,7 @@ class InferenceConfig(BaseModel):
33
33
  mean: float
34
34
  std: float = Field(..., ge=0.0)
35
35
 
36
- transforms: Union[List[TRANSFORMS_UNION], Compose] = Field(
36
+ transforms: List[TRANSFORMS_UNION] = Field(
37
37
  default=[
38
38
  {
39
39
  "name": SupportedTransform.NORMALIZE.value,
@@ -51,22 +51,22 @@ class InferenceConfig(BaseModel):
51
51
  @field_validator("tile_overlap")
52
52
  @classmethod
53
53
  def all_elements_non_zero_even(
54
- cls, patch_list: Optional[Union[List[int]]]
54
+ cls, tile_overlap: Optional[Union[List[int]]]
55
55
  ) -> Optional[Union[List[int]]]:
56
56
  """
57
- Validate patch size.
57
+ Validate tile overlap.
58
58
 
59
- Patch size must be non-zero, positive and even.
59
+ Overlaps must be non-zero, positive and even.
60
60
 
61
61
  Parameters
62
62
  ----------
63
- patch_list : Optional[Union[List[int]]]
63
+ tile_overlap : Optional[Union[List[int]]]
64
64
  Patch size.
65
65
 
66
66
  Returns
67
67
  -------
68
68
  Optional[Union[List[int]]]
69
- Validated patch size.
69
+ Validated tile overlap.
70
70
 
71
71
  Raises
72
72
  ------
@@ -75,8 +75,8 @@ class InferenceConfig(BaseModel):
75
75
  ValueError
76
76
  If the patch size is not even.
77
77
  """
78
- if patch_list is not None:
79
- for dim in patch_list:
78
+ if tile_overlap is not None:
79
+ for dim in tile_overlap:
80
80
  if dim < 1:
81
81
  raise ValueError(
82
82
  f"Patch size must be non-zero positive (got {dim})."
@@ -85,7 +85,7 @@ class InferenceConfig(BaseModel):
85
85
  if dim % 2 != 0:
86
86
  raise ValueError(f"Patch size must be even (got {dim}).")
87
87
 
88
- return patch_list
88
+ return tile_overlap
89
89
 
90
90
  @field_validator("tile_size")
91
91
  @classmethod
@@ -152,19 +152,19 @@ class InferenceConfig(BaseModel):
152
152
  @field_validator("transforms")
153
153
  @classmethod
154
154
  def validate_transforms(
155
- cls, transforms: Union[List[TRANSFORMS_UNION], Compose]
156
- ) -> Union[List[TRANSFORMS_UNION], Compose]:
155
+ cls, transforms: List[TRANSFORMS_UNION]
156
+ ) -> List[TRANSFORMS_UNION]:
157
157
  """
158
158
  Validate that transforms do not have N2V pixel manipulate transforms.
159
159
 
160
160
  Parameters
161
161
  ----------
162
- transforms : Union[List[TransformModel], Compose]
162
+ transforms : List[TRANSFORMS_UNION]
163
163
  Transforms.
164
164
 
165
165
  Returns
166
166
  -------
167
- Union[List[Transformations_Union], Compose]
167
+ List[TRANSFORMS_UNION]
168
168
  Validated transforms.
169
169
 
170
170
  Raises
@@ -172,7 +172,7 @@ class InferenceConfig(BaseModel):
172
172
  ValueError
173
173
  If transforms contain N2V pixel manipulate transforms.
174
174
  """
175
- if not isinstance(transforms, Compose) and transforms is not None:
175
+ if transforms is not None:
176
176
  for transform in transforms:
177
177
  if transform.name == SupportedTransform.N2V_MANIPULATE.value:
178
178
  raise ValueError(
@@ -247,11 +247,10 @@ class InferenceConfig(BaseModel):
247
247
  """
248
248
  if self.mean is not None or self.std is not None:
249
249
  # search in the transforms for Normalize and update parameters
250
- if not isinstance(self.transforms, Compose):
251
- for transform in self.transforms:
252
- if transform.name == SupportedTransform.NORMALIZE.value:
253
- transform.mean = self.mean
254
- transform.std = self.std
250
+ for transform in self.transforms:
251
+ if transform.name == SupportedTransform.NORMALIZE.value:
252
+ transform.mean = self.mean
253
+ transform.std = self.std
255
254
 
256
255
  return self
257
256
 
@@ -1,4 +1,5 @@
1
1
  """Descriptions of the algorithms used in CAREmics."""
2
+
2
3
  from pydantic import BaseModel
3
4
 
4
5
  CUSTOM = "Custom"
@@ -1,4 +1,5 @@
1
1
  """References for the CAREamics algorithms."""
2
+
2
3
  from bioimageio.spec.generic.v0_3 import CiteEntry
3
4
 
4
5
  N2VRef = CiteEntry(
@@ -3,6 +3,7 @@ Extraction strategy module.
3
3
 
4
4
  This module defines the various extraction strategies available in CAREamics.
5
5
  """
6
+
6
7
  from careamics.utils import BaseEnum
7
8
 
8
9
 
@@ -1,4 +1,5 @@
1
1
  """Training configuration."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  from pprint import pformat
@@ -1,4 +1,5 @@
1
1
  """Pydantic model for the N2VManipulate transform."""
2
+
2
3
  from typing import Literal
3
4
 
4
5
  from pydantic import ConfigDict, Field, field_validator
@@ -1,7 +1,8 @@
1
1
  """Pydantic model for the NDFlip transform."""
2
- from typing import Literal
3
2
 
4
- from pydantic import ConfigDict, Field
3
+ from typing import Literal, Optional
4
+
5
+ from pydantic import ConfigDict
5
6
 
6
7
  from .transform_model import TransformModel
7
8
 
@@ -14,12 +15,8 @@ class NDFlipModel(TransformModel):
14
15
  ----------
15
16
  name : Literal["NDFlip"]
16
17
  Name of the transformation.
17
- p : float
18
- Probability of applying the transformation, by default 0.5.
19
- is_3D : bool
20
- Whether the transformation should be applied in 3D, by default False.
21
- flip_z : bool
22
- Whether to flip the z axis, by default True.
18
+ seed : Optional[int]
19
+ Seed for the random number generator.
23
20
  """
24
21
 
25
22
  model_config = ConfigDict(
@@ -27,6 +24,4 @@ class NDFlipModel(TransformModel):
27
24
  )
28
25
 
29
26
  name: Literal["NDFlip"] = "NDFlip"
30
- p: float = Field(default=0.5, ge=0.0, le=1.0)
31
- is_3D: bool = Field(default=False)
32
- flip_z: bool = Field(default=True)
27
+ seed: Optional[int] = None
@@ -1,4 +1,5 @@
1
1
  """Pydantic model for the Normalize transform."""
2
+
2
3
  from typing import Literal
3
4
 
4
5
  from pydantic import ConfigDict, Field
@@ -1,4 +1,5 @@
1
1
  """Parent model for the transforms."""
2
+
2
3
  from typing import Any, Dict
3
4
 
4
5
  from pydantic import BaseModel, ConfigDict
@@ -1,7 +1,8 @@
1
1
  """Pydantic model for the XYRandomRotate90 transform."""
2
- from typing import Literal
3
2
 
4
- from pydantic import ConfigDict, Field
3
+ from typing import Literal, Optional
4
+
5
+ from pydantic import ConfigDict
5
6
 
6
7
  from .transform_model import TransformModel
7
8
 
@@ -14,10 +15,8 @@ class XYRandomRotate90Model(TransformModel):
14
15
  ----------
15
16
  name : Literal["XYRandomRotate90"]
16
17
  Name of the transformation.
17
- p : float
18
- Probability of applying the transformation, by default 0.5.
19
- is_3D : bool
20
- Whether the transformation should be applied in 3D, by default False.
18
+ seed : Optional[int]
19
+ Seed for the random number generator.
21
20
  """
22
21
 
23
22
  model_config = ConfigDict(
@@ -25,5 +24,4 @@ class XYRandomRotate90Model(TransformModel):
25
24
  )
26
25
 
27
26
  name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
28
- p: float = Field(default=0.5, ge=0.0, le=1.0)
29
- is_3D: bool = Field(default=False)
27
+ seed: Optional[int] = None
@@ -3,6 +3,7 @@ Validator functions.
3
3
 
4
4
  These functions are used to validate dimensions and axes of inputs.
5
5
  """
6
+
6
7
  from typing import List, Optional, Tuple, Union
7
8
 
8
9
  _AXES = "STCZYX"
careamics/conftest.py CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  See https://sybil.readthedocs.io/en/latest/use.html#pytest
4
4
  """
5
+
5
6
  from pathlib import Path
6
7
 
7
8
  import pytest
@@ -1,6 +1,5 @@
1
1
  """Files and arrays utils used in the datasets."""
2
2
 
3
-
4
3
  __all__ = [
5
4
  "reshape_array",
6
5
  "get_files_size",
@@ -1,4 +1,5 @@
1
1
  """Convenience methods for datasets."""
2
+
2
3
  from typing import List, Tuple
3
4
 
4
5
  import numpy as np