careamics 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +4 -3
- careamics/cli/utils.py +1 -1
- careamics/config/algorithms/n2v_algorithm_model.py +1 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/callback_model.py +23 -34
- careamics/config/configuration.py +47 -1
- careamics/config/configuration_factories.py +288 -23
- careamics/config/data/__init__.py +2 -0
- careamics/config/data/data_model.py +3 -3
- careamics/config/data/ng_data_model.py +381 -0
- careamics/config/data/patching_strategies/__init__.py +14 -0
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
- careamics/config/data/patching_strategies/_patched_model.py +56 -0
- careamics/config/data/patching_strategies/random_patching_model.py +21 -0
- careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
- careamics/config/inference_model.py +6 -3
- careamics/config/support/supported_data.py +7 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/validators/validator_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +2 -1
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +2 -2
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
- careamics/dataset/patching/patching.py +3 -2
- careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
- careamics/dataset/tiling/tiled_patching.py +2 -1
- careamics/dataset_ng/dataset.py +46 -50
- careamics/dataset_ng/demos/bsd68_demo.ipynb +28 -23
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +1 -1
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +1 -1
- careamics/dataset_ng/demos/demo_datamodule.ipynb +50 -46
- careamics/dataset_ng/demos/demo_dataset.ipynb +32 -49
- careamics/dataset_ng/factory.py +58 -15
- careamics/dataset_ng/legacy_interoperability.py +3 -1
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +1 -1
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -0
- careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +43 -1
- careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
- careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +2 -1
- careamics/file_io/read/get_func.py +2 -1
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/data_module.py +218 -28
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +44 -5
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +42 -3
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +73 -4
- careamics/lightning/lightning_module.py +2 -1
- careamics/lightning/predict_data_module.py +2 -1
- careamics/lightning/train_data_module.py +2 -1
- careamics/losses/loss_factory.py +2 -1
- careamics/lvae_training/dataset/multicrop_dset.py +1 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +1 -1
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +2 -2
- careamics/models/activation.py +2 -1
- careamics/prediction_utils/prediction_outputs.py +1 -1
- careamics/prediction_utils/stitch_prediction.py +1 -1
- careamics/transforms/n2v_manipulate_torch.py +15 -9
- careamics/transforms/pixel_manipulation_torch.py +59 -92
- careamics/utils/lightning_utils.py +2 -2
- careamics/utils/metrics.py +2 -1
- careamics/utils/torch_utils.py +23 -0
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/METADATA +10 -9
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/RECORD +73 -62
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.12.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
"""Convenience functions to create configurations for training and inference."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Sequence
|
|
3
4
|
from typing import Annotated, Any, Literal, Optional, Union
|
|
4
5
|
|
|
5
6
|
from pydantic import Field, TypeAdapter
|
|
6
7
|
|
|
7
8
|
from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
|
|
8
9
|
from careamics.config.architectures import UNetModel
|
|
9
|
-
from careamics.config.data import DataConfig
|
|
10
|
+
from careamics.config.data import DataConfig, NGDataConfig
|
|
10
11
|
from careamics.config.support import (
|
|
11
12
|
SupportedArchitecture,
|
|
12
13
|
SupportedPixelManipulation,
|
|
@@ -24,7 +25,7 @@ from .configuration import Configuration
|
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
def algorithm_factory(
|
|
27
|
-
algorithm: dict[str, Any]
|
|
28
|
+
algorithm: dict[str, Any],
|
|
28
29
|
) -> Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm]:
|
|
29
30
|
"""
|
|
30
31
|
Create an algorithm model for training CAREamics.
|
|
@@ -49,7 +50,7 @@ def algorithm_factory(
|
|
|
49
50
|
|
|
50
51
|
|
|
51
52
|
def _list_spatial_augmentations(
|
|
52
|
-
augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]],
|
|
53
|
+
augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]] = None,
|
|
53
54
|
) -> list[SPATIAL_TRANSFORMS_UNION]:
|
|
54
55
|
"""
|
|
55
56
|
List the augmentations to apply.
|
|
@@ -153,6 +154,10 @@ def _create_algorithm_configuration(
|
|
|
153
154
|
n_channels_out: int,
|
|
154
155
|
use_n2v2: bool = False,
|
|
155
156
|
model_params: Optional[dict] = None,
|
|
157
|
+
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
158
|
+
optimizer_params: Optional[dict[str, Any]] = None,
|
|
159
|
+
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
160
|
+
lr_scheduler_params: Optional[dict[str, Any]] = None,
|
|
156
161
|
) -> dict:
|
|
157
162
|
"""
|
|
158
163
|
Create a dictionary with the parameters of the algorithm model.
|
|
@@ -171,10 +176,20 @@ def _create_algorithm_configuration(
|
|
|
171
176
|
Number of input channels.
|
|
172
177
|
n_channels_out : int
|
|
173
178
|
Number of output channels.
|
|
174
|
-
use_n2v2 : bool,
|
|
175
|
-
Whether to use N2V2
|
|
176
|
-
model_params : dict
|
|
179
|
+
use_n2v2 : bool, default=false
|
|
180
|
+
Whether to use N2V2.
|
|
181
|
+
model_params : dict, default=None
|
|
177
182
|
UNetModel parameters.
|
|
183
|
+
optimizer : {"Adam", "Adamax", "SGD"}, default="Adam"
|
|
184
|
+
Optimizer to use.
|
|
185
|
+
optimizer_params : dict, default=None
|
|
186
|
+
Parameters for the optimizer, see PyTorch documentation for more details.
|
|
187
|
+
lr_scheduler : {"ReduceLROnPlateau", "StepLR"}, default="ReduceLROnPlateau"
|
|
188
|
+
Learning rate scheduler to use.
|
|
189
|
+
lr_scheduler_params : dict, default=None
|
|
190
|
+
Parameters for the learning rate scheduler, see PyTorch documentation for more
|
|
191
|
+
details.
|
|
192
|
+
|
|
178
193
|
|
|
179
194
|
Returns
|
|
180
195
|
-------
|
|
@@ -195,11 +210,19 @@ def _create_algorithm_configuration(
|
|
|
195
210
|
"algorithm": algorithm,
|
|
196
211
|
"loss": loss,
|
|
197
212
|
"model": unet_model,
|
|
213
|
+
"optimizer": {
|
|
214
|
+
"name": optimizer,
|
|
215
|
+
"parameters": {} if optimizer_params is None else optimizer_params,
|
|
216
|
+
},
|
|
217
|
+
"lr_scheduler": {
|
|
218
|
+
"name": lr_scheduler,
|
|
219
|
+
"parameters": {} if lr_scheduler_params is None else lr_scheduler_params,
|
|
220
|
+
},
|
|
198
221
|
}
|
|
199
222
|
|
|
200
223
|
|
|
201
224
|
def _create_data_configuration(
|
|
202
|
-
data_type: Literal["array", "tiff", "custom"],
|
|
225
|
+
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
203
226
|
axes: str,
|
|
204
227
|
patch_size: list[int],
|
|
205
228
|
batch_size: int,
|
|
@@ -212,7 +235,7 @@ def _create_data_configuration(
|
|
|
212
235
|
|
|
213
236
|
Parameters
|
|
214
237
|
----------
|
|
215
|
-
data_type : {"array", "tiff", "custom"}
|
|
238
|
+
data_type : {"array", "tiff", "czi", "custom"}
|
|
216
239
|
Type of the data.
|
|
217
240
|
axes : str
|
|
218
241
|
Axes of the data.
|
|
@@ -254,8 +277,89 @@ def _create_data_configuration(
|
|
|
254
277
|
return DataConfig(**data)
|
|
255
278
|
|
|
256
279
|
|
|
280
|
+
def _create_ng_data_configuration(
|
|
281
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
282
|
+
axes: str,
|
|
283
|
+
patch_size: Sequence[int],
|
|
284
|
+
batch_size: int,
|
|
285
|
+
augmentations: list[SPATIAL_TRANSFORMS_UNION],
|
|
286
|
+
patch_overlaps: Optional[Sequence[int]] = None,
|
|
287
|
+
train_dataloader_params: Optional[dict[str, Any]] = None,
|
|
288
|
+
val_dataloader_params: Optional[dict[str, Any]] = None,
|
|
289
|
+
test_dataloader_params: Optional[dict[str, Any]] = None,
|
|
290
|
+
seed: Optional[int] = None,
|
|
291
|
+
) -> NGDataConfig:
|
|
292
|
+
"""
|
|
293
|
+
Create a dictionary with the parameters of the data model.
|
|
294
|
+
|
|
295
|
+
Parameters
|
|
296
|
+
----------
|
|
297
|
+
data_type : {"array", "tiff", "custom"}
|
|
298
|
+
Type of the data.
|
|
299
|
+
axes : str
|
|
300
|
+
Axes of the data.
|
|
301
|
+
patch_size : list of int
|
|
302
|
+
Size of the patches along the spatial dimensions.
|
|
303
|
+
batch_size : int
|
|
304
|
+
Batch size.
|
|
305
|
+
augmentations : list of transforms
|
|
306
|
+
List of transforms to apply.
|
|
307
|
+
patch_overlaps : Sequence of int, default=None
|
|
308
|
+
Overlaps between patches in each spatial dimension, only used with "sequential"
|
|
309
|
+
patching. If `None`, no overlap is applied. The overlap must be smaller than
|
|
310
|
+
the patch size in each spatial dimension, and the number of dimensions be either
|
|
311
|
+
2 or 3.
|
|
312
|
+
train_dataloader_params : dict
|
|
313
|
+
Parameters for the training dataloader, see PyTorch notes, by default None.
|
|
314
|
+
val_dataloader_params : dict
|
|
315
|
+
Parameters for the validation dataloader, see PyTorch notes, by default None.
|
|
316
|
+
test_dataloader_params : dict
|
|
317
|
+
Parameters for the test dataloader, see PyTorch notes, by default None.
|
|
318
|
+
seed : int, default=None
|
|
319
|
+
Random seed for reproducibility. If `None`, no seed is set.
|
|
320
|
+
|
|
321
|
+
Returns
|
|
322
|
+
-------
|
|
323
|
+
NGDataConfig
|
|
324
|
+
Next-Generation Data model with the specified parameters.
|
|
325
|
+
"""
|
|
326
|
+
# data model
|
|
327
|
+
data = {
|
|
328
|
+
"data_type": data_type,
|
|
329
|
+
"axes": axes,
|
|
330
|
+
"batch_size": batch_size,
|
|
331
|
+
"transforms": augmentations,
|
|
332
|
+
"seed": seed,
|
|
333
|
+
}
|
|
334
|
+
# don't override defaults set in DataConfig class
|
|
335
|
+
if train_dataloader_params is not None:
|
|
336
|
+
# the presence of `shuffle` key in the dataloader parameters is enforced
|
|
337
|
+
# by the NGDataConfig class
|
|
338
|
+
if "shuffle" not in train_dataloader_params:
|
|
339
|
+
train_dataloader_params["shuffle"] = True
|
|
340
|
+
|
|
341
|
+
data["train_dataloader_params"] = train_dataloader_params
|
|
342
|
+
|
|
343
|
+
if val_dataloader_params is not None:
|
|
344
|
+
data["val_dataloader_params"] = val_dataloader_params
|
|
345
|
+
|
|
346
|
+
if test_dataloader_params is not None:
|
|
347
|
+
data["test_dataloader_params"] = test_dataloader_params
|
|
348
|
+
|
|
349
|
+
# add training patching
|
|
350
|
+
data["patching"] = {
|
|
351
|
+
"name": "random",
|
|
352
|
+
"patch_size": patch_size,
|
|
353
|
+
"overlaps": patch_overlaps,
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
return NGDataConfig(**data)
|
|
357
|
+
|
|
358
|
+
|
|
257
359
|
def _create_training_configuration(
|
|
258
|
-
num_epochs: int,
|
|
360
|
+
num_epochs: int,
|
|
361
|
+
logger: Literal["wandb", "tensorboard", "none"],
|
|
362
|
+
checkpoint_params: Optional[dict[str, Any]] = None,
|
|
259
363
|
) -> TrainingConfig:
|
|
260
364
|
"""
|
|
261
365
|
Create a dictionary with the parameters of the training model.
|
|
@@ -266,6 +370,9 @@ def _create_training_configuration(
|
|
|
266
370
|
Number of epochs.
|
|
267
371
|
logger : {"wandb", "tensorboard", "none"}
|
|
268
372
|
Logger to use.
|
|
373
|
+
checkpoint_params : dict, default=None
|
|
374
|
+
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
375
|
+
(`ModelCheckpoint`) for the list of available parameters.
|
|
269
376
|
|
|
270
377
|
Returns
|
|
271
378
|
-------
|
|
@@ -275,6 +382,7 @@ def _create_training_configuration(
|
|
|
275
382
|
return TrainingConfig(
|
|
276
383
|
num_epochs=num_epochs,
|
|
277
384
|
logger=None if logger == "none" else logger,
|
|
385
|
+
checkpoint_callback={} if checkpoint_params is None else checkpoint_params,
|
|
278
386
|
)
|
|
279
387
|
|
|
280
388
|
|
|
@@ -282,7 +390,7 @@ def _create_training_configuration(
|
|
|
282
390
|
def _create_supervised_config_dict(
|
|
283
391
|
algorithm: Literal["care", "n2n"],
|
|
284
392
|
experiment_name: str,
|
|
285
|
-
data_type: Literal["array", "tiff", "custom"],
|
|
393
|
+
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
286
394
|
axes: str,
|
|
287
395
|
patch_size: list[int],
|
|
288
396
|
batch_size: int,
|
|
@@ -294,8 +402,13 @@ def _create_supervised_config_dict(
|
|
|
294
402
|
n_channels_out: Optional[int] = None,
|
|
295
403
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
296
404
|
model_params: Optional[dict] = None,
|
|
405
|
+
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
406
|
+
optimizer_params: Optional[dict[str, Any]] = None,
|
|
407
|
+
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
408
|
+
lr_scheduler_params: Optional[dict[str, Any]] = None,
|
|
297
409
|
train_dataloader_params: Optional[dict[str, Any]] = None,
|
|
298
410
|
val_dataloader_params: Optional[dict[str, Any]] = None,
|
|
411
|
+
checkpoint_params: Optional[dict[str, Any]] = None,
|
|
299
412
|
) -> dict:
|
|
300
413
|
"""
|
|
301
414
|
Create a configuration for training CARE or Noise2Noise.
|
|
@@ -306,7 +419,7 @@ def _create_supervised_config_dict(
|
|
|
306
419
|
Algorithm to use.
|
|
307
420
|
experiment_name : str
|
|
308
421
|
Name of the experiment.
|
|
309
|
-
data_type : Literal["array", "tiff", "custom"]
|
|
422
|
+
data_type : Literal["array", "tiff", "czi", "custom"]
|
|
310
423
|
Type of the data.
|
|
311
424
|
axes : str
|
|
312
425
|
Axes of the data (e.g. SYX).
|
|
@@ -330,12 +443,24 @@ def _create_supervised_config_dict(
|
|
|
330
443
|
Number of channels out.
|
|
331
444
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
332
445
|
Logger to use, by default "none".
|
|
333
|
-
model_params : dict,
|
|
334
|
-
UNetModel parameters
|
|
446
|
+
model_params : dict, default=None
|
|
447
|
+
UNetModel parameters.
|
|
448
|
+
optimizer : {"Adam", "Adamax", "SGD"}, default="Adam"
|
|
449
|
+
Optimizer to use.
|
|
450
|
+
optimizer_params : dict, default=None
|
|
451
|
+
Parameters for the optimizer, see PyTorch documentation for more details.
|
|
452
|
+
lr_scheduler : {"ReduceLROnPlateau", "StepLR"}, default="ReduceLROnPlateau"
|
|
453
|
+
Learning rate scheduler to use.
|
|
454
|
+
lr_scheduler_params : dict, default=None
|
|
455
|
+
Parameters for the learning rate scheduler, see PyTorch documentation for more
|
|
456
|
+
details.
|
|
335
457
|
train_dataloader_params : dict
|
|
336
458
|
Parameters for the training dataloader, see PyTorch notes, by default None.
|
|
337
459
|
val_dataloader_params : dict
|
|
338
460
|
Parameters for the validation dataloader, see PyTorch notes, by default None.
|
|
461
|
+
checkpoint_params : dict, default=None
|
|
462
|
+
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
463
|
+
(`ModelCheckpoint`) for the list of available parameters.
|
|
339
464
|
|
|
340
465
|
Returns
|
|
341
466
|
-------
|
|
@@ -376,6 +501,10 @@ def _create_supervised_config_dict(
|
|
|
376
501
|
n_channels_in=n_channels_in,
|
|
377
502
|
n_channels_out=n_channels_out,
|
|
378
503
|
model_params=model_params,
|
|
504
|
+
optimizer=optimizer,
|
|
505
|
+
optimizer_params=optimizer_params,
|
|
506
|
+
lr_scheduler=lr_scheduler,
|
|
507
|
+
lr_scheduler_params=lr_scheduler_params,
|
|
379
508
|
)
|
|
380
509
|
|
|
381
510
|
# data
|
|
@@ -393,6 +522,7 @@ def _create_supervised_config_dict(
|
|
|
393
522
|
training_params = _create_training_configuration(
|
|
394
523
|
num_epochs=num_epochs,
|
|
395
524
|
logger=logger,
|
|
525
|
+
checkpoint_params=checkpoint_params,
|
|
396
526
|
)
|
|
397
527
|
|
|
398
528
|
return {
|
|
@@ -405,7 +535,7 @@ def _create_supervised_config_dict(
|
|
|
405
535
|
|
|
406
536
|
def create_care_configuration(
|
|
407
537
|
experiment_name: str,
|
|
408
|
-
data_type: Literal["array", "tiff", "custom"],
|
|
538
|
+
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
409
539
|
axes: str,
|
|
410
540
|
patch_size: list[int],
|
|
411
541
|
batch_size: int,
|
|
@@ -417,8 +547,13 @@ def create_care_configuration(
|
|
|
417
547
|
n_channels_out: Optional[int] = None,
|
|
418
548
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
419
549
|
model_params: Optional[dict] = None,
|
|
550
|
+
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
551
|
+
optimizer_params: Optional[dict[str, Any]] = None,
|
|
552
|
+
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
553
|
+
lr_scheduler_params: Optional[dict[str, Any]] = None,
|
|
420
554
|
train_dataloader_params: Optional[dict[str, Any]] = None,
|
|
421
555
|
val_dataloader_params: Optional[dict[str, Any]] = None,
|
|
556
|
+
checkpoint_params: Optional[dict[str, Any]] = None,
|
|
422
557
|
) -> Configuration:
|
|
423
558
|
"""
|
|
424
559
|
Create a configuration for training CARE.
|
|
@@ -445,7 +580,7 @@ def create_care_configuration(
|
|
|
445
580
|
----------
|
|
446
581
|
experiment_name : str
|
|
447
582
|
Name of the experiment.
|
|
448
|
-
data_type : Literal["array", "tiff", "custom"]
|
|
583
|
+
data_type : Literal["array", "tiff", "czi", "custom"]
|
|
449
584
|
Type of the data.
|
|
450
585
|
axes : str
|
|
451
586
|
Axes of the data (e.g. SYX).
|
|
@@ -471,6 +606,15 @@ def create_care_configuration(
|
|
|
471
606
|
Logger to use.
|
|
472
607
|
model_params : dict, default=None
|
|
473
608
|
UNetModel parameters.
|
|
609
|
+
optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
|
|
610
|
+
Optimizer to use.
|
|
611
|
+
optimizer_params : dict, default=None
|
|
612
|
+
Parameters for the optimizer, see PyTorch documentation for more details.
|
|
613
|
+
lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
|
|
614
|
+
Learning rate scheduler to use.
|
|
615
|
+
lr_scheduler_params : dict, default=None
|
|
616
|
+
Parameters for the learning rate scheduler, see PyTorch documentation for more
|
|
617
|
+
details.
|
|
474
618
|
train_dataloader_params : dict, optional
|
|
475
619
|
Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
|
|
476
620
|
If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
|
|
@@ -479,6 +623,9 @@ def create_care_configuration(
|
|
|
479
623
|
Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
|
|
480
624
|
If left as `None`, the empty dict `{}` will be used, this is set in the
|
|
481
625
|
`GeneralDataConfig`.
|
|
626
|
+
checkpoint_params : dict, default=None
|
|
627
|
+
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
628
|
+
(`ModelCheckpoint`) for the list of available parameters.
|
|
482
629
|
|
|
483
630
|
Returns
|
|
484
631
|
-------
|
|
@@ -551,6 +698,29 @@ def create_care_configuration(
|
|
|
551
698
|
... n_channels_in=3,
|
|
552
699
|
... n_channels_out=1 # if applicable
|
|
553
700
|
... )
|
|
701
|
+
|
|
702
|
+
If you would like to train on CZI files, use `"czi"` as `data_type` and `"SCYX"` as
|
|
703
|
+
`axes` for 2-D or `"SCZYX"` for 3-D denoising. Note that `"SCYX"` can also be used
|
|
704
|
+
for 3-D data but spatial context along the Z dimension will then not be taken into
|
|
705
|
+
account.
|
|
706
|
+
>>> config_2d = create_care_configuration(
|
|
707
|
+
... experiment_name="care_experiment",
|
|
708
|
+
... data_type="czi",
|
|
709
|
+
... axes="SCYX",
|
|
710
|
+
... patch_size=[64, 64],
|
|
711
|
+
... batch_size=32,
|
|
712
|
+
... num_epochs=100,
|
|
713
|
+
... n_channels_in=1,
|
|
714
|
+
... )
|
|
715
|
+
>>> config_3d = create_care_configuration(
|
|
716
|
+
... experiment_name="care_experiment",
|
|
717
|
+
... data_type="czi",
|
|
718
|
+
... axes="SCZYX",
|
|
719
|
+
... patch_size=[16, 64, 64],
|
|
720
|
+
... batch_size=16,
|
|
721
|
+
... num_epochs=100,
|
|
722
|
+
... n_channels_in=1,
|
|
723
|
+
... )
|
|
554
724
|
"""
|
|
555
725
|
return Configuration(
|
|
556
726
|
**_create_supervised_config_dict(
|
|
@@ -568,15 +738,20 @@ def create_care_configuration(
|
|
|
568
738
|
n_channels_out=n_channels_out,
|
|
569
739
|
logger=logger,
|
|
570
740
|
model_params=model_params,
|
|
741
|
+
optimizer=optimizer,
|
|
742
|
+
optimizer_params=optimizer_params,
|
|
743
|
+
lr_scheduler=lr_scheduler,
|
|
744
|
+
lr_scheduler_params=lr_scheduler_params,
|
|
571
745
|
train_dataloader_params=train_dataloader_params,
|
|
572
746
|
val_dataloader_params=val_dataloader_params,
|
|
747
|
+
checkpoint_params=checkpoint_params,
|
|
573
748
|
)
|
|
574
749
|
)
|
|
575
750
|
|
|
576
751
|
|
|
577
752
|
def create_n2n_configuration(
|
|
578
753
|
experiment_name: str,
|
|
579
|
-
data_type: Literal["array", "tiff", "custom"],
|
|
754
|
+
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
580
755
|
axes: str,
|
|
581
756
|
patch_size: list[int],
|
|
582
757
|
batch_size: int,
|
|
@@ -588,8 +763,13 @@ def create_n2n_configuration(
|
|
|
588
763
|
n_channels_out: Optional[int] = None,
|
|
589
764
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
590
765
|
model_params: Optional[dict] = None,
|
|
766
|
+
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
767
|
+
optimizer_params: Optional[dict[str, Any]] = None,
|
|
768
|
+
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
769
|
+
lr_scheduler_params: Optional[dict[str, Any]] = None,
|
|
591
770
|
train_dataloader_params: Optional[dict[str, Any]] = None,
|
|
592
771
|
val_dataloader_params: Optional[dict[str, Any]] = None,
|
|
772
|
+
checkpoint_params: Optional[dict[str, Any]] = None,
|
|
593
773
|
) -> Configuration:
|
|
594
774
|
"""
|
|
595
775
|
Create a configuration for training Noise2Noise.
|
|
@@ -616,7 +796,7 @@ def create_n2n_configuration(
|
|
|
616
796
|
----------
|
|
617
797
|
experiment_name : str
|
|
618
798
|
Name of the experiment.
|
|
619
|
-
data_type : Literal["array", "tiff", "custom"]
|
|
799
|
+
data_type : Literal["array", "tiff", "czi", "custom"]
|
|
620
800
|
Type of the data.
|
|
621
801
|
axes : str
|
|
622
802
|
Axes of the data (e.g. SYX).
|
|
@@ -640,8 +820,17 @@ def create_n2n_configuration(
|
|
|
640
820
|
Number of channels out.
|
|
641
821
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
642
822
|
Logger to use, by default "none".
|
|
643
|
-
model_params : dict,
|
|
644
|
-
UNetModel parameters
|
|
823
|
+
model_params : dict, default=None
|
|
824
|
+
UNetModel parameters.
|
|
825
|
+
optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
|
|
826
|
+
Optimizer to use.
|
|
827
|
+
optimizer_params : dict, default=None
|
|
828
|
+
Parameters for the optimizer, see PyTorch documentation for more details.
|
|
829
|
+
lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
|
|
830
|
+
Learning rate scheduler to use.
|
|
831
|
+
lr_scheduler_params : dict, default=None
|
|
832
|
+
Parameters for the learning rate scheduler, see PyTorch documentation for more
|
|
833
|
+
details.
|
|
645
834
|
train_dataloader_params : dict, optional
|
|
646
835
|
Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
|
|
647
836
|
If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
|
|
@@ -650,6 +839,9 @@ def create_n2n_configuration(
|
|
|
650
839
|
Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
|
|
651
840
|
If left as `None`, the empty dict `{}` will be used, this is set in the
|
|
652
841
|
`GeneralDataConfig`.
|
|
842
|
+
checkpoint_params : dict, default=None
|
|
843
|
+
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
844
|
+
(`ModelCheckpoint`) for the list of available parameters.
|
|
653
845
|
|
|
654
846
|
Returns
|
|
655
847
|
-------
|
|
@@ -722,6 +914,29 @@ def create_n2n_configuration(
|
|
|
722
914
|
... n_channels_in=3,
|
|
723
915
|
... n_channels_out=1 # if applicable
|
|
724
916
|
... )
|
|
917
|
+
|
|
918
|
+
If you would like to train on CZI files, use `"czi"` as `data_type` and `"SCYX"` as
|
|
919
|
+
`axes` for 2-D or `"SCZYX"` for 3-D denoising. Note that `"SCYX"` can also be used
|
|
920
|
+
for 3-D data but spatial context along the Z dimension will then not be taken into
|
|
921
|
+
account.
|
|
922
|
+
>>> config_2d = create_n2n_configuration(
|
|
923
|
+
... experiment_name="n2n_experiment",
|
|
924
|
+
... data_type="czi",
|
|
925
|
+
... axes="SCYX",
|
|
926
|
+
... patch_size=[64, 64],
|
|
927
|
+
... batch_size=32,
|
|
928
|
+
... num_epochs=100,
|
|
929
|
+
... n_channels_in=1,
|
|
930
|
+
... )
|
|
931
|
+
>>> config_3d = create_n2n_configuration(
|
|
932
|
+
... experiment_name="n2n_experiment",
|
|
933
|
+
... data_type="czi",
|
|
934
|
+
... axes="SCZYX",
|
|
935
|
+
... patch_size=[16, 64, 64],
|
|
936
|
+
... batch_size=16,
|
|
937
|
+
... num_epochs=100,
|
|
938
|
+
... n_channels_in=1,
|
|
939
|
+
... )
|
|
725
940
|
"""
|
|
726
941
|
return Configuration(
|
|
727
942
|
**_create_supervised_config_dict(
|
|
@@ -739,15 +954,20 @@ def create_n2n_configuration(
|
|
|
739
954
|
n_channels_out=n_channels_out,
|
|
740
955
|
logger=logger,
|
|
741
956
|
model_params=model_params,
|
|
957
|
+
optimizer=optimizer,
|
|
958
|
+
optimizer_params=optimizer_params,
|
|
959
|
+
lr_scheduler=lr_scheduler,
|
|
960
|
+
lr_scheduler_params=lr_scheduler_params,
|
|
742
961
|
train_dataloader_params=train_dataloader_params,
|
|
743
962
|
val_dataloader_params=val_dataloader_params,
|
|
963
|
+
checkpoint_params=checkpoint_params,
|
|
744
964
|
)
|
|
745
965
|
)
|
|
746
966
|
|
|
747
967
|
|
|
748
968
|
def create_n2v_configuration(
|
|
749
969
|
experiment_name: str,
|
|
750
|
-
data_type: Literal["array", "tiff", "custom"],
|
|
970
|
+
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
751
971
|
axes: str,
|
|
752
972
|
patch_size: list[int],
|
|
753
973
|
batch_size: int,
|
|
@@ -762,8 +982,13 @@ def create_n2v_configuration(
|
|
|
762
982
|
struct_n2v_span: int = 5,
|
|
763
983
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
764
984
|
model_params: Optional[dict] = None,
|
|
985
|
+
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
986
|
+
optimizer_params: Optional[dict[str, Any]] = None,
|
|
987
|
+
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
988
|
+
lr_scheduler_params: Optional[dict[str, Any]] = None,
|
|
765
989
|
train_dataloader_params: Optional[dict[str, Any]] = None,
|
|
766
990
|
val_dataloader_params: Optional[dict[str, Any]] = None,
|
|
991
|
+
checkpoint_params: Optional[dict[str, Any]] = None,
|
|
767
992
|
) -> Configuration:
|
|
768
993
|
"""
|
|
769
994
|
Create a configuration for training Noise2Void.
|
|
@@ -810,7 +1035,7 @@ def create_n2v_configuration(
|
|
|
810
1035
|
----------
|
|
811
1036
|
experiment_name : str
|
|
812
1037
|
Name of the experiment.
|
|
813
|
-
data_type : Literal["array", "tiff", "custom"]
|
|
1038
|
+
data_type : Literal["array", "tiff", "czi", "custom"]
|
|
814
1039
|
Type of the data.
|
|
815
1040
|
axes : str
|
|
816
1041
|
Axes of the data (e.g. SYX).
|
|
@@ -840,8 +1065,17 @@ def create_n2v_configuration(
|
|
|
840
1065
|
Span of the structN2V mask, by default 5.
|
|
841
1066
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
842
1067
|
Logger to use, by default "none".
|
|
843
|
-
model_params : dict,
|
|
844
|
-
UNetModel parameters
|
|
1068
|
+
model_params : dict, default=None
|
|
1069
|
+
UNetModel parameters.
|
|
1070
|
+
optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
|
|
1071
|
+
Optimizer to use.
|
|
1072
|
+
optimizer_params : dict, default=None
|
|
1073
|
+
Parameters for the optimizer, see PyTorch documentation for more details.
|
|
1074
|
+
lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
|
|
1075
|
+
Learning rate scheduler to use.
|
|
1076
|
+
lr_scheduler_params : dict, default=None
|
|
1077
|
+
Parameters for the learning rate scheduler, see PyTorch documentation for more
|
|
1078
|
+
details.
|
|
845
1079
|
train_dataloader_params : dict, optional
|
|
846
1080
|
Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
|
|
847
1081
|
If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
|
|
@@ -850,6 +1084,9 @@ def create_n2v_configuration(
|
|
|
850
1084
|
Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
|
|
851
1085
|
If left as `None`, the empty dict `{}` will be used, this is set in the
|
|
852
1086
|
`GeneralDataConfig`.
|
|
1087
|
+
checkpoint_params : dict, default=None
|
|
1088
|
+
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
1089
|
+
(`ModelCheckpoint`) for the list of available parameters.
|
|
853
1090
|
|
|
854
1091
|
Returns
|
|
855
1092
|
-------
|
|
@@ -942,6 +1179,29 @@ def create_n2v_configuration(
|
|
|
942
1179
|
... independent_channels=False,
|
|
943
1180
|
... n_channels=3
|
|
944
1181
|
... )
|
|
1182
|
+
|
|
1183
|
+
If you would like to train on CZI files, use `"czi"` as `data_type` and `"SCYX"` as
|
|
1184
|
+
`axes` for 2-D or `"SCZYX"` for 3-D denoising. Note that `"SCYX"` can also be used
|
|
1185
|
+
for 3-D data but spatial context along the Z dimension will then not be taken into
|
|
1186
|
+
account.
|
|
1187
|
+
>>> config_2d = create_n2v_configuration(
|
|
1188
|
+
... experiment_name="n2v_experiment",
|
|
1189
|
+
... data_type="czi",
|
|
1190
|
+
... axes="SCYX",
|
|
1191
|
+
... patch_size=[64, 64],
|
|
1192
|
+
... batch_size=32,
|
|
1193
|
+
... num_epochs=100,
|
|
1194
|
+
... n_channels=1,
|
|
1195
|
+
... )
|
|
1196
|
+
>>> config_3d = create_n2v_configuration(
|
|
1197
|
+
... experiment_name="n2v_experiment",
|
|
1198
|
+
... data_type="czi",
|
|
1199
|
+
... axes="SCZYX",
|
|
1200
|
+
... patch_size=[16, 64, 64],
|
|
1201
|
+
... batch_size=16,
|
|
1202
|
+
... num_epochs=100,
|
|
1203
|
+
... n_channels=1,
|
|
1204
|
+
... )
|
|
945
1205
|
"""
|
|
946
1206
|
# if there are channels, we need to specify their number
|
|
947
1207
|
if "C" in axes and n_channels is None:
|
|
@@ -982,6 +1242,10 @@ def create_n2v_configuration(
|
|
|
982
1242
|
n_channels_out=n_channels,
|
|
983
1243
|
use_n2v2=use_n2v2,
|
|
984
1244
|
model_params=model_params,
|
|
1245
|
+
optimizer=optimizer,
|
|
1246
|
+
optimizer_params=optimizer_params,
|
|
1247
|
+
lr_scheduler=lr_scheduler,
|
|
1248
|
+
lr_scheduler_params=lr_scheduler_params,
|
|
985
1249
|
)
|
|
986
1250
|
algorithm_params["n2v_config"] = n2v_transform
|
|
987
1251
|
|
|
@@ -1000,6 +1264,7 @@ def create_n2v_configuration(
|
|
|
1000
1264
|
training_params = _create_training_configuration(
|
|
1001
1265
|
num_epochs=num_epochs,
|
|
1002
1266
|
logger=logger,
|
|
1267
|
+
checkpoint_params=checkpoint_params,
|
|
1003
1268
|
)
|
|
1004
1269
|
|
|
1005
1270
|
return Configuration(
|
|
@@ -95,9 +95,9 @@ class DataConfig(BaseModel):
|
|
|
95
95
|
)
|
|
96
96
|
|
|
97
97
|
# Dataset configuration
|
|
98
|
-
data_type: Literal["array", "tiff", "custom"]
|
|
99
|
-
"""Type of input data, numpy.ndarray (array) or paths (tiff and custom), as
|
|
100
|
-
in SupportedData."""
|
|
98
|
+
data_type: Literal["array", "tiff", "czi", "custom"]
|
|
99
|
+
"""Type of input data, numpy.ndarray (array) or paths (tiff, czi, and custom), as
|
|
100
|
+
defined in SupportedData."""
|
|
101
101
|
|
|
102
102
|
axes: str
|
|
103
103
|
"""Axes of the data, as defined in SupportedAxes."""
|