careamics 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (56) hide show
  1. careamics/careamist.py +49 -49
  2. careamics/cli/conf.py +6 -6
  3. careamics/cli/main.py +8 -8
  4. careamics/cli/utils.py +2 -4
  5. careamics/config/algorithms/vae_algorithm_model.py +4 -4
  6. careamics/config/callback_model.py +8 -8
  7. careamics/config/configuration_factories.py +49 -49
  8. careamics/config/data/data_model.py +7 -13
  9. careamics/config/data/ng_data_model.py +8 -14
  10. careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
  11. careamics/config/inference_model.py +6 -10
  12. careamics/config/likelihood_model.py +2 -2
  13. careamics/config/nm_model.py +5 -7
  14. careamics/config/training_model.py +4 -4
  15. careamics/config/transformations/normalize_model.py +3 -3
  16. careamics/config/transformations/xy_flip_model.py +2 -2
  17. careamics/config/transformations/xy_random_rotate90_model.py +2 -2
  18. careamics/config/validators/validator_utils.py +1 -2
  19. careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
  20. careamics/dataset/in_memory_dataset.py +2 -2
  21. careamics/dataset/iterable_dataset.py +1 -2
  22. careamics/dataset/patching/random_patching.py +6 -6
  23. careamics/dataset/patching/sequential_patching.py +4 -4
  24. careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
  25. careamics/dataset_ng/dataset.py +3 -3
  26. careamics/dataset_ng/factory.py +19 -19
  27. careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
  28. careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
  29. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
  30. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
  31. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
  32. careamics/lightning/dataset_ng/data_module.py +43 -43
  33. careamics/lightning/lightning_module.py +12 -14
  34. careamics/lightning/predict_data_module.py +8 -8
  35. careamics/lightning/train_data_module.py +11 -11
  36. careamics/losses/lvae/losses.py +9 -9
  37. careamics/model_io/bioimage/model_description.py +12 -11
  38. careamics/model_io/bmz_io.py +4 -4
  39. careamics/models/layers.py +5 -5
  40. careamics/prediction_utils/lvae_prediction.py +5 -5
  41. careamics/transforms/compose.py +9 -9
  42. careamics/transforms/n2v_manipulate.py +3 -3
  43. careamics/transforms/n2v_manipulate_torch.py +4 -4
  44. careamics/transforms/normalize.py +4 -6
  45. careamics/transforms/pixel_manipulation.py +6 -8
  46. careamics/transforms/pixel_manipulation_torch.py +5 -7
  47. careamics/transforms/xy_flip.py +3 -5
  48. careamics/transforms/xy_random_rotate90.py +3 -5
  49. careamics/utils/logging.py +8 -8
  50. careamics/utils/metrics.py +2 -2
  51. careamics/utils/plotting.py +1 -3
  52. {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/METADATA +2 -3
  53. {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/RECORD +56 -56
  54. {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/WHEEL +0 -0
  55. {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/entry_points.txt +0 -0
  56. {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  from collections.abc import Callable
4
4
  from pathlib import Path
5
- from typing import Any, Literal, Optional, Union, overload
5
+ from typing import Any, Literal, Union, overload
6
6
 
7
7
  import numpy as np
8
8
  from numpy.typing import NDArray
@@ -79,8 +79,8 @@ class CAREamist:
79
79
  def __init__( # numpydoc ignore=GL08
80
80
  self,
81
81
  source: Union[Path, str],
82
- work_dir: Optional[Union[Path, str]] = None,
83
- callbacks: Optional[list[Callback]] = None,
82
+ work_dir: Union[Path, str] | None = None,
83
+ callbacks: list[Callback] | None = None,
84
84
  enable_progress_bar: bool = True,
85
85
  ) -> None: ...
86
86
 
@@ -88,16 +88,16 @@ class CAREamist:
88
88
  def __init__( # numpydoc ignore=GL08
89
89
  self,
90
90
  source: Configuration,
91
- work_dir: Optional[Union[Path, str]] = None,
92
- callbacks: Optional[list[Callback]] = None,
91
+ work_dir: Union[Path, str] | None = None,
92
+ callbacks: list[Callback] | None = None,
93
93
  enable_progress_bar: bool = True,
94
94
  ) -> None: ...
95
95
 
96
96
  def __init__(
97
97
  self,
98
98
  source: Union[Path, str, Configuration],
99
- work_dir: Optional[Union[Path, str]] = None,
100
- callbacks: Optional[list[Callback]] = None,
99
+ work_dir: Union[Path, str] | None = None,
100
+ callbacks: list[Callback] | None = None,
101
101
  enable_progress_bar: bool = True,
102
102
  ) -> None:
103
103
  """
@@ -222,11 +222,11 @@ class CAREamist:
222
222
  )
223
223
 
224
224
  # place holder for the datamodules
225
- self.train_datamodule: Optional[TrainDataModule] = None
226
- self.pred_datamodule: Optional[PredictDataModule] = None
225
+ self.train_datamodule: TrainDataModule | None = None
226
+ self.pred_datamodule: PredictDataModule | None = None
227
227
 
228
228
  def _define_callbacks(
229
- self, callbacks: Optional[list[Callback]], enable_progress_bar: bool
229
+ self, callbacks: list[Callback] | None, enable_progress_bar: bool
230
230
  ) -> None:
231
231
  """Define the callbacks for the training loop.
232
232
 
@@ -288,11 +288,11 @@ class CAREamist:
288
288
  def train(
289
289
  self,
290
290
  *,
291
- datamodule: Optional[TrainDataModule] = None,
292
- train_source: Optional[Union[Path, str, NDArray]] = None,
293
- val_source: Optional[Union[Path, str, NDArray]] = None,
294
- train_target: Optional[Union[Path, str, NDArray]] = None,
295
- val_target: Optional[Union[Path, str, NDArray]] = None,
291
+ datamodule: TrainDataModule | None = None,
292
+ train_source: Union[Path, str, NDArray] | None = None,
293
+ val_source: Union[Path, str, NDArray] | None = None,
294
+ train_target: Union[Path, str, NDArray] | None = None,
295
+ val_target: Union[Path, str, NDArray] | None = None,
296
296
  use_in_memory: bool = True,
297
297
  val_percentage: float = 0.1,
298
298
  val_minimum_split: int = 1,
@@ -443,9 +443,9 @@ class CAREamist:
443
443
  def _train_on_array(
444
444
  self,
445
445
  train_data: NDArray,
446
- val_data: Optional[NDArray] = None,
447
- train_target: Optional[NDArray] = None,
448
- val_target: Optional[NDArray] = None,
446
+ val_data: NDArray | None = None,
447
+ train_target: NDArray | None = None,
448
+ val_target: NDArray | None = None,
449
449
  val_percentage: float = 0.1,
450
450
  val_minimum_split: int = 5,
451
451
  ) -> None:
@@ -484,9 +484,9 @@ class CAREamist:
484
484
  def _train_on_path(
485
485
  self,
486
486
  path_to_train_data: Union[Path, str],
487
- path_to_val_data: Optional[Union[Path, str]] = None,
488
- path_to_train_target: Optional[Union[Path, str]] = None,
489
- path_to_val_target: Optional[Union[Path, str]] = None,
487
+ path_to_val_data: Union[Path, str] | None = None,
488
+ path_to_train_target: Union[Path, str] | None = None,
489
+ path_to_val_target: Union[Path, str] | None = None,
490
490
  use_in_memory: bool = True,
491
491
  val_percentage: float = 0.1,
492
492
  val_minimum_split: int = 1,
@@ -549,13 +549,13 @@ class CAREamist:
549
549
  source: Union[Path, str],
550
550
  *,
551
551
  batch_size: int = 1,
552
- tile_size: Optional[tuple[int, ...]] = None,
553
- tile_overlap: Optional[tuple[int, ...]] = (48, 48),
554
- axes: Optional[str] = None,
555
- data_type: Optional[Literal["tiff", "custom"]] = None,
552
+ tile_size: tuple[int, ...] | None = None,
553
+ tile_overlap: tuple[int, ...] | None = (48, 48),
554
+ axes: str | None = None,
555
+ data_type: Literal["tiff", "custom"] | None = None,
556
556
  tta_transforms: bool = False,
557
- dataloader_params: Optional[dict] = None,
558
- read_source_func: Optional[Callable] = None,
557
+ dataloader_params: dict | None = None,
558
+ read_source_func: Callable | None = None,
559
559
  extension_filter: str = "",
560
560
  ) -> Union[list[NDArray], NDArray]: ...
561
561
 
@@ -565,12 +565,12 @@ class CAREamist:
565
565
  source: NDArray,
566
566
  *,
567
567
  batch_size: int = 1,
568
- tile_size: Optional[tuple[int, ...]] = None,
569
- tile_overlap: Optional[tuple[int, ...]] = (48, 48),
570
- axes: Optional[str] = None,
571
- data_type: Optional[Literal["array"]] = None,
568
+ tile_size: tuple[int, ...] | None = None,
569
+ tile_overlap: tuple[int, ...] | None = (48, 48),
570
+ axes: str | None = None,
571
+ data_type: Literal["array"] | None = None,
572
572
  tta_transforms: bool = False,
573
- dataloader_params: Optional[dict] = None,
573
+ dataloader_params: dict | None = None,
574
574
  ) -> Union[list[NDArray], NDArray]: ...
575
575
 
576
576
  def predict(
@@ -578,13 +578,13 @@ class CAREamist:
578
578
  source: Union[PredictDataModule, Path, str, NDArray],
579
579
  *,
580
580
  batch_size: int = 1,
581
- tile_size: Optional[tuple[int, ...]] = None,
582
- tile_overlap: Optional[tuple[int, ...]] = (48, 48),
583
- axes: Optional[str] = None,
584
- data_type: Optional[Literal["array", "tiff", "custom"]] = None,
581
+ tile_size: tuple[int, ...] | None = None,
582
+ tile_overlap: tuple[int, ...] | None = (48, 48),
583
+ axes: str | None = None,
584
+ data_type: Literal["array", "tiff", "custom"] | None = None,
585
585
  tta_transforms: bool = False,
586
- dataloader_params: Optional[dict] = None,
587
- read_source_func: Optional[Callable] = None,
586
+ dataloader_params: dict | None = None,
587
+ read_source_func: Callable | None = None,
588
588
  extension_filter: str = "",
589
589
  **kwargs: Any,
590
590
  ) -> Union[list[NDArray], NDArray]:
@@ -704,18 +704,18 @@ class CAREamist:
704
704
  source: Union[PredictDataModule, Path, str],
705
705
  *,
706
706
  batch_size: int = 1,
707
- tile_size: Optional[tuple[int, ...]] = None,
708
- tile_overlap: Optional[tuple[int, ...]] = (48, 48),
709
- axes: Optional[str] = None,
710
- data_type: Optional[Literal["tiff", "custom"]] = None,
707
+ tile_size: tuple[int, ...] | None = None,
708
+ tile_overlap: tuple[int, ...] | None = (48, 48),
709
+ axes: str | None = None,
710
+ data_type: Literal["tiff", "custom"] | None = None,
711
711
  tta_transforms: bool = False,
712
- dataloader_params: Optional[dict] = None,
713
- read_source_func: Optional[Callable] = None,
712
+ dataloader_params: dict | None = None,
713
+ read_source_func: Callable | None = None,
714
714
  extension_filter: str = "",
715
715
  write_type: Literal["tiff", "custom"] = "tiff",
716
- write_extension: Optional[str] = None,
717
- write_func: Optional[WriteFunc] = None,
718
- write_func_kwargs: Optional[dict[str, Any]] = None,
716
+ write_extension: str | None = None,
717
+ write_func: WriteFunc | None = None,
718
+ write_func_kwargs: dict[str, Any] | None = None,
719
719
  prediction_dir: Union[Path, str] = "predictions",
720
720
  **kwargs,
721
721
  ) -> None:
@@ -885,8 +885,8 @@ class CAREamist:
885
885
  authors: list[dict],
886
886
  general_description: str,
887
887
  data_description: str,
888
- covers: Optional[list[Union[Path, str]]] = None,
889
- channel_names: Optional[list[str]] = None,
888
+ covers: list[Union[Path, str]] | None = None,
889
+ channel_names: list[str] | None = None,
890
890
  model_version: str = "0.1.0",
891
891
  ) -> None:
892
892
  """Export the model to the BioImage Model Zoo format.
careamics/cli/conf.py CHANGED
@@ -3,7 +3,7 @@
3
3
  import sys
4
4
  from dataclasses import dataclass
5
5
  from pathlib import Path
6
- from typing import Annotated, Optional
6
+ from typing import Annotated
7
7
 
8
8
  import click
9
9
  import typer
@@ -135,10 +135,10 @@ def care( # numpydoc ignore=PR01
135
135
  ),
136
136
  ] = "mae",
137
137
  n_channels_in: Annotated[
138
- Optional[int], typer.Option(help="Number of channels in")
138
+ int | None, typer.Option(help="Number of channels in")
139
139
  ] = None,
140
140
  n_channels_out: Annotated[
141
- Optional[int], typer.Option(help="Number of channels out")
141
+ int | None, typer.Option(help="Number of channels out")
142
142
  ] = None,
143
143
  logger: Annotated[
144
144
  click.Choice,
@@ -222,10 +222,10 @@ def n2n( # numpydoc ignore=PR01
222
222
  ),
223
223
  ] = "mae",
224
224
  n_channels_in: Annotated[
225
- Optional[int], typer.Option(help="Number of channels in")
225
+ int | None, typer.Option(help="Number of channels in")
226
226
  ] = None,
227
227
  n_channels_out: Annotated[
228
- Optional[int], typer.Option(help="Number of channels out")
228
+ int | None, typer.Option(help="Number of channels out")
229
229
  ] = None,
230
230
  logger: Annotated[
231
231
  click.Choice,
@@ -300,7 +300,7 @@ def n2v( # numpydoc ignore=PR01
300
300
  ] = True,
301
301
  use_n2v2: Annotated[bool, typer.Option(help="Whether to use N2V2")] = False,
302
302
  n_channels: Annotated[
303
- Optional[int], typer.Option(help="Number of channels (in and out)")
303
+ int | None, typer.Option(help="Number of channels (in and out)")
304
304
  ] = None,
305
305
  roi_size: Annotated[int, typer.Option(help="N2V pixel manipulation area.")] = 11,
306
306
  masked_pixel_percentage: Annotated[
careamics/cli/main.py CHANGED
@@ -7,7 +7,7 @@ its implementation is contained in the conf.py file.
7
7
  """
8
8
 
9
9
  from pathlib import Path
10
- from typing import Annotated, Optional
10
+ from typing import Annotated
11
11
 
12
12
  import click
13
13
  import typer
@@ -47,7 +47,7 @@ def train( # numpydoc ignore=PR01
47
47
  ),
48
48
  ],
49
49
  train_target: Annotated[
50
- Optional[Path],
50
+ Path | None,
51
51
  typer.Option(
52
52
  "--train-target",
53
53
  "-tt",
@@ -58,7 +58,7 @@ def train( # numpydoc ignore=PR01
58
58
  ),
59
59
  ] = None,
60
60
  val_source: Annotated[
61
- Optional[Path],
61
+ Path | None,
62
62
  typer.Option(
63
63
  "--val-source",
64
64
  "-vs",
@@ -69,7 +69,7 @@ def train( # numpydoc ignore=PR01
69
69
  ),
70
70
  ] = None,
71
71
  val_target: Annotated[
72
- Optional[Path],
72
+ Path | None,
73
73
  typer.Option(
74
74
  "--val-target",
75
75
  "-vt",
@@ -96,7 +96,7 @@ def train( # numpydoc ignore=PR01
96
96
  typer.Option(help="Minimum number of files to use for validation,"),
97
97
  ] = 1,
98
98
  work_dir: Annotated[
99
- Optional[Path],
99
+ Path | None,
100
100
  typer.Option(
101
101
  "--work-dir",
102
102
  "-wd",
@@ -142,7 +142,7 @@ def predict( # numpydoc ignore=PR01
142
142
  ],
143
143
  batch_size: Annotated[int, typer.Option(help="Batch size.")] = 1,
144
144
  tile_size: Annotated[
145
- Optional[click.Tuple],
145
+ click.Tuple | None,
146
146
  typer.Option(
147
147
  help=(
148
148
  "Size of the tiles to use for prediction, (if the data "
@@ -164,7 +164,7 @@ def predict( # numpydoc ignore=PR01
164
164
  ),
165
165
  ] = (48, 48, -1),
166
166
  axes: Annotated[
167
- Optional[str],
167
+ str | None,
168
168
  typer.Option(
169
169
  help="Axes of the input data. If unused the data is assumed to have the "
170
170
  "same axes as the original training data."
@@ -190,7 +190,7 @@ def predict( # numpydoc ignore=PR01
190
190
  ] = "tiff",
191
191
  # TODO: could make dataloader_params as json, necessary?
192
192
  work_dir: Annotated[
193
- Optional[Path],
193
+ Path | None,
194
194
  typer.Option(
195
195
  "--work-dir",
196
196
  "-wd",
careamics/cli/utils.py CHANGED
@@ -1,11 +1,9 @@
1
1
  """Utility functions for the CAREamics CLI."""
2
2
 
3
- from typing import Optional
4
-
5
3
 
6
4
  def handle_2D_3D_callback(
7
- value: Optional[tuple[int, int, int]],
8
- ) -> Optional[tuple[int, ...]]:
5
+ value: tuple[int, int, int] | None,
6
+ ) -> tuple[int, ...] | None:
9
7
  """
10
8
  Callback for options that require 2D or 3D inputs.
11
9
 
@@ -3,7 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from pprint import pformat
6
- from typing import Literal, Optional
6
+ from typing import Literal
7
7
 
8
8
  from pydantic import BaseModel, ConfigDict, model_validator
9
9
  from typing_extensions import Self
@@ -45,9 +45,9 @@ class VAEBasedAlgorithm(BaseModel):
45
45
  # NOTE: these are all configs (pydantic models)
46
46
  loss: LVAELossConfig
47
47
  model: LVAEModel
48
- noise_model: Optional[MultiChannelNMConfig] = None
49
- noise_model_likelihood: Optional[NMLikelihoodConfig] = None
50
- gaussian_likelihood: Optional[GaussianLikelihoodConfig] = None
48
+ noise_model: MultiChannelNMConfig | None = None
49
+ noise_model_likelihood: NMLikelihoodConfig | None = None
50
+ gaussian_likelihood: GaussianLikelihoodConfig | None = None
51
51
 
52
52
  # Optional fields
53
53
  optimizer: OptimizerModel = OptimizerModel()
@@ -3,7 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from datetime import timedelta
6
- from typing import Literal, Optional
6
+ from typing import Literal
7
7
 
8
8
  from pydantic import (
9
9
  BaseModel,
@@ -33,7 +33,7 @@ class CheckpointModel(BaseModel):
33
33
  save_weights_only: bool = Field(default=False)
34
34
  """When `True`, only the model's weights will be saved (model.save_weights)."""
35
35
 
36
- save_last: Optional[Literal[True, False, "link"]] = Field(default=True)
36
+ save_last: Literal[True, False, "link"] | None = Field(default=True)
37
37
  """When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
38
38
 
39
39
  save_top_k: int = Field(default=3, ge=-1, le=100)
@@ -51,13 +51,13 @@ class CheckpointModel(BaseModel):
51
51
  auto_insert_metric_name: bool = Field(default=False)
52
52
  """When `True`, the checkpoints filenames will contain the metric name."""
53
53
 
54
- every_n_train_steps: Optional[int] = Field(default=None, ge=1, le=1000)
54
+ every_n_train_steps: int | None = Field(default=None, ge=1, le=1000)
55
55
  """Number of training steps between checkpoints."""
56
56
 
57
- train_time_interval: Optional[timedelta] = Field(default=None)
57
+ train_time_interval: timedelta | None = Field(default=None)
58
58
  """Checkpoints are monitored at the specified time interval."""
59
59
 
60
- every_n_epochs: Optional[int] = Field(default=None, ge=1, le=100)
60
+ every_n_epochs: int | None = Field(default=None, ge=1, le=100)
61
61
  """Number of epochs between checkpoints."""
62
62
 
63
63
 
@@ -96,14 +96,14 @@ class EarlyStoppingModel(BaseModel):
96
96
  """When `True`, stops training when the monitored quantity becomes `NaN` or
97
97
  `inf`."""
98
98
 
99
- stopping_threshold: Optional[float] = Field(default=None)
99
+ stopping_threshold: float | None = Field(default=None)
100
100
  """Stop training immediately once the monitored quantity reaches this threshold."""
101
101
 
102
- divergence_threshold: Optional[float] = Field(default=None)
102
+ divergence_threshold: float | None = Field(default=None)
103
103
  """Stop training as soon as the monitored quantity becomes worse than this
104
104
  threshold."""
105
105
 
106
- check_on_train_epoch_end: Optional[bool] = Field(default=False)
106
+ check_on_train_epoch_end: bool | None = Field(default=False)
107
107
  """Whether to run early stopping at the end of the training epoch. If this is
108
108
  `False`, then the check runs at the end of the validation."""
109
109
 
@@ -1,7 +1,7 @@
1
1
  """Convenience functions to create configurations for training and inference."""
2
2
 
3
3
  from collections.abc import Sequence
4
- from typing import Annotated, Any, Literal, Optional, Union
4
+ from typing import Annotated, Any, Literal, Union
5
5
 
6
6
  from pydantic import Field, TypeAdapter
7
7
 
@@ -50,7 +50,7 @@ def algorithm_factory(
50
50
 
51
51
 
52
52
  def _list_spatial_augmentations(
53
- augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]] = None,
53
+ augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
54
54
  ) -> list[SPATIAL_TRANSFORMS_UNION]:
55
55
  """
56
56
  List the augmentations to apply.
@@ -105,7 +105,7 @@ def _create_unet_configuration(
105
105
  n_channels_out: int,
106
106
  independent_channels: bool,
107
107
  use_n2v2: bool,
108
- model_params: Optional[dict[str, Any]] = None,
108
+ model_params: dict[str, Any] | None = None,
109
109
  ) -> UNetModel:
110
110
  """
111
111
  Create a dictionary with the parameters of the UNet model.
@@ -153,11 +153,11 @@ def _create_algorithm_configuration(
153
153
  n_channels_in: int,
154
154
  n_channels_out: int,
155
155
  use_n2v2: bool = False,
156
- model_params: Optional[dict] = None,
156
+ model_params: dict | None = None,
157
157
  optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
158
- optimizer_params: Optional[dict[str, Any]] = None,
158
+ optimizer_params: dict[str, Any] | None = None,
159
159
  lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
160
- lr_scheduler_params: Optional[dict[str, Any]] = None,
160
+ lr_scheduler_params: dict[str, Any] | None = None,
161
161
  ) -> dict:
162
162
  """
163
163
  Create a dictionary with the parameters of the algorithm model.
@@ -227,8 +227,8 @@ def _create_data_configuration(
227
227
  patch_size: list[int],
228
228
  batch_size: int,
229
229
  augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
230
- train_dataloader_params: Optional[dict[str, Any]] = None,
231
- val_dataloader_params: Optional[dict[str, Any]] = None,
230
+ train_dataloader_params: dict[str, Any] | None = None,
231
+ val_dataloader_params: dict[str, Any] | None = None,
232
232
  ) -> DataConfig:
233
233
  """
234
234
  Create a dictionary with the parameters of the data model.
@@ -283,11 +283,11 @@ def _create_ng_data_configuration(
283
283
  patch_size: Sequence[int],
284
284
  batch_size: int,
285
285
  augmentations: list[SPATIAL_TRANSFORMS_UNION],
286
- patch_overlaps: Optional[Sequence[int]] = None,
287
- train_dataloader_params: Optional[dict[str, Any]] = None,
288
- val_dataloader_params: Optional[dict[str, Any]] = None,
289
- test_dataloader_params: Optional[dict[str, Any]] = None,
290
- seed: Optional[int] = None,
286
+ patch_overlaps: Sequence[int] | None = None,
287
+ train_dataloader_params: dict[str, Any] | None = None,
288
+ val_dataloader_params: dict[str, Any] | None = None,
289
+ test_dataloader_params: dict[str, Any] | None = None,
290
+ seed: int | None = None,
291
291
  ) -> NGDataConfig:
292
292
  """
293
293
  Create a dictionary with the parameters of the data model.
@@ -359,7 +359,7 @@ def _create_ng_data_configuration(
359
359
  def _create_training_configuration(
360
360
  num_epochs: int,
361
361
  logger: Literal["wandb", "tensorboard", "none"],
362
- checkpoint_params: Optional[dict[str, Any]] = None,
362
+ checkpoint_params: dict[str, Any] | None = None,
363
363
  ) -> TrainingConfig:
364
364
  """
365
365
  Create a dictionary with the parameters of the training model.
@@ -395,20 +395,20 @@ def _create_supervised_config_dict(
395
395
  patch_size: list[int],
396
396
  batch_size: int,
397
397
  num_epochs: int,
398
- augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]] = None,
398
+ augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
399
399
  independent_channels: bool = True,
400
400
  loss: Literal["mae", "mse"] = "mae",
401
- n_channels_in: Optional[int] = None,
402
- n_channels_out: Optional[int] = None,
401
+ n_channels_in: int | None = None,
402
+ n_channels_out: int | None = None,
403
403
  logger: Literal["wandb", "tensorboard", "none"] = "none",
404
- model_params: Optional[dict] = None,
404
+ model_params: dict | None = None,
405
405
  optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
406
- optimizer_params: Optional[dict[str, Any]] = None,
406
+ optimizer_params: dict[str, Any] | None = None,
407
407
  lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
408
- lr_scheduler_params: Optional[dict[str, Any]] = None,
409
- train_dataloader_params: Optional[dict[str, Any]] = None,
410
- val_dataloader_params: Optional[dict[str, Any]] = None,
411
- checkpoint_params: Optional[dict[str, Any]] = None,
408
+ lr_scheduler_params: dict[str, Any] | None = None,
409
+ train_dataloader_params: dict[str, Any] | None = None,
410
+ val_dataloader_params: dict[str, Any] | None = None,
411
+ checkpoint_params: dict[str, Any] | None = None,
412
412
  ) -> dict:
413
413
  """
414
414
  Create a configuration for training CARE or Noise2Noise.
@@ -540,20 +540,20 @@ def create_care_configuration(
540
540
  patch_size: list[int],
541
541
  batch_size: int,
542
542
  num_epochs: int,
543
- augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
543
+ augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
544
544
  independent_channels: bool = True,
545
545
  loss: Literal["mae", "mse"] = "mae",
546
- n_channels_in: Optional[int] = None,
547
- n_channels_out: Optional[int] = None,
546
+ n_channels_in: int | None = None,
547
+ n_channels_out: int | None = None,
548
548
  logger: Literal["wandb", "tensorboard", "none"] = "none",
549
- model_params: Optional[dict] = None,
549
+ model_params: dict | None = None,
550
550
  optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
551
- optimizer_params: Optional[dict[str, Any]] = None,
551
+ optimizer_params: dict[str, Any] | None = None,
552
552
  lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
553
- lr_scheduler_params: Optional[dict[str, Any]] = None,
554
- train_dataloader_params: Optional[dict[str, Any]] = None,
555
- val_dataloader_params: Optional[dict[str, Any]] = None,
556
- checkpoint_params: Optional[dict[str, Any]] = None,
553
+ lr_scheduler_params: dict[str, Any] | None = None,
554
+ train_dataloader_params: dict[str, Any] | None = None,
555
+ val_dataloader_params: dict[str, Any] | None = None,
556
+ checkpoint_params: dict[str, Any] | None = None,
557
557
  ) -> Configuration:
558
558
  """
559
559
  Create a configuration for training CARE.
@@ -756,20 +756,20 @@ def create_n2n_configuration(
756
756
  patch_size: list[int],
757
757
  batch_size: int,
758
758
  num_epochs: int,
759
- augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
759
+ augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
760
760
  independent_channels: bool = True,
761
761
  loss: Literal["mae", "mse"] = "mae",
762
- n_channels_in: Optional[int] = None,
763
- n_channels_out: Optional[int] = None,
762
+ n_channels_in: int | None = None,
763
+ n_channels_out: int | None = None,
764
764
  logger: Literal["wandb", "tensorboard", "none"] = "none",
765
- model_params: Optional[dict] = None,
765
+ model_params: dict | None = None,
766
766
  optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
767
- optimizer_params: Optional[dict[str, Any]] = None,
767
+ optimizer_params: dict[str, Any] | None = None,
768
768
  lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
769
- lr_scheduler_params: Optional[dict[str, Any]] = None,
770
- train_dataloader_params: Optional[dict[str, Any]] = None,
771
- val_dataloader_params: Optional[dict[str, Any]] = None,
772
- checkpoint_params: Optional[dict[str, Any]] = None,
769
+ lr_scheduler_params: dict[str, Any] | None = None,
770
+ train_dataloader_params: dict[str, Any] | None = None,
771
+ val_dataloader_params: dict[str, Any] | None = None,
772
+ checkpoint_params: dict[str, Any] | None = None,
773
773
  ) -> Configuration:
774
774
  """
775
775
  Create a configuration for training Noise2Noise.
@@ -972,23 +972,23 @@ def create_n2v_configuration(
972
972
  patch_size: list[int],
973
973
  batch_size: int,
974
974
  num_epochs: int,
975
- augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
975
+ augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
976
976
  independent_channels: bool = True,
977
977
  use_n2v2: bool = False,
978
- n_channels: Optional[int] = None,
978
+ n_channels: int | None = None,
979
979
  roi_size: int = 11,
980
980
  masked_pixel_percentage: float = 0.2,
981
981
  struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
982
982
  struct_n2v_span: int = 5,
983
983
  logger: Literal["wandb", "tensorboard", "none"] = "none",
984
- model_params: Optional[dict] = None,
984
+ model_params: dict | None = None,
985
985
  optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
986
- optimizer_params: Optional[dict[str, Any]] = None,
986
+ optimizer_params: dict[str, Any] | None = None,
987
987
  lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
988
- lr_scheduler_params: Optional[dict[str, Any]] = None,
989
- train_dataloader_params: Optional[dict[str, Any]] = None,
990
- val_dataloader_params: Optional[dict[str, Any]] = None,
991
- checkpoint_params: Optional[dict[str, Any]] = None,
988
+ lr_scheduler_params: dict[str, Any] | None = None,
989
+ train_dataloader_params: dict[str, Any] | None = None,
990
+ val_dataloader_params: dict[str, Any] | None = None,
991
+ checkpoint_params: dict[str, Any] | None = None,
992
992
  ) -> Configuration:
993
993
  """
994
994
  Create a configuration for training Noise2Void.
@@ -6,7 +6,7 @@ import os
6
6
  import sys
7
7
  from collections.abc import Sequence
8
8
  from pprint import pformat
9
- from typing import Annotated, Any, Literal, Optional, Union
9
+ from typing import Annotated, Any, Literal, Union
10
10
  from warnings import warn
11
11
 
12
12
  import numpy as np
@@ -109,22 +109,16 @@ class DataConfig(BaseModel):
109
109
  """Batch size for training."""
110
110
 
111
111
  # Optional fields
112
- image_means: Optional[list[Float]] = Field(
113
- default=None, min_length=0, max_length=32
114
- )
112
+ image_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
115
113
  """Means of the data across channels, used for normalization."""
116
114
 
117
- image_stds: Optional[list[Float]] = Field(default=None, min_length=0, max_length=32)
115
+ image_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
118
116
  """Standard deviations of the data across channels, used for normalization."""
119
117
 
120
- target_means: Optional[list[Float]] = Field(
121
- default=None, min_length=0, max_length=32
122
- )
118
+ target_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
123
119
  """Means of the target data across channels, used for normalization."""
124
120
 
125
- target_stds: Optional[list[Float]] = Field(
126
- default=None, min_length=0, max_length=32
127
- )
121
+ target_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
128
122
  """Standard deviations of the target data across channels, used for
129
123
  normalization."""
130
124
 
@@ -388,8 +382,8 @@ class DataConfig(BaseModel):
388
382
  self,
389
383
  image_means: Union[NDArray, tuple, list, None],
390
384
  image_stds: Union[NDArray, tuple, list, None],
391
- target_means: Optional[Union[NDArray, tuple, list, None]] = None,
392
- target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
385
+ target_means: Union[NDArray, tuple, list, None] | None = None,
386
+ target_stds: Union[NDArray, tuple, list, None] | None = None,
393
387
  ) -> None:
394
388
  """
395
389
  Set mean and standard deviation of the data across channels.