careamics 0.1.0rc3__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (66) hide show
  1. careamics/__init__.py +8 -6
  2. careamics/careamist.py +30 -29
  3. careamics/config/__init__.py +12 -9
  4. careamics/config/algorithm_model.py +5 -5
  5. careamics/config/architectures/unet_model.py +1 -0
  6. careamics/config/callback_model.py +1 -0
  7. careamics/config/configuration_example.py +87 -0
  8. careamics/config/configuration_factory.py +285 -78
  9. careamics/config/configuration_model.py +22 -23
  10. careamics/config/data_model.py +62 -160
  11. careamics/config/inference_model.py +20 -21
  12. careamics/config/references/algorithm_descriptions.py +1 -0
  13. careamics/config/references/references.py +1 -0
  14. careamics/config/support/supported_extraction_strategies.py +1 -0
  15. careamics/config/support/supported_optimizers.py +3 -3
  16. careamics/config/training_model.py +2 -1
  17. careamics/config/transformations/n2v_manipulate_model.py +2 -1
  18. careamics/config/transformations/nd_flip_model.py +7 -12
  19. careamics/config/transformations/normalize_model.py +2 -1
  20. careamics/config/transformations/transform_model.py +1 -0
  21. careamics/config/transformations/xy_random_rotate90_model.py +7 -9
  22. careamics/config/validators/validator_utils.py +1 -0
  23. careamics/conftest.py +1 -0
  24. careamics/dataset/dataset_utils/__init__.py +0 -1
  25. careamics/dataset/dataset_utils/dataset_utils.py +1 -0
  26. careamics/dataset/in_memory_dataset.py +17 -48
  27. careamics/dataset/iterable_dataset.py +16 -71
  28. careamics/dataset/patching/__init__.py +0 -7
  29. careamics/dataset/patching/patching.py +1 -0
  30. careamics/dataset/patching/sequential_patching.py +6 -6
  31. careamics/dataset/patching/tiled_patching.py +10 -6
  32. careamics/lightning_datamodule.py +123 -49
  33. careamics/lightning_module.py +7 -7
  34. careamics/lightning_prediction_datamodule.py +59 -48
  35. careamics/losses/__init__.py +0 -1
  36. careamics/losses/loss_factory.py +1 -0
  37. careamics/model_io/__init__.py +0 -1
  38. careamics/model_io/bioimage/_readme_factory.py +2 -1
  39. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  40. careamics/model_io/bioimage/model_description.py +4 -3
  41. careamics/model_io/bmz_io.py +8 -7
  42. careamics/model_io/model_io_utils.py +4 -4
  43. careamics/models/layers.py +1 -0
  44. careamics/models/model_factory.py +1 -0
  45. careamics/models/unet.py +91 -17
  46. careamics/prediction/stitch_prediction.py +1 -0
  47. careamics/transforms/__init__.py +2 -23
  48. careamics/transforms/compose.py +98 -0
  49. careamics/transforms/n2v_manipulate.py +18 -23
  50. careamics/transforms/nd_flip.py +38 -64
  51. careamics/transforms/normalize.py +45 -34
  52. careamics/transforms/pixel_manipulation.py +2 -2
  53. careamics/transforms/transform.py +33 -0
  54. careamics/transforms/tta.py +2 -2
  55. careamics/transforms/xy_random_rotate90.py +41 -68
  56. careamics/utils/__init__.py +0 -1
  57. careamics/utils/context.py +1 -0
  58. careamics/utils/logging.py +1 -0
  59. careamics/utils/metrics.py +1 -0
  60. careamics/utils/torch_utils.py +1 -0
  61. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/METADATA +16 -61
  62. careamics-0.1.0rc5.dist-info/RECORD +111 -0
  63. careamics/dataset/patching/patch_transform.py +0 -44
  64. careamics-0.1.0rc3.dist-info/RECORD +0 -109
  65. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/WHEEL +0 -0
  66. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/licenses/LICENSE +0 -0
careamics/__init__.py CHANGED
@@ -9,16 +9,18 @@ except PackageNotFoundError:
9
9
 
10
10
  __all__ = [
11
11
  "CAREamist",
12
- "CAREamicsModule",
12
+ "CAREamicsModuleWrapper",
13
+ "CAREamicsPredictData",
14
+ "CAREamicsTrainData",
13
15
  "Configuration",
14
16
  "load_configuration",
15
17
  "save_configuration",
16
- "CAREamicsTrainDataModule",
17
- "CAREamicsPredictDataModule",
18
+ "TrainingDataWrapper",
19
+ "PredictDataWrapper",
18
20
  ]
19
21
 
20
22
  from .careamist import CAREamist
21
23
  from .config import Configuration, load_configuration, save_configuration
22
- from .lightning_datamodule import CAREamicsTrainDataModule
23
- from .lightning_module import CAREamicsModule
24
- from .lightning_prediction_datamodule import CAREamicsPredictDataModule
24
+ from .lightning_datamodule import CAREamicsTrainData, TrainingDataWrapper
25
+ from .lightning_module import CAREamicsModuleWrapper
26
+ from .lightning_prediction_datamodule import CAREamicsPredictData, PredictDataWrapper
careamics/careamist.py CHANGED
@@ -20,9 +20,9 @@ from careamics.config import (
20
20
  )
21
21
  from careamics.config.inference_model import TRANSFORMS_UNION
22
22
  from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
23
- from careamics.lightning_datamodule import CAREamicsWood
24
- from careamics.lightning_module import CAREamicsKiln
25
- from careamics.lightning_prediction_datamodule import CAREamicsClay
23
+ from careamics.lightning_datamodule import CAREamicsTrainData
24
+ from careamics.lightning_module import CAREamicsModule
25
+ from careamics.lightning_prediction_datamodule import CAREamicsPredictData
26
26
  from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
27
27
  from careamics.model_io import export_to_bmz, load_pretrained
28
28
  from careamics.utils import check_path_exists, get_logger
@@ -73,8 +73,7 @@ class CAREamist:
73
73
  source: Union[Path, str],
74
74
  work_dir: Optional[str] = None,
75
75
  experiment_name: str = "CAREamics",
76
- ) -> None:
77
- ...
76
+ ) -> None: ...
78
77
 
79
78
  @overload
80
79
  def __init__( # numpydoc ignore=GL08
@@ -82,8 +81,7 @@ class CAREamist:
82
81
  source: Configuration,
83
82
  work_dir: Optional[str] = None,
84
83
  experiment_name: str = "CAREamics",
85
- ) -> None:
86
- ...
84
+ ) -> None: ...
87
85
 
88
86
  def __init__(
89
87
  self,
@@ -140,7 +138,7 @@ class CAREamist:
140
138
  self.cfg = source
141
139
 
142
140
  # instantiate model
143
- self.model = CAREamicsKiln(
141
+ self.model = CAREamicsModule(
144
142
  algorithm_config=self.cfg.algorithm_config,
145
143
  )
146
144
 
@@ -156,7 +154,7 @@ class CAREamist:
156
154
  self.cfg = load_configuration(source)
157
155
 
158
156
  # instantiate model
159
- self.model = CAREamicsKiln(
157
+ self.model = CAREamicsModule(
160
158
  algorithm_config=self.cfg.algorithm_config,
161
159
  )
162
160
 
@@ -193,8 +191,8 @@ class CAREamist:
193
191
  self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer)
194
192
 
195
193
  # place holder for the datamodules
196
- self.train_datamodule: Optional[CAREamicsWood] = None
197
- self.pred_datamodule: Optional[CAREamicsClay] = None
194
+ self.train_datamodule: Optional[CAREamicsTrainData] = None
195
+ self.pred_datamodule: Optional[CAREamicsPredictData] = None
198
196
 
199
197
  def _define_callbacks(self) -> List[Callback]:
200
198
  """
@@ -227,7 +225,7 @@ class CAREamist:
227
225
  def train(
228
226
  self,
229
227
  *,
230
- datamodule: Optional[CAREamicsWood] = None,
228
+ datamodule: Optional[CAREamicsTrainData] = None,
231
229
  train_source: Optional[Union[Path, str, np.ndarray]] = None,
232
230
  val_source: Optional[Union[Path, str, np.ndarray]] = None,
233
231
  train_target: Optional[Union[Path, str, np.ndarray]] = None,
@@ -360,7 +358,7 @@ class CAREamist:
360
358
  f"instance (got {type(train_source)})."
361
359
  )
362
360
 
363
- def _train_on_datamodule(self, datamodule: CAREamicsWood) -> None:
361
+ def _train_on_datamodule(self, datamodule: CAREamicsTrainData) -> None:
364
362
  """
365
363
  Train the model on the provided datamodule.
366
364
 
@@ -402,7 +400,7 @@ class CAREamist:
402
400
  Minimum number of patches to use for validation, by default 5.
403
401
  """
404
402
  # create datamodule
405
- datamodule = CAREamicsWood(
403
+ datamodule = CAREamicsTrainData(
406
404
  data_config=self.cfg.data_config,
407
405
  train_data=train_data,
408
406
  val_data=val_data,
@@ -458,7 +456,7 @@ class CAREamist:
458
456
  path_to_val_target = check_path_exists(path_to_val_target)
459
457
 
460
458
  # create datamodule
461
- datamodule = CAREamicsWood(
459
+ datamodule = CAREamicsTrainData(
462
460
  data_config=self.cfg.data_config,
463
461
  train_data=path_to_train_data,
464
462
  val_data=path_to_val_data,
@@ -475,11 +473,10 @@ class CAREamist:
475
473
  @overload
476
474
  def predict( # numpydoc ignore=GL08
477
475
  self,
478
- source: CAREamicsClay,
476
+ source: CAREamicsPredictData,
479
477
  *,
480
478
  checkpoint: Optional[Literal["best", "last"]] = None,
481
- ) -> Union[list, np.ndarray]:
482
- ...
479
+ ) -> Union[list, np.ndarray]: ...
483
480
 
484
481
  @overload
485
482
  def predict( # numpydoc ignore=GL08
@@ -497,8 +494,7 @@ class CAREamist:
497
494
  read_source_func: Optional[Callable] = None,
498
495
  extension_filter: str = "",
499
496
  checkpoint: Optional[Literal["best", "last"]] = None,
500
- ) -> Union[list, np.ndarray]:
501
- ...
497
+ ) -> Union[list, np.ndarray]: ...
502
498
 
503
499
  @overload
504
500
  def predict( # numpydoc ignore=GL08
@@ -514,12 +510,11 @@ class CAREamist:
514
510
  tta_transforms: bool = True,
515
511
  dataloader_params: Optional[Dict] = None,
516
512
  checkpoint: Optional[Literal["best", "last"]] = None,
517
- ) -> Union[list, np.ndarray]:
518
- ...
513
+ ) -> Union[list, np.ndarray]: ...
519
514
 
520
515
  def predict(
521
516
  self,
522
- source: Union[CAREamicsClay, Path, str, np.ndarray],
517
+ source: Union[CAREamicsPredictData, Path, str, np.ndarray],
523
518
  *,
524
519
  batch_size: int = 1,
525
520
  tile_size: Optional[Tuple[int, ...]] = None,
@@ -548,6 +543,12 @@ class CAREamist:
548
543
  Test-time augmentation (TTA) can be switched off using the `tta_transforms`
549
544
  parameter.
550
545
 
546
+ Note that if you are using a UNet model and tiling, the tile size must be
547
+ divisible in every dimension by 2**d, where d is the depth of the model. This
548
+ avoids artefacts arising from the broken shift invariance induced by the
549
+ pooling layers of the UNet. If your image has less dimensions, as it may
550
+ happen in the Z dimension, consider padding your image.
551
+
551
552
  Parameters
552
553
  ----------
553
554
  source : Union[CAREamicsClay, Path, str, np.ndarray]
@@ -587,7 +588,7 @@ class CAREamist:
587
588
  ValueError
588
589
  If the input is not a CAREamicsClay instance, a path or a numpy array.
589
590
  """
590
- if isinstance(source, CAREamicsClay):
591
+ if isinstance(source, CAREamicsPredictData):
591
592
  # record datamodule
592
593
  self.pred_datamodule = source
593
594
 
@@ -602,7 +603,7 @@ class CAREamist:
602
603
  )
603
604
  # create predict config, reuse training config if parameters missing
604
605
  prediction_config = create_inference_configuration(
605
- training_configuration=self.cfg,
606
+ configuration=self.cfg,
606
607
  tile_size=tile_size,
607
608
  tile_overlap=tile_overlap,
608
609
  data_type=data_type,
@@ -623,8 +624,8 @@ class CAREamist:
623
624
  source_path = check_path_exists(source)
624
625
 
625
626
  # create datamodule
626
- datamodule = CAREamicsClay(
627
- prediction_config=prediction_config,
627
+ datamodule = CAREamicsPredictData(
628
+ pred_config=prediction_config,
628
629
  pred_data=source_path,
629
630
  read_source_func=read_source_func,
630
631
  extension_filter=extension_filter,
@@ -640,8 +641,8 @@ class CAREamist:
640
641
 
641
642
  elif isinstance(source, np.ndarray):
642
643
  # create datamodule
643
- datamodule = CAREamicsClay(
644
- prediction_config=prediction_config,
644
+ datamodule = CAREamicsPredictData(
645
+ pred_config=prediction_config,
645
646
  pred_data=source,
646
647
  dataloader_params=dataloader_params,
647
648
  )
@@ -1,16 +1,17 @@
1
1
  """Configuration module."""
2
2
 
3
-
4
3
  __all__ = [
5
- "AlgorithmModel",
6
- "DataModel",
4
+ "AlgorithmConfig",
5
+ "DataConfig",
7
6
  "Configuration",
8
7
  "CheckpointModel",
9
- "InferenceModel",
8
+ "InferenceConfig",
10
9
  "load_configuration",
11
10
  "save_configuration",
12
- "TrainingModel",
11
+ "TrainingConfig",
13
12
  "create_n2v_configuration",
13
+ "create_n2n_configuration",
14
+ "create_care_configuration",
14
15
  "register_model",
15
16
  "CustomModel",
16
17
  "create_inference_configuration",
@@ -18,11 +19,13 @@ __all__ = [
18
19
  "ConfigurationInformation",
19
20
  ]
20
21
 
21
- from .algorithm_model import AlgorithmModel
22
+ from .algorithm_model import AlgorithmConfig
22
23
  from .architectures import CustomModel, clear_custom_models, register_model
23
24
  from .callback_model import CheckpointModel
24
25
  from .configuration_factory import (
26
+ create_care_configuration,
25
27
  create_inference_configuration,
28
+ create_n2n_configuration,
26
29
  create_n2v_configuration,
27
30
  )
28
31
  from .configuration_model import (
@@ -30,6 +33,6 @@ from .configuration_model import (
30
33
  load_configuration,
31
34
  save_configuration,
32
35
  )
33
- from .data_model import DataModel
34
- from .inference_model import InferenceModel
35
- from .training_model import TrainingModel
36
+ from .data_model import DataConfig
37
+ from .inference_model import InferenceConfig
38
+ from .training_model import TrainingConfig
@@ -10,7 +10,7 @@ from .architectures import CustomModel, UNetModel, VAEModel
10
10
  from .optimizer_models import LrSchedulerModel, OptimizerModel
11
11
 
12
12
 
13
- class AlgorithmModel(BaseModel):
13
+ class AlgorithmConfig(BaseModel):
14
14
  """Algorithm configuration.
15
15
 
16
16
  This Pydantic model validates the parameters governing the components of the
@@ -45,7 +45,7 @@ class AlgorithmModel(BaseModel):
45
45
  Examples
46
46
  --------
47
47
  Minimum example:
48
- >>> from careamics.config import AlgorithmModel
48
+ >>> from careamics.config import AlgorithmConfig
49
49
  >>> config_dict = {
50
50
  ... "algorithm": "n2v",
51
51
  ... "loss": "n2v",
@@ -53,11 +53,11 @@ class AlgorithmModel(BaseModel):
53
53
  ... "architecture": "UNet",
54
54
  ... }
55
55
  ... }
56
- >>> config = AlgorithmModel(**config_dict)
56
+ >>> config = AlgorithmConfig(**config_dict)
57
57
 
58
58
  Using a custom model:
59
59
  >>> from torch import nn, ones
60
- >>> from careamics.config import AlgorithmModel, register_model
60
+ >>> from careamics.config import AlgorithmConfig, register_model
61
61
  ...
62
62
  >>> @register_model(name="linear_model")
63
63
  ... class LinearModel(nn.Module):
@@ -80,7 +80,7 @@ class AlgorithmModel(BaseModel):
80
80
  ... "out_features": 5,
81
81
  ... }
82
82
  ... }
83
- >>> config = AlgorithmModel(**config_dict)
83
+ >>> config = AlgorithmConfig(**config_dict)
84
84
  """
85
85
 
86
86
  # Pydantic class configuration
@@ -39,6 +39,7 @@ class UNetModel(ArchitectureModel):
39
39
  "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
40
40
  ] = Field(default="None", validate_default=True)
41
41
  n2v2: bool = Field(default=False, validate_default=True)
42
+ independent_channels: bool = Field(default=True, validate_default=True)
42
43
 
43
44
  @field_validator("num_channels_init")
44
45
  @classmethod
@@ -1,4 +1,5 @@
1
1
  """Checkpoint saving configuration."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  from datetime import timedelta
@@ -0,0 +1,87 @@
1
+ from .algorithm_model import AlgorithmConfig
2
+ from .architectures import UNetModel
3
+ from .configuration_model import Configuration
4
+ from .data_model import DataConfig
5
+ from .optimizer_models import LrSchedulerModel, OptimizerModel
6
+ from .support import (
7
+ SupportedActivation,
8
+ SupportedAlgorithm,
9
+ SupportedArchitecture,
10
+ SupportedData,
11
+ SupportedLogger,
12
+ SupportedLoss,
13
+ SupportedOptimizer,
14
+ SupportedPixelManipulation,
15
+ SupportedScheduler,
16
+ SupportedTransform,
17
+ )
18
+ from .training_model import TrainingConfig
19
+
20
+
21
+ def full_configuration_example() -> Configuration:
22
+ """Returns a dictionnary representing a full configuration example.
23
+
24
+ Returns
25
+ -------
26
+ Configuration
27
+ Full configuration example.
28
+ """
29
+ experiment_name = "Full example"
30
+ algorithm_model = AlgorithmConfig(
31
+ algorithm=SupportedAlgorithm.N2V.value,
32
+ loss=SupportedLoss.N2V.value,
33
+ model=UNetModel(
34
+ architecture=SupportedArchitecture.UNET.value,
35
+ in_channels=1,
36
+ num_classes=1,
37
+ depth=2,
38
+ num_channels_init=32,
39
+ final_activation=SupportedActivation.NONE.value,
40
+ n2v2=True,
41
+ ),
42
+ optimizer=OptimizerModel(
43
+ name=SupportedOptimizer.ADAM.value, parameters={"lr": 0.0001}
44
+ ),
45
+ lr_scheduler=LrSchedulerModel(
46
+ name=SupportedScheduler.REDUCE_LR_ON_PLATEAU.value,
47
+ ),
48
+ )
49
+ data_model = DataConfig(
50
+ data_type=SupportedData.ARRAY.value,
51
+ patch_size=(256, 256),
52
+ batch_size=8,
53
+ axes="YX",
54
+ transforms=[
55
+ {
56
+ "name": SupportedTransform.NORMALIZE.value,
57
+ },
58
+ {
59
+ "name": SupportedTransform.NDFLIP.value,
60
+ },
61
+ {
62
+ "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
63
+ },
64
+ {
65
+ "name": SupportedTransform.N2V_MANIPULATE.value,
66
+ "roi_size": 11,
67
+ "masked_pixel_percentage": 0.2,
68
+ "strategy": SupportedPixelManipulation.MEDIAN.value,
69
+ },
70
+ ],
71
+ mean=0.485,
72
+ std=0.229,
73
+ dataloader_params={
74
+ "num_workers": 4,
75
+ },
76
+ )
77
+ training_model = TrainingConfig(
78
+ num_epochs=30,
79
+ logger=SupportedLogger.WANDB.value,
80
+ )
81
+
82
+ return Configuration(
83
+ experiment_name=experiment_name,
84
+ algorithm_config=algorithm_model,
85
+ data_config=data_model,
86
+ training_config=training_model,
87
+ )