careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,215 @@
1
+ """Module containing convenience function to create `WriteStrategy`."""
2
+
3
+ from typing import Any, Optional
4
+
5
+ from careamics.config.support import SupportedData
6
+ from careamics.file_io import SupportedWriteType, WriteFunc, get_write_func
7
+
8
+ from .write_strategy import CacheTiles, WriteImage, WriteStrategy
9
+
10
+
11
+ def create_write_strategy(
12
+ write_type: SupportedWriteType,
13
+ tiled: bool,
14
+ write_func: Optional[WriteFunc] = None,
15
+ write_extension: Optional[str] = None,
16
+ write_func_kwargs: Optional[dict[str, Any]] = None,
17
+ ) -> WriteStrategy:
18
+ """
19
+ Create a write strategy from convenient parameters.
20
+
21
+ Parameters
22
+ ----------
23
+ write_type : {"tiff", "custom"}
24
+ The data type to save as, includes custom.
25
+ tiled : bool
26
+ Whether the prediction will be tiled or not.
27
+ write_func : WriteFunc, optional
28
+ If a known `write_type` is selected this argument is ignored. For a custom
29
+ `write_type` a function to save the data must be passed. See notes below.
30
+ write_extension : str, optional
31
+ If a known `write_type` is selected this argument is ignored. For a custom
32
+ `write_type` an extension to save the data with must be passed.
33
+ write_func_kwargs : dict of {str: any}, optional
34
+ Additional keyword arguments to be passed to the save function.
35
+
36
+ Returns
37
+ -------
38
+ WriteStrategy
39
+ A strategy for writing predicions.
40
+
41
+ Notes
42
+ -----
43
+ The `write_func` function signature must match that of the example below
44
+ ```
45
+ write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
46
+ ```
47
+
48
+ The `write_func_kwargs` will be passed to the `write_func` doing the following:
49
+ ```
50
+ write_func(file_path=file_path, img=img, **kwargs)
51
+ ```
52
+ """
53
+ if write_func_kwargs is None:
54
+ write_func_kwargs = {}
55
+
56
+ write_strategy: WriteStrategy
57
+ if not tiled:
58
+ write_func = select_write_func(write_type=write_type, write_func=write_func)
59
+ write_extension = select_write_extension(
60
+ write_type=write_type, write_extension=write_extension
61
+ )
62
+ write_strategy = WriteImage(
63
+ write_func=write_func,
64
+ write_extension=write_extension,
65
+ write_func_kwargs=write_func_kwargs,
66
+ )
67
+ else:
68
+ # select CacheTiles or WriteTilesZarr (when implemented)
69
+ write_strategy = _create_tiled_write_strategy(
70
+ write_type=write_type,
71
+ write_func=write_func,
72
+ write_extension=write_extension,
73
+ write_func_kwargs=write_func_kwargs,
74
+ )
75
+
76
+ return write_strategy
77
+
78
+
79
+ def _create_tiled_write_strategy(
80
+ write_type: SupportedWriteType,
81
+ write_func: Optional[WriteFunc],
82
+ write_extension: Optional[str],
83
+ write_func_kwargs: dict[str, Any],
84
+ ) -> WriteStrategy:
85
+ """
86
+ Create a tiled write strategy.
87
+
88
+ Either `CacheTiles` for caching tiles until a whole image is predicted or
89
+ `WriteTilesZarr` for writing tiles directly to disk.
90
+
91
+ Parameters
92
+ ----------
93
+ write_type : {"tiff", "custom"}
94
+ The data type to save as, includes custom.
95
+ write_func : WriteFunc, optional
96
+ If a known `write_type` is selected this argument is ignored. For a custom
97
+ `write_type` a function to save the data must be passed. See notes below.
98
+ write_extension : str, optional
99
+ If a known `write_type` is selected this argument is ignored. For a custom
100
+ `write_type` an extension to save the data with must be passed.
101
+ write_func_kwargs : dict of {str: any}
102
+ Additional keyword arguments to be passed to the save function.
103
+
104
+ Returns
105
+ -------
106
+ WriteStrategy
107
+ A strategy for writing tiled predictions.
108
+
109
+ Raises
110
+ ------
111
+ NotImplementedError
112
+ if `write_type="zarr" is chosen.
113
+ """
114
+ # if write_type == SupportedData.ZARR:
115
+ # create *args, **kwargs
116
+ # return WriteTilesZarr(*args, **kwargs)
117
+ # else:
118
+ if write_type == "zarr":
119
+ raise NotImplementedError("Saving to zarr is not implemented yet.")
120
+ else:
121
+ write_func = select_write_func(write_type=write_type, write_func=write_func)
122
+ write_extension = select_write_extension(
123
+ write_type=write_type, write_extension=write_extension
124
+ )
125
+ return CacheTiles(
126
+ write_func=write_func,
127
+ write_extension=write_extension,
128
+ write_func_kwargs=write_func_kwargs,
129
+ )
130
+
131
+
132
+ def select_write_func(
133
+ write_type: SupportedWriteType, write_func: Optional[WriteFunc] = None
134
+ ) -> WriteFunc:
135
+ """
136
+ Return a function to write images.
137
+
138
+ If `write_type` is "custom" then `write_func`, otherwise the known write function
139
+ is selected.
140
+
141
+ Parameters
142
+ ----------
143
+ write_type : {"tiff", "custom"}
144
+ The data type to save as, includes custom.
145
+ write_func : WriteFunc, optional
146
+ If a known `write_type` is selected this argument is ignored. For a custom
147
+ `write_type` a function to save the data must be passed. See notes below.
148
+
149
+ Returns
150
+ -------
151
+ WriteFunc
152
+ A function for writing images.
153
+
154
+ Raises
155
+ ------
156
+ ValueError
157
+ If `write_type="custom"` but `write_func` has not been given.
158
+
159
+ Notes
160
+ -----
161
+ The `write_func` function signature must match that of the example below
162
+ ```
163
+ write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
164
+ ```
165
+ """
166
+ if write_type == SupportedData.CUSTOM:
167
+ if write_func is None:
168
+ raise ValueError(
169
+ "A save function must be provided for custom data types."
170
+ # TODO: link to how save functions should be implemented
171
+ )
172
+ else:
173
+ write_func = write_func
174
+ else:
175
+ write_func = get_write_func(write_type)
176
+ return write_func
177
+
178
+
179
+ def select_write_extension(
180
+ write_type: SupportedWriteType, write_extension: Optional[str] = None
181
+ ) -> str:
182
+ """
183
+ Return an extension to add to file paths.
184
+
185
+ If `write_type` is "custom" then `write_extension`, otherwise the known
186
+ write extension is selected.
187
+
188
+ Parameters
189
+ ----------
190
+ write_type : {"tiff", "custom"}
191
+ The data type to save as, includes custom.
192
+ write_extension : str, optional
193
+ If a known `write_type` is selected this argument is ignored. For a custom
194
+ `write_type` an extension to save the data with must be passed.
195
+
196
+ Returns
197
+ -------
198
+ str
199
+ The extension to be added to file paths.
200
+
201
+ Raises
202
+ ------
203
+ ValueError
204
+ If `self.save_type="custom"` but `save_extension` has not been given.
205
+ """
206
+ write_type_: SupportedData = SupportedData(write_type) # new variable for mypy
207
+ if write_type_ == SupportedData.CUSTOM:
208
+ if write_extension is None:
209
+ raise ValueError("A save extension must be provided for custom data types.")
210
+ else:
211
+ write_extension = write_extension
212
+ else:
213
+ # kind of a weird pattern -> reason to move get_extension from SupportedData
214
+ write_extension = write_type_.get_extension(write_type_)
215
+ return write_extension
@@ -0,0 +1,90 @@
1
+ """Progressbar callback."""
2
+
3
+ import sys
4
+ from typing import Dict, Union
5
+
6
+ from pytorch_lightning import LightningModule, Trainer
7
+ from pytorch_lightning.callbacks import TQDMProgressBar
8
+ from tqdm import tqdm
9
+
10
+
11
+ class ProgressBarCallback(TQDMProgressBar):
12
+ """Progress bar for training and validation steps."""
13
+
14
+ def init_train_tqdm(self) -> tqdm:
15
+ """Override this to customize the tqdm bar for training.
16
+
17
+ Returns
18
+ -------
19
+ tqdm
20
+ A tqdm bar.
21
+ """
22
+ bar = tqdm(
23
+ desc="Training",
24
+ position=(2 * self.process_position),
25
+ disable=self.is_disabled,
26
+ leave=True,
27
+ dynamic_ncols=True,
28
+ file=sys.stdout,
29
+ smoothing=0,
30
+ )
31
+ return bar
32
+
33
+ def init_validation_tqdm(self) -> tqdm:
34
+ """Override this to customize the tqdm bar for validation.
35
+
36
+ Returns
37
+ -------
38
+ tqdm
39
+ A tqdm bar.
40
+ """
41
+ # The main progress bar doesn't exist in `trainer.validate()`
42
+ has_main_bar = self.train_progress_bar is not None
43
+ bar = tqdm(
44
+ desc="Validating",
45
+ position=(2 * self.process_position + has_main_bar),
46
+ disable=self.is_disabled,
47
+ leave=False,
48
+ dynamic_ncols=True,
49
+ file=sys.stdout,
50
+ )
51
+ return bar
52
+
53
+ def init_test_tqdm(self) -> tqdm:
54
+ """Override this to customize the tqdm bar for testing.
55
+
56
+ Returns
57
+ -------
58
+ tqdm
59
+ A tqdm bar.
60
+ """
61
+ bar = tqdm(
62
+ desc="Testing",
63
+ position=(2 * self.process_position),
64
+ disable=self.is_disabled,
65
+ leave=True,
66
+ dynamic_ncols=False,
67
+ ncols=100,
68
+ file=sys.stdout,
69
+ )
70
+ return bar
71
+
72
+ def get_metrics(
73
+ self, trainer: Trainer, pl_module: LightningModule
74
+ ) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
75
+ """Override this to customize the metrics displayed in the progress bar.
76
+
77
+ Parameters
78
+ ----------
79
+ trainer : Trainer
80
+ The trainer object.
81
+ pl_module : LightningModule
82
+ The LightningModule object, unused.
83
+
84
+ Returns
85
+ -------
86
+ dict
87
+ A dictionary with the metrics to display in the progress bar.
88
+ """
89
+ pbar_metrics = trainer.progress_bar_metrics
90
+ return {**pbar_metrics}