careamics 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (59) hide show
  1. careamics/careamist.py +6 -12
  2. careamics/cli/conf.py +18 -3
  3. careamics/config/__init__.py +8 -0
  4. careamics/config/algorithms/__init__.py +4 -0
  5. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  6. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  7. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  8. careamics/config/algorithms/vae_algorithm_model.py +51 -16
  9. careamics/config/architectures/lvae_model.py +12 -8
  10. careamics/config/callback_model.py +7 -3
  11. careamics/config/configuration.py +9 -8
  12. careamics/config/configuration_factories.py +843 -29
  13. careamics/config/data/data_model.py +1 -2
  14. careamics/config/data/ng_data_model.py +1 -2
  15. careamics/config/inference_model.py +1 -2
  16. careamics/config/likelihood_model.py +2 -2
  17. careamics/config/loss_model.py +6 -2
  18. careamics/config/nm_model.py +26 -1
  19. careamics/config/optimizer_models.py +1 -2
  20. careamics/config/support/supported_algorithms.py +5 -3
  21. careamics/config/support/supported_losses.py +5 -2
  22. careamics/config/training_model.py +6 -36
  23. careamics/config/transformations/normalize_model.py +1 -2
  24. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  25. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  26. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  27. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  28. careamics/file_io/read/__init__.py +0 -1
  29. careamics/lightning/__init__.py +16 -2
  30. careamics/lightning/callbacks/__init__.py +2 -0
  31. careamics/lightning/callbacks/data_stats_callback.py +23 -0
  32. careamics/lightning/lightning_module.py +161 -61
  33. careamics/lightning/microsplit_data_module.py +631 -0
  34. careamics/lightning/predict_data_module.py +8 -1
  35. careamics/lightning/train_data_module.py +19 -8
  36. careamics/losses/__init__.py +7 -1
  37. careamics/losses/loss_factory.py +9 -1
  38. careamics/losses/lvae/losses.py +85 -0
  39. careamics/lvae_training/dataset/__init__.py +8 -8
  40. careamics/lvae_training/dataset/config.py +56 -44
  41. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  42. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  43. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  44. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  45. careamics/model_io/bmz_io.py +9 -5
  46. careamics/models/lvae/likelihoods.py +30 -14
  47. careamics/models/lvae/lvae.py +2 -2
  48. careamics/models/lvae/noise_models.py +20 -14
  49. careamics/prediction_utils/__init__.py +8 -2
  50. careamics/prediction_utils/prediction_outputs.py +48 -3
  51. careamics/prediction_utils/stitch_prediction.py +71 -0
  52. careamics/transforms/xy_random_rotate90.py +1 -1
  53. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -15
  54. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/RECORD +57 -55
  55. careamics/dataset/zarr_dataset.py +0 -151
  56. careamics/file_io/read/zarr.py +0 -60
  57. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
  58. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
  59. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
@@ -5,9 +5,20 @@ from typing import Annotated, Any, Literal, Union
5
5
 
6
6
  from pydantic import Field, TypeAdapter
7
7
 
8
- from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
9
- from careamics.config.architectures import UNetModel
8
+ from careamics.config.algorithms import (
9
+ CAREAlgorithm,
10
+ MicroSplitAlgorithm,
11
+ N2NAlgorithm,
12
+ N2VAlgorithm,
13
+ )
14
+ from careamics.config.architectures import LVAEModel, UNetModel
10
15
  from careamics.config.data import DataConfig, NGDataConfig
16
+ from careamics.config.likelihood_model import (
17
+ GaussianLikelihoodConfig,
18
+ NMLikelihoodConfig,
19
+ )
20
+ from careamics.config.loss_model import LVAELossConfig
21
+ from careamics.config.nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
11
22
  from careamics.config.support import (
12
23
  SupportedArchitecture,
13
24
  SupportedPixelManipulation,
@@ -20,6 +31,7 @@ from careamics.config.transformations import (
20
31
  XYFlipModel,
21
32
  XYRandomRotate90Model,
22
33
  )
34
+ from careamics.lvae_training.dataset.config import MicroSplitDataConfig
23
35
 
24
36
  from .configuration import Configuration
25
37
 
@@ -224,7 +236,7 @@ def _create_algorithm_configuration(
224
236
  def _create_data_configuration(
225
237
  data_type: Literal["array", "tiff", "czi", "custom"],
226
238
  axes: str,
227
- patch_size: list[int],
239
+ patch_size: Sequence[int],
228
240
  batch_size: int,
229
241
  augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
230
242
  train_dataloader_params: dict[str, Any] | None = None,
@@ -277,6 +289,66 @@ def _create_data_configuration(
277
289
  return DataConfig(**data)
278
290
 
279
291
 
292
+ def _create_microsplit_data_configuration(
293
+ data_type: Literal["array", "tiff", "custom"],
294
+ axes: str,
295
+ patch_size: Sequence[int],
296
+ grid_size: int,
297
+ multiscale_count: int,
298
+ batch_size: int,
299
+ augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
300
+ train_dataloader_params: dict[str, Any] | None = None,
301
+ val_dataloader_params: dict[str, Any] | None = None,
302
+ ) -> DataConfig:
303
+ """
304
+ Create a dictionary with the parameters of the data model.
305
+
306
+ Parameters
307
+ ----------
308
+ data_type : {"array", "tiff", "czi", "custom"}
309
+ Type of the data.
310
+ axes : str
311
+ Axes of the data.
312
+ patch_size : list of int
313
+ Size of the patches along the spatial dimensions.
314
+ batch_size : int
315
+ Batch size.
316
+ augmentations : list of transforms
317
+ List of transforms to apply.
318
+ train_dataloader_params : dict
319
+ Parameters for the training dataloader, see PyTorch notes, by default None.
320
+ val_dataloader_params : dict
321
+ Parameters for the validation dataloader, see PyTorch notes, by default None.
322
+
323
+ Returns
324
+ -------
325
+ DataConfig
326
+ Data model with the specified parameters.
327
+ """
328
+ # data model
329
+ data = {
330
+ "data_type": data_type,
331
+ "axes": axes,
332
+ "image_size": patch_size,
333
+ "grid_size": grid_size,
334
+ "multiscale_lowres_count": multiscale_count,
335
+ "batch_size": batch_size,
336
+ "transforms": augmentations,
337
+ }
338
+ # Don't override defaults set in DataConfig class
339
+ if train_dataloader_params is not None:
340
+ # DataConfig enforces the presence of `shuffle` key in the dataloader parameters
341
+ if "shuffle" not in train_dataloader_params:
342
+ train_dataloader_params["shuffle"] = True
343
+
344
+ data["train_dataloader_params"] = train_dataloader_params
345
+
346
+ if val_dataloader_params is not None:
347
+ data["val_dataloader_params"] = val_dataloader_params
348
+
349
+ return MicroSplitDataConfig(**data)
350
+
351
+
280
352
  def _create_ng_data_configuration(
281
353
  data_type: Literal["array", "tiff", "custom"],
282
354
  axes: str,
@@ -357,7 +429,7 @@ def _create_ng_data_configuration(
357
429
 
358
430
 
359
431
  def _create_training_configuration(
360
- num_epochs: int,
432
+ trainer_params: dict,
361
433
  logger: Literal["wandb", "tensorboard", "none"],
362
434
  checkpoint_params: dict[str, Any] | None = None,
363
435
  ) -> TrainingConfig:
@@ -366,8 +438,8 @@ def _create_training_configuration(
366
438
 
367
439
  Parameters
368
440
  ----------
369
- num_epochs : int
370
- Number of epochs.
441
+ trainer_params : dict
442
+ Parameters for Lightning Trainer class, see PyTorch Lightning documentation.
371
443
  logger : {"wandb", "tensorboard", "none"}
372
444
  Logger to use.
373
445
  checkpoint_params : dict, default=None
@@ -380,7 +452,7 @@ def _create_training_configuration(
380
452
  Training model with the specified parameters.
381
453
  """
382
454
  return TrainingConfig(
383
- num_epochs=num_epochs,
455
+ lightning_trainer_config=trainer_params,
384
456
  logger=None if logger == "none" else logger,
385
457
  checkpoint_callback={} if checkpoint_params is None else checkpoint_params,
386
458
  )
@@ -392,9 +464,9 @@ def _create_supervised_config_dict(
392
464
  experiment_name: str,
393
465
  data_type: Literal["array", "tiff", "czi", "custom"],
394
466
  axes: str,
395
- patch_size: list[int],
467
+ patch_size: Sequence[int],
396
468
  batch_size: int,
397
- num_epochs: int,
469
+ trainer_params: dict | None = None,
398
470
  augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
399
471
  independent_channels: bool = True,
400
472
  loss: Literal["mae", "mse"] = "mae",
@@ -409,6 +481,8 @@ def _create_supervised_config_dict(
409
481
  train_dataloader_params: dict[str, Any] | None = None,
410
482
  val_dataloader_params: dict[str, Any] | None = None,
411
483
  checkpoint_params: dict[str, Any] | None = None,
484
+ num_epochs: int | None = None,
485
+ num_steps: int | None = None,
412
486
  ) -> dict:
413
487
  """
414
488
  Create a configuration for training CARE or Noise2Noise.
@@ -427,8 +501,8 @@ def _create_supervised_config_dict(
427
501
  Size of the patches along the spatial dimensions (e.g. [64, 64]).
428
502
  batch_size : int
429
503
  Batch size.
430
- num_epochs : int
431
- Number of epochs.
504
+ trainer_params : dict
505
+ Parameters for the training configuration.
432
506
  augmentations : list of transforms, default=None
433
507
  List of transforms to apply, either both or one of XYFlipModel and
434
508
  XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
@@ -461,6 +535,13 @@ def _create_supervised_config_dict(
461
535
  checkpoint_params : dict, default=None
462
536
  Parameters for the checkpoint callback, see PyTorch Lightning documentation
463
537
  (`ModelCheckpoint`) for the list of available parameters.
538
+ num_epochs : int or None, default=None
539
+ Number of epochs to train for. If provided, this will be added to
540
+ trainer_params.
541
+ num_steps : int or None, default=None
542
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
543
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
544
+ documentation for more details.
464
545
 
465
546
  Returns
466
547
  -------
@@ -518,9 +599,18 @@ def _create_supervised_config_dict(
518
599
  val_dataloader_params=val_dataloader_params,
519
600
  )
520
601
 
602
+ # Handle trainer parameters with num_epochs and num_steps
603
+ final_trainer_params = {} if trainer_params is None else trainer_params.copy()
604
+
605
+ # Add num_epochs and num_steps if provided
606
+ if num_epochs is not None:
607
+ final_trainer_params["max_epochs"] = num_epochs
608
+ if num_steps is not None:
609
+ final_trainer_params["limit_train_batches"] = num_steps
610
+
521
611
  # training
522
612
  training_params = _create_training_configuration(
523
- num_epochs=num_epochs,
613
+ trainer_params=final_trainer_params,
524
614
  logger=logger,
525
615
  checkpoint_params=checkpoint_params,
526
616
  )
@@ -537,15 +627,17 @@ def create_care_configuration(
537
627
  experiment_name: str,
538
628
  data_type: Literal["array", "tiff", "czi", "custom"],
539
629
  axes: str,
540
- patch_size: list[int],
630
+ patch_size: Sequence[int],
541
631
  batch_size: int,
542
- num_epochs: int,
632
+ num_epochs: int = 100,
633
+ num_steps: int | None = None,
543
634
  augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
544
635
  independent_channels: bool = True,
545
636
  loss: Literal["mae", "mse"] = "mae",
546
637
  n_channels_in: int | None = None,
547
638
  n_channels_out: int | None = None,
548
639
  logger: Literal["wandb", "tensorboard", "none"] = "none",
640
+ trainer_params: dict | None = None,
549
641
  model_params: dict | None = None,
550
642
  optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
551
643
  optimizer_params: dict[str, Any] | None = None,
@@ -588,8 +680,13 @@ def create_care_configuration(
588
680
  Size of the patches along the spatial dimensions (e.g. [64, 64]).
589
681
  batch_size : int
590
682
  Batch size.
591
- num_epochs : int
592
- Number of epochs.
683
+ num_epochs : int, default=100
684
+ Number of epochs to train for. If provided, this will be added to
685
+ trainer_params.
686
+ num_steps : int, optional
687
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
688
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
689
+ documentation for more details.
593
690
  augmentations : list of transforms, default=None
594
691
  List of transforms to apply, either both or one of XYFlipModel and
595
692
  XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
@@ -604,6 +701,8 @@ def create_care_configuration(
604
701
  Number of channels out.
605
702
  logger : Literal["wandb", "tensorboard", "none"], default="none"
606
703
  Logger to use.
704
+ trainer_params : dict, optional
705
+ Parameters for the trainer class, see PyTorch Lightning documentation.
607
706
  model_params : dict, default=None
608
707
  UNetModel parameters.
609
708
  optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
@@ -644,6 +743,16 @@ def create_care_configuration(
644
743
  ... num_epochs=100
645
744
  ... )
646
745
 
746
+ You can also limit the number of batches per epoch:
747
+ >>> config = create_care_configuration(
748
+ ... experiment_name="care_experiment",
749
+ ... data_type="array",
750
+ ... axes="YX",
751
+ ... patch_size=[64, 64],
752
+ ... batch_size=32,
753
+ ... num_steps=100 # limit to 100 batches per epoch
754
+ ... )
755
+
647
756
  To disable transforms, simply set `augmentations` to an empty list:
648
757
  >>> config = create_care_configuration(
649
758
  ... experiment_name="care_experiment",
@@ -730,13 +839,13 @@ def create_care_configuration(
730
839
  axes=axes,
731
840
  patch_size=patch_size,
732
841
  batch_size=batch_size,
733
- num_epochs=num_epochs,
734
842
  augmentations=augmentations,
735
843
  independent_channels=independent_channels,
736
844
  loss=loss,
737
845
  n_channels_in=n_channels_in,
738
846
  n_channels_out=n_channels_out,
739
847
  logger=logger,
848
+ trainer_params=trainer_params,
740
849
  model_params=model_params,
741
850
  optimizer=optimizer,
742
851
  optimizer_params=optimizer_params,
@@ -745,6 +854,8 @@ def create_care_configuration(
745
854
  train_dataloader_params=train_dataloader_params,
746
855
  val_dataloader_params=val_dataloader_params,
747
856
  checkpoint_params=checkpoint_params,
857
+ num_epochs=num_epochs,
858
+ num_steps=num_steps,
748
859
  )
749
860
  )
750
861
 
@@ -753,15 +864,17 @@ def create_n2n_configuration(
753
864
  experiment_name: str,
754
865
  data_type: Literal["array", "tiff", "czi", "custom"],
755
866
  axes: str,
756
- patch_size: list[int],
867
+ patch_size: Sequence[int],
757
868
  batch_size: int,
758
- num_epochs: int,
869
+ num_epochs: int = 100,
870
+ num_steps: int | None = None,
759
871
  augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
760
872
  independent_channels: bool = True,
761
873
  loss: Literal["mae", "mse"] = "mae",
762
874
  n_channels_in: int | None = None,
763
875
  n_channels_out: int | None = None,
764
876
  logger: Literal["wandb", "tensorboard", "none"] = "none",
877
+ trainer_params: dict | None = None,
765
878
  model_params: dict | None = None,
766
879
  optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
767
880
  optimizer_params: dict[str, Any] | None = None,
@@ -804,8 +917,13 @@ def create_n2n_configuration(
804
917
  Size of the patches along the spatial dimensions (e.g. [64, 64]).
805
918
  batch_size : int
806
919
  Batch size.
807
- num_epochs : int
808
- Number of epochs.
920
+ num_epochs : int, default=100
921
+ Number of epochs to train for. If provided, this will be added to
922
+ trainer_params.
923
+ num_steps : int, optional
924
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
925
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
926
+ documentation for more details.
809
927
  augmentations : list of transforms, default=None
810
928
  List of transforms to apply, either both or one of XYFlipModel and
811
929
  XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
@@ -820,6 +938,8 @@ def create_n2n_configuration(
820
938
  Number of channels out.
821
939
  logger : Literal["wandb", "tensorboard", "none"], optional
822
940
  Logger to use, by default "none".
941
+ trainer_params : dict, optional
942
+ Parameters for the trainer class, see PyTorch Lightning documentation.
823
943
  model_params : dict, default=None
824
944
  UNetModel parameters.
825
945
  optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
@@ -860,6 +980,16 @@ def create_n2n_configuration(
860
980
  ... num_epochs=100
861
981
  ... )
862
982
 
983
+ You can also limit the number of batches per epoch:
984
+ >>> config = create_n2n_configuration(
985
+ ... experiment_name="n2n_experiment",
986
+ ... data_type="array",
987
+ ... axes="YX",
988
+ ... patch_size=[64, 64],
989
+ ... batch_size=32,
990
+ ... num_steps=100 # limit to 100 batches per epoch
991
+ ... )
992
+
863
993
  To disable transforms, simply set `augmentations` to an empty list:
864
994
  >>> config = create_n2n_configuration(
865
995
  ... experiment_name="n2n_experiment",
@@ -871,8 +1001,7 @@ def create_n2n_configuration(
871
1001
  ... augmentations=[]
872
1002
  ... )
873
1003
 
874
- A list of transforms can be passed to the `augmentations` parameter to replace the
875
- default augmentations:
1004
+ A list of transforms can be passed to the `augmentations` parameter:
876
1005
  >>> from careamics.config.transformations import XYFlipModel
877
1006
  >>> config = create_n2n_configuration(
878
1007
  ... experiment_name="n2n_experiment",
@@ -946,7 +1075,7 @@ def create_n2n_configuration(
946
1075
  axes=axes,
947
1076
  patch_size=patch_size,
948
1077
  batch_size=batch_size,
949
- num_epochs=num_epochs,
1078
+ trainer_params=trainer_params,
950
1079
  augmentations=augmentations,
951
1080
  independent_channels=independent_channels,
952
1081
  loss=loss,
@@ -961,6 +1090,8 @@ def create_n2n_configuration(
961
1090
  train_dataloader_params=train_dataloader_params,
962
1091
  val_dataloader_params=val_dataloader_params,
963
1092
  checkpoint_params=checkpoint_params,
1093
+ num_epochs=num_epochs,
1094
+ num_steps=num_steps,
964
1095
  )
965
1096
  )
966
1097
 
@@ -969,9 +1100,10 @@ def create_n2v_configuration(
969
1100
  experiment_name: str,
970
1101
  data_type: Literal["array", "tiff", "czi", "custom"],
971
1102
  axes: str,
972
- patch_size: list[int],
1103
+ patch_size: Sequence[int],
973
1104
  batch_size: int,
974
- num_epochs: int,
1105
+ num_epochs: int = 100,
1106
+ num_steps: int | None = None,
975
1107
  augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
976
1108
  independent_channels: bool = True,
977
1109
  use_n2v2: bool = False,
@@ -980,6 +1112,7 @@ def create_n2v_configuration(
980
1112
  masked_pixel_percentage: float = 0.2,
981
1113
  struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
982
1114
  struct_n2v_span: int = 5,
1115
+ trainer_params: dict | None = None,
983
1116
  logger: Literal["wandb", "tensorboard", "none"] = "none",
984
1117
  model_params: dict | None = None,
985
1118
  optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
@@ -1043,8 +1176,13 @@ def create_n2v_configuration(
1043
1176
  Size of the patches along the spatial dimensions (e.g. [64, 64]).
1044
1177
  batch_size : int
1045
1178
  Batch size.
1046
- num_epochs : int
1047
- Number of epochs.
1179
+ num_epochs : int, default=100
1180
+ Number of epochs to train for. If provided, this will be added to
1181
+ trainer_params.
1182
+ num_steps : int, optional
1183
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
1184
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
1185
+ documentation for more details.
1048
1186
  augmentations : list of transforms, default=None
1049
1187
  List of transforms to apply, either both or one of XYFlipModel and
1050
1188
  XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
@@ -1063,6 +1201,8 @@ def create_n2v_configuration(
1063
1201
  Axis along which to apply structN2V mask, by default "none".
1064
1202
  struct_n2v_span : int, optional
1065
1203
  Span of the structN2V mask, by default 5.
1204
+ trainer_params : dict, optional
1205
+ Parameters for the trainer, see the relevant documentation.
1066
1206
  logger : Literal["wandb", "tensorboard", "none"], optional
1067
1207
  Logger to use, by default "none".
1068
1208
  model_params : dict, default=None
@@ -1105,6 +1245,16 @@ def create_n2v_configuration(
1105
1245
  ... num_epochs=100
1106
1246
  ... )
1107
1247
 
1248
+ You can also limit the number of batches per epoch:
1249
+ >>> config = create_n2v_configuration(
1250
+ ... experiment_name="n2v_experiment",
1251
+ ... data_type="array",
1252
+ ... axes="YX",
1253
+ ... patch_size=[64, 64],
1254
+ ... batch_size=32,
1255
+ ... num_steps=100 # limit to 100 batches per epoch
1256
+ ... )
1257
+
1108
1258
  To disable transforms, simply set `augmentations` to an empty list:
1109
1259
  >>> config = create_n2v_configuration(
1110
1260
  ... experiment_name="n2v_experiment",
@@ -1261,8 +1411,17 @@ def create_n2v_configuration(
1261
1411
  )
1262
1412
 
1263
1413
  # training
1414
+ # Handle trainer parameters with num_epochs and nun_steps
1415
+ final_trainer_params = {} if trainer_params is None else trainer_params.copy()
1416
+
1417
+ # Add num_epochs and nun_steps if provided
1418
+ if num_epochs is not None:
1419
+ final_trainer_params["max_epochs"] = num_epochs
1420
+ if num_steps is not None:
1421
+ final_trainer_params["limit_train_batches"] = num_steps
1422
+
1264
1423
  training_params = _create_training_configuration(
1265
- num_epochs=num_epochs,
1424
+ trainer_params=final_trainer_params,
1266
1425
  logger=logger,
1267
1426
  checkpoint_params=checkpoint_params,
1268
1427
  )
@@ -1273,3 +1432,658 @@ def create_n2v_configuration(
1273
1432
  data_config=data_params,
1274
1433
  training_config=training_params,
1275
1434
  )
1435
+
1436
+
1437
+ def _create_vae_configuration(
1438
+ input_shape: Sequence[int],
1439
+ encoder_conv_strides: tuple[int, ...],
1440
+ decoder_conv_strides: tuple[int, ...],
1441
+ multiscale_count: int,
1442
+ z_dims: tuple[int, ...],
1443
+ output_channels: int,
1444
+ encoder_n_filters: int,
1445
+ decoder_n_filters: int,
1446
+ encoder_dropout: float,
1447
+ decoder_dropout: float,
1448
+ nonlinearity: Literal[
1449
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
1450
+ ],
1451
+ predict_logvar: Literal[None, "pixelwise"],
1452
+ analytical_kl: bool,
1453
+ ) -> LVAEModel:
1454
+ """Create a dictionary with the parameters of the vae based algorithm model.
1455
+
1456
+ Parameters
1457
+ ----------
1458
+ input_shape : tuple[int, ...]
1459
+ Shape of the input patch (Z, Y, X) or (Y, X) if the data is 2D.
1460
+ encoder_conv_strides : tuple[int, ...]
1461
+ Strides of the encoder convolutional layers, length also defines 2D or 3D.
1462
+ decoder_conv_strides : tuple[int, ...]
1463
+ Strides of the decoder convolutional layers, length also defines 2D or 3D.
1464
+ multiscale_count : int
1465
+ Number of lateral context layers, specific to MicroSplit.
1466
+ z_dims : tuple[int, ...]
1467
+ Number of hierarchies in the LVAE model.
1468
+ output_channels : int
1469
+ Number of output channels.
1470
+ encoder_n_filters : int
1471
+ Number of filters in the convolutional layers of the encoder.
1472
+ decoder_n_filters : int
1473
+ Number of filters in the convolutional layers of the decoder.
1474
+ encoder_dropout : float
1475
+ Dropout rate for the encoder.
1476
+ decoder_dropout : float
1477
+ Dropout rate for the decoder.
1478
+ nonlinearity : Literal
1479
+ Type of nonlinearity function to use.
1480
+ predict_logvar : Literal # TODO needs review
1481
+ _description_.
1482
+ analytical_kl : bool # TODO needs clarification
1483
+ _description_.
1484
+
1485
+ Returns
1486
+ -------
1487
+ LVAEModel
1488
+ LVAE model with the specified parameters.
1489
+ """
1490
+ return LVAEModel(
1491
+ architecture=SupportedArchitecture.LVAE.value,
1492
+ input_shape=input_shape,
1493
+ encoder_conv_strides=encoder_conv_strides,
1494
+ decoder_conv_strides=decoder_conv_strides,
1495
+ multiscale_count=multiscale_count,
1496
+ z_dims=z_dims,
1497
+ output_channels=output_channels,
1498
+ encoder_n_filters=encoder_n_filters,
1499
+ decoder_n_filters=decoder_n_filters,
1500
+ encoder_dropout=encoder_dropout,
1501
+ decoder_dropout=decoder_dropout,
1502
+ nonlinearity=nonlinearity,
1503
+ predict_logvar=predict_logvar,
1504
+ analytical_kl=analytical_kl,
1505
+ )
1506
+
1507
+
1508
+ def _create_vae_based_algorithm(
1509
+ algorithm: Literal["hdn", "microsplit"],
1510
+ loss: LVAELossConfig,
1511
+ input_shape: Sequence[int],
1512
+ encoder_conv_strides: tuple[int, ...],
1513
+ decoder_conv_strides: tuple[int, ...],
1514
+ multiscale_count: int,
1515
+ z_dims: tuple[int, ...],
1516
+ output_channels: int,
1517
+ encoder_n_filters: int,
1518
+ decoder_n_filters: int,
1519
+ encoder_dropout: float,
1520
+ decoder_dropout: float,
1521
+ nonlinearity: Literal[
1522
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
1523
+ ],
1524
+ predict_logvar: Literal[None, "pixelwise"],
1525
+ analytical_kl: bool,
1526
+ gaussian_likelihood: GaussianLikelihoodConfig | None = None,
1527
+ nm_likelihood: NMLikelihoodConfig | None = None,
1528
+ ) -> dict:
1529
+ """
1530
+ Create a dictionary with the parameters of the VAE-based algorithm model.
1531
+
1532
+ Parameters
1533
+ ----------
1534
+ algorithm : Literal["hdn"]
1535
+ The algorithm type.
1536
+ loss : Literal["hdn"]
1537
+ The loss function type.
1538
+ input_shape : tuple[int, ...]
1539
+ The shape of the input data.
1540
+ encoder_conv_strides : list[int]
1541
+ The strides of the encoder convolutional layers.
1542
+ decoder_conv_strides : list[int]
1543
+ The strides of the decoder convolutional layers.
1544
+ multiscale_count : int
1545
+ The number of multiscale layers.
1546
+ z_dims : list[int]
1547
+ The dimensions of the latent space.
1548
+ output_channels : int
1549
+ The number of output channels.
1550
+ encoder_n_filters : int
1551
+ The number of filters in the encoder.
1552
+ decoder_n_filters : int
1553
+ The number of filters in the decoder.
1554
+ encoder_dropout : float
1555
+ The dropout rate for the encoder.
1556
+ decoder_dropout : float
1557
+ The dropout rate for the decoder.
1558
+ nonlinearity : Literal
1559
+ The nonlinearity function to use.
1560
+ predict_logvar : Literal[None, "pixelwise"]
1561
+ The type of log variance prediction.
1562
+ analytical_kl : bool
1563
+ Whether to use analytical KL divergence.
1564
+ gaussian_likelihood : Optional[GaussianLikelihoodConfig], optional
1565
+ The Gaussian likelihood model, by default None.
1566
+ nm_likelihood : Optional[NMLikelihoodConfig], optional
1567
+ The noise model likelihood model, by default None.
1568
+
1569
+ Returns
1570
+ -------
1571
+ dict
1572
+ A dictionary with the parameters of the VAE-based algorithm model.
1573
+ """
1574
+ network_model = _create_vae_configuration(
1575
+ input_shape=input_shape,
1576
+ encoder_conv_strides=encoder_conv_strides,
1577
+ decoder_conv_strides=decoder_conv_strides,
1578
+ multiscale_count=multiscale_count,
1579
+ z_dims=z_dims,
1580
+ output_channels=output_channels,
1581
+ encoder_n_filters=encoder_n_filters,
1582
+ decoder_n_filters=decoder_n_filters,
1583
+ encoder_dropout=encoder_dropout,
1584
+ decoder_dropout=decoder_dropout,
1585
+ nonlinearity=nonlinearity,
1586
+ predict_logvar=predict_logvar,
1587
+ analytical_kl=analytical_kl,
1588
+ )
1589
+ assert gaussian_likelihood or nm_likelihood, "Likelihood model must be specified"
1590
+ return {
1591
+ "algorithm": algorithm,
1592
+ "loss": loss,
1593
+ "model": network_model,
1594
+ "gaussian_likelihood": gaussian_likelihood,
1595
+ "noise_model_likelihood": nm_likelihood,
1596
+ }
1597
+
1598
+
1599
+ def get_likelihood_config(
1600
+ loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"],
1601
+ # TODO remove different microsplit loss types, refac
1602
+ predict_logvar: Literal["pixelwise"] | None = None,
1603
+ logvar_lowerbound: float | None = -5.0,
1604
+ nm_paths: list[str] | None = None,
1605
+ data_stats: tuple[float, float] | None = None,
1606
+ ) -> tuple[
1607
+ GaussianLikelihoodConfig | None,
1608
+ MultiChannelNMConfig | None,
1609
+ NMLikelihoodConfig | None,
1610
+ ]:
1611
+ """Get the likelihood configuration for split models.
1612
+
1613
+ Parameters
1614
+ ----------
1615
+ loss_type : Literal["musplit", "denoisplit", "denoisplit_musplit"]
1616
+ The type of loss function to use.
1617
+ predict_logvar : Literal["pixelwise"] | None, optional
1618
+ Type of log variance prediction, by default None.
1619
+ Required when loss_type is "musplit" or "denoisplit_musplit".
1620
+ logvar_lowerbound : float | None, optional
1621
+ Lower bound for the log variance, by default -5.0.
1622
+ Used when loss_type is "musplit" or "denoisplit_musplit".
1623
+ nm_paths : list[str] | None, optional
1624
+ Paths to the noise model files, by default None.
1625
+ Required when loss_type is "denoisplit" or "denoisplit_musplit".
1626
+ data_stats : tuple[float, float] | None, optional
1627
+ Data statistics (mean, std), by default None.
1628
+ Required when loss_type is "denoisplit" or "denoisplit_musplit".
1629
+
1630
+ Returns
1631
+ -------
1632
+ tuple[GaussianLikelihoodConfig | None, MultiChannelNMConfig | None,
1633
+ NMLikelihoodConfig | None]
1634
+ A tuple containing the likelihood and noise model configurations for the
1635
+ specified loss type.
1636
+
1637
+ - GaussianLikelihoodConfig: Gaussian likelihood configuration for musplit losses
1638
+ - MultiChannelNMConfig: Multi-channel noise model configuration for denoisplit
1639
+ losses
1640
+ - NMLikelihoodConfig: Noise model likelihood configuration for denoisplit losses
1641
+
1642
+ Raises
1643
+ ------
1644
+ ValueError
1645
+ If required parameters are missing for the specified loss_type.
1646
+ """
1647
+ # gaussian likelihood
1648
+ if loss_type in ["musplit", "denoisplit_musplit"]:
1649
+ # if predict_logvar is None:
1650
+ # raise ValueError(f"predict_logvar is required for loss_type '{loss_type}'")
1651
+ # TODO validators should be in pydantic models
1652
+ gaussian_lik_config = GaussianLikelihoodConfig(
1653
+ predict_logvar=predict_logvar,
1654
+ logvar_lowerbound=logvar_lowerbound,
1655
+ )
1656
+ else:
1657
+ gaussian_lik_config = None
1658
+
1659
+ # noise model likelihood
1660
+ if loss_type in ["denoisplit", "denoisplit_musplit"]:
1661
+ # if nm_paths is None:
1662
+ # raise ValueError(f"nm_paths is required for loss_type '{loss_type}'")
1663
+ # if data_stats is None:
1664
+ # raise ValueError(f"data_stats is required for loss_type '{loss_type}'")
1665
+ # TODO validators should be in pydantic models
1666
+ gmm_list = []
1667
+ if nm_paths is not None:
1668
+ for NM_path in nm_paths:
1669
+ gmm_list.append(
1670
+ GaussianMixtureNMConfig(
1671
+ model_type="GaussianMixtureNoiseModel",
1672
+ path=NM_path,
1673
+ )
1674
+ )
1675
+ noise_model_config = MultiChannelNMConfig(noise_models=gmm_list)
1676
+ nm_lik_config = NMLikelihoodConfig() # TODO this config isn't needed probably
1677
+ else:
1678
+ noise_model_config = None
1679
+ nm_lik_config = None
1680
+
1681
+ return gaussian_lik_config, noise_model_config, nm_lik_config
1682
+
1683
+
1684
+ # TODO wrap parameters into model, loss etc
1685
+ # TODO refac likelihood configs to make it 1. Can it be done ?
1686
+ def create_hdn_configuration(
1687
+ experiment_name: str,
1688
+ data_type: Literal["array", "tiff", "custom"],
1689
+ axes: str,
1690
+ patch_size: Sequence[int],
1691
+ batch_size: int,
1692
+ num_epochs: int = 100,
1693
+ num_steps: int | None = None,
1694
+ encoder_conv_strides: tuple[int, ...] = (2, 2),
1695
+ decoder_conv_strides: tuple[int, ...] = (2, 2),
1696
+ multiscale_count: int = 1,
1697
+ z_dims: tuple[int, ...] = (128, 128),
1698
+ output_channels: int = 1,
1699
+ encoder_n_filters: int = 32,
1700
+ decoder_n_filters: int = 32,
1701
+ encoder_dropout: float = 0.0,
1702
+ decoder_dropout: float = 0.0,
1703
+ nonlinearity: Literal[
1704
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
1705
+ ] = "ReLU",
1706
+ analytical_kl: bool = False,
1707
+ predict_logvar: Literal["pixelwise"] | None = None,
1708
+ logvar_lowerbound: Union[float, None] = None,
1709
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
1710
+ trainer_params: dict | None = None,
1711
+ augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
1712
+ train_dataloader_params: dict[str, Any] | None = None,
1713
+ val_dataloader_params: dict[str, Any] | None = None,
1714
+ ) -> Configuration:
1715
+ """
1716
+ Create a configuration for training HDN.
1717
+
1718
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
1719
+ 2.
1720
+
1721
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
1722
+ channels. Likewise, if you set the number of channels, then "C" must be present in
1723
+ `axes`.
1724
+
1725
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
1726
+ not specified, it will be assumed to be equal to `n_channels_in`.
1727
+
1728
+ By default, all channels are trained independently. To train all channels together,
1729
+ set `independent_channels` to False.
1730
+
1731
+ By setting `augmentations` to `None`, the default transformations (flip in X and Y,
1732
+ rotations by 90 degrees in the XY plane) are applied. Rather than the default
1733
+ transforms, a list of transforms can be passed to the `augmentations` parameter. To
1734
+ disable the transforms, simply pass an empty list.
1735
+
1736
+ # TODO revisit the necessity of model_params
1737
+
1738
+ Parameters
1739
+ ----------
1740
+ experiment_name : str
1741
+ Name of the experiment.
1742
+ data_type : Literal["array", "tiff", "custom"]
1743
+ Type of the data.
1744
+ axes : str
1745
+ Axes of the data (e.g. SYX).
1746
+ patch_size : List[int]
1747
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
1748
+ batch_size : int
1749
+ Batch size.
1750
+ num_epochs : int, default=100
1751
+ Number of epochs to train for. If provided, this will be added to
1752
+ trainer_params.
1753
+ num_steps : int, optional
1754
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
1755
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
1756
+ documentation for more details.
1757
+ encoder_conv_strides : tuple[int, ...], optional
1758
+ Strides for the encoder convolutional layers, by default (2, 2).
1759
+ decoder_conv_strides : tuple[int, ...], optional
1760
+ Strides for the decoder convolutional layers, by default (2, 2).
1761
+ multiscale_count : int, optional
1762
+ Number of scales in the multiscale architecture, by default 1.
1763
+ z_dims : tuple[int, ...], optional
1764
+ Dimensions of the latent space, by default (128, 128).
1765
+ output_channels : int, optional
1766
+ Number of output channels, by default 1.
1767
+ encoder_n_filters : int, optional
1768
+ Number of filters in the encoder, by default 32.
1769
+ decoder_n_filters : int, optional
1770
+ Number of filters in the decoder, by default 32.
1771
+ encoder_dropout : float, optional
1772
+ Dropout rate for the encoder, by default 0.0.
1773
+ decoder_dropout : float, optional
1774
+ Dropout rate for the decoder, by default 0.0.
1775
+ nonlinearity : Literal, optional
1776
+ Nonlinearity function to use, by default "ReLU".
1777
+ analytical_kl : bool, optional
1778
+ Whether to use analytical KL divergence, by default False.
1779
+ predict_logvar : Literal[None, "pixelwise"], optional
1780
+ Type of log variance prediction, by default None.
1781
+ logvar_lowerbound : Union[float, None], optional
1782
+ Lower bound for the log variance, by default None.
1783
+ logger : Literal["wandb", "tensorboard", "none"], optional
1784
+ Logger to use for training, by default "none".
1785
+ trainer_params : dict, optional
1786
+ Parameters for the trainer class, see PyTorch Lightning documentation.
1787
+ augmentations : Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]], optional
1788
+ List of augmentations to apply, by default None.
1789
+ train_dataloader_params : Optional[dict[str, Any]], optional
1790
+ Parameters for the training dataloader, by default None.
1791
+ val_dataloader_params : Optional[dict[str, Any]], optional
1792
+ Parameters for the validation dataloader, by default None.
1793
+
1794
+ Returns
1795
+ -------
1796
+ Configuration
1797
+ The configuration object for training HDN.
1798
+
1799
+ Examples
1800
+ --------
1801
+ Minimum example:
1802
+ >>> config = create_hdn_configuration(
1803
+ ... experiment_name="hdn_experiment",
1804
+ ... data_type="array",
1805
+ ... axes="YX",
1806
+ ... patch_size=[64, 64],
1807
+ ... batch_size=32,
1808
+ ... num_epochs=100
1809
+ ... )
1810
+
1811
+ You can also limit the number of batches per epoch:
1812
+ >>> config = create_hdn_configuration(
1813
+ ... experiment_name="hdn_experiment",
1814
+ ... data_type="array",
1815
+ ... axes="YX",
1816
+ ... patch_size=[64, 64],
1817
+ ... batch_size=32,
1818
+ ... num_steps=100 # limit to 100 batches per epoch
1819
+ ... )
1820
+ """
1821
+ transform_list = _list_spatial_augmentations(augmentations)
1822
+
1823
+ loss_config = LVAELossConfig(
1824
+ loss_type="hdn", denoisplit_weight=1, musplit_weight=0
1825
+ ) # TODO what are the correct defaults for HDN?
1826
+
1827
+ gaussian_likelihood = GaussianLikelihoodConfig(
1828
+ predict_logvar=predict_logvar, logvar_lowerbound=logvar_lowerbound
1829
+ )
1830
+
1831
+ # algorithm & model
1832
+ algorithm_params = _create_vae_based_algorithm(
1833
+ algorithm="hdn",
1834
+ loss=loss_config,
1835
+ input_shape=patch_size,
1836
+ encoder_conv_strides=encoder_conv_strides,
1837
+ decoder_conv_strides=decoder_conv_strides,
1838
+ multiscale_count=multiscale_count,
1839
+ z_dims=z_dims,
1840
+ output_channels=output_channels,
1841
+ encoder_n_filters=encoder_n_filters,
1842
+ decoder_n_filters=decoder_n_filters,
1843
+ encoder_dropout=encoder_dropout,
1844
+ decoder_dropout=decoder_dropout,
1845
+ nonlinearity=nonlinearity,
1846
+ predict_logvar=predict_logvar,
1847
+ analytical_kl=analytical_kl,
1848
+ gaussian_likelihood=gaussian_likelihood,
1849
+ nm_likelihood=None,
1850
+ )
1851
+
1852
+ # data
1853
+ data_params = _create_data_configuration(
1854
+ data_type=data_type,
1855
+ axes=axes,
1856
+ patch_size=patch_size,
1857
+ batch_size=batch_size,
1858
+ augmentations=transform_list,
1859
+ train_dataloader_params=train_dataloader_params,
1860
+ val_dataloader_params=val_dataloader_params,
1861
+ )
1862
+
1863
+ # Handle trainer parameters with num_epochs and num_steps
1864
+ final_trainer_params = {} if trainer_params is None else trainer_params.copy()
1865
+
1866
+ # Add num_epochs and num_steps if provided
1867
+ if num_epochs is not None:
1868
+ final_trainer_params["max_epochs"] = num_epochs
1869
+ if num_steps is not None:
1870
+ final_trainer_params["limit_train_batches"] = num_steps
1871
+
1872
+ # training
1873
+ training_params = _create_training_configuration(
1874
+ trainer_params=final_trainer_params,
1875
+ logger=logger,
1876
+ )
1877
+
1878
+ return Configuration(
1879
+ experiment_name=experiment_name,
1880
+ algorithm_config=algorithm_params,
1881
+ data_config=data_params,
1882
+ training_config=training_params,
1883
+ )
1884
+
1885
+
1886
+ def create_microsplit_configuration(
1887
+ experiment_name: str,
1888
+ data_type: Literal["array", "tiff", "custom"],
1889
+ axes: str,
1890
+ patch_size: Sequence[int],
1891
+ batch_size: int,
1892
+ num_epochs: int = 100,
1893
+ num_steps: int | None = None,
1894
+ encoder_conv_strides: tuple[int, ...] = (2, 2),
1895
+ decoder_conv_strides: tuple[int, ...] = (2, 2),
1896
+ multiscale_count: int = 3,
1897
+ grid_size: int = 32, # TODO most likely can be derived from patch size
1898
+ z_dims: tuple[int, ...] = (128, 128),
1899
+ output_channels: int = 1,
1900
+ encoder_n_filters: int = 32,
1901
+ decoder_n_filters: int = 32,
1902
+ encoder_dropout: float = 0.0,
1903
+ decoder_dropout: float = 0.0,
1904
+ nonlinearity: Literal[
1905
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
1906
+ ] = "ReLU",
1907
+ analytical_kl: bool = False,
1908
+ predict_logvar: Literal["pixelwise"] = "pixelwise",
1909
+ logvar_lowerbound: Union[float, None] = None,
1910
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
1911
+ trainer_params: dict | None = None,
1912
+ augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
1913
+ nm_paths: list[str] | None = None,
1914
+ data_stats: tuple[float, float] | None = None,
1915
+ train_dataloader_params: dict[str, Any] | None = None,
1916
+ val_dataloader_params: dict[str, Any] | None = None,
1917
+ ) -> Configuration:
1918
+ """
1919
+ Create a configuration for training MicroSplit.
1920
+
1921
+ Parameters
1922
+ ----------
1923
+ experiment_name : str
1924
+ Name of the experiment.
1925
+ data_type : Literal["array", "tiff", "custom"]
1926
+ Type of the data.
1927
+ axes : str
1928
+ Axes of the data (e.g. SYX).
1929
+ patch_size : Sequence[int]
1930
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
1931
+ batch_size : int
1932
+ Batch size.
1933
+ num_epochs : int, default=100
1934
+ Number of epochs to train for. If provided, this will be added to
1935
+ trainer_params.
1936
+ num_steps : int, optional
1937
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
1938
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
1939
+ documentation for more details.
1940
+ encoder_conv_strides : tuple[int, ...], optional
1941
+ Strides for the encoder convolutional layers, by default (2, 2).
1942
+ decoder_conv_strides : tuple[int, ...], optional
1943
+ Strides for the decoder convolutional layers, by default (2, 2).
1944
+ multiscale_count : int, optional
1945
+ Number of multiscale levels, by default 1.
1946
+ z_dims : tuple[int, ...], optional
1947
+ List of latent dimensions for each hierarchy level in the LVAE, by default (128, 128).
1948
+ output_channels : int, optional
1949
+ Number of output channels for the model, by default 1.
1950
+ encoder_n_filters : int, optional
1951
+ Number of filters in the encoder, by default 32.
1952
+ decoder_n_filters : int, optional
1953
+ Number of filters in the decoder, by default 32.
1954
+ encoder_dropout : float, optional
1955
+ Dropout rate for the encoder, by default 0.0.
1956
+ decoder_dropout : float, optional
1957
+ Dropout rate for the decoder, by default 0.0.
1958
+ nonlinearity : Literal, optional
1959
+ Nonlinearity to use in the model, by default "ReLU".
1960
+ analytical_kl : bool, optional
1961
+ Whether to use analytical KL divergence, by default False.
1962
+ predict_logvar : Literal["pixelwise"] | None, optional
1963
+ Type of log-variance prediction, by default None.
1964
+ logvar_lowerbound : Union[float, None], optional
1965
+ Lower bound for the log variance, by default None.
1966
+ logger : Literal["wandb", "tensorboard", "none"], optional
1967
+ Logger to use for training, by default "none".
1968
+ trainer_params : dict, optional
1969
+ Parameters for the trainer class, see PyTorch Lightning documentation.
1970
+ augmentations : list[Union[XYFlipModel, XYRandomRotate90Model]] | None, optional
1971
+ List of augmentations to apply, by default None.
1972
+ nm_paths : list[str] | None, optional
1973
+ Paths to the noise model files, by default None.
1974
+ data_stats : tuple[float, float] | None, optional
1975
+ Data statistics (mean, std), by default None.
1976
+ train_dataloader_params : dict[str, Any] | None, optional
1977
+ Parameters for the training dataloader, by default None.
1978
+ val_dataloader_params : dict[str, Any] | None, optional
1979
+ Parameters for the validation dataloader, by default None.
1980
+
1981
+ Returns
1982
+ -------
1983
+ Configuration
1984
+ A configuration object for the microsplit algorithm.
1985
+
1986
+ Examples
1987
+ --------
1988
+ Minimum example:
1989
+ # >>> config = create_microsplit_configuration(
1990
+ # ... experiment_name="microsplit_experiment",
1991
+ # ... data_type="array",
1992
+ # ... axes="YX",
1993
+ # ... patch_size=[64, 64],
1994
+ # ... batch_size=32,
1995
+ # ... num_epochs=100
1996
+
1997
+ # ... )
1998
+
1999
+ # You can also limit the number of batches per epoch:
2000
+ # >>> config = create_microsplit_configuration(
2001
+ # ... experiment_name="microsplit_experiment",
2002
+ # ... data_type="array",
2003
+ # ... axes="YX",
2004
+ # ... patch_size=[64, 64],
2005
+ # ... batch_size=32,
2006
+ # ... num_steps=100 # limit to 100 batches per epoch
2007
+ # ... )
2008
+ """
2009
+ transform_list = _list_spatial_augmentations(augmentations)
2010
+
2011
+ loss_config = LVAELossConfig(
2012
+ loss_type="denoisplit_musplit", denoisplit_weight=0.9, musplit_weight=0.1
2013
+ ) # TODO losses need to be refactored! just for example. Add validator if sum to 1
2014
+
2015
+ # Create likelihood configurations
2016
+ gaussian_likelihood_config, noise_model_config, nm_likelihood_config = (
2017
+ get_likelihood_config(
2018
+ loss_type="denoisplit_musplit",
2019
+ predict_logvar=predict_logvar,
2020
+ logvar_lowerbound=logvar_lowerbound,
2021
+ nm_paths=nm_paths,
2022
+ data_stats=data_stats,
2023
+ )
2024
+ )
2025
+
2026
+ # Create the LVAE model
2027
+ network_model = _create_vae_configuration(
2028
+ input_shape=patch_size,
2029
+ encoder_conv_strides=encoder_conv_strides,
2030
+ decoder_conv_strides=decoder_conv_strides,
2031
+ multiscale_count=multiscale_count,
2032
+ z_dims=z_dims,
2033
+ output_channels=output_channels,
2034
+ encoder_n_filters=encoder_n_filters,
2035
+ decoder_n_filters=decoder_n_filters,
2036
+ encoder_dropout=encoder_dropout,
2037
+ decoder_dropout=decoder_dropout,
2038
+ nonlinearity=nonlinearity,
2039
+ predict_logvar=predict_logvar,
2040
+ analytical_kl=analytical_kl,
2041
+ )
2042
+
2043
+ # Create the MicroSplit algorithm configuration
2044
+ algorithm_params = {
2045
+ "algorithm": "microsplit",
2046
+ "loss": loss_config,
2047
+ "model": network_model,
2048
+ "gaussian_likelihood": gaussian_likelihood_config,
2049
+ "noise_model": noise_model_config,
2050
+ "noise_model_likelihood": nm_likelihood_config,
2051
+ }
2052
+
2053
+ # Convert to MicroSplitAlgorithm instance
2054
+ algorithm_config = MicroSplitAlgorithm(**algorithm_params)
2055
+
2056
+ # data
2057
+ data_params = _create_microsplit_data_configuration(
2058
+ data_type=data_type,
2059
+ axes=axes,
2060
+ patch_size=patch_size,
2061
+ grid_size=grid_size,
2062
+ multiscale_count=multiscale_count,
2063
+ batch_size=batch_size,
2064
+ augmentations=transform_list,
2065
+ train_dataloader_params=train_dataloader_params,
2066
+ val_dataloader_params=val_dataloader_params,
2067
+ )
2068
+
2069
+ # Handle trainer parameters with num_epochs and num_steps
2070
+ final_trainer_params = {} if trainer_params is None else trainer_params.copy()
2071
+
2072
+ # Add num_epochs and num_steps if provided
2073
+ if num_epochs is not None:
2074
+ final_trainer_params["max_epochs"] = num_epochs
2075
+ if num_steps is not None:
2076
+ final_trainer_params["limit_train_batches"] = num_steps
2077
+
2078
+ # training
2079
+ training_params = _create_training_configuration(
2080
+ trainer_params=final_trainer_params,
2081
+ logger=logger,
2082
+ )
2083
+
2084
+ return Configuration(
2085
+ experiment_name=experiment_name,
2086
+ algorithm_config=algorithm_config,
2087
+ data_config=data_params,
2088
+ training_config=training_params,
2089
+ )