careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +11 -14
- careamics/cli/conf.py +18 -3
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +51 -16
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +7 -3
- careamics/config/configuration.py +15 -63
- careamics/config/configuration_factories.py +853 -29
- careamics/config/data/data_model.py +50 -11
- careamics/config/data/ng_data_model.py +168 -4
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_model.py +16 -0
- careamics/config/data/patch_filter/mask_filter_model.py +17 -0
- careamics/config/data/patch_filter/max_filter_model.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
- careamics/config/inference_model.py +1 -2
- careamics/config/likelihood_model.py +2 -2
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +26 -1
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +6 -36
- careamics/config/transformations/normalize_model.py +1 -2
- careamics/dataset_ng/dataset.py +57 -5
- careamics/dataset_ng/factory.py +101 -18
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/dataset_ng/data_module.py +79 -2
- careamics/lightning/lightning_module.py +162 -61
- careamics/lightning/microsplit_data_module.py +636 -0
- careamics/lightning/predict_data_module.py +8 -1
- careamics/lightning/train_data_module.py +19 -8
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +85 -0
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/lvae_training/eval_utils.py +46 -24
- careamics/model_io/bmz_io.py +9 -5
- careamics/models/lvae/likelihoods.py +31 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/prediction_outputs.py +49 -3
- careamics/prediction_utils/stitch_prediction.py +83 -1
- careamics/transforms/xy_random_rotate90.py +1 -1
- careamics/utils/version.py +4 -4
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
|
@@ -5,9 +5,20 @@ from typing import Annotated, Any, Literal, Union
|
|
|
5
5
|
|
|
6
6
|
from pydantic import Field, TypeAdapter
|
|
7
7
|
|
|
8
|
-
from careamics.config.algorithms import
|
|
9
|
-
|
|
8
|
+
from careamics.config.algorithms import (
|
|
9
|
+
CAREAlgorithm,
|
|
10
|
+
MicroSplitAlgorithm,
|
|
11
|
+
N2NAlgorithm,
|
|
12
|
+
N2VAlgorithm,
|
|
13
|
+
)
|
|
14
|
+
from careamics.config.architectures import LVAEModel, UNetModel
|
|
10
15
|
from careamics.config.data import DataConfig, NGDataConfig
|
|
16
|
+
from careamics.config.likelihood_model import (
|
|
17
|
+
GaussianLikelihoodConfig,
|
|
18
|
+
NMLikelihoodConfig,
|
|
19
|
+
)
|
|
20
|
+
from careamics.config.loss_model import LVAELossConfig
|
|
21
|
+
from careamics.config.nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
11
22
|
from careamics.config.support import (
|
|
12
23
|
SupportedArchitecture,
|
|
13
24
|
SupportedPixelManipulation,
|
|
@@ -20,6 +31,7 @@ from careamics.config.transformations import (
|
|
|
20
31
|
XYFlipModel,
|
|
21
32
|
XYRandomRotate90Model,
|
|
22
33
|
)
|
|
34
|
+
from careamics.lvae_training.dataset.config import MicroSplitDataConfig
|
|
23
35
|
|
|
24
36
|
from .configuration import Configuration
|
|
25
37
|
|
|
@@ -224,7 +236,7 @@ def _create_algorithm_configuration(
|
|
|
224
236
|
def _create_data_configuration(
|
|
225
237
|
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
226
238
|
axes: str,
|
|
227
|
-
patch_size:
|
|
239
|
+
patch_size: Sequence[int],
|
|
228
240
|
batch_size: int,
|
|
229
241
|
augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
|
|
230
242
|
train_dataloader_params: dict[str, Any] | None = None,
|
|
@@ -277,6 +289,70 @@ def _create_data_configuration(
|
|
|
277
289
|
return DataConfig(**data)
|
|
278
290
|
|
|
279
291
|
|
|
292
|
+
def _create_microsplit_data_configuration(
|
|
293
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
294
|
+
axes: str,
|
|
295
|
+
patch_size: Sequence[int],
|
|
296
|
+
grid_size: int,
|
|
297
|
+
multiscale_count: int,
|
|
298
|
+
batch_size: int,
|
|
299
|
+
augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
|
|
300
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
301
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
302
|
+
) -> DataConfig:
|
|
303
|
+
"""
|
|
304
|
+
Create a dictionary with the parameters of the data model.
|
|
305
|
+
|
|
306
|
+
Parameters
|
|
307
|
+
----------
|
|
308
|
+
data_type : {"array", "tiff", "czi", "custom"}
|
|
309
|
+
Type of the data.
|
|
310
|
+
axes : str
|
|
311
|
+
Axes of the data.
|
|
312
|
+
patch_size : list of int
|
|
313
|
+
Size of the patches along the spatial dimensions.
|
|
314
|
+
grid_size : int
|
|
315
|
+
Grid size for patch extraction.
|
|
316
|
+
multiscale_count : int
|
|
317
|
+
Number of LC scales.
|
|
318
|
+
batch_size : int
|
|
319
|
+
Batch size.
|
|
320
|
+
augmentations : list of transforms
|
|
321
|
+
List of transforms to apply.
|
|
322
|
+
train_dataloader_params : dict
|
|
323
|
+
Parameters for the training dataloader, see PyTorch notes, by default None.
|
|
324
|
+
val_dataloader_params : dict
|
|
325
|
+
Parameters for the validation dataloader, see PyTorch notes, by default None.
|
|
326
|
+
|
|
327
|
+
Returns
|
|
328
|
+
-------
|
|
329
|
+
DataConfig
|
|
330
|
+
Data model with the specified parameters.
|
|
331
|
+
"""
|
|
332
|
+
# data model
|
|
333
|
+
data = {
|
|
334
|
+
"data_type": data_type,
|
|
335
|
+
"axes": axes,
|
|
336
|
+
"image_size": patch_size,
|
|
337
|
+
"grid_size": grid_size,
|
|
338
|
+
"multiscale_lowres_count": multiscale_count,
|
|
339
|
+
"batch_size": batch_size,
|
|
340
|
+
"transforms": augmentations,
|
|
341
|
+
}
|
|
342
|
+
# Don't override defaults set in DataConfig class
|
|
343
|
+
if train_dataloader_params is not None:
|
|
344
|
+
# DataConfig enforces the presence of `shuffle` key in the dataloader parameters
|
|
345
|
+
if "shuffle" not in train_dataloader_params:
|
|
346
|
+
train_dataloader_params["shuffle"] = True
|
|
347
|
+
|
|
348
|
+
data["train_dataloader_params"] = train_dataloader_params
|
|
349
|
+
|
|
350
|
+
if val_dataloader_params is not None:
|
|
351
|
+
data["val_dataloader_params"] = val_dataloader_params
|
|
352
|
+
|
|
353
|
+
return MicroSplitDataConfig(**data)
|
|
354
|
+
|
|
355
|
+
|
|
280
356
|
def _create_ng_data_configuration(
|
|
281
357
|
data_type: Literal["array", "tiff", "custom"],
|
|
282
358
|
axes: str,
|
|
@@ -357,7 +433,7 @@ def _create_ng_data_configuration(
|
|
|
357
433
|
|
|
358
434
|
|
|
359
435
|
def _create_training_configuration(
|
|
360
|
-
|
|
436
|
+
trainer_params: dict,
|
|
361
437
|
logger: Literal["wandb", "tensorboard", "none"],
|
|
362
438
|
checkpoint_params: dict[str, Any] | None = None,
|
|
363
439
|
) -> TrainingConfig:
|
|
@@ -366,8 +442,8 @@ def _create_training_configuration(
|
|
|
366
442
|
|
|
367
443
|
Parameters
|
|
368
444
|
----------
|
|
369
|
-
|
|
370
|
-
|
|
445
|
+
trainer_params : dict
|
|
446
|
+
Parameters for Lightning Trainer class, see PyTorch Lightning documentation.
|
|
371
447
|
logger : {"wandb", "tensorboard", "none"}
|
|
372
448
|
Logger to use.
|
|
373
449
|
checkpoint_params : dict, default=None
|
|
@@ -380,7 +456,7 @@ def _create_training_configuration(
|
|
|
380
456
|
Training model with the specified parameters.
|
|
381
457
|
"""
|
|
382
458
|
return TrainingConfig(
|
|
383
|
-
|
|
459
|
+
lightning_trainer_config=trainer_params,
|
|
384
460
|
logger=None if logger == "none" else logger,
|
|
385
461
|
checkpoint_callback={} if checkpoint_params is None else checkpoint_params,
|
|
386
462
|
)
|
|
@@ -392,9 +468,9 @@ def _create_supervised_config_dict(
|
|
|
392
468
|
experiment_name: str,
|
|
393
469
|
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
394
470
|
axes: str,
|
|
395
|
-
patch_size:
|
|
471
|
+
patch_size: Sequence[int],
|
|
396
472
|
batch_size: int,
|
|
397
|
-
|
|
473
|
+
trainer_params: dict | None = None,
|
|
398
474
|
augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
|
|
399
475
|
independent_channels: bool = True,
|
|
400
476
|
loss: Literal["mae", "mse"] = "mae",
|
|
@@ -409,6 +485,8 @@ def _create_supervised_config_dict(
|
|
|
409
485
|
train_dataloader_params: dict[str, Any] | None = None,
|
|
410
486
|
val_dataloader_params: dict[str, Any] | None = None,
|
|
411
487
|
checkpoint_params: dict[str, Any] | None = None,
|
|
488
|
+
num_epochs: int | None = None,
|
|
489
|
+
num_steps: int | None = None,
|
|
412
490
|
) -> dict:
|
|
413
491
|
"""
|
|
414
492
|
Create a configuration for training CARE or Noise2Noise.
|
|
@@ -427,8 +505,8 @@ def _create_supervised_config_dict(
|
|
|
427
505
|
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
428
506
|
batch_size : int
|
|
429
507
|
Batch size.
|
|
430
|
-
|
|
431
|
-
|
|
508
|
+
trainer_params : dict
|
|
509
|
+
Parameters for the training configuration.
|
|
432
510
|
augmentations : list of transforms, default=None
|
|
433
511
|
List of transforms to apply, either both or one of XYFlipModel and
|
|
434
512
|
XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
|
|
@@ -461,6 +539,13 @@ def _create_supervised_config_dict(
|
|
|
461
539
|
checkpoint_params : dict, default=None
|
|
462
540
|
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
463
541
|
(`ModelCheckpoint`) for the list of available parameters.
|
|
542
|
+
num_epochs : int or None, default=None
|
|
543
|
+
Number of epochs to train for. If provided, this will be added to
|
|
544
|
+
trainer_params.
|
|
545
|
+
num_steps : int or None, default=None
|
|
546
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
547
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
548
|
+
documentation for more details.
|
|
464
549
|
|
|
465
550
|
Returns
|
|
466
551
|
-------
|
|
@@ -518,9 +603,18 @@ def _create_supervised_config_dict(
|
|
|
518
603
|
val_dataloader_params=val_dataloader_params,
|
|
519
604
|
)
|
|
520
605
|
|
|
606
|
+
# Handle trainer parameters with num_epochs and num_steps
|
|
607
|
+
final_trainer_params = {} if trainer_params is None else trainer_params.copy()
|
|
608
|
+
|
|
609
|
+
# Add num_epochs and num_steps if provided
|
|
610
|
+
if num_epochs is not None:
|
|
611
|
+
final_trainer_params["max_epochs"] = num_epochs
|
|
612
|
+
if num_steps is not None:
|
|
613
|
+
final_trainer_params["limit_train_batches"] = num_steps
|
|
614
|
+
|
|
521
615
|
# training
|
|
522
616
|
training_params = _create_training_configuration(
|
|
523
|
-
|
|
617
|
+
trainer_params=final_trainer_params,
|
|
524
618
|
logger=logger,
|
|
525
619
|
checkpoint_params=checkpoint_params,
|
|
526
620
|
)
|
|
@@ -537,15 +631,17 @@ def create_care_configuration(
|
|
|
537
631
|
experiment_name: str,
|
|
538
632
|
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
539
633
|
axes: str,
|
|
540
|
-
patch_size:
|
|
634
|
+
patch_size: Sequence[int],
|
|
541
635
|
batch_size: int,
|
|
542
|
-
num_epochs: int,
|
|
636
|
+
num_epochs: int = 100,
|
|
637
|
+
num_steps: int | None = None,
|
|
543
638
|
augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
|
|
544
639
|
independent_channels: bool = True,
|
|
545
640
|
loss: Literal["mae", "mse"] = "mae",
|
|
546
641
|
n_channels_in: int | None = None,
|
|
547
642
|
n_channels_out: int | None = None,
|
|
548
643
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
644
|
+
trainer_params: dict | None = None,
|
|
549
645
|
model_params: dict | None = None,
|
|
550
646
|
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
551
647
|
optimizer_params: dict[str, Any] | None = None,
|
|
@@ -588,8 +684,13 @@ def create_care_configuration(
|
|
|
588
684
|
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
589
685
|
batch_size : int
|
|
590
686
|
Batch size.
|
|
591
|
-
num_epochs : int
|
|
592
|
-
Number of epochs.
|
|
687
|
+
num_epochs : int, default=100
|
|
688
|
+
Number of epochs to train for. If provided, this will be added to
|
|
689
|
+
trainer_params.
|
|
690
|
+
num_steps : int, optional
|
|
691
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
692
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
693
|
+
documentation for more details.
|
|
593
694
|
augmentations : list of transforms, default=None
|
|
594
695
|
List of transforms to apply, either both or one of XYFlipModel and
|
|
595
696
|
XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
|
|
@@ -604,6 +705,8 @@ def create_care_configuration(
|
|
|
604
705
|
Number of channels out.
|
|
605
706
|
logger : Literal["wandb", "tensorboard", "none"], default="none"
|
|
606
707
|
Logger to use.
|
|
708
|
+
trainer_params : dict, optional
|
|
709
|
+
Parameters for the trainer class, see PyTorch Lightning documentation.
|
|
607
710
|
model_params : dict, default=None
|
|
608
711
|
UNetModel parameters.
|
|
609
712
|
optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
|
|
@@ -644,6 +747,16 @@ def create_care_configuration(
|
|
|
644
747
|
... num_epochs=100
|
|
645
748
|
... )
|
|
646
749
|
|
|
750
|
+
You can also limit the number of batches per epoch:
|
|
751
|
+
>>> config = create_care_configuration(
|
|
752
|
+
... experiment_name="care_experiment",
|
|
753
|
+
... data_type="array",
|
|
754
|
+
... axes="YX",
|
|
755
|
+
... patch_size=[64, 64],
|
|
756
|
+
... batch_size=32,
|
|
757
|
+
... num_steps=100 # limit to 100 batches per epoch
|
|
758
|
+
... )
|
|
759
|
+
|
|
647
760
|
To disable transforms, simply set `augmentations` to an empty list:
|
|
648
761
|
>>> config = create_care_configuration(
|
|
649
762
|
... experiment_name="care_experiment",
|
|
@@ -730,13 +843,13 @@ def create_care_configuration(
|
|
|
730
843
|
axes=axes,
|
|
731
844
|
patch_size=patch_size,
|
|
732
845
|
batch_size=batch_size,
|
|
733
|
-
num_epochs=num_epochs,
|
|
734
846
|
augmentations=augmentations,
|
|
735
847
|
independent_channels=independent_channels,
|
|
736
848
|
loss=loss,
|
|
737
849
|
n_channels_in=n_channels_in,
|
|
738
850
|
n_channels_out=n_channels_out,
|
|
739
851
|
logger=logger,
|
|
852
|
+
trainer_params=trainer_params,
|
|
740
853
|
model_params=model_params,
|
|
741
854
|
optimizer=optimizer,
|
|
742
855
|
optimizer_params=optimizer_params,
|
|
@@ -745,6 +858,8 @@ def create_care_configuration(
|
|
|
745
858
|
train_dataloader_params=train_dataloader_params,
|
|
746
859
|
val_dataloader_params=val_dataloader_params,
|
|
747
860
|
checkpoint_params=checkpoint_params,
|
|
861
|
+
num_epochs=num_epochs,
|
|
862
|
+
num_steps=num_steps,
|
|
748
863
|
)
|
|
749
864
|
)
|
|
750
865
|
|
|
@@ -753,15 +868,17 @@ def create_n2n_configuration(
|
|
|
753
868
|
experiment_name: str,
|
|
754
869
|
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
755
870
|
axes: str,
|
|
756
|
-
patch_size:
|
|
871
|
+
patch_size: Sequence[int],
|
|
757
872
|
batch_size: int,
|
|
758
|
-
num_epochs: int,
|
|
873
|
+
num_epochs: int = 100,
|
|
874
|
+
num_steps: int | None = None,
|
|
759
875
|
augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
|
|
760
876
|
independent_channels: bool = True,
|
|
761
877
|
loss: Literal["mae", "mse"] = "mae",
|
|
762
878
|
n_channels_in: int | None = None,
|
|
763
879
|
n_channels_out: int | None = None,
|
|
764
880
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
881
|
+
trainer_params: dict | None = None,
|
|
765
882
|
model_params: dict | None = None,
|
|
766
883
|
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
767
884
|
optimizer_params: dict[str, Any] | None = None,
|
|
@@ -804,8 +921,13 @@ def create_n2n_configuration(
|
|
|
804
921
|
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
805
922
|
batch_size : int
|
|
806
923
|
Batch size.
|
|
807
|
-
num_epochs : int
|
|
808
|
-
Number of epochs.
|
|
924
|
+
num_epochs : int, default=100
|
|
925
|
+
Number of epochs to train for. If provided, this will be added to
|
|
926
|
+
trainer_params.
|
|
927
|
+
num_steps : int, optional
|
|
928
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
929
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
930
|
+
documentation for more details.
|
|
809
931
|
augmentations : list of transforms, default=None
|
|
810
932
|
List of transforms to apply, either both or one of XYFlipModel and
|
|
811
933
|
XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
|
|
@@ -820,6 +942,8 @@ def create_n2n_configuration(
|
|
|
820
942
|
Number of channels out.
|
|
821
943
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
822
944
|
Logger to use, by default "none".
|
|
945
|
+
trainer_params : dict, optional
|
|
946
|
+
Parameters for the trainer class, see PyTorch Lightning documentation.
|
|
823
947
|
model_params : dict, default=None
|
|
824
948
|
UNetModel parameters.
|
|
825
949
|
optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
|
|
@@ -860,6 +984,16 @@ def create_n2n_configuration(
|
|
|
860
984
|
... num_epochs=100
|
|
861
985
|
... )
|
|
862
986
|
|
|
987
|
+
You can also limit the number of batches per epoch:
|
|
988
|
+
>>> config = create_n2n_configuration(
|
|
989
|
+
... experiment_name="n2n_experiment",
|
|
990
|
+
... data_type="array",
|
|
991
|
+
... axes="YX",
|
|
992
|
+
... patch_size=[64, 64],
|
|
993
|
+
... batch_size=32,
|
|
994
|
+
... num_steps=100 # limit to 100 batches per epoch
|
|
995
|
+
... )
|
|
996
|
+
|
|
863
997
|
To disable transforms, simply set `augmentations` to an empty list:
|
|
864
998
|
>>> config = create_n2n_configuration(
|
|
865
999
|
... experiment_name="n2n_experiment",
|
|
@@ -871,8 +1005,7 @@ def create_n2n_configuration(
|
|
|
871
1005
|
... augmentations=[]
|
|
872
1006
|
... )
|
|
873
1007
|
|
|
874
|
-
A list of transforms can be passed to the `augmentations` parameter
|
|
875
|
-
default augmentations:
|
|
1008
|
+
A list of transforms can be passed to the `augmentations` parameter:
|
|
876
1009
|
>>> from careamics.config.transformations import XYFlipModel
|
|
877
1010
|
>>> config = create_n2n_configuration(
|
|
878
1011
|
... experiment_name="n2n_experiment",
|
|
@@ -946,7 +1079,7 @@ def create_n2n_configuration(
|
|
|
946
1079
|
axes=axes,
|
|
947
1080
|
patch_size=patch_size,
|
|
948
1081
|
batch_size=batch_size,
|
|
949
|
-
|
|
1082
|
+
trainer_params=trainer_params,
|
|
950
1083
|
augmentations=augmentations,
|
|
951
1084
|
independent_channels=independent_channels,
|
|
952
1085
|
loss=loss,
|
|
@@ -961,6 +1094,8 @@ def create_n2n_configuration(
|
|
|
961
1094
|
train_dataloader_params=train_dataloader_params,
|
|
962
1095
|
val_dataloader_params=val_dataloader_params,
|
|
963
1096
|
checkpoint_params=checkpoint_params,
|
|
1097
|
+
num_epochs=num_epochs,
|
|
1098
|
+
num_steps=num_steps,
|
|
964
1099
|
)
|
|
965
1100
|
)
|
|
966
1101
|
|
|
@@ -969,9 +1104,10 @@ def create_n2v_configuration(
|
|
|
969
1104
|
experiment_name: str,
|
|
970
1105
|
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
971
1106
|
axes: str,
|
|
972
|
-
patch_size:
|
|
1107
|
+
patch_size: Sequence[int],
|
|
973
1108
|
batch_size: int,
|
|
974
|
-
num_epochs: int,
|
|
1109
|
+
num_epochs: int = 100,
|
|
1110
|
+
num_steps: int | None = None,
|
|
975
1111
|
augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
|
|
976
1112
|
independent_channels: bool = True,
|
|
977
1113
|
use_n2v2: bool = False,
|
|
@@ -980,6 +1116,7 @@ def create_n2v_configuration(
|
|
|
980
1116
|
masked_pixel_percentage: float = 0.2,
|
|
981
1117
|
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
982
1118
|
struct_n2v_span: int = 5,
|
|
1119
|
+
trainer_params: dict | None = None,
|
|
983
1120
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
984
1121
|
model_params: dict | None = None,
|
|
985
1122
|
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
@@ -1043,8 +1180,13 @@ def create_n2v_configuration(
|
|
|
1043
1180
|
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
1044
1181
|
batch_size : int
|
|
1045
1182
|
Batch size.
|
|
1046
|
-
num_epochs : int
|
|
1047
|
-
Number of epochs.
|
|
1183
|
+
num_epochs : int, default=100
|
|
1184
|
+
Number of epochs to train for. If provided, this will be added to
|
|
1185
|
+
trainer_params.
|
|
1186
|
+
num_steps : int, optional
|
|
1187
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
1188
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
1189
|
+
documentation for more details.
|
|
1048
1190
|
augmentations : list of transforms, default=None
|
|
1049
1191
|
List of transforms to apply, either both or one of XYFlipModel and
|
|
1050
1192
|
XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
|
|
@@ -1063,6 +1205,8 @@ def create_n2v_configuration(
|
|
|
1063
1205
|
Axis along which to apply structN2V mask, by default "none".
|
|
1064
1206
|
struct_n2v_span : int, optional
|
|
1065
1207
|
Span of the structN2V mask, by default 5.
|
|
1208
|
+
trainer_params : dict, optional
|
|
1209
|
+
Parameters for the trainer, see the relevant documentation.
|
|
1066
1210
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
1067
1211
|
Logger to use, by default "none".
|
|
1068
1212
|
model_params : dict, default=None
|
|
@@ -1105,6 +1249,16 @@ def create_n2v_configuration(
|
|
|
1105
1249
|
... num_epochs=100
|
|
1106
1250
|
... )
|
|
1107
1251
|
|
|
1252
|
+
You can also limit the number of batches per epoch:
|
|
1253
|
+
>>> config = create_n2v_configuration(
|
|
1254
|
+
... experiment_name="n2v_experiment",
|
|
1255
|
+
... data_type="array",
|
|
1256
|
+
... axes="YX",
|
|
1257
|
+
... patch_size=[64, 64],
|
|
1258
|
+
... batch_size=32,
|
|
1259
|
+
... num_steps=100 # limit to 100 batches per epoch
|
|
1260
|
+
... )
|
|
1261
|
+
|
|
1108
1262
|
To disable transforms, simply set `augmentations` to an empty list:
|
|
1109
1263
|
>>> config = create_n2v_configuration(
|
|
1110
1264
|
... experiment_name="n2v_experiment",
|
|
@@ -1261,8 +1415,17 @@ def create_n2v_configuration(
|
|
|
1261
1415
|
)
|
|
1262
1416
|
|
|
1263
1417
|
# training
|
|
1418
|
+
# Handle trainer parameters with num_epochs and nun_steps
|
|
1419
|
+
final_trainer_params = {} if trainer_params is None else trainer_params.copy()
|
|
1420
|
+
|
|
1421
|
+
# Add num_epochs and nun_steps if provided
|
|
1422
|
+
if num_epochs is not None:
|
|
1423
|
+
final_trainer_params["max_epochs"] = num_epochs
|
|
1424
|
+
if num_steps is not None:
|
|
1425
|
+
final_trainer_params["limit_train_batches"] = num_steps
|
|
1426
|
+
|
|
1264
1427
|
training_params = _create_training_configuration(
|
|
1265
|
-
|
|
1428
|
+
trainer_params=final_trainer_params,
|
|
1266
1429
|
logger=logger,
|
|
1267
1430
|
checkpoint_params=checkpoint_params,
|
|
1268
1431
|
)
|
|
@@ -1273,3 +1436,664 @@ def create_n2v_configuration(
|
|
|
1273
1436
|
data_config=data_params,
|
|
1274
1437
|
training_config=training_params,
|
|
1275
1438
|
)
|
|
1439
|
+
|
|
1440
|
+
|
|
1441
|
+
def _create_vae_configuration(
|
|
1442
|
+
input_shape: Sequence[int],
|
|
1443
|
+
encoder_conv_strides: tuple[int, ...],
|
|
1444
|
+
decoder_conv_strides: tuple[int, ...],
|
|
1445
|
+
multiscale_count: int,
|
|
1446
|
+
z_dims: tuple[int, ...],
|
|
1447
|
+
output_channels: int,
|
|
1448
|
+
encoder_n_filters: int,
|
|
1449
|
+
decoder_n_filters: int,
|
|
1450
|
+
encoder_dropout: float,
|
|
1451
|
+
decoder_dropout: float,
|
|
1452
|
+
nonlinearity: Literal[
|
|
1453
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
1454
|
+
],
|
|
1455
|
+
predict_logvar: Literal[None, "pixelwise"],
|
|
1456
|
+
analytical_kl: bool,
|
|
1457
|
+
) -> LVAEModel:
|
|
1458
|
+
"""Create a dictionary with the parameters of the vae based algorithm model.
|
|
1459
|
+
|
|
1460
|
+
Parameters
|
|
1461
|
+
----------
|
|
1462
|
+
input_shape : tuple[int, ...]
|
|
1463
|
+
Shape of the input patch (Z, Y, X) or (Y, X) if the data is 2D.
|
|
1464
|
+
encoder_conv_strides : tuple[int, ...]
|
|
1465
|
+
Strides of the encoder convolutional layers, length also defines 2D or 3D.
|
|
1466
|
+
decoder_conv_strides : tuple[int, ...]
|
|
1467
|
+
Strides of the decoder convolutional layers, length also defines 2D or 3D.
|
|
1468
|
+
multiscale_count : int
|
|
1469
|
+
Number of lateral context layers, specific to MicroSplit.
|
|
1470
|
+
z_dims : tuple[int, ...]
|
|
1471
|
+
Number of hierarchies in the LVAE model.
|
|
1472
|
+
output_channels : int
|
|
1473
|
+
Number of output channels.
|
|
1474
|
+
encoder_n_filters : int
|
|
1475
|
+
Number of filters in the convolutional layers of the encoder.
|
|
1476
|
+
decoder_n_filters : int
|
|
1477
|
+
Number of filters in the convolutional layers of the decoder.
|
|
1478
|
+
encoder_dropout : float
|
|
1479
|
+
Dropout rate for the encoder.
|
|
1480
|
+
decoder_dropout : float
|
|
1481
|
+
Dropout rate for the decoder.
|
|
1482
|
+
nonlinearity : Literal
|
|
1483
|
+
Type of nonlinearity function to use.
|
|
1484
|
+
predict_logvar : Literal # TODO needs review
|
|
1485
|
+
_description_.
|
|
1486
|
+
analytical_kl : bool # TODO needs clarification
|
|
1487
|
+
_description_.
|
|
1488
|
+
|
|
1489
|
+
Returns
|
|
1490
|
+
-------
|
|
1491
|
+
LVAEModel
|
|
1492
|
+
LVAE model with the specified parameters.
|
|
1493
|
+
"""
|
|
1494
|
+
return LVAEModel(
|
|
1495
|
+
architecture=SupportedArchitecture.LVAE.value,
|
|
1496
|
+
input_shape=input_shape,
|
|
1497
|
+
encoder_conv_strides=encoder_conv_strides,
|
|
1498
|
+
decoder_conv_strides=decoder_conv_strides,
|
|
1499
|
+
multiscale_count=multiscale_count,
|
|
1500
|
+
z_dims=z_dims,
|
|
1501
|
+
output_channels=output_channels,
|
|
1502
|
+
encoder_n_filters=encoder_n_filters,
|
|
1503
|
+
decoder_n_filters=decoder_n_filters,
|
|
1504
|
+
encoder_dropout=encoder_dropout,
|
|
1505
|
+
decoder_dropout=decoder_dropout,
|
|
1506
|
+
nonlinearity=nonlinearity,
|
|
1507
|
+
predict_logvar=predict_logvar,
|
|
1508
|
+
analytical_kl=analytical_kl,
|
|
1509
|
+
)
|
|
1510
|
+
|
|
1511
|
+
|
|
1512
|
+
def _create_vae_based_algorithm(
|
|
1513
|
+
algorithm: Literal["hdn", "microsplit"],
|
|
1514
|
+
loss: LVAELossConfig,
|
|
1515
|
+
input_shape: Sequence[int],
|
|
1516
|
+
encoder_conv_strides: tuple[int, ...],
|
|
1517
|
+
decoder_conv_strides: tuple[int, ...],
|
|
1518
|
+
multiscale_count: int,
|
|
1519
|
+
z_dims: tuple[int, ...],
|
|
1520
|
+
output_channels: int,
|
|
1521
|
+
encoder_n_filters: int,
|
|
1522
|
+
decoder_n_filters: int,
|
|
1523
|
+
encoder_dropout: float,
|
|
1524
|
+
decoder_dropout: float,
|
|
1525
|
+
nonlinearity: Literal[
|
|
1526
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
1527
|
+
],
|
|
1528
|
+
predict_logvar: Literal[None, "pixelwise"],
|
|
1529
|
+
analytical_kl: bool,
|
|
1530
|
+
gaussian_likelihood: GaussianLikelihoodConfig | None = None,
|
|
1531
|
+
nm_likelihood: NMLikelihoodConfig | None = None,
|
|
1532
|
+
) -> dict:
|
|
1533
|
+
"""
|
|
1534
|
+
Create a dictionary with the parameters of the VAE-based algorithm model.
|
|
1535
|
+
|
|
1536
|
+
Parameters
|
|
1537
|
+
----------
|
|
1538
|
+
algorithm : Literal["hdn"]
|
|
1539
|
+
The algorithm type.
|
|
1540
|
+
loss : Literal["hdn"]
|
|
1541
|
+
The loss function type.
|
|
1542
|
+
input_shape : tuple[int, ...]
|
|
1543
|
+
The shape of the input data.
|
|
1544
|
+
encoder_conv_strides : list[int]
|
|
1545
|
+
The strides of the encoder convolutional layers.
|
|
1546
|
+
decoder_conv_strides : list[int]
|
|
1547
|
+
The strides of the decoder convolutional layers.
|
|
1548
|
+
multiscale_count : int
|
|
1549
|
+
The number of multiscale layers.
|
|
1550
|
+
z_dims : list[int]
|
|
1551
|
+
The dimensions of the latent space.
|
|
1552
|
+
output_channels : int
|
|
1553
|
+
The number of output channels.
|
|
1554
|
+
encoder_n_filters : int
|
|
1555
|
+
The number of filters in the encoder.
|
|
1556
|
+
decoder_n_filters : int
|
|
1557
|
+
The number of filters in the decoder.
|
|
1558
|
+
encoder_dropout : float
|
|
1559
|
+
The dropout rate for the encoder.
|
|
1560
|
+
decoder_dropout : float
|
|
1561
|
+
The dropout rate for the decoder.
|
|
1562
|
+
nonlinearity : Literal
|
|
1563
|
+
The nonlinearity function to use.
|
|
1564
|
+
predict_logvar : Literal[None, "pixelwise"]
|
|
1565
|
+
The type of log variance prediction.
|
|
1566
|
+
analytical_kl : bool
|
|
1567
|
+
Whether to use analytical KL divergence.
|
|
1568
|
+
gaussian_likelihood : Optional[GaussianLikelihoodConfig], optional
|
|
1569
|
+
The Gaussian likelihood model, by default None.
|
|
1570
|
+
nm_likelihood : Optional[NMLikelihoodConfig], optional
|
|
1571
|
+
The noise model likelihood model, by default None.
|
|
1572
|
+
|
|
1573
|
+
Returns
|
|
1574
|
+
-------
|
|
1575
|
+
dict
|
|
1576
|
+
A dictionary with the parameters of the VAE-based algorithm model.
|
|
1577
|
+
"""
|
|
1578
|
+
network_model = _create_vae_configuration(
|
|
1579
|
+
input_shape=input_shape,
|
|
1580
|
+
encoder_conv_strides=encoder_conv_strides,
|
|
1581
|
+
decoder_conv_strides=decoder_conv_strides,
|
|
1582
|
+
multiscale_count=multiscale_count,
|
|
1583
|
+
z_dims=z_dims,
|
|
1584
|
+
output_channels=output_channels,
|
|
1585
|
+
encoder_n_filters=encoder_n_filters,
|
|
1586
|
+
decoder_n_filters=decoder_n_filters,
|
|
1587
|
+
encoder_dropout=encoder_dropout,
|
|
1588
|
+
decoder_dropout=decoder_dropout,
|
|
1589
|
+
nonlinearity=nonlinearity,
|
|
1590
|
+
predict_logvar=predict_logvar,
|
|
1591
|
+
analytical_kl=analytical_kl,
|
|
1592
|
+
)
|
|
1593
|
+
assert gaussian_likelihood or nm_likelihood, "Likelihood model must be specified"
|
|
1594
|
+
return {
|
|
1595
|
+
"algorithm": algorithm,
|
|
1596
|
+
"loss": loss,
|
|
1597
|
+
"model": network_model,
|
|
1598
|
+
"gaussian_likelihood": gaussian_likelihood,
|
|
1599
|
+
"noise_model_likelihood": nm_likelihood,
|
|
1600
|
+
}
|
|
1601
|
+
|
|
1602
|
+
|
|
1603
|
+
def get_likelihood_config(
|
|
1604
|
+
loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"],
|
|
1605
|
+
# TODO remove different microsplit loss types, refac
|
|
1606
|
+
predict_logvar: Literal["pixelwise"] | None = None,
|
|
1607
|
+
logvar_lowerbound: float | None = -5.0,
|
|
1608
|
+
nm_paths: list[str] | None = None,
|
|
1609
|
+
data_stats: tuple[float, float] | None = None,
|
|
1610
|
+
) -> tuple[
|
|
1611
|
+
GaussianLikelihoodConfig | None,
|
|
1612
|
+
MultiChannelNMConfig | None,
|
|
1613
|
+
NMLikelihoodConfig | None,
|
|
1614
|
+
]:
|
|
1615
|
+
"""Get the likelihood configuration for split models.
|
|
1616
|
+
|
|
1617
|
+
Returns a tuple containing the following optional entries:
|
|
1618
|
+
- GaussianLikelihoodConfig: Gaussian likelihood configuration for musplit losses
|
|
1619
|
+
- MultiChannelNMConfig: Multi-channel noise model configuration for denoisplit
|
|
1620
|
+
losses
|
|
1621
|
+
- NMLikelihoodConfig: Noise model likelihood configuration for denoisplit losses
|
|
1622
|
+
|
|
1623
|
+
Parameters
|
|
1624
|
+
----------
|
|
1625
|
+
loss_type : Literal["musplit", "denoisplit", "denoisplit_musplit"]
|
|
1626
|
+
The type of loss function to use.
|
|
1627
|
+
predict_logvar : Literal["pixelwise"] | None, optional
|
|
1628
|
+
Type of log variance prediction, by default None.
|
|
1629
|
+
Required when loss_type is "musplit" or "denoisplit_musplit".
|
|
1630
|
+
logvar_lowerbound : float | None, optional
|
|
1631
|
+
Lower bound for the log variance, by default -5.0.
|
|
1632
|
+
Used when loss_type is "musplit" or "denoisplit_musplit".
|
|
1633
|
+
nm_paths : list[str] | None, optional
|
|
1634
|
+
Paths to the noise model files, by default None.
|
|
1635
|
+
Required when loss_type is "denoisplit" or "denoisplit_musplit".
|
|
1636
|
+
data_stats : tuple[float, float] | None, optional
|
|
1637
|
+
Data statistics (mean, std), by default None.
|
|
1638
|
+
Required when loss_type is "denoisplit" or "denoisplit_musplit".
|
|
1639
|
+
|
|
1640
|
+
Returns
|
|
1641
|
+
-------
|
|
1642
|
+
GaussianLikelihoodConfig or None
|
|
1643
|
+
Configuration for the Gaussian likelihood model.
|
|
1644
|
+
MultiChannelNMConfig or None
|
|
1645
|
+
Configuration for the multi-channel noise model.
|
|
1646
|
+
NMLikelihoodConfig or None
|
|
1647
|
+
Configuration for the noise model likelihood.
|
|
1648
|
+
|
|
1649
|
+
Raises
|
|
1650
|
+
------
|
|
1651
|
+
ValueError
|
|
1652
|
+
If required parameters are missing for the specified loss_type.
|
|
1653
|
+
"""
|
|
1654
|
+
# gaussian likelihood
|
|
1655
|
+
if loss_type in ["musplit", "denoisplit_musplit"]:
|
|
1656
|
+
# if predict_logvar is None:
|
|
1657
|
+
# raise ValueError(f"predict_logvar is required for loss_type '{loss_type}'")
|
|
1658
|
+
# TODO validators should be in pydantic models
|
|
1659
|
+
gaussian_lik_config = GaussianLikelihoodConfig(
|
|
1660
|
+
predict_logvar=predict_logvar,
|
|
1661
|
+
logvar_lowerbound=logvar_lowerbound,
|
|
1662
|
+
)
|
|
1663
|
+
else:
|
|
1664
|
+
gaussian_lik_config = None
|
|
1665
|
+
|
|
1666
|
+
# noise model likelihood
|
|
1667
|
+
if loss_type in ["denoisplit", "denoisplit_musplit"]:
|
|
1668
|
+
# if nm_paths is None:
|
|
1669
|
+
# raise ValueError(f"nm_paths is required for loss_type '{loss_type}'")
|
|
1670
|
+
# if data_stats is None:
|
|
1671
|
+
# raise ValueError(f"data_stats is required for loss_type '{loss_type}'")
|
|
1672
|
+
# TODO validators should be in pydantic models
|
|
1673
|
+
gmm_list = []
|
|
1674
|
+
if nm_paths is not None:
|
|
1675
|
+
for NM_path in nm_paths:
|
|
1676
|
+
gmm_list.append(
|
|
1677
|
+
GaussianMixtureNMConfig(
|
|
1678
|
+
model_type="GaussianMixtureNoiseModel",
|
|
1679
|
+
path=NM_path,
|
|
1680
|
+
)
|
|
1681
|
+
)
|
|
1682
|
+
noise_model_config = MultiChannelNMConfig(noise_models=gmm_list)
|
|
1683
|
+
nm_lik_config = NMLikelihoodConfig() # TODO this config isn't needed probably
|
|
1684
|
+
else:
|
|
1685
|
+
noise_model_config = None
|
|
1686
|
+
nm_lik_config = None
|
|
1687
|
+
|
|
1688
|
+
return gaussian_lik_config, noise_model_config, nm_lik_config
|
|
1689
|
+
|
|
1690
|
+
|
|
1691
|
+
# TODO wrap parameters into model, loss etc
|
|
1692
|
+
# TODO refac likelihood configs to make it 1. Can it be done ?
|
|
1693
|
+
def create_hdn_configuration(
|
|
1694
|
+
experiment_name: str,
|
|
1695
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
1696
|
+
axes: str,
|
|
1697
|
+
patch_size: Sequence[int],
|
|
1698
|
+
batch_size: int,
|
|
1699
|
+
num_epochs: int = 100,
|
|
1700
|
+
num_steps: int | None = None,
|
|
1701
|
+
encoder_conv_strides: tuple[int, ...] = (2, 2),
|
|
1702
|
+
decoder_conv_strides: tuple[int, ...] = (2, 2),
|
|
1703
|
+
multiscale_count: int = 1,
|
|
1704
|
+
z_dims: tuple[int, ...] = (128, 128),
|
|
1705
|
+
output_channels: int = 1,
|
|
1706
|
+
encoder_n_filters: int = 32,
|
|
1707
|
+
decoder_n_filters: int = 32,
|
|
1708
|
+
encoder_dropout: float = 0.0,
|
|
1709
|
+
decoder_dropout: float = 0.0,
|
|
1710
|
+
nonlinearity: Literal[
|
|
1711
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
1712
|
+
] = "ReLU",
|
|
1713
|
+
analytical_kl: bool = False,
|
|
1714
|
+
predict_logvar: Literal["pixelwise"] | None = None,
|
|
1715
|
+
logvar_lowerbound: Union[float, None] = None,
|
|
1716
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
1717
|
+
trainer_params: dict | None = None,
|
|
1718
|
+
augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
|
|
1719
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
1720
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
1721
|
+
) -> Configuration:
|
|
1722
|
+
"""
|
|
1723
|
+
Create a configuration for training HDN.
|
|
1724
|
+
|
|
1725
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
1726
|
+
2.
|
|
1727
|
+
|
|
1728
|
+
If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
|
|
1729
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
1730
|
+
`axes`.
|
|
1731
|
+
|
|
1732
|
+
To set the number of output channels, use the `n_channels_out` parameter. If it is
|
|
1733
|
+
not specified, it will be assumed to be equal to `n_channels_in`.
|
|
1734
|
+
|
|
1735
|
+
By default, all channels are trained independently. To train all channels together,
|
|
1736
|
+
set `independent_channels` to False.
|
|
1737
|
+
|
|
1738
|
+
By setting `augmentations` to `None`, the default transformations (flip in X and Y,
|
|
1739
|
+
rotations by 90 degrees in the XY plane) are applied. Rather than the default
|
|
1740
|
+
transforms, a list of transforms can be passed to the `augmentations` parameter. To
|
|
1741
|
+
disable the transforms, simply pass an empty list.
|
|
1742
|
+
|
|
1743
|
+
# TODO revisit the necessity of model_params
|
|
1744
|
+
|
|
1745
|
+
Parameters
|
|
1746
|
+
----------
|
|
1747
|
+
experiment_name : str
|
|
1748
|
+
Name of the experiment.
|
|
1749
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
1750
|
+
Type of the data.
|
|
1751
|
+
axes : str
|
|
1752
|
+
Axes of the data (e.g. SYX).
|
|
1753
|
+
patch_size : List[int]
|
|
1754
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
1755
|
+
batch_size : int
|
|
1756
|
+
Batch size.
|
|
1757
|
+
num_epochs : int, default=100
|
|
1758
|
+
Number of epochs to train for. If provided, this will be added to
|
|
1759
|
+
trainer_params.
|
|
1760
|
+
num_steps : int, optional
|
|
1761
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
1762
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
1763
|
+
documentation for more details.
|
|
1764
|
+
encoder_conv_strides : tuple[int, ...], optional
|
|
1765
|
+
Strides for the encoder convolutional layers, by default (2, 2).
|
|
1766
|
+
decoder_conv_strides : tuple[int, ...], optional
|
|
1767
|
+
Strides for the decoder convolutional layers, by default (2, 2).
|
|
1768
|
+
multiscale_count : int, optional
|
|
1769
|
+
Number of scales in the multiscale architecture, by default 1.
|
|
1770
|
+
z_dims : tuple[int, ...], optional
|
|
1771
|
+
Dimensions of the latent space, by default (128, 128).
|
|
1772
|
+
output_channels : int, optional
|
|
1773
|
+
Number of output channels, by default 1.
|
|
1774
|
+
encoder_n_filters : int, optional
|
|
1775
|
+
Number of filters in the encoder, by default 32.
|
|
1776
|
+
decoder_n_filters : int, optional
|
|
1777
|
+
Number of filters in the decoder, by default 32.
|
|
1778
|
+
encoder_dropout : float, optional
|
|
1779
|
+
Dropout rate for the encoder, by default 0.0.
|
|
1780
|
+
decoder_dropout : float, optional
|
|
1781
|
+
Dropout rate for the decoder, by default 0.0.
|
|
1782
|
+
nonlinearity : Literal, optional
|
|
1783
|
+
Nonlinearity function to use, by default "ReLU".
|
|
1784
|
+
analytical_kl : bool, optional
|
|
1785
|
+
Whether to use analytical KL divergence, by default False.
|
|
1786
|
+
predict_logvar : Literal[None, "pixelwise"], optional
|
|
1787
|
+
Type of log variance prediction, by default None.
|
|
1788
|
+
logvar_lowerbound : Union[float, None], optional
|
|
1789
|
+
Lower bound for the log variance, by default None.
|
|
1790
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
1791
|
+
Logger to use for training, by default "none".
|
|
1792
|
+
trainer_params : dict, optional
|
|
1793
|
+
Parameters for the trainer class, see PyTorch Lightning documentation.
|
|
1794
|
+
augmentations : Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]], optional
|
|
1795
|
+
List of augmentations to apply, by default None.
|
|
1796
|
+
train_dataloader_params : Optional[dict[str, Any]], optional
|
|
1797
|
+
Parameters for the training dataloader, by default None.
|
|
1798
|
+
val_dataloader_params : Optional[dict[str, Any]], optional
|
|
1799
|
+
Parameters for the validation dataloader, by default None.
|
|
1800
|
+
|
|
1801
|
+
Returns
|
|
1802
|
+
-------
|
|
1803
|
+
Configuration
|
|
1804
|
+
The configuration object for training HDN.
|
|
1805
|
+
|
|
1806
|
+
Examples
|
|
1807
|
+
--------
|
|
1808
|
+
Minimum example:
|
|
1809
|
+
>>> config = create_hdn_configuration(
|
|
1810
|
+
... experiment_name="hdn_experiment",
|
|
1811
|
+
... data_type="array",
|
|
1812
|
+
... axes="YX",
|
|
1813
|
+
... patch_size=[64, 64],
|
|
1814
|
+
... batch_size=32,
|
|
1815
|
+
... num_epochs=100
|
|
1816
|
+
... )
|
|
1817
|
+
|
|
1818
|
+
You can also limit the number of batches per epoch:
|
|
1819
|
+
>>> config = create_hdn_configuration(
|
|
1820
|
+
... experiment_name="hdn_experiment",
|
|
1821
|
+
... data_type="array",
|
|
1822
|
+
... axes="YX",
|
|
1823
|
+
... patch_size=[64, 64],
|
|
1824
|
+
... batch_size=32,
|
|
1825
|
+
... num_steps=100 # limit to 100 batches per epoch
|
|
1826
|
+
... )
|
|
1827
|
+
"""
|
|
1828
|
+
transform_list = _list_spatial_augmentations(augmentations)
|
|
1829
|
+
|
|
1830
|
+
loss_config = LVAELossConfig(
|
|
1831
|
+
loss_type="hdn", denoisplit_weight=1, musplit_weight=0
|
|
1832
|
+
) # TODO what are the correct defaults for HDN?
|
|
1833
|
+
|
|
1834
|
+
gaussian_likelihood = GaussianLikelihoodConfig(
|
|
1835
|
+
predict_logvar=predict_logvar, logvar_lowerbound=logvar_lowerbound
|
|
1836
|
+
)
|
|
1837
|
+
|
|
1838
|
+
# algorithm & model
|
|
1839
|
+
algorithm_params = _create_vae_based_algorithm(
|
|
1840
|
+
algorithm="hdn",
|
|
1841
|
+
loss=loss_config,
|
|
1842
|
+
input_shape=patch_size,
|
|
1843
|
+
encoder_conv_strides=encoder_conv_strides,
|
|
1844
|
+
decoder_conv_strides=decoder_conv_strides,
|
|
1845
|
+
multiscale_count=multiscale_count,
|
|
1846
|
+
z_dims=z_dims,
|
|
1847
|
+
output_channels=output_channels,
|
|
1848
|
+
encoder_n_filters=encoder_n_filters,
|
|
1849
|
+
decoder_n_filters=decoder_n_filters,
|
|
1850
|
+
encoder_dropout=encoder_dropout,
|
|
1851
|
+
decoder_dropout=decoder_dropout,
|
|
1852
|
+
nonlinearity=nonlinearity,
|
|
1853
|
+
predict_logvar=predict_logvar,
|
|
1854
|
+
analytical_kl=analytical_kl,
|
|
1855
|
+
gaussian_likelihood=gaussian_likelihood,
|
|
1856
|
+
nm_likelihood=None,
|
|
1857
|
+
)
|
|
1858
|
+
|
|
1859
|
+
# data
|
|
1860
|
+
data_params = _create_data_configuration(
|
|
1861
|
+
data_type=data_type,
|
|
1862
|
+
axes=axes,
|
|
1863
|
+
patch_size=patch_size,
|
|
1864
|
+
batch_size=batch_size,
|
|
1865
|
+
augmentations=transform_list,
|
|
1866
|
+
train_dataloader_params=train_dataloader_params,
|
|
1867
|
+
val_dataloader_params=val_dataloader_params,
|
|
1868
|
+
)
|
|
1869
|
+
|
|
1870
|
+
# Handle trainer parameters with num_epochs and num_steps
|
|
1871
|
+
final_trainer_params = {} if trainer_params is None else trainer_params.copy()
|
|
1872
|
+
|
|
1873
|
+
# Add num_epochs and num_steps if provided
|
|
1874
|
+
if num_epochs is not None:
|
|
1875
|
+
final_trainer_params["max_epochs"] = num_epochs
|
|
1876
|
+
if num_steps is not None:
|
|
1877
|
+
final_trainer_params["limit_train_batches"] = num_steps
|
|
1878
|
+
|
|
1879
|
+
# training
|
|
1880
|
+
training_params = _create_training_configuration(
|
|
1881
|
+
trainer_params=final_trainer_params,
|
|
1882
|
+
logger=logger,
|
|
1883
|
+
)
|
|
1884
|
+
|
|
1885
|
+
return Configuration(
|
|
1886
|
+
experiment_name=experiment_name,
|
|
1887
|
+
algorithm_config=algorithm_params,
|
|
1888
|
+
data_config=data_params,
|
|
1889
|
+
training_config=training_params,
|
|
1890
|
+
)
|
|
1891
|
+
|
|
1892
|
+
|
|
1893
|
+
def create_microsplit_configuration(
|
|
1894
|
+
experiment_name: str,
|
|
1895
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
1896
|
+
axes: str,
|
|
1897
|
+
patch_size: Sequence[int],
|
|
1898
|
+
batch_size: int,
|
|
1899
|
+
num_epochs: int = 100,
|
|
1900
|
+
num_steps: int | None = None,
|
|
1901
|
+
encoder_conv_strides: tuple[int, ...] = (2, 2),
|
|
1902
|
+
decoder_conv_strides: tuple[int, ...] = (2, 2),
|
|
1903
|
+
multiscale_count: int = 3,
|
|
1904
|
+
grid_size: int = 32, # TODO most likely can be derived from patch size
|
|
1905
|
+
z_dims: tuple[int, ...] = (128, 128),
|
|
1906
|
+
output_channels: int = 1,
|
|
1907
|
+
encoder_n_filters: int = 32,
|
|
1908
|
+
decoder_n_filters: int = 32,
|
|
1909
|
+
encoder_dropout: float = 0.0,
|
|
1910
|
+
decoder_dropout: float = 0.0,
|
|
1911
|
+
nonlinearity: Literal[
|
|
1912
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
1913
|
+
] = "ReLU", # TODO do we need all these?
|
|
1914
|
+
analytical_kl: bool = False,
|
|
1915
|
+
predict_logvar: Literal["pixelwise"] = "pixelwise",
|
|
1916
|
+
logvar_lowerbound: Union[float, None] = None,
|
|
1917
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
1918
|
+
trainer_params: dict | None = None,
|
|
1919
|
+
augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
|
|
1920
|
+
nm_paths: list[str] | None = None,
|
|
1921
|
+
data_stats: tuple[float, float] | None = None,
|
|
1922
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
1923
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
1924
|
+
) -> Configuration:
|
|
1925
|
+
"""
|
|
1926
|
+
Create a configuration for training MicroSplit.
|
|
1927
|
+
|
|
1928
|
+
Parameters
|
|
1929
|
+
----------
|
|
1930
|
+
experiment_name : str
|
|
1931
|
+
Name of the experiment.
|
|
1932
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
1933
|
+
Type of the data.
|
|
1934
|
+
axes : str
|
|
1935
|
+
Axes of the data (e.g. SYX).
|
|
1936
|
+
patch_size : Sequence[int]
|
|
1937
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
1938
|
+
batch_size : int
|
|
1939
|
+
Batch size.
|
|
1940
|
+
num_epochs : int, default=100
|
|
1941
|
+
Number of epochs to train for. If provided, this will be added to
|
|
1942
|
+
trainer_params.
|
|
1943
|
+
num_steps : int, optional
|
|
1944
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
1945
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
1946
|
+
documentation for more details.
|
|
1947
|
+
encoder_conv_strides : tuple[int, ...], optional
|
|
1948
|
+
Strides for the encoder convolutional layers, by default (2, 2).
|
|
1949
|
+
decoder_conv_strides : tuple[int, ...], optional
|
|
1950
|
+
Strides for the decoder convolutional layers, by default (2, 2).
|
|
1951
|
+
multiscale_count : int, optional
|
|
1952
|
+
Number of multiscale levels, by default 1.
|
|
1953
|
+
grid_size : int, optional
|
|
1954
|
+
Size of the grid for the lateral context, by default 32.
|
|
1955
|
+
z_dims : tuple[int, ...], optional
|
|
1956
|
+
List of latent dimensions for each hierarchy level in the LVAE, by default
|
|
1957
|
+
(128, 128).
|
|
1958
|
+
output_channels : int, optional
|
|
1959
|
+
Number of output channels for the model, by default 1.
|
|
1960
|
+
encoder_n_filters : int, optional
|
|
1961
|
+
Number of filters in the encoder, by default 32.
|
|
1962
|
+
decoder_n_filters : int, optional
|
|
1963
|
+
Number of filters in the decoder, by default 32.
|
|
1964
|
+
encoder_dropout : float, optional
|
|
1965
|
+
Dropout rate for the encoder, by default 0.0.
|
|
1966
|
+
decoder_dropout : float, optional
|
|
1967
|
+
Dropout rate for the decoder, by default 0.0.
|
|
1968
|
+
nonlinearity : Literal, optional
|
|
1969
|
+
Nonlinearity to use in the model, by default "ReLU".
|
|
1970
|
+
analytical_kl : bool, optional
|
|
1971
|
+
Whether to use analytical KL divergence, by default False.
|
|
1972
|
+
predict_logvar : Literal["pixelwise"] | None, optional
|
|
1973
|
+
Type of log-variance prediction, by default None.
|
|
1974
|
+
logvar_lowerbound : Union[float, None], optional
|
|
1975
|
+
Lower bound for the log variance, by default None.
|
|
1976
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
1977
|
+
Logger to use for training, by default "none".
|
|
1978
|
+
trainer_params : dict, optional
|
|
1979
|
+
Parameters for the trainer class, see PyTorch Lightning documentation.
|
|
1980
|
+
augmentations : list[Union[XYFlipModel, XYRandomRotate90Model]] | None, optional
|
|
1981
|
+
List of augmentations to apply, by default None.
|
|
1982
|
+
nm_paths : list[str] | None, optional
|
|
1983
|
+
Paths to the noise model files, by default None.
|
|
1984
|
+
data_stats : tuple[float, float] | None, optional
|
|
1985
|
+
Data statistics (mean, std), by default None.
|
|
1986
|
+
train_dataloader_params : dict[str, Any] | None, optional
|
|
1987
|
+
Parameters for the training dataloader, by default None.
|
|
1988
|
+
val_dataloader_params : dict[str, Any] | None, optional
|
|
1989
|
+
Parameters for the validation dataloader, by default None.
|
|
1990
|
+
|
|
1991
|
+
Returns
|
|
1992
|
+
-------
|
|
1993
|
+
Configuration
|
|
1994
|
+
A configuration object for the microsplit algorithm.
|
|
1995
|
+
|
|
1996
|
+
Examples
|
|
1997
|
+
--------
|
|
1998
|
+
Minimum example:
|
|
1999
|
+
# >>> config = create_microsplit_configuration(
|
|
2000
|
+
# ... experiment_name="microsplit_experiment",
|
|
2001
|
+
# ... data_type="array",
|
|
2002
|
+
# ... axes="YX",
|
|
2003
|
+
# ... patch_size=[64, 64],
|
|
2004
|
+
# ... batch_size=32,
|
|
2005
|
+
# ... num_epochs=100
|
|
2006
|
+
|
|
2007
|
+
# ... )
|
|
2008
|
+
|
|
2009
|
+
# You can also limit the number of batches per epoch:
|
|
2010
|
+
# >>> config = create_microsplit_configuration(
|
|
2011
|
+
# ... experiment_name="microsplit_experiment",
|
|
2012
|
+
# ... data_type="array",
|
|
2013
|
+
# ... axes="YX",
|
|
2014
|
+
# ... patch_size=[64, 64],
|
|
2015
|
+
# ... batch_size=32,
|
|
2016
|
+
# ... num_steps=100 # limit to 100 batches per epoch
|
|
2017
|
+
# ... )
|
|
2018
|
+
"""
|
|
2019
|
+
transform_list = _list_spatial_augmentations(augmentations)
|
|
2020
|
+
|
|
2021
|
+
loss_config = LVAELossConfig(
|
|
2022
|
+
loss_type="denoisplit_musplit", denoisplit_weight=0.9, musplit_weight=0.1
|
|
2023
|
+
) # TODO losses need to be refactored! just for example. Add validator if sum to 1
|
|
2024
|
+
|
|
2025
|
+
# Create likelihood configurations
|
|
2026
|
+
gaussian_likelihood_config, noise_model_config, nm_likelihood_config = (
|
|
2027
|
+
get_likelihood_config(
|
|
2028
|
+
loss_type="denoisplit_musplit",
|
|
2029
|
+
predict_logvar=predict_logvar,
|
|
2030
|
+
logvar_lowerbound=logvar_lowerbound,
|
|
2031
|
+
nm_paths=nm_paths,
|
|
2032
|
+
data_stats=data_stats,
|
|
2033
|
+
)
|
|
2034
|
+
)
|
|
2035
|
+
|
|
2036
|
+
# Create the LVAE model
|
|
2037
|
+
network_model = _create_vae_configuration(
|
|
2038
|
+
input_shape=patch_size,
|
|
2039
|
+
encoder_conv_strides=encoder_conv_strides,
|
|
2040
|
+
decoder_conv_strides=decoder_conv_strides,
|
|
2041
|
+
multiscale_count=multiscale_count,
|
|
2042
|
+
z_dims=z_dims,
|
|
2043
|
+
output_channels=output_channels,
|
|
2044
|
+
encoder_n_filters=encoder_n_filters,
|
|
2045
|
+
decoder_n_filters=decoder_n_filters,
|
|
2046
|
+
encoder_dropout=encoder_dropout,
|
|
2047
|
+
decoder_dropout=decoder_dropout,
|
|
2048
|
+
nonlinearity=nonlinearity,
|
|
2049
|
+
predict_logvar=predict_logvar,
|
|
2050
|
+
analytical_kl=analytical_kl,
|
|
2051
|
+
)
|
|
2052
|
+
|
|
2053
|
+
# Create the MicroSplit algorithm configuration
|
|
2054
|
+
algorithm_params = {
|
|
2055
|
+
"algorithm": "microsplit",
|
|
2056
|
+
"loss": loss_config,
|
|
2057
|
+
"model": network_model,
|
|
2058
|
+
"gaussian_likelihood": gaussian_likelihood_config,
|
|
2059
|
+
"noise_model": noise_model_config,
|
|
2060
|
+
"noise_model_likelihood": nm_likelihood_config,
|
|
2061
|
+
}
|
|
2062
|
+
|
|
2063
|
+
# Convert to MicroSplitAlgorithm instance
|
|
2064
|
+
algorithm_config = MicroSplitAlgorithm(**algorithm_params)
|
|
2065
|
+
|
|
2066
|
+
# data
|
|
2067
|
+
data_params = _create_microsplit_data_configuration(
|
|
2068
|
+
data_type=data_type,
|
|
2069
|
+
axes=axes,
|
|
2070
|
+
patch_size=patch_size,
|
|
2071
|
+
grid_size=grid_size,
|
|
2072
|
+
multiscale_count=multiscale_count,
|
|
2073
|
+
batch_size=batch_size,
|
|
2074
|
+
augmentations=transform_list,
|
|
2075
|
+
train_dataloader_params=train_dataloader_params,
|
|
2076
|
+
val_dataloader_params=val_dataloader_params,
|
|
2077
|
+
)
|
|
2078
|
+
|
|
2079
|
+
# Handle trainer parameters with num_epochs and num_steps
|
|
2080
|
+
final_trainer_params = {} if trainer_params is None else trainer_params.copy()
|
|
2081
|
+
|
|
2082
|
+
# Add num_epochs and num_steps if provided
|
|
2083
|
+
if num_epochs is not None:
|
|
2084
|
+
final_trainer_params["max_epochs"] = num_epochs
|
|
2085
|
+
if num_steps is not None:
|
|
2086
|
+
final_trainer_params["limit_train_batches"] = num_steps
|
|
2087
|
+
|
|
2088
|
+
# training
|
|
2089
|
+
training_params = _create_training_configuration(
|
|
2090
|
+
trainer_params=final_trainer_params,
|
|
2091
|
+
logger=logger,
|
|
2092
|
+
)
|
|
2093
|
+
|
|
2094
|
+
return Configuration(
|
|
2095
|
+
experiment_name=experiment_name,
|
|
2096
|
+
algorithm_config=algorithm_config,
|
|
2097
|
+
data_config=data_params,
|
|
2098
|
+
training_config=training_params,
|
|
2099
|
+
)
|