careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (87) hide show
  1. careamics/careamist.py +39 -28
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/__init__.py +7 -3
  6. careamics/config/architectures/__init__.py +2 -2
  7. careamics/config/architectures/architecture_model.py +1 -1
  8. careamics/config/architectures/custom_model.py +11 -8
  9. careamics/config/architectures/lvae_model.py +170 -0
  10. careamics/config/configuration_factory.py +481 -170
  11. careamics/config/configuration_model.py +6 -3
  12. careamics/config/data_model.py +31 -20
  13. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
  14. careamics/config/likelihood_model.py +60 -0
  15. careamics/config/nm_model.py +127 -0
  16. careamics/config/optimizer_models.py +3 -1
  17. careamics/config/support/supported_activations.py +1 -0
  18. careamics/config/support/supported_algorithms.py +17 -4
  19. careamics/config/support/supported_architectures.py +8 -11
  20. careamics/config/support/supported_losses.py +3 -1
  21. careamics/config/support/supported_optimizers.py +1 -1
  22. careamics/config/support/supported_transforms.py +1 -0
  23. careamics/config/training_model.py +35 -6
  24. careamics/config/transformations/__init__.py +4 -1
  25. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  26. careamics/config/transformations/transform_union.py +20 -0
  27. careamics/config/vae_algorithm_model.py +137 -0
  28. careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
  29. careamics/file_io/read/tiff.py +1 -1
  30. careamics/lightning/__init__.py +3 -2
  31. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  32. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  33. careamics/lightning/lightning_module.py +367 -9
  34. careamics/lightning/predict_data_module.py +2 -2
  35. careamics/lightning/train_data_module.py +4 -4
  36. careamics/losses/__init__.py +11 -1
  37. careamics/losses/fcn/__init__.py +1 -0
  38. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  39. careamics/losses/loss_factory.py +112 -6
  40. careamics/losses/lvae/__init__.py +1 -0
  41. careamics/losses/lvae/loss_utils.py +83 -0
  42. careamics/losses/lvae/losses.py +445 -0
  43. careamics/lvae_training/dataset/__init__.py +15 -0
  44. careamics/lvae_training/dataset/config.py +123 -0
  45. careamics/lvae_training/dataset/lc_dataset.py +267 -0
  46. careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
  47. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  48. careamics/lvae_training/dataset/types.py +43 -0
  49. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  50. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  51. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  52. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  53. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  54. careamics/lvae_training/eval_utils.py +109 -64
  55. careamics/lvae_training/get_config.py +1 -1
  56. careamics/lvae_training/train_lvae.py +6 -3
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +2 -2
  59. careamics/model_io/bmz_io.py +20 -7
  60. careamics/model_io/model_io_utils.py +16 -4
  61. careamics/models/__init__.py +1 -3
  62. careamics/models/activation.py +2 -0
  63. careamics/models/lvae/__init__.py +3 -0
  64. careamics/models/lvae/layers.py +21 -21
  65. careamics/models/lvae/likelihoods.py +190 -129
  66. careamics/models/lvae/lvae.py +60 -148
  67. careamics/models/lvae/noise_models.py +318 -186
  68. careamics/models/lvae/utils.py +2 -2
  69. careamics/models/model_factory.py +22 -7
  70. careamics/prediction_utils/lvae_prediction.py +158 -0
  71. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  72. careamics/prediction_utils/stitch_prediction.py +16 -2
  73. careamics/transforms/compose.py +90 -15
  74. careamics/transforms/n2v_manipulate.py +6 -2
  75. careamics/transforms/normalize.py +14 -3
  76. careamics/transforms/pixel_manipulation.py +1 -1
  77. careamics/transforms/xy_flip.py +16 -6
  78. careamics/transforms/xy_random_rotate90.py +16 -7
  79. careamics/utils/metrics.py +277 -24
  80. careamics/utils/serializers.py +60 -0
  81. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
  82. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
  83. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  84. careamics/config/architectures/vae_model.py +0 -42
  85. careamics/lvae_training/data_utils.py +0 -618
  86. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  87. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,8 @@
1
1
  """Configuration module."""
2
2
 
3
3
  __all__ = [
4
- "AlgorithmConfig",
4
+ "FCNAlgorithmConfig",
5
+ "VAEAlgorithmConfig",
5
6
  "DataConfig",
6
7
  "Configuration",
7
8
  "CheckpointModel",
@@ -15,9 +16,9 @@ __all__ = [
15
16
  "register_model",
16
17
  "CustomModel",
17
18
  "clear_custom_models",
19
+ "GaussianMixtureNMConfig",
20
+ "MultiChannelNMConfig",
18
21
  ]
19
-
20
- from .algorithm_model import AlgorithmConfig
21
22
  from .architectures import CustomModel, clear_custom_models, register_model
22
23
  from .callback_model import CheckpointModel
23
24
  from .configuration_factory import (
@@ -31,5 +32,8 @@ from .configuration_model import (
31
32
  save_configuration,
32
33
  )
33
34
  from .data_model import DataConfig
35
+ from .fcn_algorithm_model import FCNAlgorithmConfig
34
36
  from .inference_model import InferenceConfig
37
+ from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
35
38
  from .training_model import TrainingConfig
39
+ from .vae_algorithm_model import VAEAlgorithmConfig
@@ -4,7 +4,7 @@ __all__ = [
4
4
  "ArchitectureModel",
5
5
  "CustomModel",
6
6
  "UNetModel",
7
- "VAEModel",
7
+ "LVAEModel",
8
8
  "clear_custom_models",
9
9
  "get_custom_model",
10
10
  "register_model",
@@ -12,6 +12,6 @@ __all__ = [
12
12
 
13
13
  from .architecture_model import ArchitectureModel
14
14
  from .custom_model import CustomModel
15
+ from .lvae_model import LVAEModel
15
16
  from .register_model import clear_custom_models, get_custom_model, register_model
16
17
  from .unet_model import UNetModel
17
- from .vae_model import VAEModel
@@ -27,7 +27,7 @@ class ArchitectureModel(BaseModel):
27
27
  Returns
28
28
  -------
29
29
  dict[str, Any]
30
- Model as a dictionnary.
30
+ Model as a dictionary.
31
31
  """
32
32
  model_dict = super().model_dump(**kwargs)
33
33
 
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import inspect
5
6
  from pprint import pformat
6
7
  from typing import Any, Literal
7
8
 
@@ -23,12 +24,13 @@ class CustomModel(ArchitectureModel):
23
24
 
24
25
  Attributes
25
26
  ----------
26
- architecture : Literal["Custom"]
27
- Discriminator for the custom model, must be set to "Custom".
27
+ architecture : Literal["custom"]
28
+ Discriminator for the custom model, must be set to "custom".
28
29
  name : str
29
30
  Name of the custom model.
30
31
  parameters : CustomParametersModel
31
- Parameters of the custom model.
32
+ All parameters, required for the initialization of the torch module have to be
33
+ passed here.
32
34
 
33
35
  Raises
34
36
  ------
@@ -57,7 +59,7 @@ class CustomModel(ArchitectureModel):
57
59
  ...
58
60
  >>> # Create a configuration
59
61
  >>> config_dict = {
60
- ... "architecture": "Custom",
62
+ ... "architecture": "custom",
61
63
  ... "name": "my_linear",
62
64
  ... "in_features": 10,
63
65
  ... "out_features": 5,
@@ -71,10 +73,9 @@ class CustomModel(ArchitectureModel):
71
73
  )
72
74
 
73
75
  # discriminator used for choosing the pydantic model in Model
74
- architecture: Literal["Custom"]
76
+ architecture: Literal["custom"]
75
77
  """Name of the architecture."""
76
78
 
77
- # name of the custom model
78
79
  name: str
79
80
  """Name of the custom model."""
80
81
 
@@ -120,10 +121,12 @@ class CustomModel(ArchitectureModel):
120
121
  get_custom_model(self.name)(**self.model_dump())
121
122
  except Exception as e:
122
123
  raise ValueError(
123
- f"error while passing parameters to the model {e}. Verify that all "
124
+ f"while passing parameters to the model {e}. Verify that all "
124
125
  f"mandatory parameters are provided, and that either the {e} accepts "
125
126
  f"*args and **kwargs in its __init__() method, or that no additional"
126
- f"parameter is provided."
127
+ f"parameter is provided. Trace: "
128
+ f"filename: {inspect.trace()[-1].filename}, function: "
129
+ f"{inspect.trace()[-1].function}, line: {inspect.trace()[-1].lineno}"
127
130
  ) from None
128
131
 
129
132
  return self
@@ -0,0 +1,170 @@
1
+ """LVAE Pydantic model."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import ConfigDict, Field, field_validator, model_validator
6
+ from typing_extensions import Self
7
+
8
+ from .architecture_model import ArchitectureModel
9
+
10
+
11
+ # TODO: it is quite confusing to call this LVAEModel, as it is basically a config
12
+ class LVAEModel(ArchitectureModel):
13
+ """LVAE model."""
14
+
15
+ model_config = ConfigDict(validate_assignment=True, validate_default=True)
16
+
17
+ architecture: Literal["LVAE"]
18
+ input_shape: int = Field(default=64, ge=8, le=1024)
19
+ multiscale_count: int = Field(default=5) # TODO clarify
20
+ # 0 - off, len(z_dims) + 1 # TODO can/should be le to z_dims len + 1
21
+ z_dims: list = Field(default=[128, 128, 128, 128])
22
+ output_channels: int = Field(default=1, ge=1)
23
+ encoder_n_filters: int = Field(default=64, ge=8, le=1024)
24
+ decoder_n_filters: int = Field(default=64, ge=8, le=1024)
25
+ encoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
26
+ decoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
27
+ nonlinearity: Literal[
28
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
29
+ ] = Field(
30
+ default="ELU",
31
+ )
32
+
33
+ predict_logvar: Literal[None, "pixelwise"] = None
34
+
35
+ analytical_kl: bool = Field(
36
+ default=False,
37
+ )
38
+
39
+ @field_validator("encoder_n_filters")
40
+ @classmethod
41
+ def validate_encoder_even(cls, encoder_n_filters: int) -> int:
42
+ """
43
+ Validate that num_channels_init is even.
44
+
45
+ Parameters
46
+ ----------
47
+ encoder_n_filters : int
48
+ Number of channels.
49
+
50
+ Returns
51
+ -------
52
+ int
53
+ Validated number of channels.
54
+
55
+ Raises
56
+ ------
57
+ ValueError
58
+ If the number of channels is odd.
59
+ """
60
+ # if odd
61
+ if encoder_n_filters % 2 != 0:
62
+ raise ValueError(
63
+ f"Number of channels for the bottom layer must be even"
64
+ f" (got {encoder_n_filters})."
65
+ )
66
+
67
+ return encoder_n_filters
68
+
69
+ @field_validator("decoder_n_filters")
70
+ @classmethod
71
+ def validate_decoder_even(cls, decoder_n_filters: int) -> int:
72
+ """
73
+ Validate that num_channels_init is even.
74
+
75
+ Parameters
76
+ ----------
77
+ decoder_n_filters : int
78
+ Number of channels.
79
+
80
+ Returns
81
+ -------
82
+ int
83
+ Validated number of channels.
84
+
85
+ Raises
86
+ ------
87
+ ValueError
88
+ If the number of channels is odd.
89
+ """
90
+ # if odd
91
+ if decoder_n_filters % 2 != 0:
92
+ raise ValueError(
93
+ f"Number of channels for the bottom layer must be even"
94
+ f" (got {decoder_n_filters})."
95
+ )
96
+
97
+ return decoder_n_filters
98
+
99
+ @field_validator("z_dims")
100
+ def validate_z_dims(cls, z_dims: tuple) -> tuple:
101
+ """
102
+ Validate the z_dims.
103
+
104
+ Parameters
105
+ ----------
106
+ z_dims : tuple
107
+ Tuple of z dimensions.
108
+
109
+ Returns
110
+ -------
111
+ tuple
112
+ Validated z dimensions.
113
+
114
+ Raises
115
+ ------
116
+ ValueError
117
+ If the number of z dimensions is not 4.
118
+ """
119
+ if len(z_dims) < 2:
120
+ raise ValueError(
121
+ f"Number of z dimensions must be at least 2 (got {len(z_dims)})."
122
+ )
123
+
124
+ return z_dims
125
+
126
+ @model_validator(mode="after")
127
+ def validate_multiscale_count(cls, self: Self) -> Self:
128
+ """
129
+ Validate the multiscale count.
130
+
131
+ Parameters
132
+ ----------
133
+ self : Self
134
+ The model.
135
+
136
+ Returns
137
+ -------
138
+ Self
139
+ The validated model.
140
+ """
141
+ # if self.multiscale_count != 0:
142
+ # if self.multiscale_count != len(self.z_dims) - 1:
143
+ # raise ValueError(
144
+ # f"Multiscale count must be 0 or equal to the number of Z "
145
+ # f"dims - 1 (got {self.multiscale_count} and {len(self.z_dims)})."
146
+ # )
147
+
148
+ return self
149
+
150
+ def set_3D(self, is_3D: bool) -> None:
151
+ """
152
+ Set 3D model by setting the `conv_dims` parameters.
153
+
154
+ Parameters
155
+ ----------
156
+ is_3D : bool
157
+ Whether the algorithm is 3D or not.
158
+ """
159
+ raise NotImplementedError("VAE is not implemented yet.")
160
+
161
+ def is_3D(self) -> bool:
162
+ """
163
+ Return whether the model is 3D or not.
164
+
165
+ Returns
166
+ -------
167
+ bool
168
+ Whether the model is 3D or not.
169
+ """
170
+ raise NotImplementedError("VAE is not implemented yet.")