careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (79) hide show
  1. careamics/careamist.py +11 -14
  2. careamics/cli/conf.py +18 -3
  3. careamics/config/__init__.py +8 -0
  4. careamics/config/algorithms/__init__.py +4 -0
  5. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  6. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  7. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  8. careamics/config/algorithms/vae_algorithm_model.py +51 -16
  9. careamics/config/architectures/lvae_model.py +12 -8
  10. careamics/config/callback_model.py +7 -3
  11. careamics/config/configuration.py +15 -63
  12. careamics/config/configuration_factories.py +853 -29
  13. careamics/config/data/data_model.py +50 -11
  14. careamics/config/data/ng_data_model.py +168 -4
  15. careamics/config/data/patch_filter/__init__.py +15 -0
  16. careamics/config/data/patch_filter/filter_model.py +16 -0
  17. careamics/config/data/patch_filter/mask_filter_model.py +17 -0
  18. careamics/config/data/patch_filter/max_filter_model.py +15 -0
  19. careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
  20. careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
  21. careamics/config/inference_model.py +1 -2
  22. careamics/config/likelihood_model.py +2 -2
  23. careamics/config/loss_model.py +6 -2
  24. careamics/config/nm_model.py +26 -1
  25. careamics/config/optimizer_models.py +1 -2
  26. careamics/config/support/supported_algorithms.py +5 -3
  27. careamics/config/support/supported_filters.py +17 -0
  28. careamics/config/support/supported_losses.py +5 -2
  29. careamics/config/training_model.py +6 -36
  30. careamics/config/transformations/normalize_model.py +1 -2
  31. careamics/dataset_ng/dataset.py +57 -5
  32. careamics/dataset_ng/factory.py +101 -18
  33. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  34. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  35. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  36. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  37. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  38. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  39. careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
  40. careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
  41. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  42. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  43. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  44. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  45. careamics/file_io/read/__init__.py +0 -1
  46. careamics/lightning/__init__.py +16 -2
  47. careamics/lightning/callbacks/__init__.py +2 -0
  48. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  49. careamics/lightning/dataset_ng/data_module.py +79 -2
  50. careamics/lightning/lightning_module.py +162 -61
  51. careamics/lightning/microsplit_data_module.py +636 -0
  52. careamics/lightning/predict_data_module.py +8 -1
  53. careamics/lightning/train_data_module.py +19 -8
  54. careamics/losses/__init__.py +7 -1
  55. careamics/losses/loss_factory.py +9 -1
  56. careamics/losses/lvae/losses.py +85 -0
  57. careamics/lvae_training/dataset/__init__.py +8 -8
  58. careamics/lvae_training/dataset/config.py +56 -44
  59. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  60. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  61. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  62. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  63. careamics/lvae_training/eval_utils.py +46 -24
  64. careamics/model_io/bmz_io.py +9 -5
  65. careamics/models/lvae/likelihoods.py +31 -14
  66. careamics/models/lvae/lvae.py +2 -2
  67. careamics/models/lvae/noise_models.py +20 -14
  68. careamics/prediction_utils/__init__.py +8 -2
  69. careamics/prediction_utils/prediction_outputs.py +49 -3
  70. careamics/prediction_utils/stitch_prediction.py +83 -1
  71. careamics/transforms/xy_random_rotate90.py +1 -1
  72. careamics/utils/version.py +4 -4
  73. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
  74. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
  75. careamics/dataset/zarr_dataset.py +0 -151
  76. careamics/file_io/read/zarr.py +0 -60
  77. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
  78. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
  79. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
@@ -5,9 +5,20 @@ from typing import Annotated, Any, Literal, Union
5
5
 
6
6
  from pydantic import Field, TypeAdapter
7
7
 
8
- from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
9
- from careamics.config.architectures import UNetModel
8
+ from careamics.config.algorithms import (
9
+ CAREAlgorithm,
10
+ MicroSplitAlgorithm,
11
+ N2NAlgorithm,
12
+ N2VAlgorithm,
13
+ )
14
+ from careamics.config.architectures import LVAEModel, UNetModel
10
15
  from careamics.config.data import DataConfig, NGDataConfig
16
+ from careamics.config.likelihood_model import (
17
+ GaussianLikelihoodConfig,
18
+ NMLikelihoodConfig,
19
+ )
20
+ from careamics.config.loss_model import LVAELossConfig
21
+ from careamics.config.nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
11
22
  from careamics.config.support import (
12
23
  SupportedArchitecture,
13
24
  SupportedPixelManipulation,
@@ -20,6 +31,7 @@ from careamics.config.transformations import (
20
31
  XYFlipModel,
21
32
  XYRandomRotate90Model,
22
33
  )
34
+ from careamics.lvae_training.dataset.config import MicroSplitDataConfig
23
35
 
24
36
  from .configuration import Configuration
25
37
 
@@ -224,7 +236,7 @@ def _create_algorithm_configuration(
224
236
  def _create_data_configuration(
225
237
  data_type: Literal["array", "tiff", "czi", "custom"],
226
238
  axes: str,
227
- patch_size: list[int],
239
+ patch_size: Sequence[int],
228
240
  batch_size: int,
229
241
  augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
230
242
  train_dataloader_params: dict[str, Any] | None = None,
@@ -277,6 +289,70 @@ def _create_data_configuration(
277
289
  return DataConfig(**data)
278
290
 
279
291
 
292
+ def _create_microsplit_data_configuration(
293
+ data_type: Literal["array", "tiff", "custom"],
294
+ axes: str,
295
+ patch_size: Sequence[int],
296
+ grid_size: int,
297
+ multiscale_count: int,
298
+ batch_size: int,
299
+ augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
300
+ train_dataloader_params: dict[str, Any] | None = None,
301
+ val_dataloader_params: dict[str, Any] | None = None,
302
+ ) -> DataConfig:
303
+ """
304
+ Create a dictionary with the parameters of the data model.
305
+
306
+ Parameters
307
+ ----------
308
+ data_type : {"array", "tiff", "czi", "custom"}
309
+ Type of the data.
310
+ axes : str
311
+ Axes of the data.
312
+ patch_size : list of int
313
+ Size of the patches along the spatial dimensions.
314
+ grid_size : int
315
+ Grid size for patch extraction.
316
+ multiscale_count : int
317
+ Number of LC scales.
318
+ batch_size : int
319
+ Batch size.
320
+ augmentations : list of transforms
321
+ List of transforms to apply.
322
+ train_dataloader_params : dict
323
+ Parameters for the training dataloader, see PyTorch notes, by default None.
324
+ val_dataloader_params : dict
325
+ Parameters for the validation dataloader, see PyTorch notes, by default None.
326
+
327
+ Returns
328
+ -------
329
+ DataConfig
330
+ Data model with the specified parameters.
331
+ """
332
+ # data model
333
+ data = {
334
+ "data_type": data_type,
335
+ "axes": axes,
336
+ "image_size": patch_size,
337
+ "grid_size": grid_size,
338
+ "multiscale_lowres_count": multiscale_count,
339
+ "batch_size": batch_size,
340
+ "transforms": augmentations,
341
+ }
342
+ # Don't override defaults set in DataConfig class
343
+ if train_dataloader_params is not None:
344
+ # DataConfig enforces the presence of `shuffle` key in the dataloader parameters
345
+ if "shuffle" not in train_dataloader_params:
346
+ train_dataloader_params["shuffle"] = True
347
+
348
+ data["train_dataloader_params"] = train_dataloader_params
349
+
350
+ if val_dataloader_params is not None:
351
+ data["val_dataloader_params"] = val_dataloader_params
352
+
353
+ return MicroSplitDataConfig(**data)
354
+
355
+
280
356
  def _create_ng_data_configuration(
281
357
  data_type: Literal["array", "tiff", "custom"],
282
358
  axes: str,
@@ -357,7 +433,7 @@ def _create_ng_data_configuration(
357
433
 
358
434
 
359
435
  def _create_training_configuration(
360
- num_epochs: int,
436
+ trainer_params: dict,
361
437
  logger: Literal["wandb", "tensorboard", "none"],
362
438
  checkpoint_params: dict[str, Any] | None = None,
363
439
  ) -> TrainingConfig:
@@ -366,8 +442,8 @@ def _create_training_configuration(
366
442
 
367
443
  Parameters
368
444
  ----------
369
- num_epochs : int
370
- Number of epochs.
445
+ trainer_params : dict
446
+ Parameters for Lightning Trainer class, see PyTorch Lightning documentation.
371
447
  logger : {"wandb", "tensorboard", "none"}
372
448
  Logger to use.
373
449
  checkpoint_params : dict, default=None
@@ -380,7 +456,7 @@ def _create_training_configuration(
380
456
  Training model with the specified parameters.
381
457
  """
382
458
  return TrainingConfig(
383
- num_epochs=num_epochs,
459
+ lightning_trainer_config=trainer_params,
384
460
  logger=None if logger == "none" else logger,
385
461
  checkpoint_callback={} if checkpoint_params is None else checkpoint_params,
386
462
  )
@@ -392,9 +468,9 @@ def _create_supervised_config_dict(
392
468
  experiment_name: str,
393
469
  data_type: Literal["array", "tiff", "czi", "custom"],
394
470
  axes: str,
395
- patch_size: list[int],
471
+ patch_size: Sequence[int],
396
472
  batch_size: int,
397
- num_epochs: int,
473
+ trainer_params: dict | None = None,
398
474
  augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
399
475
  independent_channels: bool = True,
400
476
  loss: Literal["mae", "mse"] = "mae",
@@ -409,6 +485,8 @@ def _create_supervised_config_dict(
409
485
  train_dataloader_params: dict[str, Any] | None = None,
410
486
  val_dataloader_params: dict[str, Any] | None = None,
411
487
  checkpoint_params: dict[str, Any] | None = None,
488
+ num_epochs: int | None = None,
489
+ num_steps: int | None = None,
412
490
  ) -> dict:
413
491
  """
414
492
  Create a configuration for training CARE or Noise2Noise.
@@ -427,8 +505,8 @@ def _create_supervised_config_dict(
427
505
  Size of the patches along the spatial dimensions (e.g. [64, 64]).
428
506
  batch_size : int
429
507
  Batch size.
430
- num_epochs : int
431
- Number of epochs.
508
+ trainer_params : dict
509
+ Parameters for the training configuration.
432
510
  augmentations : list of transforms, default=None
433
511
  List of transforms to apply, either both or one of XYFlipModel and
434
512
  XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
@@ -461,6 +539,13 @@ def _create_supervised_config_dict(
461
539
  checkpoint_params : dict, default=None
462
540
  Parameters for the checkpoint callback, see PyTorch Lightning documentation
463
541
  (`ModelCheckpoint`) for the list of available parameters.
542
+ num_epochs : int or None, default=None
543
+ Number of epochs to train for. If provided, this will be added to
544
+ trainer_params.
545
+ num_steps : int or None, default=None
546
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
547
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
548
+ documentation for more details.
464
549
 
465
550
  Returns
466
551
  -------
@@ -518,9 +603,18 @@ def _create_supervised_config_dict(
518
603
  val_dataloader_params=val_dataloader_params,
519
604
  )
520
605
 
606
+ # Handle trainer parameters with num_epochs and num_steps
607
+ final_trainer_params = {} if trainer_params is None else trainer_params.copy()
608
+
609
+ # Add num_epochs and num_steps if provided
610
+ if num_epochs is not None:
611
+ final_trainer_params["max_epochs"] = num_epochs
612
+ if num_steps is not None:
613
+ final_trainer_params["limit_train_batches"] = num_steps
614
+
521
615
  # training
522
616
  training_params = _create_training_configuration(
523
- num_epochs=num_epochs,
617
+ trainer_params=final_trainer_params,
524
618
  logger=logger,
525
619
  checkpoint_params=checkpoint_params,
526
620
  )
@@ -537,15 +631,17 @@ def create_care_configuration(
537
631
  experiment_name: str,
538
632
  data_type: Literal["array", "tiff", "czi", "custom"],
539
633
  axes: str,
540
- patch_size: list[int],
634
+ patch_size: Sequence[int],
541
635
  batch_size: int,
542
- num_epochs: int,
636
+ num_epochs: int = 100,
637
+ num_steps: int | None = None,
543
638
  augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
544
639
  independent_channels: bool = True,
545
640
  loss: Literal["mae", "mse"] = "mae",
546
641
  n_channels_in: int | None = None,
547
642
  n_channels_out: int | None = None,
548
643
  logger: Literal["wandb", "tensorboard", "none"] = "none",
644
+ trainer_params: dict | None = None,
549
645
  model_params: dict | None = None,
550
646
  optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
551
647
  optimizer_params: dict[str, Any] | None = None,
@@ -588,8 +684,13 @@ def create_care_configuration(
588
684
  Size of the patches along the spatial dimensions (e.g. [64, 64]).
589
685
  batch_size : int
590
686
  Batch size.
591
- num_epochs : int
592
- Number of epochs.
687
+ num_epochs : int, default=100
688
+ Number of epochs to train for. If provided, this will be added to
689
+ trainer_params.
690
+ num_steps : int, optional
691
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
692
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
693
+ documentation for more details.
593
694
  augmentations : list of transforms, default=None
594
695
  List of transforms to apply, either both or one of XYFlipModel and
595
696
  XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
@@ -604,6 +705,8 @@ def create_care_configuration(
604
705
  Number of channels out.
605
706
  logger : Literal["wandb", "tensorboard", "none"], default="none"
606
707
  Logger to use.
708
+ trainer_params : dict, optional
709
+ Parameters for the trainer class, see PyTorch Lightning documentation.
607
710
  model_params : dict, default=None
608
711
  UNetModel parameters.
609
712
  optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
@@ -644,6 +747,16 @@ def create_care_configuration(
644
747
  ... num_epochs=100
645
748
  ... )
646
749
 
750
+ You can also limit the number of batches per epoch:
751
+ >>> config = create_care_configuration(
752
+ ... experiment_name="care_experiment",
753
+ ... data_type="array",
754
+ ... axes="YX",
755
+ ... patch_size=[64, 64],
756
+ ... batch_size=32,
757
+ ... num_steps=100 # limit to 100 batches per epoch
758
+ ... )
759
+
647
760
  To disable transforms, simply set `augmentations` to an empty list:
648
761
  >>> config = create_care_configuration(
649
762
  ... experiment_name="care_experiment",
@@ -730,13 +843,13 @@ def create_care_configuration(
730
843
  axes=axes,
731
844
  patch_size=patch_size,
732
845
  batch_size=batch_size,
733
- num_epochs=num_epochs,
734
846
  augmentations=augmentations,
735
847
  independent_channels=independent_channels,
736
848
  loss=loss,
737
849
  n_channels_in=n_channels_in,
738
850
  n_channels_out=n_channels_out,
739
851
  logger=logger,
852
+ trainer_params=trainer_params,
740
853
  model_params=model_params,
741
854
  optimizer=optimizer,
742
855
  optimizer_params=optimizer_params,
@@ -745,6 +858,8 @@ def create_care_configuration(
745
858
  train_dataloader_params=train_dataloader_params,
746
859
  val_dataloader_params=val_dataloader_params,
747
860
  checkpoint_params=checkpoint_params,
861
+ num_epochs=num_epochs,
862
+ num_steps=num_steps,
748
863
  )
749
864
  )
750
865
 
@@ -753,15 +868,17 @@ def create_n2n_configuration(
753
868
  experiment_name: str,
754
869
  data_type: Literal["array", "tiff", "czi", "custom"],
755
870
  axes: str,
756
- patch_size: list[int],
871
+ patch_size: Sequence[int],
757
872
  batch_size: int,
758
- num_epochs: int,
873
+ num_epochs: int = 100,
874
+ num_steps: int | None = None,
759
875
  augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
760
876
  independent_channels: bool = True,
761
877
  loss: Literal["mae", "mse"] = "mae",
762
878
  n_channels_in: int | None = None,
763
879
  n_channels_out: int | None = None,
764
880
  logger: Literal["wandb", "tensorboard", "none"] = "none",
881
+ trainer_params: dict | None = None,
765
882
  model_params: dict | None = None,
766
883
  optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
767
884
  optimizer_params: dict[str, Any] | None = None,
@@ -804,8 +921,13 @@ def create_n2n_configuration(
804
921
  Size of the patches along the spatial dimensions (e.g. [64, 64]).
805
922
  batch_size : int
806
923
  Batch size.
807
- num_epochs : int
808
- Number of epochs.
924
+ num_epochs : int, default=100
925
+ Number of epochs to train for. If provided, this will be added to
926
+ trainer_params.
927
+ num_steps : int, optional
928
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
929
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
930
+ documentation for more details.
809
931
  augmentations : list of transforms, default=None
810
932
  List of transforms to apply, either both or one of XYFlipModel and
811
933
  XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
@@ -820,6 +942,8 @@ def create_n2n_configuration(
820
942
  Number of channels out.
821
943
  logger : Literal["wandb", "tensorboard", "none"], optional
822
944
  Logger to use, by default "none".
945
+ trainer_params : dict, optional
946
+ Parameters for the trainer class, see PyTorch Lightning documentation.
823
947
  model_params : dict, default=None
824
948
  UNetModel parameters.
825
949
  optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
@@ -860,6 +984,16 @@ def create_n2n_configuration(
860
984
  ... num_epochs=100
861
985
  ... )
862
986
 
987
+ You can also limit the number of batches per epoch:
988
+ >>> config = create_n2n_configuration(
989
+ ... experiment_name="n2n_experiment",
990
+ ... data_type="array",
991
+ ... axes="YX",
992
+ ... patch_size=[64, 64],
993
+ ... batch_size=32,
994
+ ... num_steps=100 # limit to 100 batches per epoch
995
+ ... )
996
+
863
997
  To disable transforms, simply set `augmentations` to an empty list:
864
998
  >>> config = create_n2n_configuration(
865
999
  ... experiment_name="n2n_experiment",
@@ -871,8 +1005,7 @@ def create_n2n_configuration(
871
1005
  ... augmentations=[]
872
1006
  ... )
873
1007
 
874
- A list of transforms can be passed to the `augmentations` parameter to replace the
875
- default augmentations:
1008
+ A list of transforms can be passed to the `augmentations` parameter:
876
1009
  >>> from careamics.config.transformations import XYFlipModel
877
1010
  >>> config = create_n2n_configuration(
878
1011
  ... experiment_name="n2n_experiment",
@@ -946,7 +1079,7 @@ def create_n2n_configuration(
946
1079
  axes=axes,
947
1080
  patch_size=patch_size,
948
1081
  batch_size=batch_size,
949
- num_epochs=num_epochs,
1082
+ trainer_params=trainer_params,
950
1083
  augmentations=augmentations,
951
1084
  independent_channels=independent_channels,
952
1085
  loss=loss,
@@ -961,6 +1094,8 @@ def create_n2n_configuration(
961
1094
  train_dataloader_params=train_dataloader_params,
962
1095
  val_dataloader_params=val_dataloader_params,
963
1096
  checkpoint_params=checkpoint_params,
1097
+ num_epochs=num_epochs,
1098
+ num_steps=num_steps,
964
1099
  )
965
1100
  )
966
1101
 
@@ -969,9 +1104,10 @@ def create_n2v_configuration(
969
1104
  experiment_name: str,
970
1105
  data_type: Literal["array", "tiff", "czi", "custom"],
971
1106
  axes: str,
972
- patch_size: list[int],
1107
+ patch_size: Sequence[int],
973
1108
  batch_size: int,
974
- num_epochs: int,
1109
+ num_epochs: int = 100,
1110
+ num_steps: int | None = None,
975
1111
  augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
976
1112
  independent_channels: bool = True,
977
1113
  use_n2v2: bool = False,
@@ -980,6 +1116,7 @@ def create_n2v_configuration(
980
1116
  masked_pixel_percentage: float = 0.2,
981
1117
  struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
982
1118
  struct_n2v_span: int = 5,
1119
+ trainer_params: dict | None = None,
983
1120
  logger: Literal["wandb", "tensorboard", "none"] = "none",
984
1121
  model_params: dict | None = None,
985
1122
  optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
@@ -1043,8 +1180,13 @@ def create_n2v_configuration(
1043
1180
  Size of the patches along the spatial dimensions (e.g. [64, 64]).
1044
1181
  batch_size : int
1045
1182
  Batch size.
1046
- num_epochs : int
1047
- Number of epochs.
1183
+ num_epochs : int, default=100
1184
+ Number of epochs to train for. If provided, this will be added to
1185
+ trainer_params.
1186
+ num_steps : int, optional
1187
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
1188
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
1189
+ documentation for more details.
1048
1190
  augmentations : list of transforms, default=None
1049
1191
  List of transforms to apply, either both or one of XYFlipModel and
1050
1192
  XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
@@ -1063,6 +1205,8 @@ def create_n2v_configuration(
1063
1205
  Axis along which to apply structN2V mask, by default "none".
1064
1206
  struct_n2v_span : int, optional
1065
1207
  Span of the structN2V mask, by default 5.
1208
+ trainer_params : dict, optional
1209
+ Parameters for the trainer, see the relevant documentation.
1066
1210
  logger : Literal["wandb", "tensorboard", "none"], optional
1067
1211
  Logger to use, by default "none".
1068
1212
  model_params : dict, default=None
@@ -1105,6 +1249,16 @@ def create_n2v_configuration(
1105
1249
  ... num_epochs=100
1106
1250
  ... )
1107
1251
 
1252
+ You can also limit the number of batches per epoch:
1253
+ >>> config = create_n2v_configuration(
1254
+ ... experiment_name="n2v_experiment",
1255
+ ... data_type="array",
1256
+ ... axes="YX",
1257
+ ... patch_size=[64, 64],
1258
+ ... batch_size=32,
1259
+ ... num_steps=100 # limit to 100 batches per epoch
1260
+ ... )
1261
+
1108
1262
  To disable transforms, simply set `augmentations` to an empty list:
1109
1263
  >>> config = create_n2v_configuration(
1110
1264
  ... experiment_name="n2v_experiment",
@@ -1261,8 +1415,17 @@ def create_n2v_configuration(
1261
1415
  )
1262
1416
 
1263
1417
  # training
1418
+ # Handle trainer parameters with num_epochs and nun_steps
1419
+ final_trainer_params = {} if trainer_params is None else trainer_params.copy()
1420
+
1421
+ # Add num_epochs and nun_steps if provided
1422
+ if num_epochs is not None:
1423
+ final_trainer_params["max_epochs"] = num_epochs
1424
+ if num_steps is not None:
1425
+ final_trainer_params["limit_train_batches"] = num_steps
1426
+
1264
1427
  training_params = _create_training_configuration(
1265
- num_epochs=num_epochs,
1428
+ trainer_params=final_trainer_params,
1266
1429
  logger=logger,
1267
1430
  checkpoint_params=checkpoint_params,
1268
1431
  )
@@ -1273,3 +1436,664 @@ def create_n2v_configuration(
1273
1436
  data_config=data_params,
1274
1437
  training_config=training_params,
1275
1438
  )
1439
+
1440
+
1441
+ def _create_vae_configuration(
1442
+ input_shape: Sequence[int],
1443
+ encoder_conv_strides: tuple[int, ...],
1444
+ decoder_conv_strides: tuple[int, ...],
1445
+ multiscale_count: int,
1446
+ z_dims: tuple[int, ...],
1447
+ output_channels: int,
1448
+ encoder_n_filters: int,
1449
+ decoder_n_filters: int,
1450
+ encoder_dropout: float,
1451
+ decoder_dropout: float,
1452
+ nonlinearity: Literal[
1453
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
1454
+ ],
1455
+ predict_logvar: Literal[None, "pixelwise"],
1456
+ analytical_kl: bool,
1457
+ ) -> LVAEModel:
1458
+ """Create a dictionary with the parameters of the vae based algorithm model.
1459
+
1460
+ Parameters
1461
+ ----------
1462
+ input_shape : tuple[int, ...]
1463
+ Shape of the input patch (Z, Y, X) or (Y, X) if the data is 2D.
1464
+ encoder_conv_strides : tuple[int, ...]
1465
+ Strides of the encoder convolutional layers, length also defines 2D or 3D.
1466
+ decoder_conv_strides : tuple[int, ...]
1467
+ Strides of the decoder convolutional layers, length also defines 2D or 3D.
1468
+ multiscale_count : int
1469
+ Number of lateral context layers, specific to MicroSplit.
1470
+ z_dims : tuple[int, ...]
1471
+ Number of hierarchies in the LVAE model.
1472
+ output_channels : int
1473
+ Number of output channels.
1474
+ encoder_n_filters : int
1475
+ Number of filters in the convolutional layers of the encoder.
1476
+ decoder_n_filters : int
1477
+ Number of filters in the convolutional layers of the decoder.
1478
+ encoder_dropout : float
1479
+ Dropout rate for the encoder.
1480
+ decoder_dropout : float
1481
+ Dropout rate for the decoder.
1482
+ nonlinearity : Literal
1483
+ Type of nonlinearity function to use.
1484
+ predict_logvar : Literal # TODO needs review
1485
+ _description_.
1486
+ analytical_kl : bool # TODO needs clarification
1487
+ _description_.
1488
+
1489
+ Returns
1490
+ -------
1491
+ LVAEModel
1492
+ LVAE model with the specified parameters.
1493
+ """
1494
+ return LVAEModel(
1495
+ architecture=SupportedArchitecture.LVAE.value,
1496
+ input_shape=input_shape,
1497
+ encoder_conv_strides=encoder_conv_strides,
1498
+ decoder_conv_strides=decoder_conv_strides,
1499
+ multiscale_count=multiscale_count,
1500
+ z_dims=z_dims,
1501
+ output_channels=output_channels,
1502
+ encoder_n_filters=encoder_n_filters,
1503
+ decoder_n_filters=decoder_n_filters,
1504
+ encoder_dropout=encoder_dropout,
1505
+ decoder_dropout=decoder_dropout,
1506
+ nonlinearity=nonlinearity,
1507
+ predict_logvar=predict_logvar,
1508
+ analytical_kl=analytical_kl,
1509
+ )
1510
+
1511
+
1512
+ def _create_vae_based_algorithm(
1513
+ algorithm: Literal["hdn", "microsplit"],
1514
+ loss: LVAELossConfig,
1515
+ input_shape: Sequence[int],
1516
+ encoder_conv_strides: tuple[int, ...],
1517
+ decoder_conv_strides: tuple[int, ...],
1518
+ multiscale_count: int,
1519
+ z_dims: tuple[int, ...],
1520
+ output_channels: int,
1521
+ encoder_n_filters: int,
1522
+ decoder_n_filters: int,
1523
+ encoder_dropout: float,
1524
+ decoder_dropout: float,
1525
+ nonlinearity: Literal[
1526
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
1527
+ ],
1528
+ predict_logvar: Literal[None, "pixelwise"],
1529
+ analytical_kl: bool,
1530
+ gaussian_likelihood: GaussianLikelihoodConfig | None = None,
1531
+ nm_likelihood: NMLikelihoodConfig | None = None,
1532
+ ) -> dict:
1533
+ """
1534
+ Create a dictionary with the parameters of the VAE-based algorithm model.
1535
+
1536
+ Parameters
1537
+ ----------
1538
+ algorithm : Literal["hdn"]
1539
+ The algorithm type.
1540
+ loss : Literal["hdn"]
1541
+ The loss function type.
1542
+ input_shape : tuple[int, ...]
1543
+ The shape of the input data.
1544
+ encoder_conv_strides : list[int]
1545
+ The strides of the encoder convolutional layers.
1546
+ decoder_conv_strides : list[int]
1547
+ The strides of the decoder convolutional layers.
1548
+ multiscale_count : int
1549
+ The number of multiscale layers.
1550
+ z_dims : list[int]
1551
+ The dimensions of the latent space.
1552
+ output_channels : int
1553
+ The number of output channels.
1554
+ encoder_n_filters : int
1555
+ The number of filters in the encoder.
1556
+ decoder_n_filters : int
1557
+ The number of filters in the decoder.
1558
+ encoder_dropout : float
1559
+ The dropout rate for the encoder.
1560
+ decoder_dropout : float
1561
+ The dropout rate for the decoder.
1562
+ nonlinearity : Literal
1563
+ The nonlinearity function to use.
1564
+ predict_logvar : Literal[None, "pixelwise"]
1565
+ The type of log variance prediction.
1566
+ analytical_kl : bool
1567
+ Whether to use analytical KL divergence.
1568
+ gaussian_likelihood : Optional[GaussianLikelihoodConfig], optional
1569
+ The Gaussian likelihood model, by default None.
1570
+ nm_likelihood : Optional[NMLikelihoodConfig], optional
1571
+ The noise model likelihood model, by default None.
1572
+
1573
+ Returns
1574
+ -------
1575
+ dict
1576
+ A dictionary with the parameters of the VAE-based algorithm model.
1577
+ """
1578
+ network_model = _create_vae_configuration(
1579
+ input_shape=input_shape,
1580
+ encoder_conv_strides=encoder_conv_strides,
1581
+ decoder_conv_strides=decoder_conv_strides,
1582
+ multiscale_count=multiscale_count,
1583
+ z_dims=z_dims,
1584
+ output_channels=output_channels,
1585
+ encoder_n_filters=encoder_n_filters,
1586
+ decoder_n_filters=decoder_n_filters,
1587
+ encoder_dropout=encoder_dropout,
1588
+ decoder_dropout=decoder_dropout,
1589
+ nonlinearity=nonlinearity,
1590
+ predict_logvar=predict_logvar,
1591
+ analytical_kl=analytical_kl,
1592
+ )
1593
+ assert gaussian_likelihood or nm_likelihood, "Likelihood model must be specified"
1594
+ return {
1595
+ "algorithm": algorithm,
1596
+ "loss": loss,
1597
+ "model": network_model,
1598
+ "gaussian_likelihood": gaussian_likelihood,
1599
+ "noise_model_likelihood": nm_likelihood,
1600
+ }
1601
+
1602
+
1603
+ def get_likelihood_config(
1604
+ loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"],
1605
+ # TODO remove different microsplit loss types, refac
1606
+ predict_logvar: Literal["pixelwise"] | None = None,
1607
+ logvar_lowerbound: float | None = -5.0,
1608
+ nm_paths: list[str] | None = None,
1609
+ data_stats: tuple[float, float] | None = None,
1610
+ ) -> tuple[
1611
+ GaussianLikelihoodConfig | None,
1612
+ MultiChannelNMConfig | None,
1613
+ NMLikelihoodConfig | None,
1614
+ ]:
1615
+ """Get the likelihood configuration for split models.
1616
+
1617
+ Returns a tuple containing the following optional entries:
1618
+ - GaussianLikelihoodConfig: Gaussian likelihood configuration for musplit losses
1619
+ - MultiChannelNMConfig: Multi-channel noise model configuration for denoisplit
1620
+ losses
1621
+ - NMLikelihoodConfig: Noise model likelihood configuration for denoisplit losses
1622
+
1623
+ Parameters
1624
+ ----------
1625
+ loss_type : Literal["musplit", "denoisplit", "denoisplit_musplit"]
1626
+ The type of loss function to use.
1627
+ predict_logvar : Literal["pixelwise"] | None, optional
1628
+ Type of log variance prediction, by default None.
1629
+ Required when loss_type is "musplit" or "denoisplit_musplit".
1630
+ logvar_lowerbound : float | None, optional
1631
+ Lower bound for the log variance, by default -5.0.
1632
+ Used when loss_type is "musplit" or "denoisplit_musplit".
1633
+ nm_paths : list[str] | None, optional
1634
+ Paths to the noise model files, by default None.
1635
+ Required when loss_type is "denoisplit" or "denoisplit_musplit".
1636
+ data_stats : tuple[float, float] | None, optional
1637
+ Data statistics (mean, std), by default None.
1638
+ Required when loss_type is "denoisplit" or "denoisplit_musplit".
1639
+
1640
+ Returns
1641
+ -------
1642
+ GaussianLikelihoodConfig or None
1643
+ Configuration for the Gaussian likelihood model.
1644
+ MultiChannelNMConfig or None
1645
+ Configuration for the multi-channel noise model.
1646
+ NMLikelihoodConfig or None
1647
+ Configuration for the noise model likelihood.
1648
+
1649
+ Raises
1650
+ ------
1651
+ ValueError
1652
+ If required parameters are missing for the specified loss_type.
1653
+ """
1654
+ # gaussian likelihood
1655
+ if loss_type in ["musplit", "denoisplit_musplit"]:
1656
+ # if predict_logvar is None:
1657
+ # raise ValueError(f"predict_logvar is required for loss_type '{loss_type}'")
1658
+ # TODO validators should be in pydantic models
1659
+ gaussian_lik_config = GaussianLikelihoodConfig(
1660
+ predict_logvar=predict_logvar,
1661
+ logvar_lowerbound=logvar_lowerbound,
1662
+ )
1663
+ else:
1664
+ gaussian_lik_config = None
1665
+
1666
+ # noise model likelihood
1667
+ if loss_type in ["denoisplit", "denoisplit_musplit"]:
1668
+ # if nm_paths is None:
1669
+ # raise ValueError(f"nm_paths is required for loss_type '{loss_type}'")
1670
+ # if data_stats is None:
1671
+ # raise ValueError(f"data_stats is required for loss_type '{loss_type}'")
1672
+ # TODO validators should be in pydantic models
1673
+ gmm_list = []
1674
+ if nm_paths is not None:
1675
+ for NM_path in nm_paths:
1676
+ gmm_list.append(
1677
+ GaussianMixtureNMConfig(
1678
+ model_type="GaussianMixtureNoiseModel",
1679
+ path=NM_path,
1680
+ )
1681
+ )
1682
+ noise_model_config = MultiChannelNMConfig(noise_models=gmm_list)
1683
+ nm_lik_config = NMLikelihoodConfig() # TODO this config isn't needed probably
1684
+ else:
1685
+ noise_model_config = None
1686
+ nm_lik_config = None
1687
+
1688
+ return gaussian_lik_config, noise_model_config, nm_lik_config
1689
+
1690
+
1691
+ # TODO wrap parameters into model, loss etc
1692
+ # TODO refac likelihood configs to make it 1. Can it be done ?
1693
+ def create_hdn_configuration(
1694
+ experiment_name: str,
1695
+ data_type: Literal["array", "tiff", "custom"],
1696
+ axes: str,
1697
+ patch_size: Sequence[int],
1698
+ batch_size: int,
1699
+ num_epochs: int = 100,
1700
+ num_steps: int | None = None,
1701
+ encoder_conv_strides: tuple[int, ...] = (2, 2),
1702
+ decoder_conv_strides: tuple[int, ...] = (2, 2),
1703
+ multiscale_count: int = 1,
1704
+ z_dims: tuple[int, ...] = (128, 128),
1705
+ output_channels: int = 1,
1706
+ encoder_n_filters: int = 32,
1707
+ decoder_n_filters: int = 32,
1708
+ encoder_dropout: float = 0.0,
1709
+ decoder_dropout: float = 0.0,
1710
+ nonlinearity: Literal[
1711
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
1712
+ ] = "ReLU",
1713
+ analytical_kl: bool = False,
1714
+ predict_logvar: Literal["pixelwise"] | None = None,
1715
+ logvar_lowerbound: Union[float, None] = None,
1716
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
1717
+ trainer_params: dict | None = None,
1718
+ augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
1719
+ train_dataloader_params: dict[str, Any] | None = None,
1720
+ val_dataloader_params: dict[str, Any] | None = None,
1721
+ ) -> Configuration:
1722
+ """
1723
+ Create a configuration for training HDN.
1724
+
1725
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
1726
+ 2.
1727
+
1728
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
1729
+ channels. Likewise, if you set the number of channels, then "C" must be present in
1730
+ `axes`.
1731
+
1732
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
1733
+ not specified, it will be assumed to be equal to `n_channels_in`.
1734
+
1735
+ By default, all channels are trained independently. To train all channels together,
1736
+ set `independent_channels` to False.
1737
+
1738
+ By setting `augmentations` to `None`, the default transformations (flip in X and Y,
1739
+ rotations by 90 degrees in the XY plane) are applied. Rather than the default
1740
+ transforms, a list of transforms can be passed to the `augmentations` parameter. To
1741
+ disable the transforms, simply pass an empty list.
1742
+
1743
+ # TODO revisit the necessity of model_params
1744
+
1745
+ Parameters
1746
+ ----------
1747
+ experiment_name : str
1748
+ Name of the experiment.
1749
+ data_type : Literal["array", "tiff", "custom"]
1750
+ Type of the data.
1751
+ axes : str
1752
+ Axes of the data (e.g. SYX).
1753
+ patch_size : List[int]
1754
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
1755
+ batch_size : int
1756
+ Batch size.
1757
+ num_epochs : int, default=100
1758
+ Number of epochs to train for. If provided, this will be added to
1759
+ trainer_params.
1760
+ num_steps : int, optional
1761
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
1762
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
1763
+ documentation for more details.
1764
+ encoder_conv_strides : tuple[int, ...], optional
1765
+ Strides for the encoder convolutional layers, by default (2, 2).
1766
+ decoder_conv_strides : tuple[int, ...], optional
1767
+ Strides for the decoder convolutional layers, by default (2, 2).
1768
+ multiscale_count : int, optional
1769
+ Number of scales in the multiscale architecture, by default 1.
1770
+ z_dims : tuple[int, ...], optional
1771
+ Dimensions of the latent space, by default (128, 128).
1772
+ output_channels : int, optional
1773
+ Number of output channels, by default 1.
1774
+ encoder_n_filters : int, optional
1775
+ Number of filters in the encoder, by default 32.
1776
+ decoder_n_filters : int, optional
1777
+ Number of filters in the decoder, by default 32.
1778
+ encoder_dropout : float, optional
1779
+ Dropout rate for the encoder, by default 0.0.
1780
+ decoder_dropout : float, optional
1781
+ Dropout rate for the decoder, by default 0.0.
1782
+ nonlinearity : Literal, optional
1783
+ Nonlinearity function to use, by default "ReLU".
1784
+ analytical_kl : bool, optional
1785
+ Whether to use analytical KL divergence, by default False.
1786
+ predict_logvar : Literal[None, "pixelwise"], optional
1787
+ Type of log variance prediction, by default None.
1788
+ logvar_lowerbound : Union[float, None], optional
1789
+ Lower bound for the log variance, by default None.
1790
+ logger : Literal["wandb", "tensorboard", "none"], optional
1791
+ Logger to use for training, by default "none".
1792
+ trainer_params : dict, optional
1793
+ Parameters for the trainer class, see PyTorch Lightning documentation.
1794
+ augmentations : Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]], optional
1795
+ List of augmentations to apply, by default None.
1796
+ train_dataloader_params : Optional[dict[str, Any]], optional
1797
+ Parameters for the training dataloader, by default None.
1798
+ val_dataloader_params : Optional[dict[str, Any]], optional
1799
+ Parameters for the validation dataloader, by default None.
1800
+
1801
+ Returns
1802
+ -------
1803
+ Configuration
1804
+ The configuration object for training HDN.
1805
+
1806
+ Examples
1807
+ --------
1808
+ Minimum example:
1809
+ >>> config = create_hdn_configuration(
1810
+ ... experiment_name="hdn_experiment",
1811
+ ... data_type="array",
1812
+ ... axes="YX",
1813
+ ... patch_size=[64, 64],
1814
+ ... batch_size=32,
1815
+ ... num_epochs=100
1816
+ ... )
1817
+
1818
+ You can also limit the number of batches per epoch:
1819
+ >>> config = create_hdn_configuration(
1820
+ ... experiment_name="hdn_experiment",
1821
+ ... data_type="array",
1822
+ ... axes="YX",
1823
+ ... patch_size=[64, 64],
1824
+ ... batch_size=32,
1825
+ ... num_steps=100 # limit to 100 batches per epoch
1826
+ ... )
1827
+ """
1828
+ transform_list = _list_spatial_augmentations(augmentations)
1829
+
1830
+ loss_config = LVAELossConfig(
1831
+ loss_type="hdn", denoisplit_weight=1, musplit_weight=0
1832
+ ) # TODO what are the correct defaults for HDN?
1833
+
1834
+ gaussian_likelihood = GaussianLikelihoodConfig(
1835
+ predict_logvar=predict_logvar, logvar_lowerbound=logvar_lowerbound
1836
+ )
1837
+
1838
+ # algorithm & model
1839
+ algorithm_params = _create_vae_based_algorithm(
1840
+ algorithm="hdn",
1841
+ loss=loss_config,
1842
+ input_shape=patch_size,
1843
+ encoder_conv_strides=encoder_conv_strides,
1844
+ decoder_conv_strides=decoder_conv_strides,
1845
+ multiscale_count=multiscale_count,
1846
+ z_dims=z_dims,
1847
+ output_channels=output_channels,
1848
+ encoder_n_filters=encoder_n_filters,
1849
+ decoder_n_filters=decoder_n_filters,
1850
+ encoder_dropout=encoder_dropout,
1851
+ decoder_dropout=decoder_dropout,
1852
+ nonlinearity=nonlinearity,
1853
+ predict_logvar=predict_logvar,
1854
+ analytical_kl=analytical_kl,
1855
+ gaussian_likelihood=gaussian_likelihood,
1856
+ nm_likelihood=None,
1857
+ )
1858
+
1859
+ # data
1860
+ data_params = _create_data_configuration(
1861
+ data_type=data_type,
1862
+ axes=axes,
1863
+ patch_size=patch_size,
1864
+ batch_size=batch_size,
1865
+ augmentations=transform_list,
1866
+ train_dataloader_params=train_dataloader_params,
1867
+ val_dataloader_params=val_dataloader_params,
1868
+ )
1869
+
1870
+ # Handle trainer parameters with num_epochs and num_steps
1871
+ final_trainer_params = {} if trainer_params is None else trainer_params.copy()
1872
+
1873
+ # Add num_epochs and num_steps if provided
1874
+ if num_epochs is not None:
1875
+ final_trainer_params["max_epochs"] = num_epochs
1876
+ if num_steps is not None:
1877
+ final_trainer_params["limit_train_batches"] = num_steps
1878
+
1879
+ # training
1880
+ training_params = _create_training_configuration(
1881
+ trainer_params=final_trainer_params,
1882
+ logger=logger,
1883
+ )
1884
+
1885
+ return Configuration(
1886
+ experiment_name=experiment_name,
1887
+ algorithm_config=algorithm_params,
1888
+ data_config=data_params,
1889
+ training_config=training_params,
1890
+ )
1891
+
1892
+
1893
+ def create_microsplit_configuration(
1894
+ experiment_name: str,
1895
+ data_type: Literal["array", "tiff", "custom"],
1896
+ axes: str,
1897
+ patch_size: Sequence[int],
1898
+ batch_size: int,
1899
+ num_epochs: int = 100,
1900
+ num_steps: int | None = None,
1901
+ encoder_conv_strides: tuple[int, ...] = (2, 2),
1902
+ decoder_conv_strides: tuple[int, ...] = (2, 2),
1903
+ multiscale_count: int = 3,
1904
+ grid_size: int = 32, # TODO most likely can be derived from patch size
1905
+ z_dims: tuple[int, ...] = (128, 128),
1906
+ output_channels: int = 1,
1907
+ encoder_n_filters: int = 32,
1908
+ decoder_n_filters: int = 32,
1909
+ encoder_dropout: float = 0.0,
1910
+ decoder_dropout: float = 0.0,
1911
+ nonlinearity: Literal[
1912
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
1913
+ ] = "ReLU", # TODO do we need all these?
1914
+ analytical_kl: bool = False,
1915
+ predict_logvar: Literal["pixelwise"] = "pixelwise",
1916
+ logvar_lowerbound: Union[float, None] = None,
1917
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
1918
+ trainer_params: dict | None = None,
1919
+ augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
1920
+ nm_paths: list[str] | None = None,
1921
+ data_stats: tuple[float, float] | None = None,
1922
+ train_dataloader_params: dict[str, Any] | None = None,
1923
+ val_dataloader_params: dict[str, Any] | None = None,
1924
+ ) -> Configuration:
1925
+ """
1926
+ Create a configuration for training MicroSplit.
1927
+
1928
+ Parameters
1929
+ ----------
1930
+ experiment_name : str
1931
+ Name of the experiment.
1932
+ data_type : Literal["array", "tiff", "custom"]
1933
+ Type of the data.
1934
+ axes : str
1935
+ Axes of the data (e.g. SYX).
1936
+ patch_size : Sequence[int]
1937
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
1938
+ batch_size : int
1939
+ Batch size.
1940
+ num_epochs : int, default=100
1941
+ Number of epochs to train for. If provided, this will be added to
1942
+ trainer_params.
1943
+ num_steps : int, optional
1944
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
1945
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
1946
+ documentation for more details.
1947
+ encoder_conv_strides : tuple[int, ...], optional
1948
+ Strides for the encoder convolutional layers, by default (2, 2).
1949
+ decoder_conv_strides : tuple[int, ...], optional
1950
+ Strides for the decoder convolutional layers, by default (2, 2).
1951
+ multiscale_count : int, optional
1952
+ Number of multiscale levels, by default 1.
1953
+ grid_size : int, optional
1954
+ Size of the grid for the lateral context, by default 32.
1955
+ z_dims : tuple[int, ...], optional
1956
+ List of latent dimensions for each hierarchy level in the LVAE, by default
1957
+ (128, 128).
1958
+ output_channels : int, optional
1959
+ Number of output channels for the model, by default 1.
1960
+ encoder_n_filters : int, optional
1961
+ Number of filters in the encoder, by default 32.
1962
+ decoder_n_filters : int, optional
1963
+ Number of filters in the decoder, by default 32.
1964
+ encoder_dropout : float, optional
1965
+ Dropout rate for the encoder, by default 0.0.
1966
+ decoder_dropout : float, optional
1967
+ Dropout rate for the decoder, by default 0.0.
1968
+ nonlinearity : Literal, optional
1969
+ Nonlinearity to use in the model, by default "ReLU".
1970
+ analytical_kl : bool, optional
1971
+ Whether to use analytical KL divergence, by default False.
1972
+ predict_logvar : Literal["pixelwise"] | None, optional
1973
+ Type of log-variance prediction, by default None.
1974
+ logvar_lowerbound : Union[float, None], optional
1975
+ Lower bound for the log variance, by default None.
1976
+ logger : Literal["wandb", "tensorboard", "none"], optional
1977
+ Logger to use for training, by default "none".
1978
+ trainer_params : dict, optional
1979
+ Parameters for the trainer class, see PyTorch Lightning documentation.
1980
+ augmentations : list[Union[XYFlipModel, XYRandomRotate90Model]] | None, optional
1981
+ List of augmentations to apply, by default None.
1982
+ nm_paths : list[str] | None, optional
1983
+ Paths to the noise model files, by default None.
1984
+ data_stats : tuple[float, float] | None, optional
1985
+ Data statistics (mean, std), by default None.
1986
+ train_dataloader_params : dict[str, Any] | None, optional
1987
+ Parameters for the training dataloader, by default None.
1988
+ val_dataloader_params : dict[str, Any] | None, optional
1989
+ Parameters for the validation dataloader, by default None.
1990
+
1991
+ Returns
1992
+ -------
1993
+ Configuration
1994
+ A configuration object for the microsplit algorithm.
1995
+
1996
+ Examples
1997
+ --------
1998
+ Minimum example:
1999
+ # >>> config = create_microsplit_configuration(
2000
+ # ... experiment_name="microsplit_experiment",
2001
+ # ... data_type="array",
2002
+ # ... axes="YX",
2003
+ # ... patch_size=[64, 64],
2004
+ # ... batch_size=32,
2005
+ # ... num_epochs=100
2006
+
2007
+ # ... )
2008
+
2009
+ # You can also limit the number of batches per epoch:
2010
+ # >>> config = create_microsplit_configuration(
2011
+ # ... experiment_name="microsplit_experiment",
2012
+ # ... data_type="array",
2013
+ # ... axes="YX",
2014
+ # ... patch_size=[64, 64],
2015
+ # ... batch_size=32,
2016
+ # ... num_steps=100 # limit to 100 batches per epoch
2017
+ # ... )
2018
+ """
2019
+ transform_list = _list_spatial_augmentations(augmentations)
2020
+
2021
+ loss_config = LVAELossConfig(
2022
+ loss_type="denoisplit_musplit", denoisplit_weight=0.9, musplit_weight=0.1
2023
+ ) # TODO losses need to be refactored! just for example. Add validator if sum to 1
2024
+
2025
+ # Create likelihood configurations
2026
+ gaussian_likelihood_config, noise_model_config, nm_likelihood_config = (
2027
+ get_likelihood_config(
2028
+ loss_type="denoisplit_musplit",
2029
+ predict_logvar=predict_logvar,
2030
+ logvar_lowerbound=logvar_lowerbound,
2031
+ nm_paths=nm_paths,
2032
+ data_stats=data_stats,
2033
+ )
2034
+ )
2035
+
2036
+ # Create the LVAE model
2037
+ network_model = _create_vae_configuration(
2038
+ input_shape=patch_size,
2039
+ encoder_conv_strides=encoder_conv_strides,
2040
+ decoder_conv_strides=decoder_conv_strides,
2041
+ multiscale_count=multiscale_count,
2042
+ z_dims=z_dims,
2043
+ output_channels=output_channels,
2044
+ encoder_n_filters=encoder_n_filters,
2045
+ decoder_n_filters=decoder_n_filters,
2046
+ encoder_dropout=encoder_dropout,
2047
+ decoder_dropout=decoder_dropout,
2048
+ nonlinearity=nonlinearity,
2049
+ predict_logvar=predict_logvar,
2050
+ analytical_kl=analytical_kl,
2051
+ )
2052
+
2053
+ # Create the MicroSplit algorithm configuration
2054
+ algorithm_params = {
2055
+ "algorithm": "microsplit",
2056
+ "loss": loss_config,
2057
+ "model": network_model,
2058
+ "gaussian_likelihood": gaussian_likelihood_config,
2059
+ "noise_model": noise_model_config,
2060
+ "noise_model_likelihood": nm_likelihood_config,
2061
+ }
2062
+
2063
+ # Convert to MicroSplitAlgorithm instance
2064
+ algorithm_config = MicroSplitAlgorithm(**algorithm_params)
2065
+
2066
+ # data
2067
+ data_params = _create_microsplit_data_configuration(
2068
+ data_type=data_type,
2069
+ axes=axes,
2070
+ patch_size=patch_size,
2071
+ grid_size=grid_size,
2072
+ multiscale_count=multiscale_count,
2073
+ batch_size=batch_size,
2074
+ augmentations=transform_list,
2075
+ train_dataloader_params=train_dataloader_params,
2076
+ val_dataloader_params=val_dataloader_params,
2077
+ )
2078
+
2079
+ # Handle trainer parameters with num_epochs and num_steps
2080
+ final_trainer_params = {} if trainer_params is None else trainer_params.copy()
2081
+
2082
+ # Add num_epochs and num_steps if provided
2083
+ if num_epochs is not None:
2084
+ final_trainer_params["max_epochs"] = num_epochs
2085
+ if num_steps is not None:
2086
+ final_trainer_params["limit_train_batches"] = num_steps
2087
+
2088
+ # training
2089
+ training_params = _create_training_configuration(
2090
+ trainer_params=final_trainer_params,
2091
+ logger=logger,
2092
+ )
2093
+
2094
+ return Configuration(
2095
+ experiment_name=experiment_name,
2096
+ algorithm_config=algorithm_config,
2097
+ data_config=data_params,
2098
+ training_config=training_params,
2099
+ )