careamics 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (98) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +4 -3
  3. careamics/cli/conf.py +1 -2
  4. careamics/cli/main.py +1 -2
  5. careamics/cli/utils.py +3 -3
  6. careamics/config/__init__.py +47 -25
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +6 -1
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +103 -36
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +1 -2
  25. careamics/config/n2n_configuration.py +101 -0
  26. careamics/config/n2v_configuration.py +266 -0
  27. careamics/config/nm_model.py +1 -2
  28. careamics/config/support/__init__.py +7 -7
  29. careamics/config/support/supported_algorithms.py +0 -3
  30. careamics/config/support/supported_architectures.py +0 -4
  31. careamics/config/transformations/__init__.py +10 -4
  32. careamics/config/transformations/transform_model.py +3 -3
  33. careamics/config/transformations/transform_unions.py +42 -0
  34. careamics/config/validators/validator_utils.py +3 -3
  35. careamics/dataset/__init__.py +2 -2
  36. careamics/dataset/dataset_utils/__init__.py +3 -3
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  38. careamics/dataset/dataset_utils/file_utils.py +9 -9
  39. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  40. careamics/dataset/in_memory_dataset.py +11 -12
  41. careamics/dataset/iterable_dataset.py +4 -4
  42. careamics/dataset/iterable_pred_dataset.py +2 -1
  43. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  44. careamics/dataset/patching/random_patching.py +11 -10
  45. careamics/dataset/patching/sequential_patching.py +26 -26
  46. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  47. careamics/dataset/tiling/__init__.py +2 -2
  48. careamics/dataset/tiling/collate_tiles.py +3 -3
  49. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  50. careamics/dataset/tiling/tiled_patching.py +11 -10
  51. careamics/file_io/__init__.py +5 -5
  52. careamics/file_io/read/__init__.py +1 -1
  53. careamics/file_io/read/get_func.py +2 -2
  54. careamics/file_io/write/__init__.py +2 -2
  55. careamics/lightning/__init__.py +5 -5
  56. careamics/lightning/callbacks/__init__.py +1 -1
  57. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  58. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  60. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  61. careamics/lightning/lightning_module.py +11 -7
  62. careamics/lightning/train_data_module.py +26 -26
  63. careamics/losses/__init__.py +3 -3
  64. careamics/model_io/__init__.py +1 -1
  65. careamics/model_io/bioimage/__init__.py +1 -1
  66. careamics/model_io/bioimage/_readme_factory.py +1 -1
  67. careamics/model_io/bioimage/model_description.py +17 -17
  68. careamics/model_io/bmz_io.py +6 -17
  69. careamics/model_io/model_io_utils.py +9 -9
  70. careamics/models/layers.py +16 -16
  71. careamics/models/lvae/lvae.py +0 -3
  72. careamics/models/model_factory.py +2 -15
  73. careamics/models/unet.py +8 -8
  74. careamics/prediction_utils/__init__.py +1 -1
  75. careamics/prediction_utils/prediction_outputs.py +15 -15
  76. careamics/prediction_utils/stitch_prediction.py +6 -6
  77. careamics/transforms/__init__.py +5 -5
  78. careamics/transforms/compose.py +13 -13
  79. careamics/transforms/n2v_manipulate.py +3 -3
  80. careamics/transforms/pixel_manipulation.py +9 -9
  81. careamics/transforms/xy_random_rotate90.py +4 -4
  82. careamics/utils/__init__.py +5 -5
  83. careamics/utils/context.py +2 -1
  84. careamics/utils/logging.py +11 -10
  85. careamics/utils/torch_utils.py +7 -7
  86. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/METADATA +11 -11
  87. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/RECORD +90 -85
  88. careamics/config/architectures/custom_model.py +0 -162
  89. careamics/config/architectures/register_model.py +0 -103
  90. careamics/config/configuration_model.py +0 -603
  91. careamics/config/fcn_algorithm_model.py +0 -152
  92. careamics/config/references/__init__.py +0 -45
  93. careamics/config/references/algorithm_descriptions.py +0 -132
  94. careamics/config/references/references.py +0 -39
  95. careamics/config/transformations/transform_union.py +0 -20
  96. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/WHEEL +0 -0
  97. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  98. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -9,7 +9,7 @@ import pytorch_lightning as L
9
9
  from numpy.typing import NDArray
10
10
  from torch.utils.data import DataLoader, IterableDataset
11
11
 
12
- from careamics.config import DataConfig
12
+ from careamics.config.data import DataConfig, GeneralDataConfig, N2VDataConfig
13
13
  from careamics.config.support import SupportedData
14
14
  from careamics.config.transformations import TransformModel
15
15
  from careamics.dataset.dataset_utils import (
@@ -119,7 +119,7 @@ class TrainDataModule(L.LightningDataModule):
119
119
 
120
120
  def __init__(
121
121
  self,
122
- data_config: DataConfig,
122
+ data_config: GeneralDataConfig,
123
123
  train_data: Union[Path, str, NDArray],
124
124
  val_data: Optional[Union[Path, str, NDArray]] = None,
125
125
  train_data_target: Optional[Union[Path, str, NDArray]] = None,
@@ -219,7 +219,7 @@ class TrainDataModule(L.LightningDataModule):
219
219
  )
220
220
 
221
221
  # configuration
222
- self.data_config: DataConfig = data_config
222
+ self.data_config: GeneralDataConfig = data_config
223
223
  self.data_type: str = data_config.data_type
224
224
  self.batch_size: int = data_config.batch_size
225
225
  self.use_in_memory: bool = use_in_memory
@@ -502,12 +502,23 @@ def create_train_datamodule(
502
502
  """Create a TrainDataModule.
503
503
 
504
504
  This function is used to explicitly pass the parameters usually contained in a
505
- `data_model` configuration to a TrainDataModule.
505
+ `GenericDataConfig` to a TrainDataModule.
506
506
 
507
507
  Since the lightning datamodule has no access to the model, make sure that the
508
508
  parameters passed to the datamodule are consistent with the model's requirements and
509
509
  are coherent.
510
510
 
511
+ By default, the train DataModule will be set for Noise2Void if no target data is
512
+ provided. That means that it will add a `N2VManipulateModel` transformation to the
513
+ list of augmentations. The default augmentations are XY flip, XY rotation, and N2V
514
+ pixel manipulation. If you pass a training target data, the default behaviour is to
515
+ train a supervised model. It will use the default XY flip and rotation
516
+ augmentations.
517
+
518
+ To use a different set of transformations, you can pass a list of transforms to
519
+ `transforms`. Note that if you intend to use Noise2Void, you should add
520
+ `N2VManipulateModel` as the last transform in the list of transformations.
521
+
511
522
  The data module can be used with Path, str or numpy arrays. In the case of
512
523
  numpy arrays, it loads and computes all the patches in memory. For Path and str
513
524
  inputs, it calculates the total file size and estimate whether it can fit in
@@ -518,11 +529,6 @@ def create_train_datamodule(
518
529
  To use array data, set `data_type` to `array` and pass a numpy array to
519
530
  `train_data`.
520
531
 
521
- In particular, N2V requires a specific transformation (N2V manipulates), which is
522
- not compatible with supervised training. The default transformations applied to the
523
- training patches are defined in `careamics.config.data_model`. To use different
524
- transformations, pass a list of transforms. See examples for more details.
525
-
526
532
  By default, CAREamics only supports types defined in
527
533
  `careamics.config.support.SupportedData`. To read custom data types, you can set
528
534
  `data_type` to `custom` and provide a function that returns a numpy array from a
@@ -627,12 +633,12 @@ def create_train_datamodule(
627
633
  transforms:
628
634
  >>> import numpy as np
629
635
  >>> from careamics.lightning import create_train_datamodule
636
+ >>> from careamics.config.transformations import XYFlipModel, N2VManipulateModel
630
637
  >>> from careamics.config.support import SupportedTransform
631
638
  >>> my_array = np.arange(256).reshape(16, 16)
632
639
  >>> my_transforms = [
633
- ... {
634
- ... "name": SupportedTransform.XY_FLIP.value,
635
- ... }
640
+ ... XYFlipModel(flip_y=False),
641
+ ... N2VManipulateModel()
636
642
  ... ]
637
643
  >>> data_module = create_train_datamodule(
638
644
  ... train_data=my_array,
@@ -659,21 +665,15 @@ def create_train_datamodule(
659
665
  if transforms is not None:
660
666
  data_dict["transforms"] = transforms
661
667
 
662
- # validate configuration
663
- data_config = DataConfig(**data_dict)
668
+ # TODO not compatible with HDN, consider adding an argument for n2v/hdn
669
+ if train_target_data is None:
670
+ data_config: GeneralDataConfig = N2VDataConfig(**data_dict)
671
+ assert isinstance(data_config, N2VDataConfig)
664
672
 
665
- # N2V specific checks, N2V, structN2V, and transforms
666
- if data_config.has_n2v_manipulate():
667
- # there is not target, n2v2 and structN2V can be changed
668
- if train_target_data is None:
669
- data_config.set_N2V2(use_n2v2)
670
- data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
671
- else:
672
- raise ValueError(
673
- "Cannot have both supervised training (target data) and "
674
- "N2V manipulation in the transforms. Pass a list of transforms "
675
- "that is compatible with your supervised training."
676
- )
673
+ data_config.set_n2v2(use_n2v2)
674
+ data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
675
+ else:
676
+ data_config = DataConfig(**data_dict)
677
677
 
678
678
  # sanity check on the dataloader parameters
679
679
  if "batch_size" in dataloader_params:
@@ -1,13 +1,13 @@
1
1
  """Losses module."""
2
2
 
3
3
  __all__ = [
4
+ "denoisplit_loss",
5
+ "denoisplit_musplit_loss",
4
6
  "loss_factory",
5
7
  "mae_loss",
6
8
  "mse_loss",
7
- "n2v_loss",
8
- "denoisplit_loss",
9
9
  "musplit_loss",
10
- "denoisplit_musplit_loss",
10
+ "n2v_loss",
11
11
  ]
12
12
 
13
13
  from .fcn.losses import mae_loss, mse_loss, n2v_loss
@@ -1,6 +1,6 @@
1
1
  """Model I/O utilities."""
2
2
 
3
- __all__ = ["load_pretrained", "export_to_bmz"]
3
+ __all__ = ["export_to_bmz", "load_pretrained"]
4
4
 
5
5
 
6
6
  from .bmz_io import export_to_bmz
@@ -1,10 +1,10 @@
1
1
  """Bioimage Model Zoo format functions."""
2
2
 
3
3
  __all__ = [
4
+ "create_env_text",
4
5
  "create_model_description",
5
6
  "extract_model_path",
6
7
  "get_unzip_path",
7
- "create_env_text",
8
8
  ]
9
9
 
10
10
  from .bioimage_utils import create_env_text, get_unzip_path
@@ -55,7 +55,7 @@ def readme_factory(
55
55
  readme.touch()
56
56
 
57
57
  # algorithm pretty name
58
- algorithm_flavour = config.get_algorithm_flavour()
58
+ algorithm_flavour = config.get_algorithm_friendly_name()
59
59
  algorithm_pretty_name = algorithm_flavour + " - CAREamics"
60
60
 
61
61
  description = [f"# {algorithm_pretty_name}\n\n"]
@@ -1,7 +1,7 @@
1
1
  """Module use to build BMZ model description."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import List, Optional, Tuple, Union
4
+ from typing import Optional, Union
5
5
 
6
6
  import numpy as np
7
7
  from bioimageio.spec._internal.io import resolve_and_extract
@@ -28,17 +28,17 @@ from bioimageio.spec.model.v0_5 import (
28
28
  WeightsDescr,
29
29
  )
30
30
 
31
- from careamics.config import Configuration, DataConfig
31
+ from careamics.config import Configuration, GeneralDataConfig
32
32
 
33
33
  from ._readme_factory import readme_factory
34
34
 
35
35
 
36
36
  def _create_axes(
37
37
  array: np.ndarray,
38
- data_config: DataConfig,
39
- channel_names: Optional[List[str]] = None,
38
+ data_config: GeneralDataConfig,
39
+ channel_names: Optional[list[str]] = None,
40
40
  is_input: bool = True,
41
- ) -> List[AxisBase]:
41
+ ) -> list[AxisBase]:
42
42
  """Create axes description.
43
43
 
44
44
  Array shape is expected to be SC(Z)YX.
@@ -49,15 +49,15 @@ def _create_axes(
49
49
  Array.
50
50
  data_config : DataModel
51
51
  CAREamics data configuration.
52
- channel_names : Optional[List[str]], optional
52
+ channel_names : Optional[list[str]], optional
53
53
  Channel names, by default None.
54
54
  is_input : bool, optional
55
55
  Whether the axes are input axes, by default True.
56
56
 
57
57
  Returns
58
58
  -------
59
- List[AxisBase]
60
- List of axes description.
59
+ list[AxisBase]
60
+ list of axes description.
61
61
 
62
62
  Raises
63
63
  ------
@@ -102,11 +102,11 @@ def _create_axes(
102
102
  def _create_inputs_ouputs(
103
103
  input_array: np.ndarray,
104
104
  output_array: np.ndarray,
105
- data_config: DataConfig,
105
+ data_config: GeneralDataConfig,
106
106
  input_path: Union[Path, str],
107
107
  output_path: Union[Path, str],
108
- channel_names: Optional[List[str]] = None,
109
- ) -> Tuple[InputTensorDescr, OutputTensorDescr]:
108
+ channel_names: Optional[list[str]] = None,
109
+ ) -> tuple[InputTensorDescr, OutputTensorDescr]:
110
110
  """Create input and output tensor description.
111
111
 
112
112
  Input and output paths must point to a `.npy` file.
@@ -123,12 +123,12 @@ def _create_inputs_ouputs(
123
123
  Path to input .npy file.
124
124
  output_path : Union[Path, str]
125
125
  Path to output .npy file.
126
- channel_names : Optional[List[str]], optional
126
+ channel_names : Optional[list[str]], optional
127
127
  Channel names, by default None.
128
128
 
129
129
  Returns
130
130
  -------
131
- Tuple[InputTensorDescr, OutputTensorDescr]
131
+ tuple[InputTensorDescr, OutputTensorDescr]
132
132
  Input and output tensor descriptions.
133
133
  """
134
134
  input_axes = _create_axes(input_array, data_config, channel_names)
@@ -188,7 +188,7 @@ def create_model_description(
188
188
  name: str,
189
189
  general_description: str,
190
190
  data_description: str,
191
- authors: List[Author],
191
+ authors: list[Author],
192
192
  inputs: Union[Path, str],
193
193
  outputs: Union[Path, str],
194
194
  weights_path: Union[Path, str],
@@ -197,7 +197,7 @@ def create_model_description(
197
197
  config_path: Union[Path, str],
198
198
  env_path: Union[Path, str],
199
199
  covers: list[Union[Path, str]],
200
- channel_names: Optional[List[str]] = None,
200
+ channel_names: Optional[list[str]] = None,
201
201
  model_version: str = "0.1.0",
202
202
  ) -> ModelDescr:
203
203
  """Create model description.
@@ -212,7 +212,7 @@ def create_model_description(
212
212
  General description of the model.
213
213
  data_description : str
214
214
  Description of the data the model was trained on.
215
- authors : List[Author]
215
+ authors : list[Author]
216
216
  Authors of the model.
217
217
  inputs : Union[Path, str]
218
218
  Path to input .npy file.
@@ -230,7 +230,7 @@ def create_model_description(
230
230
  Path to environment file.
231
231
  covers : list of pathlib.Path or str
232
232
  Paths to cover images.
233
- channel_names : Optional[List[str]], optional
233
+ channel_names : Optional[list[str]], optional
234
234
  Channel names, by default None.
235
235
  model_version : str, default "0.1.0"
236
236
  Model version.
@@ -2,7 +2,7 @@
2
2
 
3
3
  import tempfile
4
4
  from pathlib import Path
5
- from typing import List, Optional, Tuple, Union
5
+ from typing import Optional, Union
6
6
 
7
7
  import numpy as np
8
8
  import pkg_resources
@@ -87,11 +87,11 @@ def export_to_bmz(
87
87
  model_name: str,
88
88
  general_description: str,
89
89
  data_description: str,
90
- authors: List[dict],
90
+ authors: list[dict],
91
91
  input_array: np.ndarray,
92
92
  output_array: np.ndarray,
93
93
  covers: Optional[list[Union[Path, str]]] = None,
94
- channel_names: Optional[List[str]] = None,
94
+ channel_names: Optional[list[str]] = None,
95
95
  model_version: str = "0.1.0",
96
96
  ) -> None:
97
97
  """Export the model to BioImage Model Zoo format.
@@ -115,7 +115,7 @@ def export_to_bmz(
115
115
  General description of the model.
116
116
  data_description : str
117
117
  Description of the data the model was trained on.
118
- authors : List[dict]
118
+ authors : list[dict]
119
119
  Authors of the model.
120
120
  input_array : np.ndarray
121
121
  Input array, should not have been normalized.
@@ -123,24 +123,13 @@ def export_to_bmz(
123
123
  Output array, should have been denormalized.
124
124
  covers : list of pathlib.Path or str, default=None
125
125
  Paths to the cover images.
126
- channel_names : Optional[List[str]], optional
126
+ channel_names : Optional[list[str]], optional
127
127
  Channel names, by default None.
128
128
  model_version : str, default="0.1.0"
129
129
  Model version.
130
-
131
- Raises
132
- ------
133
- ValueError
134
- If the model is a Custom model.
135
130
  """
136
131
  path_to_archive = Path(path_to_archive)
137
132
 
138
- # method is not compatible with Custom models
139
- if config.algorithm_config.model.architecture == SupportedArchitecture.CUSTOM:
140
- raise ValueError(
141
- "Exporting Custom models to BioImage Model Zoo format is not supported."
142
- )
143
-
144
133
  if path_to_archive.suffix != ".zip":
145
134
  raise ValueError(
146
135
  f"Path to archive must point to a zip file, got {path_to_archive}."
@@ -212,7 +201,7 @@ def export_to_bmz(
212
201
 
213
202
  def load_from_bmz(
214
203
  path: Union[Path, str, HttpUrl]
215
- ) -> Tuple[Union[FCNModule, VAEModule], Configuration]:
204
+ ) -> tuple[Union[FCNModule, VAEModule], Configuration]:
216
205
  """Load a model from a BioImage Model Zoo archive.
217
206
 
218
207
  Parameters
@@ -1,11 +1,11 @@
1
1
  """Utility functions to load pretrained models."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Tuple, Union
4
+ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from careamics.config import Configuration
8
+ from careamics.config import Configuration, configuration_factory
9
9
  from careamics.lightning.lightning_module import FCNModule, VAEModule
10
10
  from careamics.model_io.bmz_io import load_from_bmz
11
11
  from careamics.utils import check_path_exists
@@ -13,7 +13,7 @@ from careamics.utils import check_path_exists
13
13
 
14
14
  def load_pretrained(
15
15
  path: Union[Path, str]
16
- ) -> Tuple[Union[FCNModule, VAEModule], Configuration]:
16
+ ) -> tuple[Union[FCNModule, VAEModule], Configuration]:
17
17
  """
18
18
  Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
19
19
 
@@ -26,8 +26,8 @@ def load_pretrained(
26
26
 
27
27
  Returns
28
28
  -------
29
- Tuple[CAREamicsKiln, Configuration]
30
- Tuple of CAREamics model and its configuration.
29
+ tuple[CAREamicsKiln, Configuration]
30
+ tuple of CAREamics model and its configuration.
31
31
 
32
32
  Raises
33
33
  ------
@@ -48,7 +48,7 @@ def load_pretrained(
48
48
 
49
49
  def _load_checkpoint(
50
50
  path: Union[Path, str]
51
- ) -> Tuple[Union[FCNModule, VAEModule], Configuration]:
51
+ ) -> tuple[Union[FCNModule, VAEModule], Configuration]:
52
52
  """
53
53
  Load a model from a checkpoint and return both model and configuration.
54
54
 
@@ -59,8 +59,8 @@ def _load_checkpoint(
59
59
 
60
60
  Returns
61
61
  -------
62
- Tuple[CAREamicsKiln, Configuration]
63
- Tuple of CAREamics model and its configuration.
62
+ tuple[CAREamicsKiln, Configuration]
63
+ tuple of CAREamics model and its configuration.
64
64
 
65
65
  Raises
66
66
  ------
@@ -92,4 +92,4 @@ def _load_checkpoint(
92
92
  f"{cfg_dict['algorithm_config']['model']['architecture']}"
93
93
  )
94
94
 
95
- return model, Configuration(**cfg_dict)
95
+ return model, configuration_factory(cfg_dict)
@@ -4,7 +4,7 @@ Layer module.
4
4
  This submodule contains layers used in the CAREamics models.
5
5
  """
6
6
 
7
- from typing import List, Optional, Tuple, Union
7
+ from typing import Optional, Union
8
8
 
9
9
  import torch
10
10
  import torch.nn as nn
@@ -157,22 +157,22 @@ class Conv_Block(nn.Module):
157
157
 
158
158
 
159
159
  def _unpack_kernel_size(
160
- kernel_size: Union[Tuple[int, ...], int], dim: int
161
- ) -> Tuple[int, ...]:
160
+ kernel_size: Union[tuple[int, ...], int], dim: int
161
+ ) -> tuple[int, ...]:
162
162
  """Unpack kernel_size to a tuple of ints.
163
163
 
164
164
  Inspired by Kornia implementation. TODO: link
165
165
 
166
166
  Parameters
167
167
  ----------
168
- kernel_size : Union[Tuple[int, ...], int]
168
+ kernel_size : Union[tuple[int, ...], int]
169
169
  Kernel size.
170
170
  dim : int
171
171
  Number of dimensions.
172
172
 
173
173
  Returns
174
174
  -------
175
- Tuple[int, ...]
175
+ tuple[int, ...]
176
176
  Kernel size tuple.
177
177
  """
178
178
  if isinstance(kernel_size, int):
@@ -183,20 +183,20 @@ def _unpack_kernel_size(
183
183
 
184
184
 
185
185
  def _compute_zero_padding(
186
- kernel_size: Union[Tuple[int, ...], int], dim: int
187
- ) -> Tuple[int, ...]:
186
+ kernel_size: Union[tuple[int, ...], int], dim: int
187
+ ) -> tuple[int, ...]:
188
188
  """Utility function that computes zero padding tuple.
189
189
 
190
190
  Parameters
191
191
  ----------
192
- kernel_size : Union[Tuple[int, ...], int]
192
+ kernel_size : Union[tuple[int, ...], int]
193
193
  Kernel size.
194
194
  dim : int
195
195
  Number of dimensions.
196
196
 
197
197
  Returns
198
198
  -------
199
- Tuple[int, ...]
199
+ tuple[int, ...]
200
200
  Zero padding tuple.
201
201
  """
202
202
  kernel_dims = _unpack_kernel_size(kernel_size, dim)
@@ -245,8 +245,8 @@ def get_pascal_kernel_1d(
245
245
  >>> get_pascal_kernel_1d(6)
246
246
  tensor([ 1., 5., 10., 10., 5., 1.])
247
247
  """
248
- pre: List[float] = []
249
- cur: List[float] = []
248
+ pre: list[float] = []
249
+ cur: list[float] = []
250
250
  for i in range(kernel_size):
251
251
  cur = [1.0] * (i + 1)
252
252
 
@@ -266,7 +266,7 @@ def get_pascal_kernel_1d(
266
266
 
267
267
 
268
268
  def _get_pascal_kernel_nd(
269
- kernel_size: Union[Tuple[int, int], int],
269
+ kernel_size: Union[tuple[int, int], int],
270
270
  norm: bool = True,
271
271
  dim: int = 2,
272
272
  *,
@@ -282,7 +282,7 @@ def _get_pascal_kernel_nd(
282
282
 
283
283
  Parameters
284
284
  ----------
285
- kernel_size : Union[Tuple[int, int], int]
285
+ kernel_size : Union[tuple[int, int], int]
286
286
  Kernel size for the pascal kernel.
287
287
  norm : bool
288
288
  Normalize the kernel, by default True.
@@ -420,7 +420,7 @@ class MaxBlurPool(nn.Module):
420
420
  ----------
421
421
  dim : int
422
422
  Toggles between 2D and 3D.
423
- kernel_size : Union[Tuple[int, int], int]
423
+ kernel_size : Union[tuple[int, int], int]
424
424
  Kernel size for max pooling.
425
425
  stride : int
426
426
  Stride for pooling.
@@ -433,7 +433,7 @@ class MaxBlurPool(nn.Module):
433
433
  def __init__(
434
434
  self,
435
435
  dim: int,
436
- kernel_size: Union[Tuple[int, int], int],
436
+ kernel_size: Union[tuple[int, int], int],
437
437
  stride: int = 2,
438
438
  max_pool_size: int = 2,
439
439
  ceil_mode: bool = False,
@@ -444,7 +444,7 @@ class MaxBlurPool(nn.Module):
444
444
  ----------
445
445
  dim : int
446
446
  Dimension of the convolution.
447
- kernel_size : Union[Tuple[int, int], int]
447
+ kernel_size : Union[tuple[int, int], int]
448
448
  Kernel size for max pooling.
449
449
  stride : int, optional
450
450
  Stride, by default 2.
@@ -12,8 +12,6 @@ import numpy as np
12
12
  import torch
13
13
  import torch.nn as nn
14
14
 
15
- from careamics.config.architectures import register_model
16
-
17
15
  from ..activation import get_activation
18
16
  from .layers import (
19
17
  BottomUpDeterministicResBlock,
@@ -25,7 +23,6 @@ from .layers import (
25
23
  from .utils import Interpolate, ModelType, crop_img_tensor
26
24
 
27
25
 
28
- @register_model("LVAE")
29
26
  class LadderVAE(nn.Module):
30
27
  """
31
28
  Constructor.
@@ -1,8 +1,4 @@
1
- """
2
- Model factory.
3
-
4
- Model creation factory functions.
5
- """
1
+ """Model creation factory functions."""
6
2
 
7
3
  from __future__ import annotations
8
4
 
@@ -10,10 +6,6 @@ from typing import TYPE_CHECKING, Union
10
6
 
11
7
  import torch
12
8
 
13
- from careamics.config.architectures import (
14
- CustomModel,
15
- get_custom_model,
16
- )
17
9
  from careamics.config.support import SupportedArchitecture
18
10
  from careamics.models.lvae import LadderVAE as LVAE
19
11
  from careamics.models.unet import UNet
@@ -21,7 +13,6 @@ from careamics.utils import get_logger
21
13
 
22
14
  if TYPE_CHECKING:
23
15
  from careamics.config.architectures import (
24
- CustomModel,
25
16
  LVAEModel,
26
17
  UNetModel,
27
18
  )
@@ -31,7 +22,7 @@ logger = get_logger(__name__)
31
22
 
32
23
 
33
24
  def model_factory(
34
- model_configuration: Union[UNetModel, LVAEModel, CustomModel],
25
+ model_configuration: Union[UNetModel, LVAEModel],
35
26
  ) -> torch.nn.Module:
36
27
  """
37
28
  Deep learning model factory.
@@ -57,10 +48,6 @@ def model_factory(
57
48
  return UNet(**model_configuration.model_dump())
58
49
  elif model_configuration.architecture == SupportedArchitecture.LVAE:
59
50
  return LVAE(**model_configuration.model_dump())
60
- elif model_configuration.architecture == SupportedArchitecture.CUSTOM:
61
- assert isinstance(model_configuration, CustomModel)
62
- model = get_custom_model(model_configuration.name)
63
- return model(**model_configuration.model_dump())
64
51
  else:
65
52
  raise NotImplementedError(
66
53
  f"Model {model_configuration.architecture} is not implemented or unknown."
careamics/models/unet.py CHANGED
@@ -4,7 +4,7 @@ UNet model.
4
4
  A UNet encoder, decoder and complete model.
5
5
  """
6
6
 
7
- from typing import Any, List, Tuple, Union
7
+ from typing import Any, Union
8
8
 
9
9
  import torch
10
10
  import torch.nn as nn
@@ -104,7 +104,7 @@ class UnetEncoder(nn.Module):
104
104
  encoder_blocks.append(self.pooling)
105
105
  self.encoder_blocks = nn.ModuleList(encoder_blocks)
106
106
 
107
- def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
107
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
108
108
  """
109
109
  Forward pass.
110
110
 
@@ -115,7 +115,7 @@ class UnetEncoder(nn.Module):
115
115
 
116
116
  Returns
117
117
  -------
118
- List[torch.Tensor]
118
+ list[torch.Tensor]
119
119
  Output of each encoder block (skip connections) and final output of the
120
120
  encoder.
121
121
  """
@@ -202,7 +202,7 @@ class UnetDecoder(nn.Module):
202
202
  groups=self.groups,
203
203
  )
204
204
 
205
- decoder_blocks: List[nn.Module] = []
205
+ decoder_blocks: list[nn.Module] = []
206
206
  for n in range(depth):
207
207
  decoder_blocks.append(upsampling)
208
208
  in_channels = (num_channels_init * 2 ** (depth - n)) * groups
@@ -230,7 +230,7 @@ class UnetDecoder(nn.Module):
230
230
 
231
231
  Parameters
232
232
  ----------
233
- *features : List[torch.Tensor]
233
+ *features : list[torch.Tensor]
234
234
  List containing the output of each encoder block(skip connections) and final
235
235
  output of the encoder.
236
236
 
@@ -240,7 +240,7 @@ class UnetDecoder(nn.Module):
240
240
  Output of the decoder.
241
241
  """
242
242
  x: torch.Tensor = features[0]
243
- skip_connections: Tuple[torch.Tensor, ...] = features[-1:0:-1]
243
+ skip_connections: tuple[torch.Tensor, ...] = features[-1:0:-1]
244
244
 
245
245
  x = self.bottleneck(x)
246
246
 
@@ -289,10 +289,10 @@ class UnetDecoder(nn.Module):
289
289
  m = A.shape[1] // groups
290
290
  n = B.shape[1] // groups
291
291
 
292
- A_groups: List[torch.Tensor] = [
292
+ A_groups: list[torch.Tensor] = [
293
293
  A[:, i * m : (i + 1) * m] for i in range(groups)
294
294
  ]
295
- B_groups: List[torch.Tensor] = [
295
+ B_groups: list[torch.Tensor] = [
296
296
  B[:, i * n : (i + 1) * n] for i in range(groups)
297
297
  ]
298
298
 
@@ -1,9 +1,9 @@
1
1
  """Package to house various prediction utilies."""
2
2
 
3
3
  __all__ = [
4
+ "convert_outputs",
4
5
  "stitch_prediction",
5
6
  "stitch_prediction_single",
6
- "convert_outputs",
7
7
  ]
8
8
 
9
9
  from .prediction_outputs import convert_outputs