careamics 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (98) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +4 -3
  3. careamics/cli/conf.py +1 -2
  4. careamics/cli/main.py +1 -2
  5. careamics/cli/utils.py +3 -3
  6. careamics/config/__init__.py +47 -25
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +6 -1
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +103 -36
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +1 -2
  25. careamics/config/n2n_configuration.py +101 -0
  26. careamics/config/n2v_configuration.py +266 -0
  27. careamics/config/nm_model.py +1 -2
  28. careamics/config/support/__init__.py +7 -7
  29. careamics/config/support/supported_algorithms.py +0 -3
  30. careamics/config/support/supported_architectures.py +0 -4
  31. careamics/config/transformations/__init__.py +10 -4
  32. careamics/config/transformations/transform_model.py +3 -3
  33. careamics/config/transformations/transform_unions.py +42 -0
  34. careamics/config/validators/validator_utils.py +3 -3
  35. careamics/dataset/__init__.py +2 -2
  36. careamics/dataset/dataset_utils/__init__.py +3 -3
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  38. careamics/dataset/dataset_utils/file_utils.py +9 -9
  39. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  40. careamics/dataset/in_memory_dataset.py +11 -12
  41. careamics/dataset/iterable_dataset.py +4 -4
  42. careamics/dataset/iterable_pred_dataset.py +2 -1
  43. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  44. careamics/dataset/patching/random_patching.py +11 -10
  45. careamics/dataset/patching/sequential_patching.py +26 -26
  46. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  47. careamics/dataset/tiling/__init__.py +2 -2
  48. careamics/dataset/tiling/collate_tiles.py +3 -3
  49. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  50. careamics/dataset/tiling/tiled_patching.py +11 -10
  51. careamics/file_io/__init__.py +5 -5
  52. careamics/file_io/read/__init__.py +1 -1
  53. careamics/file_io/read/get_func.py +2 -2
  54. careamics/file_io/write/__init__.py +2 -2
  55. careamics/lightning/__init__.py +5 -5
  56. careamics/lightning/callbacks/__init__.py +1 -1
  57. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  58. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  60. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  61. careamics/lightning/lightning_module.py +11 -7
  62. careamics/lightning/train_data_module.py +26 -26
  63. careamics/losses/__init__.py +3 -3
  64. careamics/model_io/__init__.py +1 -1
  65. careamics/model_io/bioimage/__init__.py +1 -1
  66. careamics/model_io/bioimage/_readme_factory.py +1 -1
  67. careamics/model_io/bioimage/model_description.py +17 -17
  68. careamics/model_io/bmz_io.py +6 -17
  69. careamics/model_io/model_io_utils.py +9 -9
  70. careamics/models/layers.py +16 -16
  71. careamics/models/lvae/lvae.py +0 -3
  72. careamics/models/model_factory.py +2 -15
  73. careamics/models/unet.py +8 -8
  74. careamics/prediction_utils/__init__.py +1 -1
  75. careamics/prediction_utils/prediction_outputs.py +15 -15
  76. careamics/prediction_utils/stitch_prediction.py +6 -6
  77. careamics/transforms/__init__.py +5 -5
  78. careamics/transforms/compose.py +13 -13
  79. careamics/transforms/n2v_manipulate.py +3 -3
  80. careamics/transforms/pixel_manipulation.py +9 -9
  81. careamics/transforms/xy_random_rotate90.py +4 -4
  82. careamics/utils/__init__.py +5 -5
  83. careamics/utils/context.py +2 -1
  84. careamics/utils/logging.py +11 -10
  85. careamics/utils/torch_utils.py +7 -7
  86. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/METADATA +11 -11
  87. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/RECORD +90 -85
  88. careamics/config/architectures/custom_model.py +0 -162
  89. careamics/config/architectures/register_model.py +0 -103
  90. careamics/config/configuration_model.py +0 -603
  91. careamics/config/fcn_algorithm_model.py +0 -152
  92. careamics/config/references/__init__.py +0 -45
  93. careamics/config/references/algorithm_descriptions.py +0 -132
  94. careamics/config/references/references.py +0 -39
  95. careamics/config/transformations/transform_union.py +0 -20
  96. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/WHEEL +0 -0
  97. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  98. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -10,7 +10,7 @@ from typing import Callable, Optional
10
10
  import numpy as np
11
11
  from torch.utils.data import IterableDataset
12
12
 
13
- from careamics.config import DataConfig
13
+ from careamics.config import GeneralDataConfig
14
14
  from careamics.config.transformations import NormalizeModel
15
15
  from careamics.file_io.read import read_tiff
16
16
  from careamics.transforms import Compose
@@ -49,7 +49,7 @@ class PathIterableDataset(IterableDataset):
49
49
 
50
50
  def __init__(
51
51
  self,
52
- data_config: DataConfig,
52
+ data_config: GeneralDataConfig,
53
53
  src_files: list[Path],
54
54
  target_files: Optional[list[Path]] = None,
55
55
  read_source_func: Callable = read_tiff,
@@ -58,7 +58,7 @@ class PathIterableDataset(IterableDataset):
58
58
 
59
59
  Parameters
60
60
  ----------
61
- data_config : DataConfig
61
+ data_config : GeneralDataConfig
62
62
  Data configuration.
63
63
  src_files : list[Path]
64
64
  List of data files.
@@ -115,7 +115,7 @@ class PathIterableDataset(IterableDataset):
115
115
  target_stds=self.target_stats.stds,
116
116
  )
117
117
  ]
118
- + data_config.transforms
118
+ + list(data_config.transforms)
119
119
  )
120
120
 
121
121
  def _calculate_mean_and_std(self) -> tuple[Stats, Stats]:
@@ -2,8 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from collections.abc import Generator
5
6
  from pathlib import Path
6
- from typing import Any, Callable, Generator
7
+ from typing import Any, Callable
7
8
 
8
9
  from numpy.typing import NDArray
9
10
  from torch.utils.data import IterableDataset
@@ -2,8 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from collections.abc import Generator
5
6
  from pathlib import Path
6
- from typing import Any, Callable, Generator
7
+ from typing import Any, Callable
7
8
 
8
9
  from numpy.typing import NDArray
9
10
  from torch.utils.data import IterableDataset
@@ -1,6 +1,7 @@
1
1
  """Random patching utilities."""
2
2
 
3
- from typing import Generator, List, Optional, Tuple, Union
3
+ from collections.abc import Generator
4
+ from typing import Optional, Union
4
5
 
5
6
  import numpy as np
6
7
  import zarr
@@ -11,10 +12,10 @@ from .validate_patch_dimension import validate_patch_dimensions
11
12
  # TOOD split in testable functions
12
13
  def extract_patches_random(
13
14
  arr: np.ndarray,
14
- patch_size: Union[List[int], Tuple[int, ...]],
15
+ patch_size: Union[list[int], tuple[int, ...]],
15
16
  target: Optional[np.ndarray] = None,
16
17
  seed: Optional[int] = None,
17
- ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
18
+ ) -> Generator[tuple[np.ndarray, Optional[np.ndarray]], None, None]:
18
19
  """
19
20
  Generate patches from an array in a random manner.
20
21
 
@@ -31,12 +32,12 @@ def extract_patches_random(
31
32
  ----------
32
33
  arr : np.ndarray
33
34
  Input image array.
34
- patch_size : Tuple[int]
35
+ patch_size : tuple of int
35
36
  Patch sizes in each dimension.
36
37
  target : Optional[np.ndarray], optional
37
38
  Target array, by default None.
38
- seed : Optional[int], optional
39
- Random seed, by default None.
39
+ seed : int or None, default=None
40
+ Random seed.
40
41
 
41
42
  Yields
42
43
  ------
@@ -112,8 +113,8 @@ def extract_patches_random(
112
113
 
113
114
  def extract_patches_random_from_chunks(
114
115
  arr: zarr.Array,
115
- patch_size: Union[List[int], Tuple[int, ...]],
116
- chunk_size: Union[List[int], Tuple[int, ...]],
116
+ patch_size: Union[list[int], tuple[int, ...]],
117
+ chunk_size: Union[list[int], tuple[int, ...]],
117
118
  chunk_limit: Optional[int] = None,
118
119
  seed: Optional[int] = None,
119
120
  ) -> Generator[np.ndarray, None, None]:
@@ -127,9 +128,9 @@ def extract_patches_random_from_chunks(
127
128
  ----------
128
129
  arr : np.ndarray
129
130
  Input image array.
130
- patch_size : Union[List[int], Tuple[int, ...]]
131
+ patch_size : Union[list[int], tuple[int, ...]]
131
132
  Patch sizes in each dimension.
132
- chunk_size : Union[List[int], Tuple[int, ...]]
133
+ chunk_size : Union[list[int], tuple[int, ...]]
133
134
  Chunk sizes to load from the.
134
135
  chunk_limit : Optional[int], optional
135
136
  Number of chunks to load, by default None.
@@ -1,6 +1,6 @@
1
1
  """Sequential patching functions."""
2
2
 
3
- from typing import List, Optional, Tuple, Union
3
+ from typing import Optional, Union
4
4
 
5
5
  import numpy as np
6
6
  from skimage.util import view_as_windows
@@ -9,21 +9,21 @@ from .validate_patch_dimension import validate_patch_dimensions
9
9
 
10
10
 
11
11
  def _compute_number_of_patches(
12
- arr_shape: Tuple[int, ...], patch_sizes: Union[List[int], Tuple[int, ...]]
13
- ) -> Tuple[int, ...]:
12
+ arr_shape: tuple[int, ...], patch_sizes: Union[list[int], tuple[int, ...]]
13
+ ) -> tuple[int, ...]:
14
14
  """
15
15
  Compute the number of patches that fit in each dimension.
16
16
 
17
17
  Parameters
18
18
  ----------
19
- arr_shape : Tuple[int, ...]
19
+ arr_shape : tuple[int, ...]
20
20
  Shape of the input array.
21
- patch_sizes : Union[List[int], Tuple[int, ...]
21
+ patch_sizes : Union[list[int], tuple[int, ...]
22
22
  Shape of the patches.
23
23
 
24
24
  Returns
25
25
  -------
26
- Tuple[int, ...]
26
+ tuple[int, ...]
27
27
  Number of patches in each dimension.
28
28
  """
29
29
  if len(arr_shape) != len(patch_sizes):
@@ -47,8 +47,8 @@ def _compute_number_of_patches(
47
47
 
48
48
 
49
49
  def _compute_overlap(
50
- arr_shape: Tuple[int, ...], patch_sizes: Union[List[int], Tuple[int, ...]]
51
- ) -> Tuple[int, ...]:
50
+ arr_shape: tuple[int, ...], patch_sizes: Union[list[int], tuple[int, ...]]
51
+ ) -> tuple[int, ...]:
52
52
  """
53
53
  Compute the overlap between patches in each dimension.
54
54
 
@@ -57,14 +57,14 @@ def _compute_overlap(
57
57
 
58
58
  Parameters
59
59
  ----------
60
- arr_shape : Tuple[int, ...]
60
+ arr_shape : tuple[int, ...]
61
61
  Input array shape.
62
- patch_sizes : Union[List[int], Tuple[int, ...]]
62
+ patch_sizes : Union[list[int], tuple[int, ...]]
63
63
  Size of the patches.
64
64
 
65
65
  Returns
66
66
  -------
67
- Tuple[int, ...]
67
+ tuple[int, ...]
68
68
  Overlap between patches in each dimension.
69
69
  """
70
70
  n_patches = _compute_number_of_patches(arr_shape, patch_sizes)
@@ -80,21 +80,21 @@ def _compute_overlap(
80
80
 
81
81
 
82
82
  def _compute_patch_steps(
83
- patch_sizes: Union[List[int], Tuple[int, ...]], overlaps: Tuple[int, ...]
84
- ) -> Tuple[int, ...]:
83
+ patch_sizes: Union[list[int], tuple[int, ...]], overlaps: tuple[int, ...]
84
+ ) -> tuple[int, ...]:
85
85
  """
86
86
  Compute steps between patches.
87
87
 
88
88
  Parameters
89
89
  ----------
90
- patch_sizes : Tuple[int]
90
+ patch_sizes : tuple[int]
91
91
  Size of the patches.
92
- overlaps : Tuple[int]
92
+ overlaps : tuple[int]
93
93
  Overlap between patches.
94
94
 
95
95
  Returns
96
96
  -------
97
- Tuple[int]
97
+ tuple[int]
98
98
  Steps between patches.
99
99
  """
100
100
  steps = [
@@ -107,9 +107,9 @@ def _compute_patch_steps(
107
107
  # TODO why stack the target here and not on a different dimension before this function?
108
108
  def _compute_patch_views(
109
109
  arr: np.ndarray,
110
- window_shape: List[int],
111
- step: Tuple[int, ...],
112
- output_shape: List[int],
110
+ window_shape: list[int],
111
+ step: tuple[int, ...],
112
+ output_shape: list[int],
113
113
  target: Optional[np.ndarray] = None,
114
114
  ) -> np.ndarray:
115
115
  """
@@ -119,11 +119,11 @@ def _compute_patch_views(
119
119
  ----------
120
120
  arr : np.ndarray
121
121
  Array from which the views are extracted.
122
- window_shape : Tuple[int]
122
+ window_shape : tuple[int]
123
123
  Shape of the views.
124
- step : Tuple[int]
124
+ step : tuple[int]
125
125
  Steps between views.
126
- output_shape : Tuple[int]
126
+ output_shape : tuple[int]
127
127
  Shape of the output array.
128
128
  target : Optional[np.ndarray], optional
129
129
  Target array, by default None.
@@ -150,9 +150,9 @@ def _compute_patch_views(
150
150
 
151
151
  def extract_patches_sequential(
152
152
  arr: np.ndarray,
153
- patch_size: Union[List[int], Tuple[int, ...]],
153
+ patch_size: Union[list[int], tuple[int, ...]],
154
154
  target: Optional[np.ndarray] = None,
155
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
155
+ ) -> tuple[np.ndarray, Optional[np.ndarray]]:
156
156
  """
157
157
  Generate patches from an array in a sequential manner.
158
158
 
@@ -163,14 +163,14 @@ def extract_patches_sequential(
163
163
  ----------
164
164
  arr : np.ndarray
165
165
  Input image array.
166
- patch_size : Tuple[int]
166
+ patch_size : tuple[int]
167
167
  Patch sizes in each dimension.
168
168
  target : Optional[np.ndarray], optional
169
169
  Target array, by default None.
170
170
 
171
171
  Returns
172
172
  -------
173
- Tuple[np.ndarray, Optional[np.ndarray]]
173
+ tuple[np.ndarray, Optional[np.ndarray]]
174
174
  Patches.
175
175
  """
176
176
  is_3d_patch = len(patch_size) == 3
@@ -1,13 +1,13 @@
1
1
  """Patch validation functions."""
2
2
 
3
- from typing import List, Tuple, Union
3
+ from typing import Union
4
4
 
5
5
  import numpy as np
6
6
 
7
7
 
8
8
  def validate_patch_dimensions(
9
9
  arr: np.ndarray,
10
- patch_size: Union[List[int], Tuple[int, ...]],
10
+ patch_size: Union[list[int], tuple[int, ...]],
11
11
  is_3d_patch: bool,
12
12
  ) -> None:
13
13
  """
@@ -26,7 +26,7 @@ def validate_patch_dimensions(
26
26
  ----------
27
27
  arr : np.ndarray
28
28
  Input array.
29
- patch_size : Union[List[int], Tuple[int, ...]]
29
+ patch_size : Union[list[int], tuple[int, ...]]
30
30
  Size of the patches along each dimension of the array, except the first.
31
31
  is_3d_patch : bool
32
32
  Whether the patch is 3D or not.
@@ -1,9 +1,9 @@
1
1
  """Tiling functions."""
2
2
 
3
3
  __all__ = [
4
- "stitch_prediction",
5
- "extract_tiles",
6
4
  "collate_tiles",
5
+ "extract_tiles",
6
+ "stitch_prediction",
7
7
  ]
8
8
 
9
9
  from .collate_tiles import collate_tiles
@@ -1,6 +1,6 @@
1
1
  """Collate function for tiling."""
2
2
 
3
- from typing import Any, List, Tuple
3
+ from typing import Any
4
4
 
5
5
  import numpy as np
6
6
  from torch.utils.data.dataloader import default_collate
@@ -8,7 +8,7 @@ from torch.utils.data.dataloader import default_collate
8
8
  from careamics.config.tile_information import TileInformation
9
9
 
10
10
 
11
- def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
11
+ def collate_tiles(batch: list[tuple[np.ndarray, TileInformation]]) -> Any:
12
12
  """
13
13
  Collate tiles received from CAREamics prediction dataloader.
14
14
 
@@ -19,7 +19,7 @@ def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
19
19
 
20
20
  Parameters
21
21
  ----------
22
- batch : List[Tuple[np.ndarray, TileInformation], ...]
22
+ batch : list[tuple[np.ndarray, TileInformation], ...]
23
23
  Batch of tiles.
24
24
 
25
25
  Returns
@@ -2,7 +2,8 @@
2
2
 
3
3
  import builtins
4
4
  import itertools
5
- from typing import Any, Generator, Optional, Union
5
+ from collections.abc import Generator
6
+ from typing import Any, Optional, Union
6
7
 
7
8
  import numpy as np
8
9
  from numpy.typing import NDArray
@@ -1,7 +1,8 @@
1
1
  """Tiled patching utilities."""
2
2
 
3
3
  import itertools
4
- from typing import Generator, List, Tuple, Union
4
+ from collections.abc import Generator
5
+ from typing import Union
5
6
 
6
7
  import numpy as np
7
8
 
@@ -10,7 +11,7 @@ from careamics.config.tile_information import TileInformation
10
11
 
11
12
  def _compute_crop_and_stitch_coords_1d(
12
13
  axis_size: int, tile_size: int, overlap: int
13
- ) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]:
14
+ ) -> tuple[list[tuple[int, int]], list[tuple[int, int]], list[tuple[int, int]]]:
14
15
  """
15
16
  Compute the coordinates of each tile along an axis, given the overlap.
16
17
 
@@ -25,8 +26,8 @@ def _compute_crop_and_stitch_coords_1d(
25
26
 
26
27
  Returns
27
28
  -------
28
- Tuple[Tuple[int, ...], ...]
29
- Tuple of all coordinates for given axis.
29
+ tuple[tuple[int, ...], ...]
30
+ tuple of all coordinates for given axis.
30
31
  """
31
32
  # Compute the step between tiles
32
33
  step = tile_size - overlap
@@ -81,9 +82,9 @@ def _compute_crop_and_stitch_coords_1d(
81
82
 
82
83
  def extract_tiles(
83
84
  arr: np.ndarray,
84
- tile_size: Union[List[int], Tuple[int, ...]],
85
- overlaps: Union[List[int], Tuple[int, ...]],
86
- ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
85
+ tile_size: Union[list[int], tuple[int, ...]],
86
+ overlaps: Union[list[int], tuple[int, ...]],
87
+ ) -> Generator[tuple[np.ndarray, TileInformation], None, None]:
87
88
  """Generate tiles from the input array with specified overlap.
88
89
 
89
90
  The tiles cover the whole array. The method returns a generator that yields
@@ -98,14 +99,14 @@ def extract_tiles(
98
99
  ----------
99
100
  arr : np.ndarray
100
101
  Array of shape (S, C, (Z), Y, X).
101
- tile_size : Union[List[int], Tuple[int]]
102
+ tile_size : Union[list[int], tuple[int]]
102
103
  Tile sizes in each dimension, of length 2 or 3.
103
- overlaps : Union[List[int], Tuple[int]]
104
+ overlaps : Union[list[int], tuple[int]]
104
105
  Overlap values in each dimension, of length 2 or 3.
105
106
 
106
107
  Yields
107
108
  ------
108
- Generator[Tuple[np.ndarray, TileInformation], None, None]
109
+ Generator[tuple[np.ndarray, TileInformation], None, None]
109
110
  Tile generator, yields the tile and additional information.
110
111
  """
111
112
  # Iterate over num samples (S)
@@ -1,13 +1,13 @@
1
1
  """Functions relating reading and writing image files."""
2
2
 
3
3
  __all__ = [
4
- "read",
5
- "write",
6
- "get_read_func",
7
- "get_write_func",
8
4
  "ReadFunc",
9
- "WriteFunc",
10
5
  "SupportedWriteType",
6
+ "WriteFunc",
7
+ "get_read_func",
8
+ "get_write_func",
9
+ "read",
10
+ "write",
11
11
  ]
12
12
 
13
13
  from . import read, write
@@ -1,10 +1,10 @@
1
1
  """Functions relating to reading image files of different formats."""
2
2
 
3
3
  __all__ = [
4
+ "ReadFunc",
4
5
  "get_read_func",
5
6
  "read_tiff",
6
7
  "read_zarr",
7
- "ReadFunc",
8
8
  ]
9
9
 
10
10
  from .get_func import ReadFunc, get_read_func
@@ -1,7 +1,7 @@
1
1
  """Module to get read functions."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Callable, Dict, Protocol, Union
4
+ from typing import Callable, Protocol, Union
5
5
 
6
6
  from numpy.typing import NDArray
7
7
 
@@ -30,7 +30,7 @@ class ReadFunc(Protocol):
30
30
  """
31
31
 
32
32
 
33
- READ_FUNCS: Dict[SupportedData, ReadFunc] = {
33
+ READ_FUNCS: dict[SupportedData, ReadFunc] = {
34
34
  SupportedData.TIFF: read_tiff,
35
35
  }
36
36
 
@@ -1,10 +1,10 @@
1
1
  """Functions relating to writing image files of different formats."""
2
2
 
3
3
  __all__ = [
4
+ "SupportedWriteType",
5
+ "WriteFunc",
4
6
  "get_write_func",
5
7
  "write_tiff",
6
- "WriteFunc",
7
- "SupportedWriteType",
8
8
  ]
9
9
 
10
10
  from .get_func import (
@@ -2,14 +2,14 @@
2
2
 
3
3
  __all__ = [
4
4
  "FCNModule",
5
+ "HyperParametersCallback",
6
+ "PredictDataModule",
7
+ "ProgressBarCallback",
8
+ "TrainDataModule",
5
9
  "VAEModule",
6
10
  "create_careamics_module",
7
- "TrainDataModule",
8
- "create_train_datamodule",
9
- "PredictDataModule",
10
11
  "create_predict_datamodule",
11
- "HyperParametersCallback",
12
- "ProgressBarCallback",
12
+ "create_train_datamodule",
13
13
  ]
14
14
 
15
15
  from .callbacks import HyperParametersCallback, ProgressBarCallback
@@ -2,8 +2,8 @@
2
2
 
3
3
  __all__ = [
4
4
  "HyperParametersCallback",
5
- "ProgressBarCallback",
6
5
  "PredictionWriterCallback",
6
+ "ProgressBarCallback",
7
7
  ]
8
8
 
9
9
  from .hyperparameters_callback import HyperParametersCallback
@@ -1,12 +1,12 @@
1
1
  """A package for the `PredictionWriterCallback` class and utilities."""
2
2
 
3
3
  __all__ = [
4
+ "CacheTiles",
4
5
  "PredictionWriterCallback",
5
- "create_write_strategy",
6
- "WriteStrategy",
7
6
  "WriteImage",
8
- "CacheTiles",
7
+ "WriteStrategy",
9
8
  "WriteTilesZarr",
9
+ "create_write_strategy",
10
10
  "select_write_extension",
11
11
  "select_write_func",
12
12
  ]
@@ -2,8 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Any, Optional, Sequence, Union
7
+ from typing import Any, Optional, Union
7
8
 
8
9
  from pytorch_lightning import LightningModule, Trainer
9
10
  from pytorch_lightning.callbacks import BasePredictionWriter
@@ -1,7 +1,8 @@
1
1
  """Module containing different strategies for writing predictions."""
2
2
 
3
+ from collections.abc import Sequence
3
4
  from pathlib import Path
4
- from typing import Any, Optional, Protocol, Sequence, Union
5
+ from typing import Any, Optional, Protocol, Union
5
6
 
6
7
  import numpy as np
7
8
  from numpy.typing import NDArray
@@ -1,7 +1,7 @@
1
1
  """Progressbar callback."""
2
2
 
3
3
  import sys
4
- from typing import Dict, Union
4
+ from typing import Union
5
5
 
6
6
  from pytorch_lightning import LightningModule, Trainer
7
7
  from pytorch_lightning.callbacks import TQDMProgressBar
@@ -71,7 +71,7 @@ class ProgressBarCallback(TQDMProgressBar):
71
71
 
72
72
  def get_metrics(
73
73
  self, trainer: Trainer, pl_module: LightningModule
74
- ) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
74
+ ) -> dict[str, Union[int, str, float, dict[str, float]]]:
75
75
  """Override this to customize the metrics displayed in the progress bar.
76
76
 
77
77
  Parameters
@@ -6,7 +6,7 @@ import numpy as np
6
6
  import pytorch_lightning as L
7
7
  from torch import Tensor, nn
8
8
 
9
- from careamics.config import FCNAlgorithmConfig, VAEAlgorithmConfig
9
+ from careamics.config import UNetBasedAlgorithm, VAEBasedAlgorithm
10
10
  from careamics.config.support import (
11
11
  SupportedAlgorithm,
12
12
  SupportedArchitecture,
@@ -34,6 +34,7 @@ from careamics.utils.torch_utils import get_optimizer, get_scheduler
34
34
  NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
35
35
 
36
36
 
37
+ # TODO rename to UNetModule
37
38
  class FCNModule(L.LightningModule):
38
39
  """
39
40
  CAREamics Lightning module.
@@ -60,7 +61,7 @@ class FCNModule(L.LightningModule):
60
61
  Learning rate scheduler name.
61
62
  """
62
63
 
63
- def __init__(self, algorithm_config: Union[FCNAlgorithmConfig, dict]) -> None:
64
+ def __init__(self, algorithm_config: Union[UNetBasedAlgorithm, dict]) -> None:
64
65
  """Lightning module for CAREamics.
65
66
 
66
67
  This class encapsulates the a PyTorch model along with the training, validation,
@@ -74,7 +75,9 @@ class FCNModule(L.LightningModule):
74
75
  super().__init__()
75
76
  # if loading from a checkpoint, AlgorithmModel needs to be instantiated
76
77
  if isinstance(algorithm_config, dict):
77
- algorithm_config = FCNAlgorithmConfig(**algorithm_config)
78
+ algorithm_config = UNetBasedAlgorithm(
79
+ **algorithm_config
80
+ ) # TODO this needs to be updated using the algorithm-specific class
78
81
 
79
82
  # create model and loss function
80
83
  self.model: nn.Module = model_factory(algorithm_config.model)
@@ -266,7 +269,7 @@ class VAEModule(L.LightningModule):
266
269
  Learning rate scheduler name.
267
270
  """
268
271
 
269
- def __init__(self, algorithm_config: Union[VAEAlgorithmConfig, dict]) -> None:
272
+ def __init__(self, algorithm_config: Union[VAEBasedAlgorithm, dict]) -> None:
270
273
  """Lightning module for CAREamics.
271
274
 
272
275
  This class encapsulates the a PyTorch model along with the training, validation,
@@ -280,7 +283,7 @@ class VAEModule(L.LightningModule):
280
283
  super().__init__()
281
284
  # if loading from a checkpoint, AlgorithmModel needs to be instantiated
282
285
  self.algorithm_config = (
283
- VAEAlgorithmConfig(**algorithm_config)
286
+ VAEBasedAlgorithm(**algorithm_config)
284
287
  if isinstance(algorithm_config, dict)
285
288
  else algorithm_config
286
289
  )
@@ -656,9 +659,10 @@ def create_careamics_module(
656
659
  algorithm_configuration["model"] = model_configuration
657
660
 
658
661
  # call the parent init using an AlgorithmModel instance
662
+ # TODO broken by new configutations!
659
663
  algorithm_str = algorithm_configuration["algorithm"]
660
- if algorithm_str in FCNAlgorithmConfig.get_compatible_algorithms():
661
- return FCNModule(FCNAlgorithmConfig(**algorithm_configuration))
664
+ if algorithm_str in UNetBasedAlgorithm.get_compatible_algorithms():
665
+ return FCNModule(UNetBasedAlgorithm(**algorithm_configuration))
662
666
  else:
663
667
  raise NotImplementedError(
664
668
  f"Model {algorithm_str} is not implemented or unknown."