careamics 0.0.1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (48) hide show
  1. careamics/__init__.py +7 -1
  2. careamics/bioimage/__init__.py +15 -0
  3. careamics/bioimage/docs/Noise2Void.md +5 -0
  4. careamics/bioimage/docs/__init__.py +1 -0
  5. careamics/bioimage/io.py +182 -0
  6. careamics/bioimage/rdf.py +105 -0
  7. careamics/config/__init__.py +11 -0
  8. careamics/config/algorithm.py +231 -0
  9. careamics/config/config.py +297 -0
  10. careamics/config/config_filter.py +44 -0
  11. careamics/config/data.py +194 -0
  12. careamics/config/torch_optim.py +118 -0
  13. careamics/config/training.py +534 -0
  14. careamics/dataset/__init__.py +1 -0
  15. careamics/dataset/dataset_utils.py +111 -0
  16. careamics/dataset/extraction_strategy.py +21 -0
  17. careamics/dataset/in_memory_dataset.py +202 -0
  18. careamics/dataset/patching.py +492 -0
  19. careamics/dataset/prepare_dataset.py +175 -0
  20. careamics/dataset/tiff_dataset.py +212 -0
  21. careamics/engine.py +1014 -0
  22. careamics/losses/__init__.py +4 -0
  23. careamics/losses/loss_factory.py +38 -0
  24. careamics/losses/losses.py +34 -0
  25. careamics/manipulation/__init__.py +4 -0
  26. careamics/manipulation/pixel_manipulation.py +158 -0
  27. careamics/models/__init__.py +4 -0
  28. careamics/models/layers.py +152 -0
  29. careamics/models/model_factory.py +251 -0
  30. careamics/models/unet.py +322 -0
  31. careamics/prediction/__init__.py +9 -0
  32. careamics/prediction/prediction_utils.py +106 -0
  33. careamics/utils/__init__.py +20 -0
  34. careamics/utils/ascii_logo.txt +9 -0
  35. careamics/utils/augment.py +65 -0
  36. careamics/utils/context.py +45 -0
  37. careamics/utils/logging.py +321 -0
  38. careamics/utils/metrics.py +160 -0
  39. careamics/utils/normalization.py +55 -0
  40. careamics/utils/torch_utils.py +89 -0
  41. careamics/utils/validators.py +170 -0
  42. careamics/utils/wandb.py +121 -0
  43. careamics-0.1.0rc2.dist-info/METADATA +81 -0
  44. careamics-0.1.0rc2.dist-info/RECORD +47 -0
  45. {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/WHEEL +1 -1
  46. {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/licenses/LICENSE +1 -1
  47. careamics-0.0.1.dist-info/METADATA +0 -46
  48. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,321 @@
1
+ """
2
+ Logging submodule.
3
+
4
+ The methods are responsible for the in-console logger.
5
+ """
6
+ import logging
7
+ import sys
8
+ import time
9
+ from pathlib import Path
10
+ from typing import Any, Dict, Generator, List, Optional, Union
11
+
12
+ LOGGERS: dict = {}
13
+
14
+
15
+ def get_logger(
16
+ name: str,
17
+ log_level: int = logging.INFO,
18
+ log_path: Optional[Union[str, Path]] = None,
19
+ ) -> logging.Logger:
20
+ """
21
+ Create a python logger instance with configured handlers.
22
+
23
+ Parameters
24
+ ----------
25
+ name : str
26
+ Name of the logger.
27
+ log_level : int, optional
28
+ Log level (info, error etc.), by default logging.INFO.
29
+ log_path : Optional[Union[str, Path]], optional
30
+ Path in which to save the log, by default None.
31
+
32
+ Returns
33
+ -------
34
+ logging.Logger
35
+ Logger.
36
+ """
37
+ logger = logging.getLogger(name)
38
+ logger.propagate = False
39
+
40
+ if name in LOGGERS:
41
+ return logger
42
+
43
+ for logger_name in LOGGERS:
44
+ if name.startswith(logger_name):
45
+ return logger
46
+
47
+ logger.propagate = False
48
+
49
+ if log_path:
50
+ handlers = [
51
+ logging.StreamHandler(),
52
+ logging.FileHandler(log_path),
53
+ ]
54
+ else:
55
+ handlers = [logging.StreamHandler()]
56
+
57
+ formatter = logging.Formatter("%(message)s")
58
+
59
+ for handler in handlers:
60
+ handler.setFormatter(formatter) # type: ignore
61
+ handler.setLevel(log_level) # type: ignore
62
+ logger.addHandler(handler) # type: ignore
63
+
64
+ logger.setLevel(log_level)
65
+ LOGGERS[name] = True
66
+
67
+ logger.propagate = False
68
+
69
+ return logger
70
+
71
+
72
+ class ProgressBar:
73
+ """
74
+ Keras style progress bar.
75
+
76
+ Adapted from https://github.com/yueyericardo/pkbar.
77
+
78
+ Parameters
79
+ ----------
80
+ max_value : Optional[int], optional
81
+ Maximum progress bar value, by default None.
82
+ epoch : Optional[int], optional
83
+ Zero-indexed current epoch, by default None.
84
+ num_epochs : Optional[int], optional
85
+ Total number of epochs, by default None.
86
+ stateful_metrics : Optional[List], optional
87
+ Iterable of string names of metrics that should *not* be averaged over time.
88
+ Metrics in this list will be displayed as-is. All others will be averaged by
89
+ the progress bar before display, by default None.
90
+ always_stateful : bool, optional
91
+ Whether to set all metrics to be stateful, by default False.
92
+ mode : str, optional
93
+ Mode, one of "train", "val", or "predict", by default "train".
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ max_value: Optional[int] = None,
99
+ epoch: Optional[int] = None,
100
+ num_epochs: Optional[int] = None,
101
+ stateful_metrics: Optional[List] = None,
102
+ always_stateful: bool = False,
103
+ mode: str = "train",
104
+ ) -> None:
105
+ """
106
+ Constructor.
107
+
108
+ Parameters
109
+ ----------
110
+ max_value : Optional[int], optional
111
+ Maximum progress bar value, by default None.
112
+ epoch : Optional[int], optional
113
+ Zero-indexed current epoch, by default None.
114
+ num_epochs : Optional[int], optional
115
+ Total number of epochs, by default None.
116
+ stateful_metrics : Optional[List], optional
117
+ Iterable of string names of metrics that should *not* be averaged over time.
118
+ Metrics in this list will be displayed as-is. All others will be averaged by
119
+ the progress bar before display, by default None.
120
+ always_stateful : bool, optional
121
+ Whether to set all metrics to be stateful, by default False.
122
+ mode : str, optional
123
+ Mode, one of "train", "val", or "predict", by default "train".
124
+ """
125
+ self.max_value = max_value
126
+ # Width of the progress bar
127
+ self.width = 30
128
+ self.always_stateful = always_stateful
129
+
130
+ if (epoch is not None) and (num_epochs is not None):
131
+ print(f"Epoch: {epoch + 1}/{num_epochs}")
132
+
133
+ if stateful_metrics:
134
+ self.stateful_metrics = set(stateful_metrics)
135
+ else:
136
+ self.stateful_metrics = set()
137
+
138
+ self._dynamic_display = (
139
+ (hasattr(sys.stdout, "isatty") and sys.stdout.isatty())
140
+ or "ipykernel" in sys.modules
141
+ or "posix" in sys.modules
142
+ )
143
+ self._total_width = 0
144
+ self._seen_so_far = 0
145
+ # We use a dict + list to avoid garbage collection
146
+ # issues found in OrderedDict
147
+ self._values: Dict[Any, Any] = {}
148
+ self._values_order: List[Any] = []
149
+ self._start = time.time()
150
+ self._last_update = 0.0
151
+ self.spin = self.spinning_cursor() if self.max_value is None else None
152
+ if mode == "train" and self.max_value is None:
153
+ self.message = "Estimating dataset size"
154
+ elif mode == "val":
155
+ self.message = "Validating"
156
+ elif mode == "predict":
157
+ self.message = "Denoising"
158
+
159
+ def update(
160
+ self, current_step: int, batch_size: int = 1, values: Optional[List] = None
161
+ ) -> None:
162
+ """
163
+ Update the progress bar.
164
+
165
+ Parameters
166
+ ----------
167
+ current_step : int
168
+ Index of the current step.
169
+ batch_size : int, optional
170
+ Batch size, by default 1.
171
+ values : Optional[List], optional
172
+ Updated metrics values, by default None.
173
+ """
174
+ values = values or []
175
+ for k, v in values:
176
+ # if torch tensor, convert it to numpy
177
+ if str(type(v)) == "<class 'torch.Tensor'>":
178
+ v = v.detach().cpu().numpy()
179
+
180
+ if k not in self._values_order:
181
+ self._values_order.append(k)
182
+ if k not in self.stateful_metrics and not self.always_stateful:
183
+ if k not in self._values:
184
+ self._values[k] = [
185
+ v * (current_step - self._seen_so_far),
186
+ current_step - self._seen_so_far,
187
+ ]
188
+ else:
189
+ self._values[k][0] += v * (current_step - self._seen_so_far)
190
+ self._values[k][1] += current_step - self._seen_so_far
191
+ else:
192
+ # Stateful metrics output a numeric value. This representation
193
+ # means "take an average from a single value" but keeps the
194
+ # numeric formatting.
195
+ self._values[k] = [v, 1]
196
+
197
+ self._seen_so_far = current_step
198
+
199
+ now = time.time()
200
+ info = f" - {(now - self._start):.0f}s"
201
+
202
+ prev_total_width = self._total_width
203
+ if self._dynamic_display:
204
+ sys.stdout.write("\b" * prev_total_width)
205
+ sys.stdout.write("\r")
206
+ else:
207
+ sys.stdout.write("\n")
208
+
209
+ if self.max_value is not None:
210
+ bar = f"{current_step}/{self.max_value} ["
211
+ progress = float(current_step) / self.max_value
212
+ progress_width = int(self.width * progress)
213
+ if progress_width > 0:
214
+ bar += "=" * (progress_width - 1)
215
+ if current_step < self.max_value:
216
+ bar += ">"
217
+ else:
218
+ bar += "="
219
+ bar += "." * (self.width - progress_width)
220
+ bar += "]"
221
+ else:
222
+ bar = (
223
+ f"{self.message} {next(self.spin)}, tile " # type: ignore
224
+ f"No. {current_step * batch_size}"
225
+ )
226
+
227
+ self._total_width = len(bar)
228
+ sys.stdout.write(bar)
229
+
230
+ if current_step > 0:
231
+ time_per_unit = (now - self._start) / current_step
232
+ else:
233
+ time_per_unit = 0
234
+
235
+ if time_per_unit >= 1 or time_per_unit == 0:
236
+ info += f" {time_per_unit:.0f}s/step"
237
+ elif time_per_unit >= 1e-3:
238
+ info += f" {time_per_unit * 1e3:.0f}ms/step"
239
+ else:
240
+ info += f" {time_per_unit * 1e6:.0f}us/step"
241
+
242
+ for k in self._values_order:
243
+ info += f" - {k}:"
244
+ if isinstance(self._values[k], list):
245
+ avg = self._values[k][0] / max(1, self._values[k][1])
246
+ if abs(avg) > 1e-3:
247
+ info += f" {avg:.4f}"
248
+ else:
249
+ info += f" {avg:.4e}"
250
+ else:
251
+ info += f" {self._values[k]}s"
252
+
253
+ self._total_width += len(info)
254
+ if prev_total_width > self._total_width:
255
+ info += " " * (prev_total_width - self._total_width)
256
+
257
+ if self.max_value is not None and current_step >= self.max_value:
258
+ info += "\n"
259
+
260
+ sys.stdout.write(info)
261
+ sys.stdout.flush()
262
+
263
+ self._last_update = now
264
+
265
+ def add(self, n: int, values: Optional[List] = None) -> None:
266
+ """
267
+ Update the progress bar by n steps.
268
+
269
+ Parameters
270
+ ----------
271
+ n : int
272
+ Number of steps to increase the progress bar with.
273
+ values : Optional[List], optional
274
+ Updated metrics values, by default None.
275
+ """
276
+ self.update(self._seen_so_far + n, 1, values=values)
277
+
278
+ def spinning_cursor(self) -> Generator:
279
+ """
280
+ Generate a spinning cursor animation.
281
+
282
+ Taken from https://github.com/manrajgrover/py-spinners/tree/master.
283
+
284
+ Returns
285
+ -------
286
+ Generator
287
+ Generator of animation frames.
288
+ """
289
+ while True:
290
+ yield from [
291
+ "▓ ----- ▒",
292
+ "▓ ----- ▒",
293
+ "▓ ----- ▒",
294
+ "▓ ->--- ▒",
295
+ "▓ ->--- ▒",
296
+ "▓ ->--- ▒",
297
+ "▓ -->-- ▒",
298
+ "▓ -->-- ▒",
299
+ "▓ -->-- ▒",
300
+ "▓ --->- ▒",
301
+ "▓ --->- ▒",
302
+ "▓ --->- ▒",
303
+ "▓ ----> ▒",
304
+ "▓ ----> ▒",
305
+ "▓ ----> ▒",
306
+ "▒ ----- ░",
307
+ "▒ ----- ░",
308
+ "▒ ----- ░",
309
+ "▒ ->--- ░",
310
+ "▒ ->--- ░",
311
+ "▒ ->--- ░",
312
+ "▒ -->-- ░",
313
+ "▒ -->-- ░",
314
+ "▒ -->-- ░",
315
+ "▒ --->- ░",
316
+ "▒ --->- ░",
317
+ "▒ --->- ░",
318
+ "▒ ----> ░",
319
+ "▒ ----> ░",
320
+ "▒ ----> ░",
321
+ ]
@@ -0,0 +1,160 @@
1
+ """
2
+ Metrics submodule.
3
+
4
+ This module contains various metrics and a metrics tracking class.
5
+ """
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from skimage.metrics import peak_signal_noise_ratio
11
+
12
+
13
+ def psnr(gt: np.ndarray, pred: np.ndarray, range: float = 255.0) -> float:
14
+ """
15
+ Peak Signal to Noise Ratio.
16
+
17
+ This method calls skimage.metrics.peak_signal_noise_ratio. See:
18
+ https://scikit-image.org/docs/dev/api/skimage.metrics.html.
19
+
20
+ Parameters
21
+ ----------
22
+ gt : NumPy array
23
+ Ground truth image.
24
+ pred : NumPy array
25
+ Predicted image.
26
+ range : float, optional
27
+ The images pixel range, by default 255.0.
28
+
29
+ Returns
30
+ -------
31
+ float
32
+ PSNR value.
33
+ """
34
+ return peak_signal_noise_ratio(gt, pred, data_range=range)
35
+
36
+
37
+ def _zero_mean(x: np.ndarray) -> np.ndarray:
38
+ """
39
+ Zero the mean of an array.
40
+
41
+ Parameters
42
+ ----------
43
+ x : NumPy array
44
+ Input array.
45
+
46
+ Returns
47
+ -------
48
+ NumPy array
49
+ Zero-mean array.
50
+ """
51
+ return x - np.mean(x)
52
+
53
+
54
+ def _fix_range(gt: np.ndarray, x: np.ndarray) -> np.ndarray:
55
+ """
56
+ Adjust the range of an array based on a reference ground-truth array.
57
+
58
+ Parameters
59
+ ----------
60
+ gt : np.ndarray
61
+ Ground truth image.
62
+ x : np.ndarray
63
+ Input array.
64
+
65
+ Returns
66
+ -------
67
+ np.ndarray
68
+ Range-adjusted array.
69
+ """
70
+ a = np.sum(gt * x) / (np.sum(x * x))
71
+ return x * a
72
+
73
+
74
+ def _fix(gt: np.ndarray, x: np.ndarray) -> np.ndarray:
75
+ """
76
+ Zero mean a groud truth array and adjust the range of the array.
77
+
78
+ Parameters
79
+ ----------
80
+ gt : np.ndarray
81
+ Ground truth image.
82
+ x : np.ndarray
83
+ Input array.
84
+
85
+ Returns
86
+ -------
87
+ np.ndarray
88
+ Zero-mean and range-adjusted array.
89
+ """
90
+ gt_ = _zero_mean(gt)
91
+ return _fix_range(gt_, _zero_mean(x))
92
+
93
+
94
+ def scale_invariant_psnr(
95
+ gt: np.ndarray, pred: np.ndarray
96
+ ) -> Union[float, torch.tensor]:
97
+ """
98
+ Scale invariant PSNR.
99
+
100
+ Parameters
101
+ ----------
102
+ gt : np.ndarray
103
+ Ground truth image.
104
+ pred : np.ndarray
105
+ Predicted image.
106
+
107
+ Returns
108
+ -------
109
+ Union[float, torch.tensor]
110
+ Scale invariant PSNR value.
111
+ """
112
+ range_parameter = (np.max(gt) - np.min(gt)) / np.std(gt)
113
+ gt_ = _zero_mean(gt) / np.std(gt)
114
+ return psnr(_zero_mean(gt_), _fix(gt_, pred), range_parameter)
115
+
116
+
117
+ class MetricTracker:
118
+ """
119
+ Metric tracker class.
120
+
121
+ This class is used to track values, sum, count and average of a metric over time.
122
+
123
+ Attributes
124
+ ----------
125
+ val : int
126
+ Last value of the metric.
127
+ avg : torch.Tensor.float
128
+ Average value of the metric.
129
+ sum : int
130
+ Sum of the metric values (times number of values).
131
+ count : int
132
+ Number of values.
133
+ """
134
+
135
+ def __init__(self) -> None:
136
+ """Constructor."""
137
+ self.reset()
138
+
139
+ def reset(self) -> None:
140
+ """Reset the metric tracker state."""
141
+ self.val = 0.0
142
+ self.avg: torch.Tensor.float = 0.0
143
+ self.sum = 0.0
144
+ self.count = 0.0
145
+
146
+ def update(self, value: int, n: int = 1) -> None:
147
+ """
148
+ Update the metric tracker state.
149
+
150
+ Parameters
151
+ ----------
152
+ value : int
153
+ Value to update the metric tracker with.
154
+ n : int
155
+ Number of values, equals to batch size.
156
+ """
157
+ self.val = value
158
+ self.sum += value * n
159
+ self.count += n
160
+ self.avg = self.sum / self.count
@@ -0,0 +1,55 @@
1
+ """
2
+ Normalization submodule.
3
+
4
+ These methods are used to normalize and denormalize images.
5
+ """
6
+ import numpy as np
7
+
8
+
9
+ def normalize(img: np.ndarray, mean: float, std: float) -> np.ndarray:
10
+ """
11
+ Normalize an image using mean and standard deviation.
12
+
13
+ Images are normalised by subtracting the mean and dividing by the standard
14
+ deviation.
15
+
16
+ Parameters
17
+ ----------
18
+ img : np.ndarray
19
+ Image to normalize.
20
+ mean : float
21
+ Mean.
22
+ std : float
23
+ Standard deviation.
24
+
25
+ Returns
26
+ -------
27
+ np.ndarray
28
+ Normalized array.
29
+ """
30
+ zero_mean = img - mean
31
+ return zero_mean / std
32
+
33
+
34
+ def denormalize(img: np.ndarray, mean: float, std: float) -> np.ndarray:
35
+ """
36
+ Denormalize an image using mean and standard deviation.
37
+
38
+ Images are denormalised by multiplying by the standard deviation and adding the
39
+ mean.
40
+
41
+ Parameters
42
+ ----------
43
+ img : np.ndarray
44
+ Image to denormalize.
45
+ mean : float
46
+ Mean.
47
+ std : float
48
+ Standard deviation.
49
+
50
+ Returns
51
+ -------
52
+ np.ndarray
53
+ Denormalized array.
54
+ """
55
+ return img * std + mean
@@ -0,0 +1,89 @@
1
+ """
2
+ Convenience functions using torch.
3
+
4
+ These functions are used to control certain aspects and behaviours of PyTorch.
5
+ """
6
+ import logging
7
+
8
+ import torch
9
+
10
+
11
+ def get_device() -> torch.device:
12
+ """
13
+ Select the device to use for training.
14
+
15
+ Returns
16
+ -------
17
+ torch.device
18
+ CUDA or CPU device, depending on availability of CUDA devices.
19
+ """
20
+ if torch.cuda.is_available():
21
+ logging.info("CUDA available. Using GPU.")
22
+ device = torch.device("cuda")
23
+ else:
24
+ logging.info("CUDA not available. Using CPU.")
25
+ device = torch.device("cpu")
26
+ return device
27
+
28
+
29
+ # def compile_model(model: torch.nn.Module) -> torch.nn.Module:
30
+ # """
31
+ # Torch.compile wrapper.
32
+
33
+ # Parameters
34
+ # ----------
35
+ # model : torch.nn.Module
36
+ # Model.
37
+
38
+ # Returns
39
+ # -------
40
+ # torch.nn.Module
41
+ # Compiled model if compile is available, the model itself otherwise.
42
+ # """
43
+ # if hasattr(torch, "compile") and sys.version_info.minor <= 9:
44
+ # return torch.compile(model, mode="reduce-overhead")
45
+ # else:
46
+ # return model
47
+
48
+
49
+ # def seed_everything(seed: int) -> None:
50
+ # """
51
+ # Seed all random number generators for reproducibility.
52
+
53
+ # Parameters
54
+ # ----------
55
+ # seed : int
56
+ # Seed.
57
+ # """
58
+ # import random
59
+
60
+ # import numpy as np
61
+
62
+ # random.seed(seed)
63
+ # np.random.seed(seed)
64
+ # torch.manual_seed(seed)
65
+ # torch.cuda.manual_seed_all(seed)
66
+
67
+
68
+ # def setup_cudnn_reproducibility(
69
+ # deterministic: bool = True, benchmark: bool = True
70
+ # ) -> None:
71
+ # """
72
+ # Prepare CuDNN benchmark and sets it to be deterministic/non-deterministic mode.
73
+
74
+ # Parameters
75
+ # ----------
76
+ # deterministic : bool
77
+ # Deterministic mode, if running CuDNN backend.
78
+ # benchmark : bool
79
+ # If True, uses CuDNN heuristics to figure out which algorithm will be most
80
+ # performant for your model architecture and input. False may slow down training
81
+ # """
82
+ # if torch.cuda.is_available():
83
+ # if deterministic:
84
+ # deterministic = os.environ.get("CUDNN_DETERMINISTIC", "True") == "True"
85
+ # torch.backends.cudnn.deterministic = deterministic
86
+
87
+ # if benchmark:
88
+ # benchmark = os.environ.get("CUDNN_BENCHMARK", "True") == "True"
89
+ # torch.backends.cudnn.benchmark = benchmark