careamics 0.0.9__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +0 -4
- careamics/careamist.py +0 -1
- careamics/config/__init__.py +1 -13
- careamics/config/algorithms/care_algorithm_model.py +84 -0
- careamics/config/algorithms/n2n_algorithm_model.py +85 -0
- careamics/config/algorithms/n2v_algorithm_model.py +269 -1
- careamics/config/configuration.py +21 -13
- careamics/config/configuration_factories.py +179 -187
- careamics/config/configuration_io.py +2 -2
- careamics/config/data/__init__.py +1 -4
- careamics/config/data/data_model.py +46 -62
- careamics/config/support/supported_transforms.py +1 -1
- careamics/config/transformations/__init__.py +0 -2
- careamics/config/transformations/n2v_manipulate_model.py +15 -0
- careamics/config/transformations/transform_unions.py +0 -13
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +3 -10
- careamics/dataset/in_memory_pred_dataset.py +3 -5
- careamics/dataset/in_memory_tiled_pred_dataset.py +2 -2
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +3 -5
- careamics/dataset/iterable_tiled_pred_dataset.py +3 -3
- careamics/dataset_ng/dataset/__init__.py +3 -0
- careamics/dataset_ng/dataset/dataset.py +184 -0
- careamics/dataset_ng/demo_dataset.ipynb +271 -0
- careamics/dataset_ng/demo_patch_extractor.py +53 -0
- careamics/dataset_ng/demo_patch_extractor_factory.py +37 -0
- careamics/dataset_ng/patch_extractor/__init__.py +10 -0
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +111 -0
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +9 -0
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +53 -0
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +55 -0
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +163 -0
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +140 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +29 -0
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +208 -0
- careamics/dataset_ng/patching_strategies/__init__.py +11 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +82 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +338 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +75 -0
- careamics/lightning/lightning_module.py +78 -27
- careamics/lightning/train_data_module.py +8 -39
- careamics/losses/fcn/losses.py +17 -10
- careamics/model_io/bioimage/bioimage_utils.py +5 -3
- careamics/model_io/bioimage/model_description.py +3 -3
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +2 -2
- careamics/transforms/__init__.py +2 -1
- careamics/transforms/compose.py +5 -15
- careamics/transforms/n2v_manipulate_torch.py +143 -0
- careamics/transforms/pixel_manipulation.py +1 -0
- careamics/transforms/pixel_manipulation_torch.py +418 -0
- careamics/utils/version.py +38 -0
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/METADATA +7 -8
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/RECORD +58 -41
- careamics/config/care_configuration.py +0 -100
- careamics/config/data/n2v_data_model.py +0 -193
- careamics/config/n2n_configuration.py +0 -101
- careamics/config/n2v_configuration.py +0 -266
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/WHEEL +0 -0
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,76 +2,25 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Annotated, Any, Literal, Optional, Union
|
|
4
4
|
|
|
5
|
-
from pydantic import
|
|
5
|
+
from pydantic import Field, TypeAdapter
|
|
6
6
|
|
|
7
7
|
from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
|
|
8
8
|
from careamics.config.architectures import UNetModel
|
|
9
|
-
from careamics.config.
|
|
10
|
-
from careamics.config.configuration import Configuration
|
|
11
|
-
from careamics.config.data import DataConfig, N2VDataConfig
|
|
12
|
-
from careamics.config.n2n_configuration import N2NConfiguration
|
|
13
|
-
from careamics.config.n2v_configuration import N2VConfiguration
|
|
9
|
+
from careamics.config.data import DataConfig
|
|
14
10
|
from careamics.config.support import (
|
|
15
|
-
SupportedAlgorithm,
|
|
16
11
|
SupportedArchitecture,
|
|
17
12
|
SupportedPixelManipulation,
|
|
18
13
|
SupportedTransform,
|
|
19
14
|
)
|
|
20
15
|
from careamics.config.training_model import TrainingConfig
|
|
21
16
|
from careamics.config.transformations import (
|
|
22
|
-
N2V_TRANSFORMS_UNION,
|
|
23
17
|
SPATIAL_TRANSFORMS_UNION,
|
|
24
18
|
N2VManipulateModel,
|
|
25
19
|
XYFlipModel,
|
|
26
20
|
XYRandomRotate90Model,
|
|
27
21
|
)
|
|
28
22
|
|
|
29
|
-
|
|
30
|
-
def _algorithm_config_discriminator(value: Union[dict, Configuration]) -> str:
|
|
31
|
-
"""Discriminate algorithm-specific configurations based on the algorithm.
|
|
32
|
-
|
|
33
|
-
Parameters
|
|
34
|
-
----------
|
|
35
|
-
value : Any
|
|
36
|
-
Value to discriminate.
|
|
37
|
-
|
|
38
|
-
Returns
|
|
39
|
-
-------
|
|
40
|
-
str
|
|
41
|
-
Discriminator value.
|
|
42
|
-
"""
|
|
43
|
-
if isinstance(value, dict):
|
|
44
|
-
return value["algorithm_config"]["algorithm"]
|
|
45
|
-
return value.algorithm_config.algorithm
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def configuration_factory(
|
|
49
|
-
configuration: dict[str, Any]
|
|
50
|
-
) -> Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]:
|
|
51
|
-
"""
|
|
52
|
-
Create a configuration for training CAREamics.
|
|
53
|
-
|
|
54
|
-
Parameters
|
|
55
|
-
----------
|
|
56
|
-
configuration : dict
|
|
57
|
-
Configuration dictionary.
|
|
58
|
-
|
|
59
|
-
Returns
|
|
60
|
-
-------
|
|
61
|
-
N2VConfiguration or N2NConfiguration or CAREConfiguration
|
|
62
|
-
Configuration for training CAREamics.
|
|
63
|
-
"""
|
|
64
|
-
adapter: TypeAdapter = TypeAdapter(
|
|
65
|
-
Annotated[
|
|
66
|
-
Union[
|
|
67
|
-
Annotated[N2VConfiguration, Tag(SupportedAlgorithm.N2V.value)],
|
|
68
|
-
Annotated[N2NConfiguration, Tag(SupportedAlgorithm.N2N.value)],
|
|
69
|
-
Annotated[CAREConfiguration, Tag(SupportedAlgorithm.CARE.value)],
|
|
70
|
-
],
|
|
71
|
-
Discriminator(_algorithm_config_discriminator),
|
|
72
|
-
]
|
|
73
|
-
)
|
|
74
|
-
return adapter.validate_python(configuration)
|
|
23
|
+
from .configuration import Configuration
|
|
75
24
|
|
|
76
25
|
|
|
77
26
|
def algorithm_factory(
|
|
@@ -90,28 +39,15 @@ def algorithm_factory(
|
|
|
90
39
|
N2VAlgorithm or N2NAlgorithm or CAREAlgorithm
|
|
91
40
|
Algorithm model for training CAREamics.
|
|
92
41
|
"""
|
|
93
|
-
adapter: TypeAdapter = TypeAdapter(
|
|
42
|
+
adapter: TypeAdapter = TypeAdapter(
|
|
43
|
+
Annotated[
|
|
44
|
+
Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm],
|
|
45
|
+
Field(discriminator="algorithm"),
|
|
46
|
+
]
|
|
47
|
+
)
|
|
94
48
|
return adapter.validate_python(algorithm)
|
|
95
49
|
|
|
96
50
|
|
|
97
|
-
def data_factory(data: dict[str, Any]) -> Union[DataConfig, N2VDataConfig]:
|
|
98
|
-
"""
|
|
99
|
-
Create a data model for training CAREamics.
|
|
100
|
-
|
|
101
|
-
Parameters
|
|
102
|
-
----------
|
|
103
|
-
data : dict
|
|
104
|
-
Data dictionary.
|
|
105
|
-
|
|
106
|
-
Returns
|
|
107
|
-
-------
|
|
108
|
-
DataConfig or N2VDataConfig
|
|
109
|
-
Data model for training CAREamics.
|
|
110
|
-
"""
|
|
111
|
-
adapter: TypeAdapter = TypeAdapter(Union[DataConfig, N2VDataConfig])
|
|
112
|
-
return adapter.validate_python(data)
|
|
113
|
-
|
|
114
|
-
|
|
115
51
|
def _list_spatial_augmentations(
|
|
116
52
|
augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]],
|
|
117
53
|
) -> list[SPATIAL_TRANSFORMS_UNION]:
|
|
@@ -208,70 +144,42 @@ def _create_unet_configuration(
|
|
|
208
144
|
)
|
|
209
145
|
|
|
210
146
|
|
|
211
|
-
def
|
|
212
|
-
algorithm: Literal["n2v", "care", "n2n"],
|
|
213
|
-
experiment_name: str,
|
|
214
|
-
data_type: Literal["array", "tiff", "custom"],
|
|
147
|
+
def _create_algorithm_configuration(
|
|
215
148
|
axes: str,
|
|
216
|
-
|
|
217
|
-
batch_size: int,
|
|
218
|
-
num_epochs: int,
|
|
219
|
-
augmentations: Union[list[N2V_TRANSFORMS_UNION], list[SPATIAL_TRANSFORMS_UNION]],
|
|
220
|
-
independent_channels: bool,
|
|
149
|
+
algorithm: Literal["n2v", "care", "n2n"],
|
|
221
150
|
loss: Literal["n2v", "mae", "mse"],
|
|
151
|
+
independent_channels: bool,
|
|
222
152
|
n_channels_in: int,
|
|
223
153
|
n_channels_out: int,
|
|
224
|
-
logger: Literal["wandb", "tensorboard", "none"],
|
|
225
154
|
use_n2v2: bool = False,
|
|
226
155
|
model_params: Optional[dict] = None,
|
|
227
|
-
|
|
228
|
-
val_dataloader_params: Optional[dict[str, Any]] = None,
|
|
229
|
-
) -> Configuration:
|
|
156
|
+
) -> dict:
|
|
230
157
|
"""
|
|
231
|
-
Create a
|
|
158
|
+
Create a dictionary with the parameters of the algorithm model.
|
|
232
159
|
|
|
233
160
|
Parameters
|
|
234
161
|
----------
|
|
162
|
+
axes : str
|
|
163
|
+
Axes of the data.
|
|
235
164
|
algorithm : {"n2v", "care", "n2n"}
|
|
236
165
|
Algorithm to use.
|
|
237
|
-
experiment_name : str
|
|
238
|
-
Name of the experiment.
|
|
239
|
-
data_type : {"array", "tiff", "custom"}
|
|
240
|
-
Type of the data.
|
|
241
|
-
axes : str
|
|
242
|
-
Axes of the data (e.g. SYX).
|
|
243
|
-
patch_size : list of int
|
|
244
|
-
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
245
|
-
batch_size : int
|
|
246
|
-
Batch size.
|
|
247
|
-
num_epochs : int
|
|
248
|
-
Number of epochs.
|
|
249
|
-
augmentations : list of transforms
|
|
250
|
-
List of transforms to apply, either both or one of XYFlipModel and
|
|
251
|
-
XYRandomRotate90Model.
|
|
252
|
-
independent_channels : bool
|
|
253
|
-
Whether to train all channels independently.
|
|
254
166
|
loss : {"n2v", "mae", "mse"}
|
|
255
167
|
Loss function to use.
|
|
168
|
+
independent_channels : bool
|
|
169
|
+
Whether to train all channels independently.
|
|
256
170
|
n_channels_in : int
|
|
257
|
-
Number of channels
|
|
171
|
+
Number of input channels.
|
|
258
172
|
n_channels_out : int
|
|
259
|
-
Number of channels
|
|
260
|
-
logger : {"wandb", "tensorboard", "none"}
|
|
261
|
-
Logger to use.
|
|
173
|
+
Number of output channels.
|
|
262
174
|
use_n2v2 : bool, optional
|
|
263
175
|
Whether to use N2V2, by default False.
|
|
264
176
|
model_params : dict
|
|
265
177
|
UNetModel parameters.
|
|
266
|
-
train_dataloader_params : dict
|
|
267
|
-
Parameters for the training dataloader, see PyTorch notes, by default None.
|
|
268
|
-
val_dataloader_params : dict
|
|
269
|
-
Parameters for the validation dataloader, see PyTorch notes, by default None.
|
|
270
178
|
|
|
271
179
|
Returns
|
|
272
180
|
-------
|
|
273
|
-
|
|
274
|
-
|
|
181
|
+
dict
|
|
182
|
+
Algorithm model as dictionnary with the specified parameters.
|
|
275
183
|
"""
|
|
276
184
|
# model
|
|
277
185
|
unet_model = _create_unet_configuration(
|
|
@@ -283,13 +191,47 @@ def _create_configuration(
|
|
|
283
191
|
model_params=model_params,
|
|
284
192
|
)
|
|
285
193
|
|
|
286
|
-
|
|
287
|
-
algorithm_config = {
|
|
194
|
+
return {
|
|
288
195
|
"algorithm": algorithm,
|
|
289
196
|
"loss": loss,
|
|
290
197
|
"model": unet_model,
|
|
291
198
|
}
|
|
292
199
|
|
|
200
|
+
|
|
201
|
+
def _create_data_configuration(
|
|
202
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
203
|
+
axes: str,
|
|
204
|
+
patch_size: list[int],
|
|
205
|
+
batch_size: int,
|
|
206
|
+
augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
|
|
207
|
+
train_dataloader_params: Optional[dict[str, Any]] = None,
|
|
208
|
+
val_dataloader_params: Optional[dict[str, Any]] = None,
|
|
209
|
+
) -> DataConfig:
|
|
210
|
+
"""
|
|
211
|
+
Create a dictionary with the parameters of the data model.
|
|
212
|
+
|
|
213
|
+
Parameters
|
|
214
|
+
----------
|
|
215
|
+
data_type : {"array", "tiff", "custom"}
|
|
216
|
+
Type of the data.
|
|
217
|
+
axes : str
|
|
218
|
+
Axes of the data.
|
|
219
|
+
patch_size : list of int
|
|
220
|
+
Size of the patches along the spatial dimensions.
|
|
221
|
+
batch_size : int
|
|
222
|
+
Batch size.
|
|
223
|
+
augmentations : list of transforms
|
|
224
|
+
List of transforms to apply.
|
|
225
|
+
train_dataloader_params : dict
|
|
226
|
+
Parameters for the training dataloader, see PyTorch notes, by default None.
|
|
227
|
+
val_dataloader_params : dict
|
|
228
|
+
Parameters for the validation dataloader, see PyTorch notes, by default None.
|
|
229
|
+
|
|
230
|
+
Returns
|
|
231
|
+
-------
|
|
232
|
+
DataConfig
|
|
233
|
+
Data model with the specified parameters.
|
|
234
|
+
"""
|
|
293
235
|
# data model
|
|
294
236
|
data = {
|
|
295
237
|
"data_type": data_type,
|
|
@@ -300,30 +242,44 @@ def _create_configuration(
|
|
|
300
242
|
}
|
|
301
243
|
# Don't override defaults set in DataConfig class
|
|
302
244
|
if train_dataloader_params is not None:
|
|
245
|
+
# DataConfig enforces the presence of `shuffle` key in the dataloader parameters
|
|
246
|
+
if "shuffle" not in train_dataloader_params:
|
|
247
|
+
train_dataloader_params["shuffle"] = True
|
|
248
|
+
|
|
303
249
|
data["train_dataloader_params"] = train_dataloader_params
|
|
250
|
+
|
|
304
251
|
if val_dataloader_params is not None:
|
|
305
252
|
data["val_dataloader_params"] = val_dataloader_params
|
|
306
253
|
|
|
307
|
-
|
|
308
|
-
|
|
254
|
+
return DataConfig(**data)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _create_training_configuration(
|
|
258
|
+
num_epochs: int, logger: Literal["wandb", "tensorboard", "none"]
|
|
259
|
+
) -> TrainingConfig:
|
|
260
|
+
"""
|
|
261
|
+
Create a dictionary with the parameters of the training model.
|
|
262
|
+
|
|
263
|
+
Parameters
|
|
264
|
+
----------
|
|
265
|
+
num_epochs : int
|
|
266
|
+
Number of epochs.
|
|
267
|
+
logger : {"wandb", "tensorboard", "none"}
|
|
268
|
+
Logger to use.
|
|
269
|
+
|
|
270
|
+
Returns
|
|
271
|
+
-------
|
|
272
|
+
TrainingConfig
|
|
273
|
+
Training model with the specified parameters.
|
|
274
|
+
"""
|
|
275
|
+
return TrainingConfig(
|
|
309
276
|
num_epochs=num_epochs,
|
|
310
|
-
batch_size=batch_size,
|
|
311
277
|
logger=None if logger == "none" else logger,
|
|
312
278
|
)
|
|
313
279
|
|
|
314
|
-
# create configuration
|
|
315
|
-
configuration = {
|
|
316
|
-
"experiment_name": experiment_name,
|
|
317
|
-
"algorithm_config": algorithm_config,
|
|
318
|
-
"data_config": data,
|
|
319
|
-
"training_config": training,
|
|
320
|
-
}
|
|
321
|
-
|
|
322
|
-
return configuration_factory(configuration)
|
|
323
|
-
|
|
324
280
|
|
|
325
281
|
# TODO reconsider naming once we officially support LVAE approaches
|
|
326
|
-
def
|
|
282
|
+
def _create_supervised_config_dict(
|
|
327
283
|
algorithm: Literal["care", "n2n"],
|
|
328
284
|
experiment_name: str,
|
|
329
285
|
data_type: Literal["array", "tiff", "custom"],
|
|
@@ -331,7 +287,7 @@ def _create_supervised_configuration(
|
|
|
331
287
|
patch_size: list[int],
|
|
332
288
|
batch_size: int,
|
|
333
289
|
num_epochs: int,
|
|
334
|
-
augmentations: Optional[list[
|
|
290
|
+
augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]] = None,
|
|
335
291
|
independent_channels: bool = True,
|
|
336
292
|
loss: Literal["mae", "mse"] = "mae",
|
|
337
293
|
n_channels_in: Optional[int] = None,
|
|
@@ -340,7 +296,7 @@ def _create_supervised_configuration(
|
|
|
340
296
|
model_params: Optional[dict] = None,
|
|
341
297
|
train_dataloader_params: Optional[dict[str, Any]] = None,
|
|
342
298
|
val_dataloader_params: Optional[dict[str, Any]] = None,
|
|
343
|
-
) ->
|
|
299
|
+
) -> dict:
|
|
344
300
|
"""
|
|
345
301
|
Create a configuration for training CARE or Noise2Noise.
|
|
346
302
|
|
|
@@ -411,25 +367,41 @@ def _create_supervised_configuration(
|
|
|
411
367
|
# augmentations
|
|
412
368
|
spatial_transform_list = _list_spatial_augmentations(augmentations)
|
|
413
369
|
|
|
414
|
-
|
|
370
|
+
# algorithm
|
|
371
|
+
algorithm_params = _create_algorithm_configuration(
|
|
372
|
+
axes=axes,
|
|
415
373
|
algorithm=algorithm,
|
|
416
|
-
|
|
374
|
+
loss=loss,
|
|
375
|
+
independent_channels=independent_channels,
|
|
376
|
+
n_channels_in=n_channels_in,
|
|
377
|
+
n_channels_out=n_channels_out,
|
|
378
|
+
model_params=model_params,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# data
|
|
382
|
+
data_params = _create_data_configuration(
|
|
417
383
|
data_type=data_type,
|
|
418
384
|
axes=axes,
|
|
419
385
|
patch_size=patch_size,
|
|
420
386
|
batch_size=batch_size,
|
|
421
|
-
num_epochs=num_epochs,
|
|
422
387
|
augmentations=spatial_transform_list,
|
|
423
|
-
independent_channels=independent_channels,
|
|
424
|
-
loss=loss,
|
|
425
|
-
n_channels_in=n_channels_in,
|
|
426
|
-
n_channels_out=n_channels_out,
|
|
427
|
-
logger=logger,
|
|
428
|
-
model_params=model_params,
|
|
429
388
|
train_dataloader_params=train_dataloader_params,
|
|
430
389
|
val_dataloader_params=val_dataloader_params,
|
|
431
390
|
)
|
|
432
391
|
|
|
392
|
+
# training
|
|
393
|
+
training_params = _create_training_configuration(
|
|
394
|
+
num_epochs=num_epochs,
|
|
395
|
+
logger=logger,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
return {
|
|
399
|
+
"experiment_name": experiment_name,
|
|
400
|
+
"algorithm_config": algorithm_params,
|
|
401
|
+
"data_config": data_params,
|
|
402
|
+
"training_config": training_params,
|
|
403
|
+
}
|
|
404
|
+
|
|
433
405
|
|
|
434
406
|
def create_care_configuration(
|
|
435
407
|
experiment_name: str,
|
|
@@ -580,23 +552,25 @@ def create_care_configuration(
|
|
|
580
552
|
... n_channels_out=1 # if applicable
|
|
581
553
|
... )
|
|
582
554
|
"""
|
|
583
|
-
return
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
555
|
+
return Configuration(
|
|
556
|
+
**_create_supervised_config_dict(
|
|
557
|
+
algorithm="care",
|
|
558
|
+
experiment_name=experiment_name,
|
|
559
|
+
data_type=data_type,
|
|
560
|
+
axes=axes,
|
|
561
|
+
patch_size=patch_size,
|
|
562
|
+
batch_size=batch_size,
|
|
563
|
+
num_epochs=num_epochs,
|
|
564
|
+
augmentations=augmentations,
|
|
565
|
+
independent_channels=independent_channels,
|
|
566
|
+
loss=loss,
|
|
567
|
+
n_channels_in=n_channels_in,
|
|
568
|
+
n_channels_out=n_channels_out,
|
|
569
|
+
logger=logger,
|
|
570
|
+
model_params=model_params,
|
|
571
|
+
train_dataloader_params=train_dataloader_params,
|
|
572
|
+
val_dataloader_params=val_dataloader_params,
|
|
573
|
+
)
|
|
600
574
|
)
|
|
601
575
|
|
|
602
576
|
|
|
@@ -749,23 +723,25 @@ def create_n2n_configuration(
|
|
|
749
723
|
... n_channels_out=1 # if applicable
|
|
750
724
|
... )
|
|
751
725
|
"""
|
|
752
|
-
return
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
726
|
+
return Configuration(
|
|
727
|
+
**_create_supervised_config_dict(
|
|
728
|
+
algorithm="n2n",
|
|
729
|
+
experiment_name=experiment_name,
|
|
730
|
+
data_type=data_type,
|
|
731
|
+
axes=axes,
|
|
732
|
+
patch_size=patch_size,
|
|
733
|
+
batch_size=batch_size,
|
|
734
|
+
num_epochs=num_epochs,
|
|
735
|
+
augmentations=augmentations,
|
|
736
|
+
independent_channels=independent_channels,
|
|
737
|
+
loss=loss,
|
|
738
|
+
n_channels_in=n_channels_in,
|
|
739
|
+
n_channels_out=n_channels_out,
|
|
740
|
+
logger=logger,
|
|
741
|
+
model_params=model_params,
|
|
742
|
+
train_dataloader_params=train_dataloader_params,
|
|
743
|
+
val_dataloader_params=val_dataloader_params,
|
|
744
|
+
)
|
|
769
745
|
)
|
|
770
746
|
|
|
771
747
|
|
|
@@ -995,24 +971,40 @@ def create_n2v_configuration(
|
|
|
995
971
|
struct_mask_axis=struct_n2v_axis,
|
|
996
972
|
struct_mask_span=struct_n2v_span,
|
|
997
973
|
)
|
|
998
|
-
transform_list: list[N2V_TRANSFORMS_UNION] = spatial_transforms + [n2v_transform]
|
|
999
974
|
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
experiment_name=experiment_name,
|
|
1003
|
-
data_type=data_type,
|
|
975
|
+
# algorithm
|
|
976
|
+
algorithm_params = _create_algorithm_configuration(
|
|
1004
977
|
axes=axes,
|
|
1005
|
-
|
|
1006
|
-
batch_size=batch_size,
|
|
1007
|
-
num_epochs=num_epochs,
|
|
1008
|
-
augmentations=transform_list,
|
|
1009
|
-
independent_channels=independent_channels,
|
|
978
|
+
algorithm="n2v",
|
|
1010
979
|
loss="n2v",
|
|
1011
|
-
|
|
980
|
+
independent_channels=independent_channels,
|
|
1012
981
|
n_channels_in=n_channels,
|
|
1013
982
|
n_channels_out=n_channels,
|
|
1014
|
-
|
|
983
|
+
use_n2v2=use_n2v2,
|
|
1015
984
|
model_params=model_params,
|
|
985
|
+
)
|
|
986
|
+
algorithm_params["n2v_config"] = n2v_transform
|
|
987
|
+
|
|
988
|
+
# data
|
|
989
|
+
data_params = _create_data_configuration(
|
|
990
|
+
data_type=data_type,
|
|
991
|
+
axes=axes,
|
|
992
|
+
patch_size=patch_size,
|
|
993
|
+
batch_size=batch_size,
|
|
994
|
+
augmentations=spatial_transforms,
|
|
1016
995
|
train_dataloader_params=train_dataloader_params,
|
|
1017
996
|
val_dataloader_params=val_dataloader_params,
|
|
1018
997
|
)
|
|
998
|
+
|
|
999
|
+
# training
|
|
1000
|
+
training_params = _create_training_configuration(
|
|
1001
|
+
num_epochs=num_epochs,
|
|
1002
|
+
logger=logger,
|
|
1003
|
+
)
|
|
1004
|
+
|
|
1005
|
+
return Configuration(
|
|
1006
|
+
experiment_name=experiment_name,
|
|
1007
|
+
algorithm_config=algorithm_params,
|
|
1008
|
+
data_config=data_params,
|
|
1009
|
+
training_config=training_params,
|
|
1010
|
+
)
|
|
@@ -5,7 +5,7 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
import yaml
|
|
7
7
|
|
|
8
|
-
from careamics.config import Configuration
|
|
8
|
+
from careamics.config import Configuration
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def load_configuration(path: Union[str, Path]) -> Configuration:
|
|
@@ -35,7 +35,7 @@ def load_configuration(path: Union[str, Path]) -> Configuration:
|
|
|
35
35
|
|
|
36
36
|
dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
|
|
37
37
|
|
|
38
|
-
return
|
|
38
|
+
return Configuration(**dictionary)
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
|
|
@@ -19,7 +19,7 @@ from pydantic import (
|
|
|
19
19
|
)
|
|
20
20
|
from typing_extensions import Self
|
|
21
21
|
|
|
22
|
-
from ..transformations import
|
|
22
|
+
from ..transformations import XYFlipModel, XYRandomRotate90Model
|
|
23
23
|
from ..validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
24
24
|
|
|
25
25
|
|
|
@@ -46,8 +46,46 @@ Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type
|
|
|
46
46
|
"""Annotated float type, used to serialize floats to strings."""
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
class
|
|
50
|
-
"""
|
|
49
|
+
class DataConfig(BaseModel):
|
|
50
|
+
"""Data configuration.
|
|
51
|
+
|
|
52
|
+
If std is specified, mean must be specified as well. Note that setting the std first
|
|
53
|
+
and then the mean (if they were both `None` before) will raise a validation error.
|
|
54
|
+
Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
|
|
55
|
+
to be lists of floats, one for each channel. For supervised tasks, the mean and std
|
|
56
|
+
of the target could be different from the input data.
|
|
57
|
+
|
|
58
|
+
All supported transforms are defined in the SupportedTransform enum.
|
|
59
|
+
|
|
60
|
+
Examples
|
|
61
|
+
--------
|
|
62
|
+
Minimum example:
|
|
63
|
+
|
|
64
|
+
>>> data = DataConfig(
|
|
65
|
+
... data_type="array", # defined in SupportedData
|
|
66
|
+
... patch_size=[128, 128],
|
|
67
|
+
... batch_size=4,
|
|
68
|
+
... axes="YX"
|
|
69
|
+
... )
|
|
70
|
+
|
|
71
|
+
To change the image_means and image_stds of the data:
|
|
72
|
+
>>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
|
|
73
|
+
|
|
74
|
+
One can pass also a list of transformations, by keyword, using the
|
|
75
|
+
SupportedTransform value:
|
|
76
|
+
>>> from careamics.config.support import SupportedTransform
|
|
77
|
+
>>> data = DataConfig(
|
|
78
|
+
... data_type="tiff",
|
|
79
|
+
... patch_size=[128, 128],
|
|
80
|
+
... batch_size=4,
|
|
81
|
+
... axes="YX",
|
|
82
|
+
... transforms=[
|
|
83
|
+
... {
|
|
84
|
+
... "name": "XYFlip",
|
|
85
|
+
... }
|
|
86
|
+
... ]
|
|
87
|
+
... )
|
|
88
|
+
"""
|
|
51
89
|
|
|
52
90
|
# Pydantic class configuration
|
|
53
91
|
model_config = ConfigDict(
|
|
@@ -88,10 +126,7 @@ class GeneralDataConfig(BaseModel):
|
|
|
88
126
|
"""Standard deviations of the target data across channels, used for
|
|
89
127
|
normalization."""
|
|
90
128
|
|
|
91
|
-
|
|
92
|
-
# complaining, this is important for instance to differentiate N2VDataConfig and
|
|
93
|
-
# DataConfig
|
|
94
|
-
transforms: Sequence[N2V_TRANSFORMS_UNION] = Field(
|
|
129
|
+
transforms: Sequence[Union[XYFlipModel, XYRandomRotate90Model]] = Field(
|
|
95
130
|
default=[
|
|
96
131
|
XYFlipModel(),
|
|
97
132
|
XYRandomRotate90Model(),
|
|
@@ -104,7 +139,9 @@ class GeneralDataConfig(BaseModel):
|
|
|
104
139
|
train_dataloader_params: dict[str, Any] = Field(
|
|
105
140
|
default={"shuffle": True}, validate_default=True
|
|
106
141
|
)
|
|
107
|
-
"""Dictionary of PyTorch training dataloader parameters.
|
|
142
|
+
"""Dictionary of PyTorch training dataloader parameters. The dataloader parameters,
|
|
143
|
+
should include the `shuffle` key, which is set to `True` by default. We strongly
|
|
144
|
+
recommend to keep it as `True` to ensure the best training results."""
|
|
108
145
|
|
|
109
146
|
val_dataloader_params: dict[str, Any] = Field(default={})
|
|
110
147
|
"""Dictionary of PyTorch validation dataloader parameters."""
|
|
@@ -207,7 +244,7 @@ class GeneralDataConfig(BaseModel):
|
|
|
207
244
|
):
|
|
208
245
|
warn(
|
|
209
246
|
"Dataloader parameters include `shuffle=False`, this will be passed to "
|
|
210
|
-
"the training dataloader and may
|
|
247
|
+
"the training dataloader and may lead to lower quality results.",
|
|
211
248
|
stacklevel=1,
|
|
212
249
|
)
|
|
213
250
|
return train_dataloader_params
|
|
@@ -363,56 +400,3 @@ class GeneralDataConfig(BaseModel):
|
|
|
363
400
|
Patch size.
|
|
364
401
|
"""
|
|
365
402
|
self._update(axes=axes, patch_size=patch_size)
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
class DataConfig(GeneralDataConfig):
|
|
369
|
-
"""
|
|
370
|
-
Data configuration.
|
|
371
|
-
|
|
372
|
-
If std is specified, mean must be specified as well. Note that setting the std first
|
|
373
|
-
and then the mean (if they were both `None` before) will raise a validation error.
|
|
374
|
-
Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
|
|
375
|
-
to be lists of floats, one for each channel. For supervised tasks, the mean and std
|
|
376
|
-
of the target could be different from the input data.
|
|
377
|
-
|
|
378
|
-
All supported transforms are defined in the SupportedTransform enum.
|
|
379
|
-
|
|
380
|
-
Examples
|
|
381
|
-
--------
|
|
382
|
-
Minimum example:
|
|
383
|
-
|
|
384
|
-
>>> data = DataConfig(
|
|
385
|
-
... data_type="array", # defined in SupportedData
|
|
386
|
-
... patch_size=[128, 128],
|
|
387
|
-
... batch_size=4,
|
|
388
|
-
... axes="YX"
|
|
389
|
-
... )
|
|
390
|
-
|
|
391
|
-
To change the image_means and image_stds of the data:
|
|
392
|
-
>>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
|
|
393
|
-
|
|
394
|
-
One can pass also a list of transformations, by keyword, using the
|
|
395
|
-
SupportedTransform value:
|
|
396
|
-
>>> from careamics.config.support import SupportedTransform
|
|
397
|
-
>>> data = DataConfig(
|
|
398
|
-
... data_type="tiff",
|
|
399
|
-
... patch_size=[128, 128],
|
|
400
|
-
... batch_size=4,
|
|
401
|
-
... axes="YX",
|
|
402
|
-
... transforms=[
|
|
403
|
-
... {
|
|
404
|
-
... "name": "XYFlip",
|
|
405
|
-
... }
|
|
406
|
-
... ]
|
|
407
|
-
... )
|
|
408
|
-
"""
|
|
409
|
-
|
|
410
|
-
transforms: Sequence[Union[XYFlipModel, XYRandomRotate90Model]] = Field(
|
|
411
|
-
default=[
|
|
412
|
-
XYFlipModel(),
|
|
413
|
-
XYRandomRotate90Model(),
|
|
414
|
-
],
|
|
415
|
-
validate_default=True,
|
|
416
|
-
)
|
|
417
|
-
"""List of transformations to apply to the data, available transforms are defined
|
|
418
|
-
in SupportedTransform. This excludes N2V specific transformations."""
|