careamics 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (74) hide show
  1. careamics/careamist.py +4 -3
  2. careamics/cli/utils.py +1 -1
  3. careamics/config/algorithms/n2v_algorithm_model.py +1 -1
  4. careamics/config/architectures/unet_model.py +3 -0
  5. careamics/config/callback_model.py +23 -34
  6. careamics/config/configuration.py +47 -1
  7. careamics/config/configuration_factories.py +288 -23
  8. careamics/config/data/__init__.py +2 -0
  9. careamics/config/data/data_model.py +3 -3
  10. careamics/config/data/ng_data_model.py +381 -0
  11. careamics/config/data/patching_strategies/__init__.py +14 -0
  12. careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
  13. careamics/config/data/patching_strategies/_patched_model.py +56 -0
  14. careamics/config/data/patching_strategies/random_patching_model.py +21 -0
  15. careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
  16. careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
  17. careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
  18. careamics/config/inference_model.py +6 -3
  19. careamics/config/support/supported_data.py +7 -0
  20. careamics/config/support/supported_patching_strategies.py +22 -0
  21. careamics/config/validators/validator_utils.py +4 -3
  22. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  23. careamics/dataset/in_memory_dataset.py +2 -1
  24. careamics/dataset/iterable_dataset.py +2 -2
  25. careamics/dataset/iterable_pred_dataset.py +2 -2
  26. careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
  27. careamics/dataset/patching/patching.py +3 -2
  28. careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
  29. careamics/dataset/tiling/tiled_patching.py +2 -1
  30. careamics/dataset_ng/dataset.py +46 -50
  31. careamics/dataset_ng/demos/bsd68_demo.ipynb +28 -23
  32. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +1 -1
  33. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +1 -1
  34. careamics/dataset_ng/demos/demo_datamodule.ipynb +50 -46
  35. careamics/dataset_ng/demos/demo_dataset.ipynb +32 -49
  36. careamics/dataset_ng/factory.py +58 -15
  37. careamics/dataset_ng/legacy_interoperability.py +3 -1
  38. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +1 -1
  39. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -0
  40. careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
  41. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
  42. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +43 -1
  43. careamics/dataset_ng/patching_strategies/random_patching.py +4 -2
  44. careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
  45. careamics/dataset_ng/patching_strategies/tiling_strategy.py +2 -1
  46. careamics/file_io/read/get_func.py +2 -1
  47. careamics/lightning/dataset_ng/__init__.py +1 -0
  48. careamics/lightning/dataset_ng/data_module.py +218 -28
  49. careamics/lightning/dataset_ng/lightning_modules/care_module.py +44 -5
  50. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +42 -3
  51. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +73 -4
  52. careamics/lightning/lightning_module.py +2 -1
  53. careamics/lightning/predict_data_module.py +2 -1
  54. careamics/lightning/train_data_module.py +2 -1
  55. careamics/losses/loss_factory.py +2 -1
  56. careamics/lvae_training/dataset/multicrop_dset.py +1 -1
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +1 -1
  59. careamics/model_io/bmz_io.py +1 -1
  60. careamics/model_io/model_io_utils.py +2 -2
  61. careamics/models/activation.py +2 -1
  62. careamics/models/unet.py +16 -10
  63. careamics/prediction_utils/prediction_outputs.py +1 -1
  64. careamics/prediction_utils/stitch_prediction.py +1 -1
  65. careamics/transforms/n2v_manipulate_torch.py +15 -9
  66. careamics/transforms/pixel_manipulation_torch.py +59 -92
  67. careamics/utils/lightning_utils.py +2 -2
  68. careamics/utils/metrics.py +2 -1
  69. careamics/utils/torch_utils.py +23 -0
  70. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/METADATA +10 -9
  71. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/RECORD +74 -63
  72. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/WHEEL +0 -0
  73. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/entry_points.txt +0 -0
  74. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,13 @@
1
1
  """Convenience functions to create configurations for training and inference."""
2
2
 
3
+ from collections.abc import Sequence
3
4
  from typing import Annotated, Any, Literal, Optional, Union
4
5
 
5
6
  from pydantic import Field, TypeAdapter
6
7
 
7
8
  from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
8
9
  from careamics.config.architectures import UNetModel
9
- from careamics.config.data import DataConfig
10
+ from careamics.config.data import DataConfig, NGDataConfig
10
11
  from careamics.config.support import (
11
12
  SupportedArchitecture,
12
13
  SupportedPixelManipulation,
@@ -24,7 +25,7 @@ from .configuration import Configuration
24
25
 
25
26
 
26
27
  def algorithm_factory(
27
- algorithm: dict[str, Any]
28
+ algorithm: dict[str, Any],
28
29
  ) -> Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm]:
29
30
  """
30
31
  Create an algorithm model for training CAREamics.
@@ -49,7 +50,7 @@ def algorithm_factory(
49
50
 
50
51
 
51
52
  def _list_spatial_augmentations(
52
- augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]],
53
+ augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]] = None,
53
54
  ) -> list[SPATIAL_TRANSFORMS_UNION]:
54
55
  """
55
56
  List the augmentations to apply.
@@ -153,6 +154,10 @@ def _create_algorithm_configuration(
153
154
  n_channels_out: int,
154
155
  use_n2v2: bool = False,
155
156
  model_params: Optional[dict] = None,
157
+ optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
158
+ optimizer_params: Optional[dict[str, Any]] = None,
159
+ lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
160
+ lr_scheduler_params: Optional[dict[str, Any]] = None,
156
161
  ) -> dict:
157
162
  """
158
163
  Create a dictionary with the parameters of the algorithm model.
@@ -171,10 +176,20 @@ def _create_algorithm_configuration(
171
176
  Number of input channels.
172
177
  n_channels_out : int
173
178
  Number of output channels.
174
- use_n2v2 : bool, optional
175
- Whether to use N2V2, by default False.
176
- model_params : dict
179
+ use_n2v2 : bool, default=false
180
+ Whether to use N2V2.
181
+ model_params : dict, default=None
177
182
  UNetModel parameters.
183
+ optimizer : {"Adam", "Adamax", "SGD"}, default="Adam"
184
+ Optimizer to use.
185
+ optimizer_params : dict, default=None
186
+ Parameters for the optimizer, see PyTorch documentation for more details.
187
+ lr_scheduler : {"ReduceLROnPlateau", "StepLR"}, default="ReduceLROnPlateau"
188
+ Learning rate scheduler to use.
189
+ lr_scheduler_params : dict, default=None
190
+ Parameters for the learning rate scheduler, see PyTorch documentation for more
191
+ details.
192
+
178
193
 
179
194
  Returns
180
195
  -------
@@ -195,11 +210,19 @@ def _create_algorithm_configuration(
195
210
  "algorithm": algorithm,
196
211
  "loss": loss,
197
212
  "model": unet_model,
213
+ "optimizer": {
214
+ "name": optimizer,
215
+ "parameters": {} if optimizer_params is None else optimizer_params,
216
+ },
217
+ "lr_scheduler": {
218
+ "name": lr_scheduler,
219
+ "parameters": {} if lr_scheduler_params is None else lr_scheduler_params,
220
+ },
198
221
  }
199
222
 
200
223
 
201
224
  def _create_data_configuration(
202
- data_type: Literal["array", "tiff", "custom"],
225
+ data_type: Literal["array", "tiff", "czi", "custom"],
203
226
  axes: str,
204
227
  patch_size: list[int],
205
228
  batch_size: int,
@@ -212,7 +235,7 @@ def _create_data_configuration(
212
235
 
213
236
  Parameters
214
237
  ----------
215
- data_type : {"array", "tiff", "custom"}
238
+ data_type : {"array", "tiff", "czi", "custom"}
216
239
  Type of the data.
217
240
  axes : str
218
241
  Axes of the data.
@@ -254,8 +277,89 @@ def _create_data_configuration(
254
277
  return DataConfig(**data)
255
278
 
256
279
 
280
+ def _create_ng_data_configuration(
281
+ data_type: Literal["array", "tiff", "custom"],
282
+ axes: str,
283
+ patch_size: Sequence[int],
284
+ batch_size: int,
285
+ augmentations: list[SPATIAL_TRANSFORMS_UNION],
286
+ patch_overlaps: Optional[Sequence[int]] = None,
287
+ train_dataloader_params: Optional[dict[str, Any]] = None,
288
+ val_dataloader_params: Optional[dict[str, Any]] = None,
289
+ test_dataloader_params: Optional[dict[str, Any]] = None,
290
+ seed: Optional[int] = None,
291
+ ) -> NGDataConfig:
292
+ """
293
+ Create a dictionary with the parameters of the data model.
294
+
295
+ Parameters
296
+ ----------
297
+ data_type : {"array", "tiff", "custom"}
298
+ Type of the data.
299
+ axes : str
300
+ Axes of the data.
301
+ patch_size : list of int
302
+ Size of the patches along the spatial dimensions.
303
+ batch_size : int
304
+ Batch size.
305
+ augmentations : list of transforms
306
+ List of transforms to apply.
307
+ patch_overlaps : Sequence of int, default=None
308
+ Overlaps between patches in each spatial dimension, only used with "sequential"
309
+ patching. If `None`, no overlap is applied. The overlap must be smaller than
310
+ the patch size in each spatial dimension, and the number of dimensions be either
311
+ 2 or 3.
312
+ train_dataloader_params : dict
313
+ Parameters for the training dataloader, see PyTorch notes, by default None.
314
+ val_dataloader_params : dict
315
+ Parameters for the validation dataloader, see PyTorch notes, by default None.
316
+ test_dataloader_params : dict
317
+ Parameters for the test dataloader, see PyTorch notes, by default None.
318
+ seed : int, default=None
319
+ Random seed for reproducibility. If `None`, no seed is set.
320
+
321
+ Returns
322
+ -------
323
+ NGDataConfig
324
+ Next-Generation Data model with the specified parameters.
325
+ """
326
+ # data model
327
+ data = {
328
+ "data_type": data_type,
329
+ "axes": axes,
330
+ "batch_size": batch_size,
331
+ "transforms": augmentations,
332
+ "seed": seed,
333
+ }
334
+ # don't override defaults set in DataConfig class
335
+ if train_dataloader_params is not None:
336
+ # the presence of `shuffle` key in the dataloader parameters is enforced
337
+ # by the NGDataConfig class
338
+ if "shuffle" not in train_dataloader_params:
339
+ train_dataloader_params["shuffle"] = True
340
+
341
+ data["train_dataloader_params"] = train_dataloader_params
342
+
343
+ if val_dataloader_params is not None:
344
+ data["val_dataloader_params"] = val_dataloader_params
345
+
346
+ if test_dataloader_params is not None:
347
+ data["test_dataloader_params"] = test_dataloader_params
348
+
349
+ # add training patching
350
+ data["patching"] = {
351
+ "name": "random",
352
+ "patch_size": patch_size,
353
+ "overlaps": patch_overlaps,
354
+ }
355
+
356
+ return NGDataConfig(**data)
357
+
358
+
257
359
  def _create_training_configuration(
258
- num_epochs: int, logger: Literal["wandb", "tensorboard", "none"]
360
+ num_epochs: int,
361
+ logger: Literal["wandb", "tensorboard", "none"],
362
+ checkpoint_params: Optional[dict[str, Any]] = None,
259
363
  ) -> TrainingConfig:
260
364
  """
261
365
  Create a dictionary with the parameters of the training model.
@@ -266,6 +370,9 @@ def _create_training_configuration(
266
370
  Number of epochs.
267
371
  logger : {"wandb", "tensorboard", "none"}
268
372
  Logger to use.
373
+ checkpoint_params : dict, default=None
374
+ Parameters for the checkpoint callback, see PyTorch Lightning documentation
375
+ (`ModelCheckpoint`) for the list of available parameters.
269
376
 
270
377
  Returns
271
378
  -------
@@ -275,6 +382,7 @@ def _create_training_configuration(
275
382
  return TrainingConfig(
276
383
  num_epochs=num_epochs,
277
384
  logger=None if logger == "none" else logger,
385
+ checkpoint_callback={} if checkpoint_params is None else checkpoint_params,
278
386
  )
279
387
 
280
388
 
@@ -282,7 +390,7 @@ def _create_training_configuration(
282
390
  def _create_supervised_config_dict(
283
391
  algorithm: Literal["care", "n2n"],
284
392
  experiment_name: str,
285
- data_type: Literal["array", "tiff", "custom"],
393
+ data_type: Literal["array", "tiff", "czi", "custom"],
286
394
  axes: str,
287
395
  patch_size: list[int],
288
396
  batch_size: int,
@@ -294,8 +402,13 @@ def _create_supervised_config_dict(
294
402
  n_channels_out: Optional[int] = None,
295
403
  logger: Literal["wandb", "tensorboard", "none"] = "none",
296
404
  model_params: Optional[dict] = None,
405
+ optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
406
+ optimizer_params: Optional[dict[str, Any]] = None,
407
+ lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
408
+ lr_scheduler_params: Optional[dict[str, Any]] = None,
297
409
  train_dataloader_params: Optional[dict[str, Any]] = None,
298
410
  val_dataloader_params: Optional[dict[str, Any]] = None,
411
+ checkpoint_params: Optional[dict[str, Any]] = None,
299
412
  ) -> dict:
300
413
  """
301
414
  Create a configuration for training CARE or Noise2Noise.
@@ -306,7 +419,7 @@ def _create_supervised_config_dict(
306
419
  Algorithm to use.
307
420
  experiment_name : str
308
421
  Name of the experiment.
309
- data_type : Literal["array", "tiff", "custom"]
422
+ data_type : Literal["array", "tiff", "czi", "custom"]
310
423
  Type of the data.
311
424
  axes : str
312
425
  Axes of the data (e.g. SYX).
@@ -330,12 +443,24 @@ def _create_supervised_config_dict(
330
443
  Number of channels out.
331
444
  logger : Literal["wandb", "tensorboard", "none"], optional
332
445
  Logger to use, by default "none".
333
- model_params : dict, optional
334
- UNetModel parameters, by default {}.
446
+ model_params : dict, default=None
447
+ UNetModel parameters.
448
+ optimizer : {"Adam", "Adamax", "SGD"}, default="Adam"
449
+ Optimizer to use.
450
+ optimizer_params : dict, default=None
451
+ Parameters for the optimizer, see PyTorch documentation for more details.
452
+ lr_scheduler : {"ReduceLROnPlateau", "StepLR"}, default="ReduceLROnPlateau"
453
+ Learning rate scheduler to use.
454
+ lr_scheduler_params : dict, default=None
455
+ Parameters for the learning rate scheduler, see PyTorch documentation for more
456
+ details.
335
457
  train_dataloader_params : dict
336
458
  Parameters for the training dataloader, see PyTorch notes, by default None.
337
459
  val_dataloader_params : dict
338
460
  Parameters for the validation dataloader, see PyTorch notes, by default None.
461
+ checkpoint_params : dict, default=None
462
+ Parameters for the checkpoint callback, see PyTorch Lightning documentation
463
+ (`ModelCheckpoint`) for the list of available parameters.
339
464
 
340
465
  Returns
341
466
  -------
@@ -376,6 +501,10 @@ def _create_supervised_config_dict(
376
501
  n_channels_in=n_channels_in,
377
502
  n_channels_out=n_channels_out,
378
503
  model_params=model_params,
504
+ optimizer=optimizer,
505
+ optimizer_params=optimizer_params,
506
+ lr_scheduler=lr_scheduler,
507
+ lr_scheduler_params=lr_scheduler_params,
379
508
  )
380
509
 
381
510
  # data
@@ -393,6 +522,7 @@ def _create_supervised_config_dict(
393
522
  training_params = _create_training_configuration(
394
523
  num_epochs=num_epochs,
395
524
  logger=logger,
525
+ checkpoint_params=checkpoint_params,
396
526
  )
397
527
 
398
528
  return {
@@ -405,7 +535,7 @@ def _create_supervised_config_dict(
405
535
 
406
536
  def create_care_configuration(
407
537
  experiment_name: str,
408
- data_type: Literal["array", "tiff", "custom"],
538
+ data_type: Literal["array", "tiff", "czi", "custom"],
409
539
  axes: str,
410
540
  patch_size: list[int],
411
541
  batch_size: int,
@@ -417,8 +547,13 @@ def create_care_configuration(
417
547
  n_channels_out: Optional[int] = None,
418
548
  logger: Literal["wandb", "tensorboard", "none"] = "none",
419
549
  model_params: Optional[dict] = None,
550
+ optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
551
+ optimizer_params: Optional[dict[str, Any]] = None,
552
+ lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
553
+ lr_scheduler_params: Optional[dict[str, Any]] = None,
420
554
  train_dataloader_params: Optional[dict[str, Any]] = None,
421
555
  val_dataloader_params: Optional[dict[str, Any]] = None,
556
+ checkpoint_params: Optional[dict[str, Any]] = None,
422
557
  ) -> Configuration:
423
558
  """
424
559
  Create a configuration for training CARE.
@@ -445,7 +580,7 @@ def create_care_configuration(
445
580
  ----------
446
581
  experiment_name : str
447
582
  Name of the experiment.
448
- data_type : Literal["array", "tiff", "custom"]
583
+ data_type : Literal["array", "tiff", "czi", "custom"]
449
584
  Type of the data.
450
585
  axes : str
451
586
  Axes of the data (e.g. SYX).
@@ -471,6 +606,15 @@ def create_care_configuration(
471
606
  Logger to use.
472
607
  model_params : dict, default=None
473
608
  UNetModel parameters.
609
+ optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
610
+ Optimizer to use.
611
+ optimizer_params : dict, default=None
612
+ Parameters for the optimizer, see PyTorch documentation for more details.
613
+ lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
614
+ Learning rate scheduler to use.
615
+ lr_scheduler_params : dict, default=None
616
+ Parameters for the learning rate scheduler, see PyTorch documentation for more
617
+ details.
474
618
  train_dataloader_params : dict, optional
475
619
  Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
476
620
  If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
@@ -479,6 +623,9 @@ def create_care_configuration(
479
623
  Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
480
624
  If left as `None`, the empty dict `{}` will be used, this is set in the
481
625
  `GeneralDataConfig`.
626
+ checkpoint_params : dict, default=None
627
+ Parameters for the checkpoint callback, see PyTorch Lightning documentation
628
+ (`ModelCheckpoint`) for the list of available parameters.
482
629
 
483
630
  Returns
484
631
  -------
@@ -551,6 +698,29 @@ def create_care_configuration(
551
698
  ... n_channels_in=3,
552
699
  ... n_channels_out=1 # if applicable
553
700
  ... )
701
+
702
+ If you would like to train on CZI files, use `"czi"` as `data_type` and `"SCYX"` as
703
+ `axes` for 2-D or `"SCZYX"` for 3-D denoising. Note that `"SCYX"` can also be used
704
+ for 3-D data but spatial context along the Z dimension will then not be taken into
705
+ account.
706
+ >>> config_2d = create_care_configuration(
707
+ ... experiment_name="care_experiment",
708
+ ... data_type="czi",
709
+ ... axes="SCYX",
710
+ ... patch_size=[64, 64],
711
+ ... batch_size=32,
712
+ ... num_epochs=100,
713
+ ... n_channels_in=1,
714
+ ... )
715
+ >>> config_3d = create_care_configuration(
716
+ ... experiment_name="care_experiment",
717
+ ... data_type="czi",
718
+ ... axes="SCZYX",
719
+ ... patch_size=[16, 64, 64],
720
+ ... batch_size=16,
721
+ ... num_epochs=100,
722
+ ... n_channels_in=1,
723
+ ... )
554
724
  """
555
725
  return Configuration(
556
726
  **_create_supervised_config_dict(
@@ -568,15 +738,20 @@ def create_care_configuration(
568
738
  n_channels_out=n_channels_out,
569
739
  logger=logger,
570
740
  model_params=model_params,
741
+ optimizer=optimizer,
742
+ optimizer_params=optimizer_params,
743
+ lr_scheduler=lr_scheduler,
744
+ lr_scheduler_params=lr_scheduler_params,
571
745
  train_dataloader_params=train_dataloader_params,
572
746
  val_dataloader_params=val_dataloader_params,
747
+ checkpoint_params=checkpoint_params,
573
748
  )
574
749
  )
575
750
 
576
751
 
577
752
  def create_n2n_configuration(
578
753
  experiment_name: str,
579
- data_type: Literal["array", "tiff", "custom"],
754
+ data_type: Literal["array", "tiff", "czi", "custom"],
580
755
  axes: str,
581
756
  patch_size: list[int],
582
757
  batch_size: int,
@@ -588,8 +763,13 @@ def create_n2n_configuration(
588
763
  n_channels_out: Optional[int] = None,
589
764
  logger: Literal["wandb", "tensorboard", "none"] = "none",
590
765
  model_params: Optional[dict] = None,
766
+ optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
767
+ optimizer_params: Optional[dict[str, Any]] = None,
768
+ lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
769
+ lr_scheduler_params: Optional[dict[str, Any]] = None,
591
770
  train_dataloader_params: Optional[dict[str, Any]] = None,
592
771
  val_dataloader_params: Optional[dict[str, Any]] = None,
772
+ checkpoint_params: Optional[dict[str, Any]] = None,
593
773
  ) -> Configuration:
594
774
  """
595
775
  Create a configuration for training Noise2Noise.
@@ -616,7 +796,7 @@ def create_n2n_configuration(
616
796
  ----------
617
797
  experiment_name : str
618
798
  Name of the experiment.
619
- data_type : Literal["array", "tiff", "custom"]
799
+ data_type : Literal["array", "tiff", "czi", "custom"]
620
800
  Type of the data.
621
801
  axes : str
622
802
  Axes of the data (e.g. SYX).
@@ -640,8 +820,17 @@ def create_n2n_configuration(
640
820
  Number of channels out.
641
821
  logger : Literal["wandb", "tensorboard", "none"], optional
642
822
  Logger to use, by default "none".
643
- model_params : dict, optional
644
- UNetModel parameters, by default {}.
823
+ model_params : dict, default=None
824
+ UNetModel parameters.
825
+ optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
826
+ Optimizer to use.
827
+ optimizer_params : dict, default=None
828
+ Parameters for the optimizer, see PyTorch documentation for more details.
829
+ lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
830
+ Learning rate scheduler to use.
831
+ lr_scheduler_params : dict, default=None
832
+ Parameters for the learning rate scheduler, see PyTorch documentation for more
833
+ details.
645
834
  train_dataloader_params : dict, optional
646
835
  Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
647
836
  If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
@@ -650,6 +839,9 @@ def create_n2n_configuration(
650
839
  Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
651
840
  If left as `None`, the empty dict `{}` will be used, this is set in the
652
841
  `GeneralDataConfig`.
842
+ checkpoint_params : dict, default=None
843
+ Parameters for the checkpoint callback, see PyTorch Lightning documentation
844
+ (`ModelCheckpoint`) for the list of available parameters.
653
845
 
654
846
  Returns
655
847
  -------
@@ -722,6 +914,29 @@ def create_n2n_configuration(
722
914
  ... n_channels_in=3,
723
915
  ... n_channels_out=1 # if applicable
724
916
  ... )
917
+
918
+ If you would like to train on CZI files, use `"czi"` as `data_type` and `"SCYX"` as
919
+ `axes` for 2-D or `"SCZYX"` for 3-D denoising. Note that `"SCYX"` can also be used
920
+ for 3-D data but spatial context along the Z dimension will then not be taken into
921
+ account.
922
+ >>> config_2d = create_n2n_configuration(
923
+ ... experiment_name="n2n_experiment",
924
+ ... data_type="czi",
925
+ ... axes="SCYX",
926
+ ... patch_size=[64, 64],
927
+ ... batch_size=32,
928
+ ... num_epochs=100,
929
+ ... n_channels_in=1,
930
+ ... )
931
+ >>> config_3d = create_n2n_configuration(
932
+ ... experiment_name="n2n_experiment",
933
+ ... data_type="czi",
934
+ ... axes="SCZYX",
935
+ ... patch_size=[16, 64, 64],
936
+ ... batch_size=16,
937
+ ... num_epochs=100,
938
+ ... n_channels_in=1,
939
+ ... )
725
940
  """
726
941
  return Configuration(
727
942
  **_create_supervised_config_dict(
@@ -739,15 +954,20 @@ def create_n2n_configuration(
739
954
  n_channels_out=n_channels_out,
740
955
  logger=logger,
741
956
  model_params=model_params,
957
+ optimizer=optimizer,
958
+ optimizer_params=optimizer_params,
959
+ lr_scheduler=lr_scheduler,
960
+ lr_scheduler_params=lr_scheduler_params,
742
961
  train_dataloader_params=train_dataloader_params,
743
962
  val_dataloader_params=val_dataloader_params,
963
+ checkpoint_params=checkpoint_params,
744
964
  )
745
965
  )
746
966
 
747
967
 
748
968
  def create_n2v_configuration(
749
969
  experiment_name: str,
750
- data_type: Literal["array", "tiff", "custom"],
970
+ data_type: Literal["array", "tiff", "czi", "custom"],
751
971
  axes: str,
752
972
  patch_size: list[int],
753
973
  batch_size: int,
@@ -762,8 +982,13 @@ def create_n2v_configuration(
762
982
  struct_n2v_span: int = 5,
763
983
  logger: Literal["wandb", "tensorboard", "none"] = "none",
764
984
  model_params: Optional[dict] = None,
985
+ optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
986
+ optimizer_params: Optional[dict[str, Any]] = None,
987
+ lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
988
+ lr_scheduler_params: Optional[dict[str, Any]] = None,
765
989
  train_dataloader_params: Optional[dict[str, Any]] = None,
766
990
  val_dataloader_params: Optional[dict[str, Any]] = None,
991
+ checkpoint_params: Optional[dict[str, Any]] = None,
767
992
  ) -> Configuration:
768
993
  """
769
994
  Create a configuration for training Noise2Void.
@@ -810,7 +1035,7 @@ def create_n2v_configuration(
810
1035
  ----------
811
1036
  experiment_name : str
812
1037
  Name of the experiment.
813
- data_type : Literal["array", "tiff", "custom"]
1038
+ data_type : Literal["array", "tiff", "czi", "custom"]
814
1039
  Type of the data.
815
1040
  axes : str
816
1041
  Axes of the data (e.g. SYX).
@@ -840,8 +1065,17 @@ def create_n2v_configuration(
840
1065
  Span of the structN2V mask, by default 5.
841
1066
  logger : Literal["wandb", "tensorboard", "none"], optional
842
1067
  Logger to use, by default "none".
843
- model_params : dict, optional
844
- UNetModel parameters, by default None.
1068
+ model_params : dict, default=None
1069
+ UNetModel parameters.
1070
+ optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
1071
+ Optimizer to use.
1072
+ optimizer_params : dict, default=None
1073
+ Parameters for the optimizer, see PyTorch documentation for more details.
1074
+ lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
1075
+ Learning rate scheduler to use.
1076
+ lr_scheduler_params : dict, default=None
1077
+ Parameters for the learning rate scheduler, see PyTorch documentation for more
1078
+ details.
845
1079
  train_dataloader_params : dict, optional
846
1080
  Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
847
1081
  If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
@@ -850,6 +1084,9 @@ def create_n2v_configuration(
850
1084
  Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
851
1085
  If left as `None`, the empty dict `{}` will be used, this is set in the
852
1086
  `GeneralDataConfig`.
1087
+ checkpoint_params : dict, default=None
1088
+ Parameters for the checkpoint callback, see PyTorch Lightning documentation
1089
+ (`ModelCheckpoint`) for the list of available parameters.
853
1090
 
854
1091
  Returns
855
1092
  -------
@@ -942,6 +1179,29 @@ def create_n2v_configuration(
942
1179
  ... independent_channels=False,
943
1180
  ... n_channels=3
944
1181
  ... )
1182
+
1183
+ If you would like to train on CZI files, use `"czi"` as `data_type` and `"SCYX"` as
1184
+ `axes` for 2-D or `"SCZYX"` for 3-D denoising. Note that `"SCYX"` can also be used
1185
+ for 3-D data but spatial context along the Z dimension will then not be taken into
1186
+ account.
1187
+ >>> config_2d = create_n2v_configuration(
1188
+ ... experiment_name="n2v_experiment",
1189
+ ... data_type="czi",
1190
+ ... axes="SCYX",
1191
+ ... patch_size=[64, 64],
1192
+ ... batch_size=32,
1193
+ ... num_epochs=100,
1194
+ ... n_channels=1,
1195
+ ... )
1196
+ >>> config_3d = create_n2v_configuration(
1197
+ ... experiment_name="n2v_experiment",
1198
+ ... data_type="czi",
1199
+ ... axes="SCZYX",
1200
+ ... patch_size=[16, 64, 64],
1201
+ ... batch_size=16,
1202
+ ... num_epochs=100,
1203
+ ... n_channels=1,
1204
+ ... )
945
1205
  """
946
1206
  # if there are channels, we need to specify their number
947
1207
  if "C" in axes and n_channels is None:
@@ -982,6 +1242,10 @@ def create_n2v_configuration(
982
1242
  n_channels_out=n_channels,
983
1243
  use_n2v2=use_n2v2,
984
1244
  model_params=model_params,
1245
+ optimizer=optimizer,
1246
+ optimizer_params=optimizer_params,
1247
+ lr_scheduler=lr_scheduler,
1248
+ lr_scheduler_params=lr_scheduler_params,
985
1249
  )
986
1250
  algorithm_params["n2v_config"] = n2v_transform
987
1251
 
@@ -1000,6 +1264,7 @@ def create_n2v_configuration(
1000
1264
  training_params = _create_training_configuration(
1001
1265
  num_epochs=num_epochs,
1002
1266
  logger=logger,
1267
+ checkpoint_params=checkpoint_params,
1003
1268
  )
1004
1269
 
1005
1270
  return Configuration(
@@ -2,6 +2,8 @@
2
2
 
3
3
  __all__ = [
4
4
  "DataConfig",
5
+ "NGDataConfig",
5
6
  ]
6
7
 
7
8
  from .data_model import DataConfig
9
+ from .ng_data_model import NGDataConfig
@@ -95,9 +95,9 @@ class DataConfig(BaseModel):
95
95
  )
96
96
 
97
97
  # Dataset configuration
98
- data_type: Literal["array", "tiff", "custom"]
99
- """Type of input data, numpy.ndarray (array) or paths (tiff and custom), as defined
100
- in SupportedData."""
98
+ data_type: Literal["array", "tiff", "czi", "custom"]
99
+ """Type of input data, numpy.ndarray (array) or paths (tiff, czi, and custom), as
100
+ defined in SupportedData."""
101
101
 
102
102
  axes: str
103
103
  """Axes of the data, as defined in SupportedAxes."""