careamics 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +6 -12
- careamics/cli/conf.py +18 -3
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +51 -16
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +7 -3
- careamics/config/configuration.py +9 -8
- careamics/config/configuration_factories.py +843 -29
- careamics/config/data/data_model.py +1 -2
- careamics/config/data/ng_data_model.py +1 -2
- careamics/config/inference_model.py +1 -2
- careamics/config/likelihood_model.py +2 -2
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +26 -1
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +6 -36
- careamics/config/transformations/normalize_model.py +1 -2
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +23 -0
- careamics/lightning/lightning_module.py +161 -61
- careamics/lightning/microsplit_data_module.py +631 -0
- careamics/lightning/predict_data_module.py +8 -1
- careamics/lightning/train_data_module.py +19 -8
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +85 -0
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/model_io/bmz_io.py +9 -5
- careamics/models/lvae/likelihoods.py +30 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/prediction_outputs.py +48 -3
- careamics/prediction_utils/stitch_prediction.py +71 -0
- careamics/transforms/xy_random_rotate90.py +1 -1
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -15
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/RECORD +57 -55
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
|
@@ -5,9 +5,20 @@ from typing import Annotated, Any, Literal, Union
|
|
|
5
5
|
|
|
6
6
|
from pydantic import Field, TypeAdapter
|
|
7
7
|
|
|
8
|
-
from careamics.config.algorithms import
|
|
9
|
-
|
|
8
|
+
from careamics.config.algorithms import (
|
|
9
|
+
CAREAlgorithm,
|
|
10
|
+
MicroSplitAlgorithm,
|
|
11
|
+
N2NAlgorithm,
|
|
12
|
+
N2VAlgorithm,
|
|
13
|
+
)
|
|
14
|
+
from careamics.config.architectures import LVAEModel, UNetModel
|
|
10
15
|
from careamics.config.data import DataConfig, NGDataConfig
|
|
16
|
+
from careamics.config.likelihood_model import (
|
|
17
|
+
GaussianLikelihoodConfig,
|
|
18
|
+
NMLikelihoodConfig,
|
|
19
|
+
)
|
|
20
|
+
from careamics.config.loss_model import LVAELossConfig
|
|
21
|
+
from careamics.config.nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
11
22
|
from careamics.config.support import (
|
|
12
23
|
SupportedArchitecture,
|
|
13
24
|
SupportedPixelManipulation,
|
|
@@ -20,6 +31,7 @@ from careamics.config.transformations import (
|
|
|
20
31
|
XYFlipModel,
|
|
21
32
|
XYRandomRotate90Model,
|
|
22
33
|
)
|
|
34
|
+
from careamics.lvae_training.dataset.config import MicroSplitDataConfig
|
|
23
35
|
|
|
24
36
|
from .configuration import Configuration
|
|
25
37
|
|
|
@@ -224,7 +236,7 @@ def _create_algorithm_configuration(
|
|
|
224
236
|
def _create_data_configuration(
|
|
225
237
|
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
226
238
|
axes: str,
|
|
227
|
-
patch_size:
|
|
239
|
+
patch_size: Sequence[int],
|
|
228
240
|
batch_size: int,
|
|
229
241
|
augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
|
|
230
242
|
train_dataloader_params: dict[str, Any] | None = None,
|
|
@@ -277,6 +289,66 @@ def _create_data_configuration(
|
|
|
277
289
|
return DataConfig(**data)
|
|
278
290
|
|
|
279
291
|
|
|
292
|
+
def _create_microsplit_data_configuration(
|
|
293
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
294
|
+
axes: str,
|
|
295
|
+
patch_size: Sequence[int],
|
|
296
|
+
grid_size: int,
|
|
297
|
+
multiscale_count: int,
|
|
298
|
+
batch_size: int,
|
|
299
|
+
augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
|
|
300
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
301
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
302
|
+
) -> DataConfig:
|
|
303
|
+
"""
|
|
304
|
+
Create a dictionary with the parameters of the data model.
|
|
305
|
+
|
|
306
|
+
Parameters
|
|
307
|
+
----------
|
|
308
|
+
data_type : {"array", "tiff", "czi", "custom"}
|
|
309
|
+
Type of the data.
|
|
310
|
+
axes : str
|
|
311
|
+
Axes of the data.
|
|
312
|
+
patch_size : list of int
|
|
313
|
+
Size of the patches along the spatial dimensions.
|
|
314
|
+
batch_size : int
|
|
315
|
+
Batch size.
|
|
316
|
+
augmentations : list of transforms
|
|
317
|
+
List of transforms to apply.
|
|
318
|
+
train_dataloader_params : dict
|
|
319
|
+
Parameters for the training dataloader, see PyTorch notes, by default None.
|
|
320
|
+
val_dataloader_params : dict
|
|
321
|
+
Parameters for the validation dataloader, see PyTorch notes, by default None.
|
|
322
|
+
|
|
323
|
+
Returns
|
|
324
|
+
-------
|
|
325
|
+
DataConfig
|
|
326
|
+
Data model with the specified parameters.
|
|
327
|
+
"""
|
|
328
|
+
# data model
|
|
329
|
+
data = {
|
|
330
|
+
"data_type": data_type,
|
|
331
|
+
"axes": axes,
|
|
332
|
+
"image_size": patch_size,
|
|
333
|
+
"grid_size": grid_size,
|
|
334
|
+
"multiscale_lowres_count": multiscale_count,
|
|
335
|
+
"batch_size": batch_size,
|
|
336
|
+
"transforms": augmentations,
|
|
337
|
+
}
|
|
338
|
+
# Don't override defaults set in DataConfig class
|
|
339
|
+
if train_dataloader_params is not None:
|
|
340
|
+
# DataConfig enforces the presence of `shuffle` key in the dataloader parameters
|
|
341
|
+
if "shuffle" not in train_dataloader_params:
|
|
342
|
+
train_dataloader_params["shuffle"] = True
|
|
343
|
+
|
|
344
|
+
data["train_dataloader_params"] = train_dataloader_params
|
|
345
|
+
|
|
346
|
+
if val_dataloader_params is not None:
|
|
347
|
+
data["val_dataloader_params"] = val_dataloader_params
|
|
348
|
+
|
|
349
|
+
return MicroSplitDataConfig(**data)
|
|
350
|
+
|
|
351
|
+
|
|
280
352
|
def _create_ng_data_configuration(
|
|
281
353
|
data_type: Literal["array", "tiff", "custom"],
|
|
282
354
|
axes: str,
|
|
@@ -357,7 +429,7 @@ def _create_ng_data_configuration(
|
|
|
357
429
|
|
|
358
430
|
|
|
359
431
|
def _create_training_configuration(
|
|
360
|
-
|
|
432
|
+
trainer_params: dict,
|
|
361
433
|
logger: Literal["wandb", "tensorboard", "none"],
|
|
362
434
|
checkpoint_params: dict[str, Any] | None = None,
|
|
363
435
|
) -> TrainingConfig:
|
|
@@ -366,8 +438,8 @@ def _create_training_configuration(
|
|
|
366
438
|
|
|
367
439
|
Parameters
|
|
368
440
|
----------
|
|
369
|
-
|
|
370
|
-
|
|
441
|
+
trainer_params : dict
|
|
442
|
+
Parameters for Lightning Trainer class, see PyTorch Lightning documentation.
|
|
371
443
|
logger : {"wandb", "tensorboard", "none"}
|
|
372
444
|
Logger to use.
|
|
373
445
|
checkpoint_params : dict, default=None
|
|
@@ -380,7 +452,7 @@ def _create_training_configuration(
|
|
|
380
452
|
Training model with the specified parameters.
|
|
381
453
|
"""
|
|
382
454
|
return TrainingConfig(
|
|
383
|
-
|
|
455
|
+
lightning_trainer_config=trainer_params,
|
|
384
456
|
logger=None if logger == "none" else logger,
|
|
385
457
|
checkpoint_callback={} if checkpoint_params is None else checkpoint_params,
|
|
386
458
|
)
|
|
@@ -392,9 +464,9 @@ def _create_supervised_config_dict(
|
|
|
392
464
|
experiment_name: str,
|
|
393
465
|
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
394
466
|
axes: str,
|
|
395
|
-
patch_size:
|
|
467
|
+
patch_size: Sequence[int],
|
|
396
468
|
batch_size: int,
|
|
397
|
-
|
|
469
|
+
trainer_params: dict | None = None,
|
|
398
470
|
augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
|
|
399
471
|
independent_channels: bool = True,
|
|
400
472
|
loss: Literal["mae", "mse"] = "mae",
|
|
@@ -409,6 +481,8 @@ def _create_supervised_config_dict(
|
|
|
409
481
|
train_dataloader_params: dict[str, Any] | None = None,
|
|
410
482
|
val_dataloader_params: dict[str, Any] | None = None,
|
|
411
483
|
checkpoint_params: dict[str, Any] | None = None,
|
|
484
|
+
num_epochs: int | None = None,
|
|
485
|
+
num_steps: int | None = None,
|
|
412
486
|
) -> dict:
|
|
413
487
|
"""
|
|
414
488
|
Create a configuration for training CARE or Noise2Noise.
|
|
@@ -427,8 +501,8 @@ def _create_supervised_config_dict(
|
|
|
427
501
|
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
428
502
|
batch_size : int
|
|
429
503
|
Batch size.
|
|
430
|
-
|
|
431
|
-
|
|
504
|
+
trainer_params : dict
|
|
505
|
+
Parameters for the training configuration.
|
|
432
506
|
augmentations : list of transforms, default=None
|
|
433
507
|
List of transforms to apply, either both or one of XYFlipModel and
|
|
434
508
|
XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
|
|
@@ -461,6 +535,13 @@ def _create_supervised_config_dict(
|
|
|
461
535
|
checkpoint_params : dict, default=None
|
|
462
536
|
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
463
537
|
(`ModelCheckpoint`) for the list of available parameters.
|
|
538
|
+
num_epochs : int or None, default=None
|
|
539
|
+
Number of epochs to train for. If provided, this will be added to
|
|
540
|
+
trainer_params.
|
|
541
|
+
num_steps : int or None, default=None
|
|
542
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
543
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
544
|
+
documentation for more details.
|
|
464
545
|
|
|
465
546
|
Returns
|
|
466
547
|
-------
|
|
@@ -518,9 +599,18 @@ def _create_supervised_config_dict(
|
|
|
518
599
|
val_dataloader_params=val_dataloader_params,
|
|
519
600
|
)
|
|
520
601
|
|
|
602
|
+
# Handle trainer parameters with num_epochs and num_steps
|
|
603
|
+
final_trainer_params = {} if trainer_params is None else trainer_params.copy()
|
|
604
|
+
|
|
605
|
+
# Add num_epochs and num_steps if provided
|
|
606
|
+
if num_epochs is not None:
|
|
607
|
+
final_trainer_params["max_epochs"] = num_epochs
|
|
608
|
+
if num_steps is not None:
|
|
609
|
+
final_trainer_params["limit_train_batches"] = num_steps
|
|
610
|
+
|
|
521
611
|
# training
|
|
522
612
|
training_params = _create_training_configuration(
|
|
523
|
-
|
|
613
|
+
trainer_params=final_trainer_params,
|
|
524
614
|
logger=logger,
|
|
525
615
|
checkpoint_params=checkpoint_params,
|
|
526
616
|
)
|
|
@@ -537,15 +627,17 @@ def create_care_configuration(
|
|
|
537
627
|
experiment_name: str,
|
|
538
628
|
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
539
629
|
axes: str,
|
|
540
|
-
patch_size:
|
|
630
|
+
patch_size: Sequence[int],
|
|
541
631
|
batch_size: int,
|
|
542
|
-
num_epochs: int,
|
|
632
|
+
num_epochs: int = 100,
|
|
633
|
+
num_steps: int | None = None,
|
|
543
634
|
augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
|
|
544
635
|
independent_channels: bool = True,
|
|
545
636
|
loss: Literal["mae", "mse"] = "mae",
|
|
546
637
|
n_channels_in: int | None = None,
|
|
547
638
|
n_channels_out: int | None = None,
|
|
548
639
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
640
|
+
trainer_params: dict | None = None,
|
|
549
641
|
model_params: dict | None = None,
|
|
550
642
|
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
551
643
|
optimizer_params: dict[str, Any] | None = None,
|
|
@@ -588,8 +680,13 @@ def create_care_configuration(
|
|
|
588
680
|
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
589
681
|
batch_size : int
|
|
590
682
|
Batch size.
|
|
591
|
-
num_epochs : int
|
|
592
|
-
Number of epochs.
|
|
683
|
+
num_epochs : int, default=100
|
|
684
|
+
Number of epochs to train for. If provided, this will be added to
|
|
685
|
+
trainer_params.
|
|
686
|
+
num_steps : int, optional
|
|
687
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
688
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
689
|
+
documentation for more details.
|
|
593
690
|
augmentations : list of transforms, default=None
|
|
594
691
|
List of transforms to apply, either both or one of XYFlipModel and
|
|
595
692
|
XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
|
|
@@ -604,6 +701,8 @@ def create_care_configuration(
|
|
|
604
701
|
Number of channels out.
|
|
605
702
|
logger : Literal["wandb", "tensorboard", "none"], default="none"
|
|
606
703
|
Logger to use.
|
|
704
|
+
trainer_params : dict, optional
|
|
705
|
+
Parameters for the trainer class, see PyTorch Lightning documentation.
|
|
607
706
|
model_params : dict, default=None
|
|
608
707
|
UNetModel parameters.
|
|
609
708
|
optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
|
|
@@ -644,6 +743,16 @@ def create_care_configuration(
|
|
|
644
743
|
... num_epochs=100
|
|
645
744
|
... )
|
|
646
745
|
|
|
746
|
+
You can also limit the number of batches per epoch:
|
|
747
|
+
>>> config = create_care_configuration(
|
|
748
|
+
... experiment_name="care_experiment",
|
|
749
|
+
... data_type="array",
|
|
750
|
+
... axes="YX",
|
|
751
|
+
... patch_size=[64, 64],
|
|
752
|
+
... batch_size=32,
|
|
753
|
+
... num_steps=100 # limit to 100 batches per epoch
|
|
754
|
+
... )
|
|
755
|
+
|
|
647
756
|
To disable transforms, simply set `augmentations` to an empty list:
|
|
648
757
|
>>> config = create_care_configuration(
|
|
649
758
|
... experiment_name="care_experiment",
|
|
@@ -730,13 +839,13 @@ def create_care_configuration(
|
|
|
730
839
|
axes=axes,
|
|
731
840
|
patch_size=patch_size,
|
|
732
841
|
batch_size=batch_size,
|
|
733
|
-
num_epochs=num_epochs,
|
|
734
842
|
augmentations=augmentations,
|
|
735
843
|
independent_channels=independent_channels,
|
|
736
844
|
loss=loss,
|
|
737
845
|
n_channels_in=n_channels_in,
|
|
738
846
|
n_channels_out=n_channels_out,
|
|
739
847
|
logger=logger,
|
|
848
|
+
trainer_params=trainer_params,
|
|
740
849
|
model_params=model_params,
|
|
741
850
|
optimizer=optimizer,
|
|
742
851
|
optimizer_params=optimizer_params,
|
|
@@ -745,6 +854,8 @@ def create_care_configuration(
|
|
|
745
854
|
train_dataloader_params=train_dataloader_params,
|
|
746
855
|
val_dataloader_params=val_dataloader_params,
|
|
747
856
|
checkpoint_params=checkpoint_params,
|
|
857
|
+
num_epochs=num_epochs,
|
|
858
|
+
num_steps=num_steps,
|
|
748
859
|
)
|
|
749
860
|
)
|
|
750
861
|
|
|
@@ -753,15 +864,17 @@ def create_n2n_configuration(
|
|
|
753
864
|
experiment_name: str,
|
|
754
865
|
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
755
866
|
axes: str,
|
|
756
|
-
patch_size:
|
|
867
|
+
patch_size: Sequence[int],
|
|
757
868
|
batch_size: int,
|
|
758
|
-
num_epochs: int,
|
|
869
|
+
num_epochs: int = 100,
|
|
870
|
+
num_steps: int | None = None,
|
|
759
871
|
augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
|
|
760
872
|
independent_channels: bool = True,
|
|
761
873
|
loss: Literal["mae", "mse"] = "mae",
|
|
762
874
|
n_channels_in: int | None = None,
|
|
763
875
|
n_channels_out: int | None = None,
|
|
764
876
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
877
|
+
trainer_params: dict | None = None,
|
|
765
878
|
model_params: dict | None = None,
|
|
766
879
|
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
767
880
|
optimizer_params: dict[str, Any] | None = None,
|
|
@@ -804,8 +917,13 @@ def create_n2n_configuration(
|
|
|
804
917
|
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
805
918
|
batch_size : int
|
|
806
919
|
Batch size.
|
|
807
|
-
num_epochs : int
|
|
808
|
-
Number of epochs.
|
|
920
|
+
num_epochs : int, default=100
|
|
921
|
+
Number of epochs to train for. If provided, this will be added to
|
|
922
|
+
trainer_params.
|
|
923
|
+
num_steps : int, optional
|
|
924
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
925
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
926
|
+
documentation for more details.
|
|
809
927
|
augmentations : list of transforms, default=None
|
|
810
928
|
List of transforms to apply, either both or one of XYFlipModel and
|
|
811
929
|
XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
|
|
@@ -820,6 +938,8 @@ def create_n2n_configuration(
|
|
|
820
938
|
Number of channels out.
|
|
821
939
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
822
940
|
Logger to use, by default "none".
|
|
941
|
+
trainer_params : dict, optional
|
|
942
|
+
Parameters for the trainer class, see PyTorch Lightning documentation.
|
|
823
943
|
model_params : dict, default=None
|
|
824
944
|
UNetModel parameters.
|
|
825
945
|
optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
|
|
@@ -860,6 +980,16 @@ def create_n2n_configuration(
|
|
|
860
980
|
... num_epochs=100
|
|
861
981
|
... )
|
|
862
982
|
|
|
983
|
+
You can also limit the number of batches per epoch:
|
|
984
|
+
>>> config = create_n2n_configuration(
|
|
985
|
+
... experiment_name="n2n_experiment",
|
|
986
|
+
... data_type="array",
|
|
987
|
+
... axes="YX",
|
|
988
|
+
... patch_size=[64, 64],
|
|
989
|
+
... batch_size=32,
|
|
990
|
+
... num_steps=100 # limit to 100 batches per epoch
|
|
991
|
+
... )
|
|
992
|
+
|
|
863
993
|
To disable transforms, simply set `augmentations` to an empty list:
|
|
864
994
|
>>> config = create_n2n_configuration(
|
|
865
995
|
... experiment_name="n2n_experiment",
|
|
@@ -871,8 +1001,7 @@ def create_n2n_configuration(
|
|
|
871
1001
|
... augmentations=[]
|
|
872
1002
|
... )
|
|
873
1003
|
|
|
874
|
-
A list of transforms can be passed to the `augmentations` parameter
|
|
875
|
-
default augmentations:
|
|
1004
|
+
A list of transforms can be passed to the `augmentations` parameter:
|
|
876
1005
|
>>> from careamics.config.transformations import XYFlipModel
|
|
877
1006
|
>>> config = create_n2n_configuration(
|
|
878
1007
|
... experiment_name="n2n_experiment",
|
|
@@ -946,7 +1075,7 @@ def create_n2n_configuration(
|
|
|
946
1075
|
axes=axes,
|
|
947
1076
|
patch_size=patch_size,
|
|
948
1077
|
batch_size=batch_size,
|
|
949
|
-
|
|
1078
|
+
trainer_params=trainer_params,
|
|
950
1079
|
augmentations=augmentations,
|
|
951
1080
|
independent_channels=independent_channels,
|
|
952
1081
|
loss=loss,
|
|
@@ -961,6 +1090,8 @@ def create_n2n_configuration(
|
|
|
961
1090
|
train_dataloader_params=train_dataloader_params,
|
|
962
1091
|
val_dataloader_params=val_dataloader_params,
|
|
963
1092
|
checkpoint_params=checkpoint_params,
|
|
1093
|
+
num_epochs=num_epochs,
|
|
1094
|
+
num_steps=num_steps,
|
|
964
1095
|
)
|
|
965
1096
|
)
|
|
966
1097
|
|
|
@@ -969,9 +1100,10 @@ def create_n2v_configuration(
|
|
|
969
1100
|
experiment_name: str,
|
|
970
1101
|
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
971
1102
|
axes: str,
|
|
972
|
-
patch_size:
|
|
1103
|
+
patch_size: Sequence[int],
|
|
973
1104
|
batch_size: int,
|
|
974
|
-
num_epochs: int,
|
|
1105
|
+
num_epochs: int = 100,
|
|
1106
|
+
num_steps: int | None = None,
|
|
975
1107
|
augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
|
|
976
1108
|
independent_channels: bool = True,
|
|
977
1109
|
use_n2v2: bool = False,
|
|
@@ -980,6 +1112,7 @@ def create_n2v_configuration(
|
|
|
980
1112
|
masked_pixel_percentage: float = 0.2,
|
|
981
1113
|
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
982
1114
|
struct_n2v_span: int = 5,
|
|
1115
|
+
trainer_params: dict | None = None,
|
|
983
1116
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
984
1117
|
model_params: dict | None = None,
|
|
985
1118
|
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
@@ -1043,8 +1176,13 @@ def create_n2v_configuration(
|
|
|
1043
1176
|
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
1044
1177
|
batch_size : int
|
|
1045
1178
|
Batch size.
|
|
1046
|
-
num_epochs : int
|
|
1047
|
-
Number of epochs.
|
|
1179
|
+
num_epochs : int, default=100
|
|
1180
|
+
Number of epochs to train for. If provided, this will be added to
|
|
1181
|
+
trainer_params.
|
|
1182
|
+
num_steps : int, optional
|
|
1183
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
1184
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
1185
|
+
documentation for more details.
|
|
1048
1186
|
augmentations : list of transforms, default=None
|
|
1049
1187
|
List of transforms to apply, either both or one of XYFlipModel and
|
|
1050
1188
|
XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
|
|
@@ -1063,6 +1201,8 @@ def create_n2v_configuration(
|
|
|
1063
1201
|
Axis along which to apply structN2V mask, by default "none".
|
|
1064
1202
|
struct_n2v_span : int, optional
|
|
1065
1203
|
Span of the structN2V mask, by default 5.
|
|
1204
|
+
trainer_params : dict, optional
|
|
1205
|
+
Parameters for the trainer, see the relevant documentation.
|
|
1066
1206
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
1067
1207
|
Logger to use, by default "none".
|
|
1068
1208
|
model_params : dict, default=None
|
|
@@ -1105,6 +1245,16 @@ def create_n2v_configuration(
|
|
|
1105
1245
|
... num_epochs=100
|
|
1106
1246
|
... )
|
|
1107
1247
|
|
|
1248
|
+
You can also limit the number of batches per epoch:
|
|
1249
|
+
>>> config = create_n2v_configuration(
|
|
1250
|
+
... experiment_name="n2v_experiment",
|
|
1251
|
+
... data_type="array",
|
|
1252
|
+
... axes="YX",
|
|
1253
|
+
... patch_size=[64, 64],
|
|
1254
|
+
... batch_size=32,
|
|
1255
|
+
... num_steps=100 # limit to 100 batches per epoch
|
|
1256
|
+
... )
|
|
1257
|
+
|
|
1108
1258
|
To disable transforms, simply set `augmentations` to an empty list:
|
|
1109
1259
|
>>> config = create_n2v_configuration(
|
|
1110
1260
|
... experiment_name="n2v_experiment",
|
|
@@ -1261,8 +1411,17 @@ def create_n2v_configuration(
|
|
|
1261
1411
|
)
|
|
1262
1412
|
|
|
1263
1413
|
# training
|
|
1414
|
+
# Handle trainer parameters with num_epochs and nun_steps
|
|
1415
|
+
final_trainer_params = {} if trainer_params is None else trainer_params.copy()
|
|
1416
|
+
|
|
1417
|
+
# Add num_epochs and nun_steps if provided
|
|
1418
|
+
if num_epochs is not None:
|
|
1419
|
+
final_trainer_params["max_epochs"] = num_epochs
|
|
1420
|
+
if num_steps is not None:
|
|
1421
|
+
final_trainer_params["limit_train_batches"] = num_steps
|
|
1422
|
+
|
|
1264
1423
|
training_params = _create_training_configuration(
|
|
1265
|
-
|
|
1424
|
+
trainer_params=final_trainer_params,
|
|
1266
1425
|
logger=logger,
|
|
1267
1426
|
checkpoint_params=checkpoint_params,
|
|
1268
1427
|
)
|
|
@@ -1273,3 +1432,658 @@ def create_n2v_configuration(
|
|
|
1273
1432
|
data_config=data_params,
|
|
1274
1433
|
training_config=training_params,
|
|
1275
1434
|
)
|
|
1435
|
+
|
|
1436
|
+
|
|
1437
|
+
def _create_vae_configuration(
|
|
1438
|
+
input_shape: Sequence[int],
|
|
1439
|
+
encoder_conv_strides: tuple[int, ...],
|
|
1440
|
+
decoder_conv_strides: tuple[int, ...],
|
|
1441
|
+
multiscale_count: int,
|
|
1442
|
+
z_dims: tuple[int, ...],
|
|
1443
|
+
output_channels: int,
|
|
1444
|
+
encoder_n_filters: int,
|
|
1445
|
+
decoder_n_filters: int,
|
|
1446
|
+
encoder_dropout: float,
|
|
1447
|
+
decoder_dropout: float,
|
|
1448
|
+
nonlinearity: Literal[
|
|
1449
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
1450
|
+
],
|
|
1451
|
+
predict_logvar: Literal[None, "pixelwise"],
|
|
1452
|
+
analytical_kl: bool,
|
|
1453
|
+
) -> LVAEModel:
|
|
1454
|
+
"""Create a dictionary with the parameters of the vae based algorithm model.
|
|
1455
|
+
|
|
1456
|
+
Parameters
|
|
1457
|
+
----------
|
|
1458
|
+
input_shape : tuple[int, ...]
|
|
1459
|
+
Shape of the input patch (Z, Y, X) or (Y, X) if the data is 2D.
|
|
1460
|
+
encoder_conv_strides : tuple[int, ...]
|
|
1461
|
+
Strides of the encoder convolutional layers, length also defines 2D or 3D.
|
|
1462
|
+
decoder_conv_strides : tuple[int, ...]
|
|
1463
|
+
Strides of the decoder convolutional layers, length also defines 2D or 3D.
|
|
1464
|
+
multiscale_count : int
|
|
1465
|
+
Number of lateral context layers, specific to MicroSplit.
|
|
1466
|
+
z_dims : tuple[int, ...]
|
|
1467
|
+
Number of hierarchies in the LVAE model.
|
|
1468
|
+
output_channels : int
|
|
1469
|
+
Number of output channels.
|
|
1470
|
+
encoder_n_filters : int
|
|
1471
|
+
Number of filters in the convolutional layers of the encoder.
|
|
1472
|
+
decoder_n_filters : int
|
|
1473
|
+
Number of filters in the convolutional layers of the decoder.
|
|
1474
|
+
encoder_dropout : float
|
|
1475
|
+
Dropout rate for the encoder.
|
|
1476
|
+
decoder_dropout : float
|
|
1477
|
+
Dropout rate for the decoder.
|
|
1478
|
+
nonlinearity : Literal
|
|
1479
|
+
Type of nonlinearity function to use.
|
|
1480
|
+
predict_logvar : Literal # TODO needs review
|
|
1481
|
+
_description_.
|
|
1482
|
+
analytical_kl : bool # TODO needs clarification
|
|
1483
|
+
_description_.
|
|
1484
|
+
|
|
1485
|
+
Returns
|
|
1486
|
+
-------
|
|
1487
|
+
LVAEModel
|
|
1488
|
+
LVAE model with the specified parameters.
|
|
1489
|
+
"""
|
|
1490
|
+
return LVAEModel(
|
|
1491
|
+
architecture=SupportedArchitecture.LVAE.value,
|
|
1492
|
+
input_shape=input_shape,
|
|
1493
|
+
encoder_conv_strides=encoder_conv_strides,
|
|
1494
|
+
decoder_conv_strides=decoder_conv_strides,
|
|
1495
|
+
multiscale_count=multiscale_count,
|
|
1496
|
+
z_dims=z_dims,
|
|
1497
|
+
output_channels=output_channels,
|
|
1498
|
+
encoder_n_filters=encoder_n_filters,
|
|
1499
|
+
decoder_n_filters=decoder_n_filters,
|
|
1500
|
+
encoder_dropout=encoder_dropout,
|
|
1501
|
+
decoder_dropout=decoder_dropout,
|
|
1502
|
+
nonlinearity=nonlinearity,
|
|
1503
|
+
predict_logvar=predict_logvar,
|
|
1504
|
+
analytical_kl=analytical_kl,
|
|
1505
|
+
)
|
|
1506
|
+
|
|
1507
|
+
|
|
1508
|
+
def _create_vae_based_algorithm(
|
|
1509
|
+
algorithm: Literal["hdn", "microsplit"],
|
|
1510
|
+
loss: LVAELossConfig,
|
|
1511
|
+
input_shape: Sequence[int],
|
|
1512
|
+
encoder_conv_strides: tuple[int, ...],
|
|
1513
|
+
decoder_conv_strides: tuple[int, ...],
|
|
1514
|
+
multiscale_count: int,
|
|
1515
|
+
z_dims: tuple[int, ...],
|
|
1516
|
+
output_channels: int,
|
|
1517
|
+
encoder_n_filters: int,
|
|
1518
|
+
decoder_n_filters: int,
|
|
1519
|
+
encoder_dropout: float,
|
|
1520
|
+
decoder_dropout: float,
|
|
1521
|
+
nonlinearity: Literal[
|
|
1522
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
1523
|
+
],
|
|
1524
|
+
predict_logvar: Literal[None, "pixelwise"],
|
|
1525
|
+
analytical_kl: bool,
|
|
1526
|
+
gaussian_likelihood: GaussianLikelihoodConfig | None = None,
|
|
1527
|
+
nm_likelihood: NMLikelihoodConfig | None = None,
|
|
1528
|
+
) -> dict:
|
|
1529
|
+
"""
|
|
1530
|
+
Create a dictionary with the parameters of the VAE-based algorithm model.
|
|
1531
|
+
|
|
1532
|
+
Parameters
|
|
1533
|
+
----------
|
|
1534
|
+
algorithm : Literal["hdn"]
|
|
1535
|
+
The algorithm type.
|
|
1536
|
+
loss : Literal["hdn"]
|
|
1537
|
+
The loss function type.
|
|
1538
|
+
input_shape : tuple[int, ...]
|
|
1539
|
+
The shape of the input data.
|
|
1540
|
+
encoder_conv_strides : list[int]
|
|
1541
|
+
The strides of the encoder convolutional layers.
|
|
1542
|
+
decoder_conv_strides : list[int]
|
|
1543
|
+
The strides of the decoder convolutional layers.
|
|
1544
|
+
multiscale_count : int
|
|
1545
|
+
The number of multiscale layers.
|
|
1546
|
+
z_dims : list[int]
|
|
1547
|
+
The dimensions of the latent space.
|
|
1548
|
+
output_channels : int
|
|
1549
|
+
The number of output channels.
|
|
1550
|
+
encoder_n_filters : int
|
|
1551
|
+
The number of filters in the encoder.
|
|
1552
|
+
decoder_n_filters : int
|
|
1553
|
+
The number of filters in the decoder.
|
|
1554
|
+
encoder_dropout : float
|
|
1555
|
+
The dropout rate for the encoder.
|
|
1556
|
+
decoder_dropout : float
|
|
1557
|
+
The dropout rate for the decoder.
|
|
1558
|
+
nonlinearity : Literal
|
|
1559
|
+
The nonlinearity function to use.
|
|
1560
|
+
predict_logvar : Literal[None, "pixelwise"]
|
|
1561
|
+
The type of log variance prediction.
|
|
1562
|
+
analytical_kl : bool
|
|
1563
|
+
Whether to use analytical KL divergence.
|
|
1564
|
+
gaussian_likelihood : Optional[GaussianLikelihoodConfig], optional
|
|
1565
|
+
The Gaussian likelihood model, by default None.
|
|
1566
|
+
nm_likelihood : Optional[NMLikelihoodConfig], optional
|
|
1567
|
+
The noise model likelihood model, by default None.
|
|
1568
|
+
|
|
1569
|
+
Returns
|
|
1570
|
+
-------
|
|
1571
|
+
dict
|
|
1572
|
+
A dictionary with the parameters of the VAE-based algorithm model.
|
|
1573
|
+
"""
|
|
1574
|
+
network_model = _create_vae_configuration(
|
|
1575
|
+
input_shape=input_shape,
|
|
1576
|
+
encoder_conv_strides=encoder_conv_strides,
|
|
1577
|
+
decoder_conv_strides=decoder_conv_strides,
|
|
1578
|
+
multiscale_count=multiscale_count,
|
|
1579
|
+
z_dims=z_dims,
|
|
1580
|
+
output_channels=output_channels,
|
|
1581
|
+
encoder_n_filters=encoder_n_filters,
|
|
1582
|
+
decoder_n_filters=decoder_n_filters,
|
|
1583
|
+
encoder_dropout=encoder_dropout,
|
|
1584
|
+
decoder_dropout=decoder_dropout,
|
|
1585
|
+
nonlinearity=nonlinearity,
|
|
1586
|
+
predict_logvar=predict_logvar,
|
|
1587
|
+
analytical_kl=analytical_kl,
|
|
1588
|
+
)
|
|
1589
|
+
assert gaussian_likelihood or nm_likelihood, "Likelihood model must be specified"
|
|
1590
|
+
return {
|
|
1591
|
+
"algorithm": algorithm,
|
|
1592
|
+
"loss": loss,
|
|
1593
|
+
"model": network_model,
|
|
1594
|
+
"gaussian_likelihood": gaussian_likelihood,
|
|
1595
|
+
"noise_model_likelihood": nm_likelihood,
|
|
1596
|
+
}
|
|
1597
|
+
|
|
1598
|
+
|
|
1599
|
+
def get_likelihood_config(
|
|
1600
|
+
loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"],
|
|
1601
|
+
# TODO remove different microsplit loss types, refac
|
|
1602
|
+
predict_logvar: Literal["pixelwise"] | None = None,
|
|
1603
|
+
logvar_lowerbound: float | None = -5.0,
|
|
1604
|
+
nm_paths: list[str] | None = None,
|
|
1605
|
+
data_stats: tuple[float, float] | None = None,
|
|
1606
|
+
) -> tuple[
|
|
1607
|
+
GaussianLikelihoodConfig | None,
|
|
1608
|
+
MultiChannelNMConfig | None,
|
|
1609
|
+
NMLikelihoodConfig | None,
|
|
1610
|
+
]:
|
|
1611
|
+
"""Get the likelihood configuration for split models.
|
|
1612
|
+
|
|
1613
|
+
Parameters
|
|
1614
|
+
----------
|
|
1615
|
+
loss_type : Literal["musplit", "denoisplit", "denoisplit_musplit"]
|
|
1616
|
+
The type of loss function to use.
|
|
1617
|
+
predict_logvar : Literal["pixelwise"] | None, optional
|
|
1618
|
+
Type of log variance prediction, by default None.
|
|
1619
|
+
Required when loss_type is "musplit" or "denoisplit_musplit".
|
|
1620
|
+
logvar_lowerbound : float | None, optional
|
|
1621
|
+
Lower bound for the log variance, by default -5.0.
|
|
1622
|
+
Used when loss_type is "musplit" or "denoisplit_musplit".
|
|
1623
|
+
nm_paths : list[str] | None, optional
|
|
1624
|
+
Paths to the noise model files, by default None.
|
|
1625
|
+
Required when loss_type is "denoisplit" or "denoisplit_musplit".
|
|
1626
|
+
data_stats : tuple[float, float] | None, optional
|
|
1627
|
+
Data statistics (mean, std), by default None.
|
|
1628
|
+
Required when loss_type is "denoisplit" or "denoisplit_musplit".
|
|
1629
|
+
|
|
1630
|
+
Returns
|
|
1631
|
+
-------
|
|
1632
|
+
tuple[GaussianLikelihoodConfig | None, MultiChannelNMConfig | None,
|
|
1633
|
+
NMLikelihoodConfig | None]
|
|
1634
|
+
A tuple containing the likelihood and noise model configurations for the
|
|
1635
|
+
specified loss type.
|
|
1636
|
+
|
|
1637
|
+
- GaussianLikelihoodConfig: Gaussian likelihood configuration for musplit losses
|
|
1638
|
+
- MultiChannelNMConfig: Multi-channel noise model configuration for denoisplit
|
|
1639
|
+
losses
|
|
1640
|
+
- NMLikelihoodConfig: Noise model likelihood configuration for denoisplit losses
|
|
1641
|
+
|
|
1642
|
+
Raises
|
|
1643
|
+
------
|
|
1644
|
+
ValueError
|
|
1645
|
+
If required parameters are missing for the specified loss_type.
|
|
1646
|
+
"""
|
|
1647
|
+
# gaussian likelihood
|
|
1648
|
+
if loss_type in ["musplit", "denoisplit_musplit"]:
|
|
1649
|
+
# if predict_logvar is None:
|
|
1650
|
+
# raise ValueError(f"predict_logvar is required for loss_type '{loss_type}'")
|
|
1651
|
+
# TODO validators should be in pydantic models
|
|
1652
|
+
gaussian_lik_config = GaussianLikelihoodConfig(
|
|
1653
|
+
predict_logvar=predict_logvar,
|
|
1654
|
+
logvar_lowerbound=logvar_lowerbound,
|
|
1655
|
+
)
|
|
1656
|
+
else:
|
|
1657
|
+
gaussian_lik_config = None
|
|
1658
|
+
|
|
1659
|
+
# noise model likelihood
|
|
1660
|
+
if loss_type in ["denoisplit", "denoisplit_musplit"]:
|
|
1661
|
+
# if nm_paths is None:
|
|
1662
|
+
# raise ValueError(f"nm_paths is required for loss_type '{loss_type}'")
|
|
1663
|
+
# if data_stats is None:
|
|
1664
|
+
# raise ValueError(f"data_stats is required for loss_type '{loss_type}'")
|
|
1665
|
+
# TODO validators should be in pydantic models
|
|
1666
|
+
gmm_list = []
|
|
1667
|
+
if nm_paths is not None:
|
|
1668
|
+
for NM_path in nm_paths:
|
|
1669
|
+
gmm_list.append(
|
|
1670
|
+
GaussianMixtureNMConfig(
|
|
1671
|
+
model_type="GaussianMixtureNoiseModel",
|
|
1672
|
+
path=NM_path,
|
|
1673
|
+
)
|
|
1674
|
+
)
|
|
1675
|
+
noise_model_config = MultiChannelNMConfig(noise_models=gmm_list)
|
|
1676
|
+
nm_lik_config = NMLikelihoodConfig() # TODO this config isn't needed probably
|
|
1677
|
+
else:
|
|
1678
|
+
noise_model_config = None
|
|
1679
|
+
nm_lik_config = None
|
|
1680
|
+
|
|
1681
|
+
return gaussian_lik_config, noise_model_config, nm_lik_config
|
|
1682
|
+
|
|
1683
|
+
|
|
1684
|
+
# TODO wrap parameters into model, loss etc
|
|
1685
|
+
# TODO refac likelihood configs to make it 1. Can it be done ?
|
|
1686
|
+
def create_hdn_configuration(
|
|
1687
|
+
experiment_name: str,
|
|
1688
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
1689
|
+
axes: str,
|
|
1690
|
+
patch_size: Sequence[int],
|
|
1691
|
+
batch_size: int,
|
|
1692
|
+
num_epochs: int = 100,
|
|
1693
|
+
num_steps: int | None = None,
|
|
1694
|
+
encoder_conv_strides: tuple[int, ...] = (2, 2),
|
|
1695
|
+
decoder_conv_strides: tuple[int, ...] = (2, 2),
|
|
1696
|
+
multiscale_count: int = 1,
|
|
1697
|
+
z_dims: tuple[int, ...] = (128, 128),
|
|
1698
|
+
output_channels: int = 1,
|
|
1699
|
+
encoder_n_filters: int = 32,
|
|
1700
|
+
decoder_n_filters: int = 32,
|
|
1701
|
+
encoder_dropout: float = 0.0,
|
|
1702
|
+
decoder_dropout: float = 0.0,
|
|
1703
|
+
nonlinearity: Literal[
|
|
1704
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
1705
|
+
] = "ReLU",
|
|
1706
|
+
analytical_kl: bool = False,
|
|
1707
|
+
predict_logvar: Literal["pixelwise"] | None = None,
|
|
1708
|
+
logvar_lowerbound: Union[float, None] = None,
|
|
1709
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
1710
|
+
trainer_params: dict | None = None,
|
|
1711
|
+
augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
|
|
1712
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
1713
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
1714
|
+
) -> Configuration:
|
|
1715
|
+
"""
|
|
1716
|
+
Create a configuration for training HDN.
|
|
1717
|
+
|
|
1718
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
1719
|
+
2.
|
|
1720
|
+
|
|
1721
|
+
If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
|
|
1722
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
1723
|
+
`axes`.
|
|
1724
|
+
|
|
1725
|
+
To set the number of output channels, use the `n_channels_out` parameter. If it is
|
|
1726
|
+
not specified, it will be assumed to be equal to `n_channels_in`.
|
|
1727
|
+
|
|
1728
|
+
By default, all channels are trained independently. To train all channels together,
|
|
1729
|
+
set `independent_channels` to False.
|
|
1730
|
+
|
|
1731
|
+
By setting `augmentations` to `None`, the default transformations (flip in X and Y,
|
|
1732
|
+
rotations by 90 degrees in the XY plane) are applied. Rather than the default
|
|
1733
|
+
transforms, a list of transforms can be passed to the `augmentations` parameter. To
|
|
1734
|
+
disable the transforms, simply pass an empty list.
|
|
1735
|
+
|
|
1736
|
+
# TODO revisit the necessity of model_params
|
|
1737
|
+
|
|
1738
|
+
Parameters
|
|
1739
|
+
----------
|
|
1740
|
+
experiment_name : str
|
|
1741
|
+
Name of the experiment.
|
|
1742
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
1743
|
+
Type of the data.
|
|
1744
|
+
axes : str
|
|
1745
|
+
Axes of the data (e.g. SYX).
|
|
1746
|
+
patch_size : List[int]
|
|
1747
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
1748
|
+
batch_size : int
|
|
1749
|
+
Batch size.
|
|
1750
|
+
num_epochs : int, default=100
|
|
1751
|
+
Number of epochs to train for. If provided, this will be added to
|
|
1752
|
+
trainer_params.
|
|
1753
|
+
num_steps : int, optional
|
|
1754
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
1755
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
1756
|
+
documentation for more details.
|
|
1757
|
+
encoder_conv_strides : tuple[int, ...], optional
|
|
1758
|
+
Strides for the encoder convolutional layers, by default (2, 2).
|
|
1759
|
+
decoder_conv_strides : tuple[int, ...], optional
|
|
1760
|
+
Strides for the decoder convolutional layers, by default (2, 2).
|
|
1761
|
+
multiscale_count : int, optional
|
|
1762
|
+
Number of scales in the multiscale architecture, by default 1.
|
|
1763
|
+
z_dims : tuple[int, ...], optional
|
|
1764
|
+
Dimensions of the latent space, by default (128, 128).
|
|
1765
|
+
output_channels : int, optional
|
|
1766
|
+
Number of output channels, by default 1.
|
|
1767
|
+
encoder_n_filters : int, optional
|
|
1768
|
+
Number of filters in the encoder, by default 32.
|
|
1769
|
+
decoder_n_filters : int, optional
|
|
1770
|
+
Number of filters in the decoder, by default 32.
|
|
1771
|
+
encoder_dropout : float, optional
|
|
1772
|
+
Dropout rate for the encoder, by default 0.0.
|
|
1773
|
+
decoder_dropout : float, optional
|
|
1774
|
+
Dropout rate for the decoder, by default 0.0.
|
|
1775
|
+
nonlinearity : Literal, optional
|
|
1776
|
+
Nonlinearity function to use, by default "ReLU".
|
|
1777
|
+
analytical_kl : bool, optional
|
|
1778
|
+
Whether to use analytical KL divergence, by default False.
|
|
1779
|
+
predict_logvar : Literal[None, "pixelwise"], optional
|
|
1780
|
+
Type of log variance prediction, by default None.
|
|
1781
|
+
logvar_lowerbound : Union[float, None], optional
|
|
1782
|
+
Lower bound for the log variance, by default None.
|
|
1783
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
1784
|
+
Logger to use for training, by default "none".
|
|
1785
|
+
trainer_params : dict, optional
|
|
1786
|
+
Parameters for the trainer class, see PyTorch Lightning documentation.
|
|
1787
|
+
augmentations : Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]], optional
|
|
1788
|
+
List of augmentations to apply, by default None.
|
|
1789
|
+
train_dataloader_params : Optional[dict[str, Any]], optional
|
|
1790
|
+
Parameters for the training dataloader, by default None.
|
|
1791
|
+
val_dataloader_params : Optional[dict[str, Any]], optional
|
|
1792
|
+
Parameters for the validation dataloader, by default None.
|
|
1793
|
+
|
|
1794
|
+
Returns
|
|
1795
|
+
-------
|
|
1796
|
+
Configuration
|
|
1797
|
+
The configuration object for training HDN.
|
|
1798
|
+
|
|
1799
|
+
Examples
|
|
1800
|
+
--------
|
|
1801
|
+
Minimum example:
|
|
1802
|
+
>>> config = create_hdn_configuration(
|
|
1803
|
+
... experiment_name="hdn_experiment",
|
|
1804
|
+
... data_type="array",
|
|
1805
|
+
... axes="YX",
|
|
1806
|
+
... patch_size=[64, 64],
|
|
1807
|
+
... batch_size=32,
|
|
1808
|
+
... num_epochs=100
|
|
1809
|
+
... )
|
|
1810
|
+
|
|
1811
|
+
You can also limit the number of batches per epoch:
|
|
1812
|
+
>>> config = create_hdn_configuration(
|
|
1813
|
+
... experiment_name="hdn_experiment",
|
|
1814
|
+
... data_type="array",
|
|
1815
|
+
... axes="YX",
|
|
1816
|
+
... patch_size=[64, 64],
|
|
1817
|
+
... batch_size=32,
|
|
1818
|
+
... num_steps=100 # limit to 100 batches per epoch
|
|
1819
|
+
... )
|
|
1820
|
+
"""
|
|
1821
|
+
transform_list = _list_spatial_augmentations(augmentations)
|
|
1822
|
+
|
|
1823
|
+
loss_config = LVAELossConfig(
|
|
1824
|
+
loss_type="hdn", denoisplit_weight=1, musplit_weight=0
|
|
1825
|
+
) # TODO what are the correct defaults for HDN?
|
|
1826
|
+
|
|
1827
|
+
gaussian_likelihood = GaussianLikelihoodConfig(
|
|
1828
|
+
predict_logvar=predict_logvar, logvar_lowerbound=logvar_lowerbound
|
|
1829
|
+
)
|
|
1830
|
+
|
|
1831
|
+
# algorithm & model
|
|
1832
|
+
algorithm_params = _create_vae_based_algorithm(
|
|
1833
|
+
algorithm="hdn",
|
|
1834
|
+
loss=loss_config,
|
|
1835
|
+
input_shape=patch_size,
|
|
1836
|
+
encoder_conv_strides=encoder_conv_strides,
|
|
1837
|
+
decoder_conv_strides=decoder_conv_strides,
|
|
1838
|
+
multiscale_count=multiscale_count,
|
|
1839
|
+
z_dims=z_dims,
|
|
1840
|
+
output_channels=output_channels,
|
|
1841
|
+
encoder_n_filters=encoder_n_filters,
|
|
1842
|
+
decoder_n_filters=decoder_n_filters,
|
|
1843
|
+
encoder_dropout=encoder_dropout,
|
|
1844
|
+
decoder_dropout=decoder_dropout,
|
|
1845
|
+
nonlinearity=nonlinearity,
|
|
1846
|
+
predict_logvar=predict_logvar,
|
|
1847
|
+
analytical_kl=analytical_kl,
|
|
1848
|
+
gaussian_likelihood=gaussian_likelihood,
|
|
1849
|
+
nm_likelihood=None,
|
|
1850
|
+
)
|
|
1851
|
+
|
|
1852
|
+
# data
|
|
1853
|
+
data_params = _create_data_configuration(
|
|
1854
|
+
data_type=data_type,
|
|
1855
|
+
axes=axes,
|
|
1856
|
+
patch_size=patch_size,
|
|
1857
|
+
batch_size=batch_size,
|
|
1858
|
+
augmentations=transform_list,
|
|
1859
|
+
train_dataloader_params=train_dataloader_params,
|
|
1860
|
+
val_dataloader_params=val_dataloader_params,
|
|
1861
|
+
)
|
|
1862
|
+
|
|
1863
|
+
# Handle trainer parameters with num_epochs and num_steps
|
|
1864
|
+
final_trainer_params = {} if trainer_params is None else trainer_params.copy()
|
|
1865
|
+
|
|
1866
|
+
# Add num_epochs and num_steps if provided
|
|
1867
|
+
if num_epochs is not None:
|
|
1868
|
+
final_trainer_params["max_epochs"] = num_epochs
|
|
1869
|
+
if num_steps is not None:
|
|
1870
|
+
final_trainer_params["limit_train_batches"] = num_steps
|
|
1871
|
+
|
|
1872
|
+
# training
|
|
1873
|
+
training_params = _create_training_configuration(
|
|
1874
|
+
trainer_params=final_trainer_params,
|
|
1875
|
+
logger=logger,
|
|
1876
|
+
)
|
|
1877
|
+
|
|
1878
|
+
return Configuration(
|
|
1879
|
+
experiment_name=experiment_name,
|
|
1880
|
+
algorithm_config=algorithm_params,
|
|
1881
|
+
data_config=data_params,
|
|
1882
|
+
training_config=training_params,
|
|
1883
|
+
)
|
|
1884
|
+
|
|
1885
|
+
|
|
1886
|
+
def create_microsplit_configuration(
|
|
1887
|
+
experiment_name: str,
|
|
1888
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
1889
|
+
axes: str,
|
|
1890
|
+
patch_size: Sequence[int],
|
|
1891
|
+
batch_size: int,
|
|
1892
|
+
num_epochs: int = 100,
|
|
1893
|
+
num_steps: int | None = None,
|
|
1894
|
+
encoder_conv_strides: tuple[int, ...] = (2, 2),
|
|
1895
|
+
decoder_conv_strides: tuple[int, ...] = (2, 2),
|
|
1896
|
+
multiscale_count: int = 3,
|
|
1897
|
+
grid_size: int = 32, # TODO most likely can be derived from patch size
|
|
1898
|
+
z_dims: tuple[int, ...] = (128, 128),
|
|
1899
|
+
output_channels: int = 1,
|
|
1900
|
+
encoder_n_filters: int = 32,
|
|
1901
|
+
decoder_n_filters: int = 32,
|
|
1902
|
+
encoder_dropout: float = 0.0,
|
|
1903
|
+
decoder_dropout: float = 0.0,
|
|
1904
|
+
nonlinearity: Literal[
|
|
1905
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
1906
|
+
] = "ReLU",
|
|
1907
|
+
analytical_kl: bool = False,
|
|
1908
|
+
predict_logvar: Literal["pixelwise"] = "pixelwise",
|
|
1909
|
+
logvar_lowerbound: Union[float, None] = None,
|
|
1910
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
1911
|
+
trainer_params: dict | None = None,
|
|
1912
|
+
augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
|
|
1913
|
+
nm_paths: list[str] | None = None,
|
|
1914
|
+
data_stats: tuple[float, float] | None = None,
|
|
1915
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
1916
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
1917
|
+
) -> Configuration:
|
|
1918
|
+
"""
|
|
1919
|
+
Create a configuration for training MicroSplit.
|
|
1920
|
+
|
|
1921
|
+
Parameters
|
|
1922
|
+
----------
|
|
1923
|
+
experiment_name : str
|
|
1924
|
+
Name of the experiment.
|
|
1925
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
1926
|
+
Type of the data.
|
|
1927
|
+
axes : str
|
|
1928
|
+
Axes of the data (e.g. SYX).
|
|
1929
|
+
patch_size : Sequence[int]
|
|
1930
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
1931
|
+
batch_size : int
|
|
1932
|
+
Batch size.
|
|
1933
|
+
num_epochs : int, default=100
|
|
1934
|
+
Number of epochs to train for. If provided, this will be added to
|
|
1935
|
+
trainer_params.
|
|
1936
|
+
num_steps : int, optional
|
|
1937
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
1938
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
1939
|
+
documentation for more details.
|
|
1940
|
+
encoder_conv_strides : tuple[int, ...], optional
|
|
1941
|
+
Strides for the encoder convolutional layers, by default (2, 2).
|
|
1942
|
+
decoder_conv_strides : tuple[int, ...], optional
|
|
1943
|
+
Strides for the decoder convolutional layers, by default (2, 2).
|
|
1944
|
+
multiscale_count : int, optional
|
|
1945
|
+
Number of multiscale levels, by default 1.
|
|
1946
|
+
z_dims : tuple[int, ...], optional
|
|
1947
|
+
List of latent dimensions for each hierarchy level in the LVAE, by default (128, 128).
|
|
1948
|
+
output_channels : int, optional
|
|
1949
|
+
Number of output channels for the model, by default 1.
|
|
1950
|
+
encoder_n_filters : int, optional
|
|
1951
|
+
Number of filters in the encoder, by default 32.
|
|
1952
|
+
decoder_n_filters : int, optional
|
|
1953
|
+
Number of filters in the decoder, by default 32.
|
|
1954
|
+
encoder_dropout : float, optional
|
|
1955
|
+
Dropout rate for the encoder, by default 0.0.
|
|
1956
|
+
decoder_dropout : float, optional
|
|
1957
|
+
Dropout rate for the decoder, by default 0.0.
|
|
1958
|
+
nonlinearity : Literal, optional
|
|
1959
|
+
Nonlinearity to use in the model, by default "ReLU".
|
|
1960
|
+
analytical_kl : bool, optional
|
|
1961
|
+
Whether to use analytical KL divergence, by default False.
|
|
1962
|
+
predict_logvar : Literal["pixelwise"] | None, optional
|
|
1963
|
+
Type of log-variance prediction, by default None.
|
|
1964
|
+
logvar_lowerbound : Union[float, None], optional
|
|
1965
|
+
Lower bound for the log variance, by default None.
|
|
1966
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
1967
|
+
Logger to use for training, by default "none".
|
|
1968
|
+
trainer_params : dict, optional
|
|
1969
|
+
Parameters for the trainer class, see PyTorch Lightning documentation.
|
|
1970
|
+
augmentations : list[Union[XYFlipModel, XYRandomRotate90Model]] | None, optional
|
|
1971
|
+
List of augmentations to apply, by default None.
|
|
1972
|
+
nm_paths : list[str] | None, optional
|
|
1973
|
+
Paths to the noise model files, by default None.
|
|
1974
|
+
data_stats : tuple[float, float] | None, optional
|
|
1975
|
+
Data statistics (mean, std), by default None.
|
|
1976
|
+
train_dataloader_params : dict[str, Any] | None, optional
|
|
1977
|
+
Parameters for the training dataloader, by default None.
|
|
1978
|
+
val_dataloader_params : dict[str, Any] | None, optional
|
|
1979
|
+
Parameters for the validation dataloader, by default None.
|
|
1980
|
+
|
|
1981
|
+
Returns
|
|
1982
|
+
-------
|
|
1983
|
+
Configuration
|
|
1984
|
+
A configuration object for the microsplit algorithm.
|
|
1985
|
+
|
|
1986
|
+
Examples
|
|
1987
|
+
--------
|
|
1988
|
+
Minimum example:
|
|
1989
|
+
# >>> config = create_microsplit_configuration(
|
|
1990
|
+
# ... experiment_name="microsplit_experiment",
|
|
1991
|
+
# ... data_type="array",
|
|
1992
|
+
# ... axes="YX",
|
|
1993
|
+
# ... patch_size=[64, 64],
|
|
1994
|
+
# ... batch_size=32,
|
|
1995
|
+
# ... num_epochs=100
|
|
1996
|
+
|
|
1997
|
+
# ... )
|
|
1998
|
+
|
|
1999
|
+
# You can also limit the number of batches per epoch:
|
|
2000
|
+
# >>> config = create_microsplit_configuration(
|
|
2001
|
+
# ... experiment_name="microsplit_experiment",
|
|
2002
|
+
# ... data_type="array",
|
|
2003
|
+
# ... axes="YX",
|
|
2004
|
+
# ... patch_size=[64, 64],
|
|
2005
|
+
# ... batch_size=32,
|
|
2006
|
+
# ... num_steps=100 # limit to 100 batches per epoch
|
|
2007
|
+
# ... )
|
|
2008
|
+
"""
|
|
2009
|
+
transform_list = _list_spatial_augmentations(augmentations)
|
|
2010
|
+
|
|
2011
|
+
loss_config = LVAELossConfig(
|
|
2012
|
+
loss_type="denoisplit_musplit", denoisplit_weight=0.9, musplit_weight=0.1
|
|
2013
|
+
) # TODO losses need to be refactored! just for example. Add validator if sum to 1
|
|
2014
|
+
|
|
2015
|
+
# Create likelihood configurations
|
|
2016
|
+
gaussian_likelihood_config, noise_model_config, nm_likelihood_config = (
|
|
2017
|
+
get_likelihood_config(
|
|
2018
|
+
loss_type="denoisplit_musplit",
|
|
2019
|
+
predict_logvar=predict_logvar,
|
|
2020
|
+
logvar_lowerbound=logvar_lowerbound,
|
|
2021
|
+
nm_paths=nm_paths,
|
|
2022
|
+
data_stats=data_stats,
|
|
2023
|
+
)
|
|
2024
|
+
)
|
|
2025
|
+
|
|
2026
|
+
# Create the LVAE model
|
|
2027
|
+
network_model = _create_vae_configuration(
|
|
2028
|
+
input_shape=patch_size,
|
|
2029
|
+
encoder_conv_strides=encoder_conv_strides,
|
|
2030
|
+
decoder_conv_strides=decoder_conv_strides,
|
|
2031
|
+
multiscale_count=multiscale_count,
|
|
2032
|
+
z_dims=z_dims,
|
|
2033
|
+
output_channels=output_channels,
|
|
2034
|
+
encoder_n_filters=encoder_n_filters,
|
|
2035
|
+
decoder_n_filters=decoder_n_filters,
|
|
2036
|
+
encoder_dropout=encoder_dropout,
|
|
2037
|
+
decoder_dropout=decoder_dropout,
|
|
2038
|
+
nonlinearity=nonlinearity,
|
|
2039
|
+
predict_logvar=predict_logvar,
|
|
2040
|
+
analytical_kl=analytical_kl,
|
|
2041
|
+
)
|
|
2042
|
+
|
|
2043
|
+
# Create the MicroSplit algorithm configuration
|
|
2044
|
+
algorithm_params = {
|
|
2045
|
+
"algorithm": "microsplit",
|
|
2046
|
+
"loss": loss_config,
|
|
2047
|
+
"model": network_model,
|
|
2048
|
+
"gaussian_likelihood": gaussian_likelihood_config,
|
|
2049
|
+
"noise_model": noise_model_config,
|
|
2050
|
+
"noise_model_likelihood": nm_likelihood_config,
|
|
2051
|
+
}
|
|
2052
|
+
|
|
2053
|
+
# Convert to MicroSplitAlgorithm instance
|
|
2054
|
+
algorithm_config = MicroSplitAlgorithm(**algorithm_params)
|
|
2055
|
+
|
|
2056
|
+
# data
|
|
2057
|
+
data_params = _create_microsplit_data_configuration(
|
|
2058
|
+
data_type=data_type,
|
|
2059
|
+
axes=axes,
|
|
2060
|
+
patch_size=patch_size,
|
|
2061
|
+
grid_size=grid_size,
|
|
2062
|
+
multiscale_count=multiscale_count,
|
|
2063
|
+
batch_size=batch_size,
|
|
2064
|
+
augmentations=transform_list,
|
|
2065
|
+
train_dataloader_params=train_dataloader_params,
|
|
2066
|
+
val_dataloader_params=val_dataloader_params,
|
|
2067
|
+
)
|
|
2068
|
+
|
|
2069
|
+
# Handle trainer parameters with num_epochs and num_steps
|
|
2070
|
+
final_trainer_params = {} if trainer_params is None else trainer_params.copy()
|
|
2071
|
+
|
|
2072
|
+
# Add num_epochs and num_steps if provided
|
|
2073
|
+
if num_epochs is not None:
|
|
2074
|
+
final_trainer_params["max_epochs"] = num_epochs
|
|
2075
|
+
if num_steps is not None:
|
|
2076
|
+
final_trainer_params["limit_train_batches"] = num_steps
|
|
2077
|
+
|
|
2078
|
+
# training
|
|
2079
|
+
training_params = _create_training_configuration(
|
|
2080
|
+
trainer_params=final_trainer_params,
|
|
2081
|
+
logger=logger,
|
|
2082
|
+
)
|
|
2083
|
+
|
|
2084
|
+
return Configuration(
|
|
2085
|
+
experiment_name=experiment_name,
|
|
2086
|
+
algorithm_config=algorithm_config,
|
|
2087
|
+
data_config=data_params,
|
|
2088
|
+
training_config=training_params,
|
|
2089
|
+
)
|