careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +164 -231
  4. careamics/config/algorithm_model.py +5 -18
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +11 -4
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -5
  12. careamics/config/configuration_factory.py +27 -41
  13. careamics/config/configuration_model.py +11 -11
  14. careamics/config/data_model.py +89 -63
  15. careamics/config/inference_model.py +28 -81
  16. careamics/config/optimizer_models.py +11 -11
  17. careamics/config/support/__init__.py +0 -2
  18. careamics/config/support/supported_activations.py +2 -0
  19. careamics/config/support/supported_algorithms.py +3 -1
  20. careamics/config/support/supported_architectures.py +2 -0
  21. careamics/config/support/supported_data.py +2 -0
  22. careamics/config/support/supported_loggers.py +2 -0
  23. careamics/config/support/supported_losses.py +2 -0
  24. careamics/config/support/supported_optimizers.py +2 -0
  25. careamics/config/support/supported_pixel_manipulations.py +3 -3
  26. careamics/config/support/supported_struct_axis.py +2 -0
  27. careamics/config/support/supported_transforms.py +4 -16
  28. careamics/config/tile_information.py +28 -58
  29. careamics/config/transformations/__init__.py +3 -2
  30. careamics/config/transformations/normalize_model.py +32 -4
  31. careamics/config/transformations/xy_flip_model.py +43 -0
  32. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  33. careamics/config/validators/validator_utils.py +1 -1
  34. careamics/conftest.py +12 -0
  35. careamics/dataset/__init__.py +12 -1
  36. careamics/dataset/dataset_utils/__init__.py +8 -1
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  38. careamics/dataset/dataset_utils/file_utils.py +4 -3
  39. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  40. careamics/dataset/dataset_utils/read_tiff.py +6 -11
  41. careamics/dataset/dataset_utils/read_utils.py +2 -0
  42. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  43. careamics/dataset/dataset_utils/running_stats.py +186 -0
  44. careamics/dataset/in_memory_dataset.py +88 -154
  45. careamics/dataset/in_memory_pred_dataset.py +88 -0
  46. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  47. careamics/dataset/iterable_dataset.py +121 -191
  48. careamics/dataset/iterable_pred_dataset.py +121 -0
  49. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  50. careamics/dataset/patching/patching.py +109 -39
  51. careamics/dataset/patching/random_patching.py +17 -6
  52. careamics/dataset/patching/sequential_patching.py +14 -8
  53. careamics/dataset/patching/validate_patch_dimension.py +7 -3
  54. careamics/dataset/tiling/__init__.py +10 -0
  55. careamics/dataset/tiling/collate_tiles.py +33 -0
  56. careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
  57. careamics/dataset/zarr_dataset.py +2 -0
  58. careamics/lightning_datamodule.py +46 -25
  59. careamics/lightning_module.py +19 -9
  60. careamics/lightning_prediction_datamodule.py +54 -84
  61. careamics/losses/__init__.py +2 -3
  62. careamics/losses/loss_factory.py +1 -1
  63. careamics/losses/losses.py +11 -7
  64. careamics/lvae_training/__init__.py +0 -0
  65. careamics/lvae_training/data_modules.py +1220 -0
  66. careamics/lvae_training/data_utils.py +618 -0
  67. careamics/lvae_training/eval_utils.py +905 -0
  68. careamics/lvae_training/get_config.py +84 -0
  69. careamics/lvae_training/lightning_module.py +701 -0
  70. careamics/lvae_training/metrics.py +214 -0
  71. careamics/lvae_training/train_lvae.py +339 -0
  72. careamics/lvae_training/train_utils.py +121 -0
  73. careamics/model_io/bioimage/model_description.py +40 -32
  74. careamics/model_io/bmz_io.py +3 -3
  75. careamics/model_io/model_io_utils.py +5 -2
  76. careamics/models/activation.py +2 -0
  77. careamics/models/layers.py +121 -25
  78. careamics/models/lvae/__init__.py +0 -0
  79. careamics/models/lvae/layers.py +1998 -0
  80. careamics/models/lvae/likelihoods.py +312 -0
  81. careamics/models/lvae/lvae.py +985 -0
  82. careamics/models/lvae/noise_models.py +409 -0
  83. careamics/models/lvae/utils.py +395 -0
  84. careamics/models/model_factory.py +1 -1
  85. careamics/models/unet.py +35 -14
  86. careamics/prediction_utils/__init__.py +12 -0
  87. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  88. careamics/prediction_utils/prediction_outputs.py +165 -0
  89. careamics/prediction_utils/stitch_prediction.py +100 -0
  90. careamics/transforms/__init__.py +2 -2
  91. careamics/transforms/compose.py +33 -7
  92. careamics/transforms/n2v_manipulate.py +52 -14
  93. careamics/transforms/normalize.py +171 -48
  94. careamics/transforms/pixel_manipulation.py +35 -11
  95. careamics/transforms/struct_mask_parameters.py +3 -1
  96. careamics/transforms/transform.py +10 -19
  97. careamics/transforms/tta.py +43 -29
  98. careamics/transforms/xy_flip.py +123 -0
  99. careamics/transforms/xy_random_rotate90.py +38 -5
  100. careamics/utils/base_enum.py +28 -0
  101. careamics/utils/path_utils.py +2 -0
  102. careamics/utils/ram.py +4 -2
  103. careamics/utils/receptive_field.py +93 -87
  104. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
  105. careamics-0.1.0rc7.dist-info/RECORD +130 -0
  106. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  107. careamics/config/noise_models.py +0 -162
  108. careamics/config/support/supported_extraction_strategies.py +0 -25
  109. careamics/config/transformations/nd_flip_model.py +0 -27
  110. careamics/lightning_prediction_loop.py +0 -116
  111. careamics/losses/noise_model_factory.py +0 -40
  112. careamics/losses/noise_models.py +0 -524
  113. careamics/prediction/__init__.py +0 -7
  114. careamics/prediction/stitch_prediction.py +0 -74
  115. careamics/transforms/nd_flip.py +0 -67
  116. careamics/utils/running_stats.py +0 -43
  117. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  118. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -1,162 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from enum import Enum
4
- from typing import Dict, Union
5
-
6
- from pydantic import BaseModel, ConfigDict, Field, field_validator
7
-
8
-
9
- class NoiseModelType(str, Enum):
10
- """
11
- Available noise models.
12
-
13
- Currently supported noise models:
14
-
15
- - hist: Histogram noise model.
16
- - gmm: Gaussian mixture model noise model.F
17
- """
18
-
19
- NONE = "none"
20
- HIST = "hist"
21
- GMM = "gmm"
22
-
23
- # TODO add validator decorator
24
- @classmethod
25
- def validate_noise_model_type(
26
- cls, noise_model: Union[str, NoiseModel], parameters: dict
27
- ) -> None:
28
- """_summary_.
29
-
30
- Parameters
31
- ----------
32
- noise_model : Union[str, NoiseModel]
33
- _description_
34
- parameters : dict
35
- _description_
36
-
37
- Returns
38
- -------
39
- BaseModel
40
- _description_
41
- """
42
- if noise_model == NoiseModelType.HIST.value:
43
- HistogramNoiseModel(**parameters)
44
- return HistogramNoiseModel().model_dump() if not parameters else parameters
45
-
46
- elif noise_model == NoiseModelType.GMM.value:
47
- GaussianMixtureNoiseModel(**parameters)
48
- return (
49
- GaussianMixtureNoiseModel().model_dump()
50
- if not parameters
51
- else parameters
52
- )
53
-
54
-
55
- class NoiseModel(BaseModel):
56
- """_summary_.
57
-
58
- Parameters
59
- ----------
60
- BaseModel : _type_
61
- _description_
62
-
63
- Returns
64
- -------
65
- _type_
66
- _description_
67
-
68
- Raises
69
- ------
70
- ValueError
71
- _description_
72
- """
73
-
74
- model_config = ConfigDict(
75
- use_enum_values=True,
76
- protected_namespaces=(), # allows to use model_* as a field name
77
- validate_assignment=True,
78
- )
79
-
80
- model_type: NoiseModelType
81
- parameters: Dict = Field(default_factory=dict, validate_default=True)
82
-
83
- @field_validator("parameters")
84
- @classmethod
85
- def validate_parameters(cls, data, values) -> Dict:
86
- """_summary_.
87
-
88
- Parameters
89
- ----------
90
- parameters : Dict
91
- _description_
92
-
93
- Returns
94
- -------
95
- Dict
96
- _description_
97
- """
98
- if values.data["model_type"] not in [NoiseModelType.GMM, NoiseModelType.HIST]:
99
- raise ValueError(
100
- f"Incorrect noise model {values.data['model_type']}."
101
- f"Please refer to the documentation" # TODO add link to documentation
102
- )
103
-
104
- parameters = NoiseModelType.validate_noise_model_type(
105
- values.data["model_type"], data
106
- )
107
- return parameters
108
-
109
-
110
- class HistogramNoiseModel(BaseModel):
111
- """
112
- Histogram noise model.
113
-
114
- Attributes
115
- ----------
116
- min_value : float
117
- Minimum value in the input.
118
- max_value : float
119
- Maximum value in the input.
120
- bins : int
121
- Number of bins of the histogram.
122
- """
123
-
124
- min_value: float = Field(default=350.0, ge=0.0, le=65535.0)
125
- max_value: float = Field(default=6500.0, ge=0.0, le=65535.0)
126
- bins: int = Field(default=256, ge=1)
127
-
128
-
129
- class GaussianMixtureNoiseModel(BaseModel):
130
- """
131
- Gaussian mixture model noise model.
132
-
133
- Attributes
134
- ----------
135
- min_signal : float
136
- Minimum signal intensity expected in the image.
137
- max_signal : float
138
- Maximum signal intensity expected in the image.
139
- weight : array
140
- A [3*n_gaussian, n_coeff] sized array containing the values of the weights
141
- describing the noise model.
142
- Each gaussian contributes three parameters (mean, standard deviation and weight),
143
- hence the number of rows in `weight` are 3*n_gaussian.
144
- If `weight = None`, the weight array is initialized using the `min_signal` and
145
- `max_signal` parameters.
146
- n_gaussian: int
147
- Number of gaussians.
148
- n_coeff: int
149
- Number of coefficients to describe the functional relationship between gaussian
150
- parameters and the signal.
151
- 2 implies a linear relationship, 3 implies a quadratic relationship and so on.
152
- device: device
153
- GPU device.
154
- min_sigma: int
155
- """
156
-
157
- num_components: int = Field(default=3, ge=1)
158
- min_value: float = Field(default=350.0, ge=0.0, le=65535.0)
159
- max_value: float = Field(default=6500.0, ge=0.0, le=65535.0)
160
- n_gaussian: int = Field(default=3, ge=1)
161
- n_coeff: int = Field(default=2, ge=1)
162
- min_sigma: int = Field(default=50, ge=1)
@@ -1,25 +0,0 @@
1
- """
2
- Extraction strategy module.
3
-
4
- This module defines the various extraction strategies available in CAREamics.
5
- """
6
-
7
- from careamics.utils import BaseEnum
8
-
9
-
10
- class SupportedExtractionStrategy(str, BaseEnum):
11
- """
12
- Available extraction strategies.
13
-
14
- Currently supported:
15
- - random: random extraction.
16
- # TODO
17
- - sequential: grid extraction, can miss edge values.
18
- - tiled: tiled extraction, covers the whole image.
19
- """
20
-
21
- RANDOM = "random"
22
- RANDOM_ZARR = "random_zarr"
23
- SEQUENTIAL = "sequential"
24
- TILED = "tiled"
25
- NONE = "none"
@@ -1,27 +0,0 @@
1
- """Pydantic model for the NDFlip transform."""
2
-
3
- from typing import Literal, Optional
4
-
5
- from pydantic import ConfigDict
6
-
7
- from .transform_model import TransformModel
8
-
9
-
10
- class NDFlipModel(TransformModel):
11
- """
12
- Pydantic model used to represent NDFlip transformation.
13
-
14
- Attributes
15
- ----------
16
- name : Literal["NDFlip"]
17
- Name of the transformation.
18
- seed : Optional[int]
19
- Seed for the random number generator.
20
- """
21
-
22
- model_config = ConfigDict(
23
- validate_assignment=True,
24
- )
25
-
26
- name: Literal["NDFlip"] = "NDFlip"
27
- seed: Optional[int] = None
@@ -1,116 +0,0 @@
1
- from typing import Optional
2
-
3
- import pytorch_lightning as L
4
- from pytorch_lightning.loops.fetchers import _DataLoaderIterDataFetcher
5
- from pytorch_lightning.loops.utilities import _no_grad_context
6
- from pytorch_lightning.trainer import call
7
- from pytorch_lightning.utilities.types import _PREDICT_OUTPUT
8
-
9
- from careamics.prediction import stitch_prediction
10
-
11
-
12
- class CAREamicsPredictionLoop(L.loops._PredictionLoop):
13
- """
14
- CAREamics prediction loop.
15
-
16
- This class extends the PyTorch Lightning `_PredictionLoop` class to include
17
- the stitching of the tiles into a single prediction result.
18
- """
19
-
20
- def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
21
- """
22
- Calls `on_predict_epoch_end` hook.
23
-
24
- Adapted from the parent method.
25
-
26
- Returns
27
- -------
28
- the results for all dataloaders
29
- """
30
- trainer = self.trainer
31
- call._call_callback_hooks(trainer, "on_predict_epoch_end")
32
- call._call_lightning_module_hook(trainer, "on_predict_epoch_end")
33
-
34
- if self.return_predictions:
35
- ########################################################
36
- ################ CAREamics specific code ###############
37
- if len(self.predicted_array) == 1:
38
- # TODO does this make sense to here? (force numpy array)
39
- return self.predicted_array[0].numpy()
40
- else:
41
- # TODO revisit logic
42
- return [element.numpy() for element in self.predicted_array]
43
- ########################################################
44
- return None
45
-
46
- @_no_grad_context
47
- def run(self) -> Optional[_PREDICT_OUTPUT]:
48
- """
49
- Runs the prediction loop.
50
-
51
- Adapted from the parent method in order to stitch the predictions.
52
-
53
- Returns
54
- -------
55
- Optional[_PREDICT_OUTPUT]
56
- Prediction output
57
- """
58
- self.setup_data()
59
- if self.skip:
60
- return None
61
- self.reset()
62
- self.on_run_start()
63
- data_fetcher = self._data_fetcher
64
- assert data_fetcher is not None
65
-
66
- self.predicted_array = []
67
- self.tiles = []
68
- self.stitching_data = []
69
-
70
- while True:
71
- try:
72
- if isinstance(data_fetcher, _DataLoaderIterDataFetcher):
73
- dataloader_iter = next(data_fetcher)
74
- # hook's batch_idx and dataloader_idx arguments correctness cannot
75
- # be guaranteed in this setting
76
- batch = data_fetcher._batch
77
- batch_idx = data_fetcher._batch_idx
78
- dataloader_idx = data_fetcher._dataloader_idx
79
- else:
80
- dataloader_iter = None
81
- batch, batch_idx, dataloader_idx = next(data_fetcher)
82
- self.batch_progress.is_last_batch = data_fetcher.done
83
-
84
- # run step hooks
85
- self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter)
86
-
87
- ########################################################
88
- ################ CAREamics specific code ###############
89
- is_tiled = len(self.predictions[batch_idx]) == 2
90
- if is_tiled:
91
- # extract the last tile flag and the coordinates (crop and stitch)
92
- last_tile, *stitch_data = self.predictions[batch_idx][1]
93
-
94
- # append the tile and the coordinates to the lists
95
- self.tiles.append(self.predictions[batch_idx][0])
96
- self.stitching_data.append(stitch_data)
97
-
98
- # if last tile, stitch the tiles and add array to the prediction
99
- if any(last_tile):
100
- predicted_batches = stitch_prediction(
101
- self.tiles, self.stitching_data
102
- )
103
- self.predicted_array.append(predicted_batches)
104
- self.tiles.clear()
105
- self.stitching_data.clear()
106
- else:
107
- # simply add the prediction to the list
108
- self.predicted_array.append(self.predictions[batch_idx])
109
- ########################################################
110
- except StopIteration:
111
- break
112
- finally:
113
- self._restarting = False
114
- return self.on_run_end()
115
-
116
- # TODO predictions aren't stacked, list returned
@@ -1,40 +0,0 @@
1
- from typing import Type, Union
2
-
3
- from ..config.noise_models import NoiseModel, NoiseModelType
4
- from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel
5
-
6
-
7
- def noise_model_factory(
8
- noise_config: NoiseModel,
9
- ) -> Type[Union[HistogramNoiseModel, GaussianMixtureNoiseModel, None]]:
10
- """Create loss model based on Configuration.
11
-
12
- Parameters
13
- ----------
14
- config : Configuration
15
- Configuration.
16
-
17
- Returns
18
- -------
19
- Noise model
20
-
21
- Raises
22
- ------
23
- NotImplementedError
24
- If the noise model is unknown.
25
- """
26
- noise_model_type = noise_config.model_type if noise_config else None
27
-
28
- if noise_model_type == NoiseModelType.HIST:
29
- return HistogramNoiseModel
30
-
31
- elif noise_model_type == NoiseModelType.GMM:
32
- return GaussianMixtureNoiseModel
33
-
34
- elif noise_model_type is None:
35
- return None
36
-
37
- else:
38
- raise NotImplementedError(
39
- f"Noise model {noise_model_type} is not yet supported."
40
- )