careamics 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (98) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +4 -3
  3. careamics/cli/conf.py +1 -2
  4. careamics/cli/main.py +1 -2
  5. careamics/cli/utils.py +3 -3
  6. careamics/config/__init__.py +47 -25
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +6 -1
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +103 -36
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +1 -2
  25. careamics/config/n2n_configuration.py +101 -0
  26. careamics/config/n2v_configuration.py +266 -0
  27. careamics/config/nm_model.py +1 -2
  28. careamics/config/support/__init__.py +7 -7
  29. careamics/config/support/supported_algorithms.py +0 -3
  30. careamics/config/support/supported_architectures.py +0 -4
  31. careamics/config/transformations/__init__.py +10 -4
  32. careamics/config/transformations/transform_model.py +3 -3
  33. careamics/config/transformations/transform_unions.py +42 -0
  34. careamics/config/validators/validator_utils.py +3 -3
  35. careamics/dataset/__init__.py +2 -2
  36. careamics/dataset/dataset_utils/__init__.py +3 -3
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  38. careamics/dataset/dataset_utils/file_utils.py +9 -9
  39. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  40. careamics/dataset/in_memory_dataset.py +11 -12
  41. careamics/dataset/iterable_dataset.py +4 -4
  42. careamics/dataset/iterable_pred_dataset.py +2 -1
  43. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  44. careamics/dataset/patching/random_patching.py +11 -10
  45. careamics/dataset/patching/sequential_patching.py +26 -26
  46. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  47. careamics/dataset/tiling/__init__.py +2 -2
  48. careamics/dataset/tiling/collate_tiles.py +3 -3
  49. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  50. careamics/dataset/tiling/tiled_patching.py +11 -10
  51. careamics/file_io/__init__.py +5 -5
  52. careamics/file_io/read/__init__.py +1 -1
  53. careamics/file_io/read/get_func.py +2 -2
  54. careamics/file_io/write/__init__.py +2 -2
  55. careamics/lightning/__init__.py +5 -5
  56. careamics/lightning/callbacks/__init__.py +1 -1
  57. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  58. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  60. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  61. careamics/lightning/lightning_module.py +11 -7
  62. careamics/lightning/train_data_module.py +26 -26
  63. careamics/losses/__init__.py +3 -3
  64. careamics/model_io/__init__.py +1 -1
  65. careamics/model_io/bioimage/__init__.py +1 -1
  66. careamics/model_io/bioimage/_readme_factory.py +1 -1
  67. careamics/model_io/bioimage/model_description.py +17 -17
  68. careamics/model_io/bmz_io.py +6 -17
  69. careamics/model_io/model_io_utils.py +9 -9
  70. careamics/models/layers.py +16 -16
  71. careamics/models/lvae/lvae.py +0 -3
  72. careamics/models/model_factory.py +2 -15
  73. careamics/models/unet.py +8 -8
  74. careamics/prediction_utils/__init__.py +1 -1
  75. careamics/prediction_utils/prediction_outputs.py +15 -15
  76. careamics/prediction_utils/stitch_prediction.py +6 -6
  77. careamics/transforms/__init__.py +5 -5
  78. careamics/transforms/compose.py +13 -13
  79. careamics/transforms/n2v_manipulate.py +3 -3
  80. careamics/transforms/pixel_manipulation.py +9 -9
  81. careamics/transforms/xy_random_rotate90.py +4 -4
  82. careamics/utils/__init__.py +5 -5
  83. careamics/utils/context.py +2 -1
  84. careamics/utils/logging.py +11 -10
  85. careamics/utils/torch_utils.py +7 -7
  86. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/METADATA +11 -11
  87. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/RECORD +90 -85
  88. careamics/config/architectures/custom_model.py +0 -162
  89. careamics/config/architectures/register_model.py +0 -103
  90. careamics/config/configuration_model.py +0 -603
  91. careamics/config/fcn_algorithm_model.py +0 -152
  92. careamics/config/references/__init__.py +0 -45
  93. careamics/config/references/algorithm_descriptions.py +0 -132
  94. careamics/config/references/references.py +0 -39
  95. careamics/config/transformations/transform_union.py +0 -20
  96. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/WHEEL +0 -0
  97. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  98. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,266 @@
1
+ """N2V configuration."""
2
+
3
+ from bioimageio.spec.generic.v0_3 import CiteEntry
4
+ from pydantic import model_validator
5
+ from typing_extensions import Self
6
+
7
+ from careamics.config.algorithms import N2VAlgorithm
8
+ from careamics.config.configuration import Configuration
9
+ from careamics.config.data.n2v_data_model import N2VDataConfig
10
+ from careamics.config.support import SupportedPixelManipulation
11
+
12
+ N2V = "Noise2Void"
13
+ N2V2 = "N2V2"
14
+ STRUCT_N2V = "StructN2V"
15
+ STRUCT_N2V2 = "StructN2V2"
16
+
17
+ N2V_REF = CiteEntry(
18
+ text='Krull, A., Buchholz, T.O. and Jug, F., 2019. "Noise2Void - Learning '
19
+ 'denoising from single noisy images". In Proceedings of the IEEE/CVF '
20
+ "conference on computer vision and pattern recognition (pp. 2129-2137).",
21
+ doi="10.1109/cvpr.2019.00223",
22
+ )
23
+
24
+ N2V2_REF = CiteEntry(
25
+ text="Höck, E., Buchholz, T.O., Brachmann, A., Jug, F. and Freytag, A., "
26
+ '2022. "N2V2 - Fixing Noise2Void checkerboard artifacts with modified '
27
+ 'sampling strategies and a tweaked network architecture". In European '
28
+ "Conference on Computer Vision (pp. 503-518).",
29
+ doi="10.1007/978-3-031-25069-9_33",
30
+ )
31
+
32
+ STRUCTN2V_REF = CiteEntry(
33
+ text="Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020."
34
+ '"Removing structured noise with self-supervised blind-spot '
35
+ 'networks". In 2020 IEEE 17th International Symposium on Biomedical '
36
+ "Imaging (ISBI) (pp. 159-163).",
37
+ doi="10.1109/isbi45749.2020.9098336",
38
+ )
39
+
40
+ N2V_DESCRIPTION = (
41
+ "Noise2Void is a UNet-based self-supervised algorithm that "
42
+ "uses blind-spot training to denoise images. In short, in every "
43
+ "patches during training, random pixels are selected and their "
44
+ "value replaced by a neighboring pixel value. The network is then "
45
+ "trained to predict the original pixel value. The algorithm "
46
+ "relies on the continuity of the signal (neighboring pixels have "
47
+ "similar values) and the pixel-wise independence of the noise "
48
+ "(the noise in a pixel is not correlated with the noise in "
49
+ "neighboring pixels)."
50
+ )
51
+
52
+ N2V2_DESCRIPTION = (
53
+ "N2V2 is a variant of Noise2Void. "
54
+ + N2V_DESCRIPTION
55
+ + "\nN2V2 introduces blur-pool layers and removed skip "
56
+ "connections in the UNet architecture to remove checkboard "
57
+ "artefacts, a common artefacts ocurring in Noise2Void."
58
+ )
59
+
60
+ STR_N2V_DESCRIPTION = (
61
+ "StructN2V is a variant of Noise2Void. "
62
+ + N2V_DESCRIPTION
63
+ + "\nStructN2V uses a linear mask (horizontal or vertical) to replace "
64
+ "the pixel values of neighbors of the masked pixels by a random "
65
+ "value. Such masking allows removing 1D structured noise from the "
66
+ "the images, the main failure case of the original N2V."
67
+ )
68
+
69
+ STR_N2V2_DESCRIPTION = (
70
+ "StructN2V2 is a a variant of Noise2Void that uses both "
71
+ "structN2V and N2V2. "
72
+ + N2V_DESCRIPTION
73
+ + "\nStructN2V2 uses a linear mask (horizontal or vertical) to replace "
74
+ "the pixel values of neighbors of the masked pixels by a random "
75
+ "value. Such masking allows removing 1D structured noise from the "
76
+ "the images, the main failure case of the original N2V."
77
+ "\nN2V2 introduces blur-pool layers and removed skip connections in "
78
+ "the UNet architecture to remove checkboard artefacts, a common "
79
+ "artefacts ocurring in Noise2Void."
80
+ )
81
+
82
+
83
+ class N2VConfiguration(Configuration):
84
+ """N2V configuration."""
85
+
86
+ algorithm_config: N2VAlgorithm
87
+
88
+ data_config: N2VDataConfig
89
+
90
+ @model_validator(mode="after")
91
+ def validate_n2v2(self) -> Self:
92
+ """Validate that the N2V2 strategy and models are set correctly.
93
+
94
+ Returns
95
+ -------
96
+ Self
97
+ The validateed configuration.
98
+
99
+
100
+ Raises
101
+ ------
102
+ ValueError
103
+ If N2V2 is used with the wrong pixel manipulation strategy.
104
+ """
105
+ if self.algorithm_config.model.n2v2:
106
+ if (
107
+ self.data_config.get_masking_strategy()
108
+ != SupportedPixelManipulation.MEDIAN.value
109
+ ):
110
+ raise ValueError(
111
+ f"N2V2 can only be used with the "
112
+ f"{SupportedPixelManipulation.MEDIAN} pixel manipulation strategy"
113
+ f". Change the N2VManipulate transform strategy."
114
+ )
115
+ else:
116
+ if (
117
+ self.data_config.get_masking_strategy()
118
+ != SupportedPixelManipulation.UNIFORM.value
119
+ ):
120
+ raise ValueError(
121
+ f"N2V can only be used with the "
122
+ f"{SupportedPixelManipulation.UNIFORM} pixel manipulation strategy"
123
+ f". Change the N2VManipulate transform strategy."
124
+ )
125
+ return self
126
+
127
+ def set_n2v2(self, use_n2v2: bool) -> None:
128
+ """
129
+ Set the configuration to use N2V2 or the vanilla Noise2Void.
130
+
131
+ Parameters
132
+ ----------
133
+ use_n2v2 : bool
134
+ Whether to use N2V2.
135
+ """
136
+ self.data_config.set_n2v2(use_n2v2)
137
+ self.algorithm_config.model.n2v2 = use_n2v2
138
+
139
+ def get_algorithm_friendly_name(self) -> str:
140
+ """
141
+ Get the friendly name of the algorithm.
142
+
143
+ Returns
144
+ -------
145
+ str
146
+ Friendly name.
147
+ """
148
+ use_n2v2 = self.algorithm_config.model.n2v2
149
+ use_structN2V = self.data_config.is_using_struct_n2v()
150
+
151
+ if use_n2v2 and use_structN2V:
152
+ return STRUCT_N2V2
153
+ elif use_n2v2:
154
+ return N2V2
155
+ elif use_structN2V:
156
+ return STRUCT_N2V
157
+ else:
158
+ return N2V
159
+
160
+ def get_algorithm_keywords(self) -> list[str]:
161
+ """
162
+ Get algorithm keywords.
163
+
164
+ Returns
165
+ -------
166
+ list[str]
167
+ List of keywords.
168
+ """
169
+ use_n2v2 = self.algorithm_config.model.n2v2
170
+ use_structN2V = self.data_config.is_using_struct_n2v()
171
+
172
+ keywords = [
173
+ "denoising",
174
+ "restoration",
175
+ "UNet",
176
+ "3D" if "Z" in self.data_config.axes else "2D",
177
+ "CAREamics",
178
+ "pytorch",
179
+ N2V,
180
+ ]
181
+
182
+ if use_n2v2:
183
+ keywords.append(N2V2)
184
+ if use_structN2V:
185
+ keywords.append(STRUCT_N2V)
186
+
187
+ return keywords
188
+
189
+ def get_algorithm_references(self) -> str:
190
+ """
191
+ Get the algorithm references.
192
+
193
+ This is used to generate the README of the BioImage Model Zoo export.
194
+
195
+ Returns
196
+ -------
197
+ str
198
+ Algorithm references.
199
+ """
200
+ use_n2v2 = self.algorithm_config.model.n2v2
201
+ use_structN2V = self.data_config.is_using_struct_n2v()
202
+
203
+ references = [
204
+ N2V_REF.text + " doi: " + N2V_REF.doi,
205
+ N2V2_REF.text + " doi: " + N2V2_REF.doi,
206
+ STRUCTN2V_REF.text + " doi: " + STRUCTN2V_REF.doi,
207
+ ]
208
+
209
+ # return the (struct)N2V(2) references
210
+ if use_n2v2 and use_structN2V:
211
+ return "\n".join(references)
212
+ elif use_n2v2:
213
+ references.pop(-1)
214
+ return "\n".join(references)
215
+ elif use_structN2V:
216
+ references.pop(-2)
217
+ return "\n".join(references)
218
+ else:
219
+ return references[0]
220
+
221
+ def get_algorithm_citations(self) -> list[CiteEntry]:
222
+ """
223
+ Return a list of citation entries of the current algorithm.
224
+
225
+ This is used to generate the model description for the BioImage Model Zoo.
226
+
227
+ Returns
228
+ -------
229
+ List[CiteEntry]
230
+ List of citation entries.
231
+ """
232
+ use_n2v2 = self.algorithm_config.model.n2v2
233
+ use_structN2V = self.data_config.is_using_struct_n2v()
234
+
235
+ references = [N2V_REF]
236
+
237
+ if use_n2v2:
238
+ references.append(N2V2_REF)
239
+
240
+ if use_structN2V:
241
+ references.append(STRUCTN2V_REF)
242
+
243
+ return references
244
+
245
+ def get_algorithm_description(self) -> str:
246
+ """
247
+ Return a description of the algorithm.
248
+
249
+ This method is used to generate the README of the BioImage Model Zoo export.
250
+
251
+ Returns
252
+ -------
253
+ str
254
+ Description of the algorithm.
255
+ """
256
+ use_n2v2 = self.algorithm_config.model.n2v2
257
+ use_structN2V = self.data_config.is_using_struct_n2v()
258
+
259
+ if use_n2v2 and use_structN2V:
260
+ return STR_N2V2_DESCRIPTION
261
+ elif use_n2v2:
262
+ return N2V2_DESCRIPTION
263
+ elif use_structN2V:
264
+ return STR_N2V_DESCRIPTION
265
+ else:
266
+ return N2V_DESCRIPTION
@@ -1,7 +1,7 @@
1
1
  """Noise models config."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Literal, Optional, Union
4
+ from typing import Annotated, Literal, Optional, Union
5
5
 
6
6
  import numpy as np
7
7
  import torch
@@ -12,7 +12,6 @@ from pydantic import (
12
12
  PlainSerializer,
13
13
  PlainValidator,
14
14
  )
15
- from typing_extensions import Annotated
16
15
 
17
16
  from careamics.utils.serializers import _array_to_json, _to_numpy
18
17
 
@@ -5,17 +5,17 @@ corresponding configuration options in the Pydantic models.
5
5
  """
6
6
 
7
7
  __all__ = [
8
- "SupportedArchitecture",
9
8
  "SupportedActivation",
10
- "SupportedOptimizer",
11
- "SupportedScheduler",
12
- "SupportedLoss",
13
9
  "SupportedAlgorithm",
14
- "SupportedPixelManipulation",
15
- "SupportedTransform",
10
+ "SupportedArchitecture",
16
11
  "SupportedData",
17
- "SupportedStructAxis",
18
12
  "SupportedLogger",
13
+ "SupportedLoss",
14
+ "SupportedOptimizer",
15
+ "SupportedPixelManipulation",
16
+ "SupportedScheduler",
17
+ "SupportedStructAxis",
18
+ "SupportedTransform",
19
19
  ]
20
20
 
21
21
 
@@ -25,9 +25,6 @@ class SupportedAlgorithm(str, BaseEnum):
25
25
  DENOISPLIT = "denoisplit"
26
26
  """An image splitting and denoising approach based on ladder VAE architectures."""
27
27
 
28
- CUSTOM = "custom"
29
- """Custom algorithm, used for cases where a custom architecture is provided."""
30
-
31
28
  # PN2V = "pn2v"
32
29
  # HDN = "hdn"
33
30
  # SEG = "segmentation"
@@ -11,7 +11,3 @@ class SupportedArchitecture(str, BaseEnum):
11
11
 
12
12
  LVAE = "LVAE"
13
13
  """Ladder Variational Autoencoder used for muSplit and denoiSplit."""
14
-
15
- CUSTOM = "custom"
16
- """Keyword used for custom architectures provided by users and only compatible
17
- with `FCNAlgorithmConfig` configuration."""
@@ -1,18 +1,24 @@
1
1
  """CAREamics transformation Pydantic models."""
2
2
 
3
3
  __all__ = [
4
+ "N2V_TRANSFORMS_UNION",
5
+ "NORM_AND_SPATIAL_UNION",
6
+ "SPATIAL_TRANSFORMS_UNION",
4
7
  "N2VManipulateModel",
5
- "XYFlipModel",
6
8
  "NormalizeModel",
7
- "XYRandomRotate90Model",
8
9
  "TransformModel",
9
- "TRANSFORMS_UNION",
10
+ "XYFlipModel",
11
+ "XYRandomRotate90Model",
10
12
  ]
11
13
 
12
14
 
13
15
  from .n2v_manipulate_model import N2VManipulateModel
14
16
  from .normalize_model import NormalizeModel
15
17
  from .transform_model import TransformModel
16
- from .transform_union import TRANSFORMS_UNION
18
+ from .transform_unions import (
19
+ N2V_TRANSFORMS_UNION,
20
+ NORM_AND_SPATIAL_UNION,
21
+ SPATIAL_TRANSFORMS_UNION,
22
+ )
17
23
  from .xy_flip_model import XYFlipModel
18
24
  from .xy_random_rotate90_model import XYRandomRotate90Model
@@ -1,6 +1,6 @@
1
1
  """Parent model for the transforms."""
2
2
 
3
- from typing import Any, Dict
3
+ from typing import Any
4
4
 
5
5
  from pydantic import BaseModel, ConfigDict
6
6
 
@@ -23,7 +23,7 @@ class TransformModel(BaseModel):
23
23
 
24
24
  name: str
25
25
 
26
- def model_dump(self, **kwargs) -> Dict[str, Any]:
26
+ def model_dump(self, **kwargs) -> dict[str, Any]:
27
27
  """
28
28
  Return the model as a dictionary.
29
29
 
@@ -34,7 +34,7 @@ class TransformModel(BaseModel):
34
34
 
35
35
  Returns
36
36
  -------
37
- Dict[str, Any]
37
+ {str: Any}
38
38
  Dictionary representation of the model.
39
39
  """
40
40
  model_dict = super().model_dump(**kwargs)
@@ -0,0 +1,42 @@
1
+ """Type used to represent all transformations users can create."""
2
+
3
+ from typing import Annotated, Union
4
+
5
+ from pydantic import Discriminator
6
+
7
+ from .n2v_manipulate_model import N2VManipulateModel
8
+ from .normalize_model import NormalizeModel
9
+ from .xy_flip_model import XYFlipModel
10
+ from .xy_random_rotate90_model import XYRandomRotate90Model
11
+
12
+ NORM_AND_SPATIAL_UNION = Annotated[
13
+ Union[
14
+ NormalizeModel,
15
+ XYFlipModel,
16
+ XYRandomRotate90Model,
17
+ N2VManipulateModel,
18
+ ],
19
+ Discriminator("name"), # used to tell the different transform models apart
20
+ ]
21
+ """All transforms including normalization."""
22
+
23
+
24
+ SPATIAL_TRANSFORMS_UNION = Annotated[
25
+ Union[
26
+ XYFlipModel,
27
+ XYRandomRotate90Model,
28
+ ],
29
+ Discriminator("name"), # used to tell the different transform models apart
30
+ ]
31
+ """Available spatial transforms in CAREamics."""
32
+
33
+
34
+ N2V_TRANSFORMS_UNION = Annotated[
35
+ Union[
36
+ XYFlipModel,
37
+ XYRandomRotate90Model,
38
+ N2VManipulateModel,
39
+ ],
40
+ Discriminator("name"), # used to tell the different transform models apart
41
+ ]
42
+ """Available N2V-compatible transforms in CAREamics."""
@@ -4,7 +4,7 @@ Validator functions.
4
4
  These functions are used to validate dimensions and axes of inputs.
5
5
  """
6
6
 
7
- from typing import List, Optional, Tuple, Union
7
+ from typing import Optional, Union
8
8
 
9
9
  _AXES = "STCZYX"
10
10
 
@@ -79,14 +79,14 @@ def value_ge_than_8_power_of_2(
79
79
 
80
80
 
81
81
  def patch_size_ge_than_8_power_of_2(
82
- patch_list: Optional[Union[List[int], Union[Tuple[int, ...]]]],
82
+ patch_list: Optional[Union[list[int], Union[tuple[int, ...]]]],
83
83
  ) -> None:
84
84
  """
85
85
  Validate that each entry is greater or equal than 8 and a power of 2.
86
86
 
87
87
  Parameters
88
88
  ----------
89
- patch_list : Optional[Union[List[int]]]
89
+ patch_list : list or typle of int, or None
90
90
  Patch size.
91
91
 
92
92
  Raises
@@ -4,9 +4,9 @@ __all__ = [
4
4
  "InMemoryDataset",
5
5
  "InMemoryPredDataset",
6
6
  "InMemoryTiledPredDataset",
7
- "PathIterableDataset",
8
- "IterableTiledPredDataset",
9
7
  "IterablePredDataset",
8
+ "IterableTiledPredDataset",
9
+ "PathIterableDataset",
10
10
  ]
11
11
 
12
12
  from .in_memory_dataset import InMemoryDataset
@@ -1,13 +1,13 @@
1
1
  """Files and arrays utils used in the datasets."""
2
2
 
3
3
  __all__ = [
4
- "reshape_array",
4
+ "WelfordStatistics",
5
5
  "compute_normalization_stats",
6
6
  "get_files_size",
7
+ "iterate_over_files",
7
8
  "list_files",
9
+ "reshape_array",
8
10
  "validate_source_target_files",
9
- "iterate_over_files",
10
- "WelfordStatistics",
11
11
  ]
12
12
 
13
13
 
@@ -1,7 +1,5 @@
1
1
  """Dataset utilities."""
2
2
 
3
- from typing import List, Tuple
4
-
5
3
  import numpy as np
6
4
 
7
5
  from careamics.utils.logging import get_logger
@@ -10,14 +8,14 @@ logger = get_logger(__name__)
10
8
 
11
9
 
12
10
  def _get_shape_order(
13
- shape_in: Tuple[int, ...], axes_in: str, ref_axes: str = "STCZYX"
14
- ) -> Tuple[Tuple[int, ...], str, List[int]]:
11
+ shape_in: tuple[int, ...], axes_in: str, ref_axes: str = "STCZYX"
12
+ ) -> tuple[tuple[int, ...], str, list[int]]:
15
13
  """
16
14
  Compute a new shape for the array based on the reference axes.
17
15
 
18
16
  Parameters
19
17
  ----------
20
- shape_in : Tuple[int, ...]
18
+ shape_in : tuple[int, ...]
21
19
  Input shape.
22
20
  axes_in : str
23
21
  Input axes.
@@ -26,7 +24,7 @@ def _get_shape_order(
26
24
 
27
25
  Returns
28
26
  -------
29
- Tuple[Tuple[int, ...], str, List[int]]
27
+ tuple[tuple[int, ...], str, list[int]]
30
28
  New shape, new axes, indices of axes in the new axes order.
31
29
  """
32
30
  indices = [axes_in.find(k) for k in ref_axes]
@@ -2,7 +2,7 @@
2
2
 
3
3
  from fnmatch import fnmatch
4
4
  from pathlib import Path
5
- from typing import List, Union
5
+ from typing import Union
6
6
 
7
7
  import numpy as np
8
8
 
@@ -12,12 +12,12 @@ from careamics.utils.logging import get_logger
12
12
  logger = get_logger(__name__)
13
13
 
14
14
 
15
- def get_files_size(files: List[Path]) -> float:
15
+ def get_files_size(files: list[Path]) -> float:
16
16
  """Get files size in MB.
17
17
 
18
18
  Parameters
19
19
  ----------
20
- files : List[Path]
20
+ files : list of pathlib.Path
21
21
  List of files.
22
22
 
23
23
  Returns
@@ -32,7 +32,7 @@ def list_files(
32
32
  data_path: Union[str, Path],
33
33
  data_type: Union[str, SupportedData],
34
34
  extension_filter: str = "",
35
- ) -> List[Path]:
35
+ ) -> list[Path]:
36
36
  """List recursively files in `data_path` and return a sorted list.
37
37
 
38
38
  If `data_path` is a file, its name is validated against the `data_type` using
@@ -55,8 +55,8 @@ def list_files(
55
55
 
56
56
  Returns
57
57
  -------
58
- List[Path]
59
- List of pathlib.Path objects.
58
+ list[Path]
59
+ list of pathlib.Path objects.
60
60
 
61
61
  Raises
62
62
  ------
@@ -105,7 +105,7 @@ def list_files(
105
105
  return files
106
106
 
107
107
 
108
- def validate_source_target_files(src_files: List[Path], tar_files: List[Path]) -> None:
108
+ def validate_source_target_files(src_files: list[Path], tar_files: list[Path]) -> None:
109
109
  """
110
110
  Validate source and target path lists.
111
111
 
@@ -113,9 +113,9 @@ def validate_source_target_files(src_files: List[Path], tar_files: List[Path]) -
113
113
 
114
114
  Parameters
115
115
  ----------
116
- src_files : List[Path]
116
+ src_files : list of pathlib.Path
117
117
  List of source files.
118
- tar_files : List[Path]
118
+ tar_files : list of pathlib.Path
119
119
  List of target files.
120
120
 
121
121
  Raises
@@ -2,13 +2,14 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from collections.abc import Generator
5
6
  from pathlib import Path
6
- from typing import Callable, Generator, Optional, Union
7
+ from typing import Callable, Optional, Union
7
8
 
8
9
  from numpy.typing import NDArray
9
10
  from torch.utils.data import get_worker_info
10
11
 
11
- from careamics.config import DataConfig, InferenceConfig
12
+ from careamics.config import GeneralDataConfig, InferenceConfig
12
13
  from careamics.file_io.read import read_tiff
13
14
  from careamics.utils.logging import get_logger
14
15
 
@@ -18,7 +19,7 @@ logger = get_logger(__name__)
18
19
 
19
20
 
20
21
  def iterate_over_files(
21
- data_config: Union[DataConfig, InferenceConfig],
22
+ data_config: Union[GeneralDataConfig, InferenceConfig],
22
23
  data_files: list[Path],
23
24
  target_files: Optional[list[Path]] = None,
24
25
  read_source_func: Callable = read_tiff,
@@ -9,13 +9,9 @@ from typing import Any, Callable, Optional, Union
9
9
  import numpy as np
10
10
  from torch.utils.data import Dataset
11
11
 
12
- from careamics.file_io.read import read_tiff
13
- from careamics.transforms import Compose
14
-
15
- from ..config import DataConfig
16
- from ..config.transformations import NormalizeModel
17
- from ..utils.logging import get_logger
18
- from .patching.patching import (
12
+ from careamics.config import GeneralDataConfig, N2VDataConfig
13
+ from careamics.config.transformations import NormalizeModel
14
+ from careamics.dataset.patching.patching import (
19
15
  PatchedOutput,
20
16
  Stats,
21
17
  prepare_patches_supervised,
@@ -23,6 +19,9 @@ from .patching.patching import (
23
19
  prepare_patches_unsupervised,
24
20
  prepare_patches_unsupervised_array,
25
21
  )
22
+ from careamics.file_io.read import read_tiff
23
+ from careamics.transforms import Compose
24
+ from careamics.utils.logging import get_logger
26
25
 
27
26
  logger = get_logger(__name__)
28
27
 
@@ -47,7 +46,7 @@ class InMemoryDataset(Dataset):
47
46
 
48
47
  def __init__(
49
48
  self,
50
- data_config: DataConfig,
49
+ data_config: GeneralDataConfig,
51
50
  inputs: Union[np.ndarray, list[Path]],
52
51
  input_target: Optional[Union[np.ndarray, list[Path]]] = None,
53
52
  read_source_func: Callable = read_tiff,
@@ -58,7 +57,7 @@ class InMemoryDataset(Dataset):
58
57
 
59
58
  Parameters
60
59
  ----------
61
- data_config : DataConfig
60
+ data_config : GeneralDataConfig
62
61
  Data configuration.
63
62
  inputs : numpy.ndarray or list[pathlib.Path]
64
63
  Input data.
@@ -124,7 +123,7 @@ class InMemoryDataset(Dataset):
124
123
  target_stds=self.target_stats.stds,
125
124
  )
126
125
  ]
127
- + self.data_config.transforms,
126
+ + list(self.data_config.transforms),
128
127
  )
129
128
 
130
129
  def _prepare_patches(self, supervised: bool) -> PatchedOutput:
@@ -219,12 +218,12 @@ class InMemoryDataset(Dataset):
219
218
 
220
219
  return self.patch_transform(patch=patch, target=target)
221
220
 
222
- elif self.data_config.has_n2v_manipulate(): # TODO not compatible with HDN
221
+ elif isinstance(self.data_config, N2VDataConfig):
223
222
  return self.patch_transform(patch=patch)
224
223
  else:
225
224
  raise ValueError(
226
225
  "Something went wrong! No target provided (not supervised training) "
227
- "and no N2V manipulation (no N2V training)."
226
+ "while the algorithm is not Noise2Void."
228
227
  )
229
228
 
230
229
  def get_data_statistics(self) -> tuple[list[float], list[float]]: