careamics 0.0.1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +7 -1
- careamics/bioimage/__init__.py +15 -0
- careamics/bioimage/docs/Noise2Void.md +5 -0
- careamics/bioimage/docs/__init__.py +1 -0
- careamics/bioimage/io.py +182 -0
- careamics/bioimage/rdf.py +105 -0
- careamics/config/__init__.py +11 -0
- careamics/config/algorithm.py +231 -0
- careamics/config/config.py +297 -0
- careamics/config/config_filter.py +44 -0
- careamics/config/data.py +194 -0
- careamics/config/torch_optim.py +118 -0
- careamics/config/training.py +534 -0
- careamics/dataset/__init__.py +1 -0
- careamics/dataset/dataset_utils.py +111 -0
- careamics/dataset/extraction_strategy.py +21 -0
- careamics/dataset/in_memory_dataset.py +202 -0
- careamics/dataset/patching.py +492 -0
- careamics/dataset/prepare_dataset.py +175 -0
- careamics/dataset/tiff_dataset.py +212 -0
- careamics/engine.py +1014 -0
- careamics/losses/__init__.py +4 -0
- careamics/losses/loss_factory.py +38 -0
- careamics/losses/losses.py +34 -0
- careamics/manipulation/__init__.py +4 -0
- careamics/manipulation/pixel_manipulation.py +158 -0
- careamics/models/__init__.py +4 -0
- careamics/models/layers.py +152 -0
- careamics/models/model_factory.py +251 -0
- careamics/models/unet.py +322 -0
- careamics/prediction/__init__.py +9 -0
- careamics/prediction/prediction_utils.py +106 -0
- careamics/utils/__init__.py +20 -0
- careamics/utils/ascii_logo.txt +9 -0
- careamics/utils/augment.py +65 -0
- careamics/utils/context.py +45 -0
- careamics/utils/logging.py +321 -0
- careamics/utils/metrics.py +160 -0
- careamics/utils/normalization.py +55 -0
- careamics/utils/torch_utils.py +89 -0
- careamics/utils/validators.py +170 -0
- careamics/utils/wandb.py +121 -0
- careamics-0.1.0rc2.dist-info/METADATA +81 -0
- careamics-0.1.0rc2.dist-info/RECORD +47 -0
- {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Logging submodule.
|
|
3
|
+
|
|
4
|
+
The methods are responsible for the in-console logger.
|
|
5
|
+
"""
|
|
6
|
+
import logging
|
|
7
|
+
import sys
|
|
8
|
+
import time
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, Generator, List, Optional, Union
|
|
11
|
+
|
|
12
|
+
LOGGERS: dict = {}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_logger(
|
|
16
|
+
name: str,
|
|
17
|
+
log_level: int = logging.INFO,
|
|
18
|
+
log_path: Optional[Union[str, Path]] = None,
|
|
19
|
+
) -> logging.Logger:
|
|
20
|
+
"""
|
|
21
|
+
Create a python logger instance with configured handlers.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
name : str
|
|
26
|
+
Name of the logger.
|
|
27
|
+
log_level : int, optional
|
|
28
|
+
Log level (info, error etc.), by default logging.INFO.
|
|
29
|
+
log_path : Optional[Union[str, Path]], optional
|
|
30
|
+
Path in which to save the log, by default None.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
logging.Logger
|
|
35
|
+
Logger.
|
|
36
|
+
"""
|
|
37
|
+
logger = logging.getLogger(name)
|
|
38
|
+
logger.propagate = False
|
|
39
|
+
|
|
40
|
+
if name in LOGGERS:
|
|
41
|
+
return logger
|
|
42
|
+
|
|
43
|
+
for logger_name in LOGGERS:
|
|
44
|
+
if name.startswith(logger_name):
|
|
45
|
+
return logger
|
|
46
|
+
|
|
47
|
+
logger.propagate = False
|
|
48
|
+
|
|
49
|
+
if log_path:
|
|
50
|
+
handlers = [
|
|
51
|
+
logging.StreamHandler(),
|
|
52
|
+
logging.FileHandler(log_path),
|
|
53
|
+
]
|
|
54
|
+
else:
|
|
55
|
+
handlers = [logging.StreamHandler()]
|
|
56
|
+
|
|
57
|
+
formatter = logging.Formatter("%(message)s")
|
|
58
|
+
|
|
59
|
+
for handler in handlers:
|
|
60
|
+
handler.setFormatter(formatter) # type: ignore
|
|
61
|
+
handler.setLevel(log_level) # type: ignore
|
|
62
|
+
logger.addHandler(handler) # type: ignore
|
|
63
|
+
|
|
64
|
+
logger.setLevel(log_level)
|
|
65
|
+
LOGGERS[name] = True
|
|
66
|
+
|
|
67
|
+
logger.propagate = False
|
|
68
|
+
|
|
69
|
+
return logger
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ProgressBar:
|
|
73
|
+
"""
|
|
74
|
+
Keras style progress bar.
|
|
75
|
+
|
|
76
|
+
Adapted from https://github.com/yueyericardo/pkbar.
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
max_value : Optional[int], optional
|
|
81
|
+
Maximum progress bar value, by default None.
|
|
82
|
+
epoch : Optional[int], optional
|
|
83
|
+
Zero-indexed current epoch, by default None.
|
|
84
|
+
num_epochs : Optional[int], optional
|
|
85
|
+
Total number of epochs, by default None.
|
|
86
|
+
stateful_metrics : Optional[List], optional
|
|
87
|
+
Iterable of string names of metrics that should *not* be averaged over time.
|
|
88
|
+
Metrics in this list will be displayed as-is. All others will be averaged by
|
|
89
|
+
the progress bar before display, by default None.
|
|
90
|
+
always_stateful : bool, optional
|
|
91
|
+
Whether to set all metrics to be stateful, by default False.
|
|
92
|
+
mode : str, optional
|
|
93
|
+
Mode, one of "train", "val", or "predict", by default "train".
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
max_value: Optional[int] = None,
|
|
99
|
+
epoch: Optional[int] = None,
|
|
100
|
+
num_epochs: Optional[int] = None,
|
|
101
|
+
stateful_metrics: Optional[List] = None,
|
|
102
|
+
always_stateful: bool = False,
|
|
103
|
+
mode: str = "train",
|
|
104
|
+
) -> None:
|
|
105
|
+
"""
|
|
106
|
+
Constructor.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
max_value : Optional[int], optional
|
|
111
|
+
Maximum progress bar value, by default None.
|
|
112
|
+
epoch : Optional[int], optional
|
|
113
|
+
Zero-indexed current epoch, by default None.
|
|
114
|
+
num_epochs : Optional[int], optional
|
|
115
|
+
Total number of epochs, by default None.
|
|
116
|
+
stateful_metrics : Optional[List], optional
|
|
117
|
+
Iterable of string names of metrics that should *not* be averaged over time.
|
|
118
|
+
Metrics in this list will be displayed as-is. All others will be averaged by
|
|
119
|
+
the progress bar before display, by default None.
|
|
120
|
+
always_stateful : bool, optional
|
|
121
|
+
Whether to set all metrics to be stateful, by default False.
|
|
122
|
+
mode : str, optional
|
|
123
|
+
Mode, one of "train", "val", or "predict", by default "train".
|
|
124
|
+
"""
|
|
125
|
+
self.max_value = max_value
|
|
126
|
+
# Width of the progress bar
|
|
127
|
+
self.width = 30
|
|
128
|
+
self.always_stateful = always_stateful
|
|
129
|
+
|
|
130
|
+
if (epoch is not None) and (num_epochs is not None):
|
|
131
|
+
print(f"Epoch: {epoch + 1}/{num_epochs}")
|
|
132
|
+
|
|
133
|
+
if stateful_metrics:
|
|
134
|
+
self.stateful_metrics = set(stateful_metrics)
|
|
135
|
+
else:
|
|
136
|
+
self.stateful_metrics = set()
|
|
137
|
+
|
|
138
|
+
self._dynamic_display = (
|
|
139
|
+
(hasattr(sys.stdout, "isatty") and sys.stdout.isatty())
|
|
140
|
+
or "ipykernel" in sys.modules
|
|
141
|
+
or "posix" in sys.modules
|
|
142
|
+
)
|
|
143
|
+
self._total_width = 0
|
|
144
|
+
self._seen_so_far = 0
|
|
145
|
+
# We use a dict + list to avoid garbage collection
|
|
146
|
+
# issues found in OrderedDict
|
|
147
|
+
self._values: Dict[Any, Any] = {}
|
|
148
|
+
self._values_order: List[Any] = []
|
|
149
|
+
self._start = time.time()
|
|
150
|
+
self._last_update = 0.0
|
|
151
|
+
self.spin = self.spinning_cursor() if self.max_value is None else None
|
|
152
|
+
if mode == "train" and self.max_value is None:
|
|
153
|
+
self.message = "Estimating dataset size"
|
|
154
|
+
elif mode == "val":
|
|
155
|
+
self.message = "Validating"
|
|
156
|
+
elif mode == "predict":
|
|
157
|
+
self.message = "Denoising"
|
|
158
|
+
|
|
159
|
+
def update(
|
|
160
|
+
self, current_step: int, batch_size: int = 1, values: Optional[List] = None
|
|
161
|
+
) -> None:
|
|
162
|
+
"""
|
|
163
|
+
Update the progress bar.
|
|
164
|
+
|
|
165
|
+
Parameters
|
|
166
|
+
----------
|
|
167
|
+
current_step : int
|
|
168
|
+
Index of the current step.
|
|
169
|
+
batch_size : int, optional
|
|
170
|
+
Batch size, by default 1.
|
|
171
|
+
values : Optional[List], optional
|
|
172
|
+
Updated metrics values, by default None.
|
|
173
|
+
"""
|
|
174
|
+
values = values or []
|
|
175
|
+
for k, v in values:
|
|
176
|
+
# if torch tensor, convert it to numpy
|
|
177
|
+
if str(type(v)) == "<class 'torch.Tensor'>":
|
|
178
|
+
v = v.detach().cpu().numpy()
|
|
179
|
+
|
|
180
|
+
if k not in self._values_order:
|
|
181
|
+
self._values_order.append(k)
|
|
182
|
+
if k not in self.stateful_metrics and not self.always_stateful:
|
|
183
|
+
if k not in self._values:
|
|
184
|
+
self._values[k] = [
|
|
185
|
+
v * (current_step - self._seen_so_far),
|
|
186
|
+
current_step - self._seen_so_far,
|
|
187
|
+
]
|
|
188
|
+
else:
|
|
189
|
+
self._values[k][0] += v * (current_step - self._seen_so_far)
|
|
190
|
+
self._values[k][1] += current_step - self._seen_so_far
|
|
191
|
+
else:
|
|
192
|
+
# Stateful metrics output a numeric value. This representation
|
|
193
|
+
# means "take an average from a single value" but keeps the
|
|
194
|
+
# numeric formatting.
|
|
195
|
+
self._values[k] = [v, 1]
|
|
196
|
+
|
|
197
|
+
self._seen_so_far = current_step
|
|
198
|
+
|
|
199
|
+
now = time.time()
|
|
200
|
+
info = f" - {(now - self._start):.0f}s"
|
|
201
|
+
|
|
202
|
+
prev_total_width = self._total_width
|
|
203
|
+
if self._dynamic_display:
|
|
204
|
+
sys.stdout.write("\b" * prev_total_width)
|
|
205
|
+
sys.stdout.write("\r")
|
|
206
|
+
else:
|
|
207
|
+
sys.stdout.write("\n")
|
|
208
|
+
|
|
209
|
+
if self.max_value is not None:
|
|
210
|
+
bar = f"{current_step}/{self.max_value} ["
|
|
211
|
+
progress = float(current_step) / self.max_value
|
|
212
|
+
progress_width = int(self.width * progress)
|
|
213
|
+
if progress_width > 0:
|
|
214
|
+
bar += "=" * (progress_width - 1)
|
|
215
|
+
if current_step < self.max_value:
|
|
216
|
+
bar += ">"
|
|
217
|
+
else:
|
|
218
|
+
bar += "="
|
|
219
|
+
bar += "." * (self.width - progress_width)
|
|
220
|
+
bar += "]"
|
|
221
|
+
else:
|
|
222
|
+
bar = (
|
|
223
|
+
f"{self.message} {next(self.spin)}, tile " # type: ignore
|
|
224
|
+
f"No. {current_step * batch_size}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
self._total_width = len(bar)
|
|
228
|
+
sys.stdout.write(bar)
|
|
229
|
+
|
|
230
|
+
if current_step > 0:
|
|
231
|
+
time_per_unit = (now - self._start) / current_step
|
|
232
|
+
else:
|
|
233
|
+
time_per_unit = 0
|
|
234
|
+
|
|
235
|
+
if time_per_unit >= 1 or time_per_unit == 0:
|
|
236
|
+
info += f" {time_per_unit:.0f}s/step"
|
|
237
|
+
elif time_per_unit >= 1e-3:
|
|
238
|
+
info += f" {time_per_unit * 1e3:.0f}ms/step"
|
|
239
|
+
else:
|
|
240
|
+
info += f" {time_per_unit * 1e6:.0f}us/step"
|
|
241
|
+
|
|
242
|
+
for k in self._values_order:
|
|
243
|
+
info += f" - {k}:"
|
|
244
|
+
if isinstance(self._values[k], list):
|
|
245
|
+
avg = self._values[k][0] / max(1, self._values[k][1])
|
|
246
|
+
if abs(avg) > 1e-3:
|
|
247
|
+
info += f" {avg:.4f}"
|
|
248
|
+
else:
|
|
249
|
+
info += f" {avg:.4e}"
|
|
250
|
+
else:
|
|
251
|
+
info += f" {self._values[k]}s"
|
|
252
|
+
|
|
253
|
+
self._total_width += len(info)
|
|
254
|
+
if prev_total_width > self._total_width:
|
|
255
|
+
info += " " * (prev_total_width - self._total_width)
|
|
256
|
+
|
|
257
|
+
if self.max_value is not None and current_step >= self.max_value:
|
|
258
|
+
info += "\n"
|
|
259
|
+
|
|
260
|
+
sys.stdout.write(info)
|
|
261
|
+
sys.stdout.flush()
|
|
262
|
+
|
|
263
|
+
self._last_update = now
|
|
264
|
+
|
|
265
|
+
def add(self, n: int, values: Optional[List] = None) -> None:
|
|
266
|
+
"""
|
|
267
|
+
Update the progress bar by n steps.
|
|
268
|
+
|
|
269
|
+
Parameters
|
|
270
|
+
----------
|
|
271
|
+
n : int
|
|
272
|
+
Number of steps to increase the progress bar with.
|
|
273
|
+
values : Optional[List], optional
|
|
274
|
+
Updated metrics values, by default None.
|
|
275
|
+
"""
|
|
276
|
+
self.update(self._seen_so_far + n, 1, values=values)
|
|
277
|
+
|
|
278
|
+
def spinning_cursor(self) -> Generator:
|
|
279
|
+
"""
|
|
280
|
+
Generate a spinning cursor animation.
|
|
281
|
+
|
|
282
|
+
Taken from https://github.com/manrajgrover/py-spinners/tree/master.
|
|
283
|
+
|
|
284
|
+
Returns
|
|
285
|
+
-------
|
|
286
|
+
Generator
|
|
287
|
+
Generator of animation frames.
|
|
288
|
+
"""
|
|
289
|
+
while True:
|
|
290
|
+
yield from [
|
|
291
|
+
"▓ ----- ▒",
|
|
292
|
+
"▓ ----- ▒",
|
|
293
|
+
"▓ ----- ▒",
|
|
294
|
+
"▓ ->--- ▒",
|
|
295
|
+
"▓ ->--- ▒",
|
|
296
|
+
"▓ ->--- ▒",
|
|
297
|
+
"▓ -->-- ▒",
|
|
298
|
+
"▓ -->-- ▒",
|
|
299
|
+
"▓ -->-- ▒",
|
|
300
|
+
"▓ --->- ▒",
|
|
301
|
+
"▓ --->- ▒",
|
|
302
|
+
"▓ --->- ▒",
|
|
303
|
+
"▓ ----> ▒",
|
|
304
|
+
"▓ ----> ▒",
|
|
305
|
+
"▓ ----> ▒",
|
|
306
|
+
"▒ ----- ░",
|
|
307
|
+
"▒ ----- ░",
|
|
308
|
+
"▒ ----- ░",
|
|
309
|
+
"▒ ->--- ░",
|
|
310
|
+
"▒ ->--- ░",
|
|
311
|
+
"▒ ->--- ░",
|
|
312
|
+
"▒ -->-- ░",
|
|
313
|
+
"▒ -->-- ░",
|
|
314
|
+
"▒ -->-- ░",
|
|
315
|
+
"▒ --->- ░",
|
|
316
|
+
"▒ --->- ░",
|
|
317
|
+
"▒ --->- ░",
|
|
318
|
+
"▒ ----> ░",
|
|
319
|
+
"▒ ----> ░",
|
|
320
|
+
"▒ ----> ░",
|
|
321
|
+
]
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Metrics submodule.
|
|
3
|
+
|
|
4
|
+
This module contains various metrics and a metrics tracking class.
|
|
5
|
+
"""
|
|
6
|
+
from typing import Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from skimage.metrics import peak_signal_noise_ratio
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def psnr(gt: np.ndarray, pred: np.ndarray, range: float = 255.0) -> float:
|
|
14
|
+
"""
|
|
15
|
+
Peak Signal to Noise Ratio.
|
|
16
|
+
|
|
17
|
+
This method calls skimage.metrics.peak_signal_noise_ratio. See:
|
|
18
|
+
https://scikit-image.org/docs/dev/api/skimage.metrics.html.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
gt : NumPy array
|
|
23
|
+
Ground truth image.
|
|
24
|
+
pred : NumPy array
|
|
25
|
+
Predicted image.
|
|
26
|
+
range : float, optional
|
|
27
|
+
The images pixel range, by default 255.0.
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
float
|
|
32
|
+
PSNR value.
|
|
33
|
+
"""
|
|
34
|
+
return peak_signal_noise_ratio(gt, pred, data_range=range)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _zero_mean(x: np.ndarray) -> np.ndarray:
|
|
38
|
+
"""
|
|
39
|
+
Zero the mean of an array.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
x : NumPy array
|
|
44
|
+
Input array.
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
NumPy array
|
|
49
|
+
Zero-mean array.
|
|
50
|
+
"""
|
|
51
|
+
return x - np.mean(x)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _fix_range(gt: np.ndarray, x: np.ndarray) -> np.ndarray:
|
|
55
|
+
"""
|
|
56
|
+
Adjust the range of an array based on a reference ground-truth array.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
gt : np.ndarray
|
|
61
|
+
Ground truth image.
|
|
62
|
+
x : np.ndarray
|
|
63
|
+
Input array.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
np.ndarray
|
|
68
|
+
Range-adjusted array.
|
|
69
|
+
"""
|
|
70
|
+
a = np.sum(gt * x) / (np.sum(x * x))
|
|
71
|
+
return x * a
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _fix(gt: np.ndarray, x: np.ndarray) -> np.ndarray:
|
|
75
|
+
"""
|
|
76
|
+
Zero mean a groud truth array and adjust the range of the array.
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
gt : np.ndarray
|
|
81
|
+
Ground truth image.
|
|
82
|
+
x : np.ndarray
|
|
83
|
+
Input array.
|
|
84
|
+
|
|
85
|
+
Returns
|
|
86
|
+
-------
|
|
87
|
+
np.ndarray
|
|
88
|
+
Zero-mean and range-adjusted array.
|
|
89
|
+
"""
|
|
90
|
+
gt_ = _zero_mean(gt)
|
|
91
|
+
return _fix_range(gt_, _zero_mean(x))
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def scale_invariant_psnr(
|
|
95
|
+
gt: np.ndarray, pred: np.ndarray
|
|
96
|
+
) -> Union[float, torch.tensor]:
|
|
97
|
+
"""
|
|
98
|
+
Scale invariant PSNR.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
gt : np.ndarray
|
|
103
|
+
Ground truth image.
|
|
104
|
+
pred : np.ndarray
|
|
105
|
+
Predicted image.
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
Union[float, torch.tensor]
|
|
110
|
+
Scale invariant PSNR value.
|
|
111
|
+
"""
|
|
112
|
+
range_parameter = (np.max(gt) - np.min(gt)) / np.std(gt)
|
|
113
|
+
gt_ = _zero_mean(gt) / np.std(gt)
|
|
114
|
+
return psnr(_zero_mean(gt_), _fix(gt_, pred), range_parameter)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class MetricTracker:
|
|
118
|
+
"""
|
|
119
|
+
Metric tracker class.
|
|
120
|
+
|
|
121
|
+
This class is used to track values, sum, count and average of a metric over time.
|
|
122
|
+
|
|
123
|
+
Attributes
|
|
124
|
+
----------
|
|
125
|
+
val : int
|
|
126
|
+
Last value of the metric.
|
|
127
|
+
avg : torch.Tensor.float
|
|
128
|
+
Average value of the metric.
|
|
129
|
+
sum : int
|
|
130
|
+
Sum of the metric values (times number of values).
|
|
131
|
+
count : int
|
|
132
|
+
Number of values.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
def __init__(self) -> None:
|
|
136
|
+
"""Constructor."""
|
|
137
|
+
self.reset()
|
|
138
|
+
|
|
139
|
+
def reset(self) -> None:
|
|
140
|
+
"""Reset the metric tracker state."""
|
|
141
|
+
self.val = 0.0
|
|
142
|
+
self.avg: torch.Tensor.float = 0.0
|
|
143
|
+
self.sum = 0.0
|
|
144
|
+
self.count = 0.0
|
|
145
|
+
|
|
146
|
+
def update(self, value: int, n: int = 1) -> None:
|
|
147
|
+
"""
|
|
148
|
+
Update the metric tracker state.
|
|
149
|
+
|
|
150
|
+
Parameters
|
|
151
|
+
----------
|
|
152
|
+
value : int
|
|
153
|
+
Value to update the metric tracker with.
|
|
154
|
+
n : int
|
|
155
|
+
Number of values, equals to batch size.
|
|
156
|
+
"""
|
|
157
|
+
self.val = value
|
|
158
|
+
self.sum += value * n
|
|
159
|
+
self.count += n
|
|
160
|
+
self.avg = self.sum / self.count
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Normalization submodule.
|
|
3
|
+
|
|
4
|
+
These methods are used to normalize and denormalize images.
|
|
5
|
+
"""
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def normalize(img: np.ndarray, mean: float, std: float) -> np.ndarray:
|
|
10
|
+
"""
|
|
11
|
+
Normalize an image using mean and standard deviation.
|
|
12
|
+
|
|
13
|
+
Images are normalised by subtracting the mean and dividing by the standard
|
|
14
|
+
deviation.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
img : np.ndarray
|
|
19
|
+
Image to normalize.
|
|
20
|
+
mean : float
|
|
21
|
+
Mean.
|
|
22
|
+
std : float
|
|
23
|
+
Standard deviation.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
np.ndarray
|
|
28
|
+
Normalized array.
|
|
29
|
+
"""
|
|
30
|
+
zero_mean = img - mean
|
|
31
|
+
return zero_mean / std
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def denormalize(img: np.ndarray, mean: float, std: float) -> np.ndarray:
|
|
35
|
+
"""
|
|
36
|
+
Denormalize an image using mean and standard deviation.
|
|
37
|
+
|
|
38
|
+
Images are denormalised by multiplying by the standard deviation and adding the
|
|
39
|
+
mean.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
img : np.ndarray
|
|
44
|
+
Image to denormalize.
|
|
45
|
+
mean : float
|
|
46
|
+
Mean.
|
|
47
|
+
std : float
|
|
48
|
+
Standard deviation.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
np.ndarray
|
|
53
|
+
Denormalized array.
|
|
54
|
+
"""
|
|
55
|
+
return img * std + mean
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Convenience functions using torch.
|
|
3
|
+
|
|
4
|
+
These functions are used to control certain aspects and behaviours of PyTorch.
|
|
5
|
+
"""
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_device() -> torch.device:
|
|
12
|
+
"""
|
|
13
|
+
Select the device to use for training.
|
|
14
|
+
|
|
15
|
+
Returns
|
|
16
|
+
-------
|
|
17
|
+
torch.device
|
|
18
|
+
CUDA or CPU device, depending on availability of CUDA devices.
|
|
19
|
+
"""
|
|
20
|
+
if torch.cuda.is_available():
|
|
21
|
+
logging.info("CUDA available. Using GPU.")
|
|
22
|
+
device = torch.device("cuda")
|
|
23
|
+
else:
|
|
24
|
+
logging.info("CUDA not available. Using CPU.")
|
|
25
|
+
device = torch.device("cpu")
|
|
26
|
+
return device
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# def compile_model(model: torch.nn.Module) -> torch.nn.Module:
|
|
30
|
+
# """
|
|
31
|
+
# Torch.compile wrapper.
|
|
32
|
+
|
|
33
|
+
# Parameters
|
|
34
|
+
# ----------
|
|
35
|
+
# model : torch.nn.Module
|
|
36
|
+
# Model.
|
|
37
|
+
|
|
38
|
+
# Returns
|
|
39
|
+
# -------
|
|
40
|
+
# torch.nn.Module
|
|
41
|
+
# Compiled model if compile is available, the model itself otherwise.
|
|
42
|
+
# """
|
|
43
|
+
# if hasattr(torch, "compile") and sys.version_info.minor <= 9:
|
|
44
|
+
# return torch.compile(model, mode="reduce-overhead")
|
|
45
|
+
# else:
|
|
46
|
+
# return model
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# def seed_everything(seed: int) -> None:
|
|
50
|
+
# """
|
|
51
|
+
# Seed all random number generators for reproducibility.
|
|
52
|
+
|
|
53
|
+
# Parameters
|
|
54
|
+
# ----------
|
|
55
|
+
# seed : int
|
|
56
|
+
# Seed.
|
|
57
|
+
# """
|
|
58
|
+
# import random
|
|
59
|
+
|
|
60
|
+
# import numpy as np
|
|
61
|
+
|
|
62
|
+
# random.seed(seed)
|
|
63
|
+
# np.random.seed(seed)
|
|
64
|
+
# torch.manual_seed(seed)
|
|
65
|
+
# torch.cuda.manual_seed_all(seed)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# def setup_cudnn_reproducibility(
|
|
69
|
+
# deterministic: bool = True, benchmark: bool = True
|
|
70
|
+
# ) -> None:
|
|
71
|
+
# """
|
|
72
|
+
# Prepare CuDNN benchmark and sets it to be deterministic/non-deterministic mode.
|
|
73
|
+
|
|
74
|
+
# Parameters
|
|
75
|
+
# ----------
|
|
76
|
+
# deterministic : bool
|
|
77
|
+
# Deterministic mode, if running CuDNN backend.
|
|
78
|
+
# benchmark : bool
|
|
79
|
+
# If True, uses CuDNN heuristics to figure out which algorithm will be most
|
|
80
|
+
# performant for your model architecture and input. False may slow down training
|
|
81
|
+
# """
|
|
82
|
+
# if torch.cuda.is_available():
|
|
83
|
+
# if deterministic:
|
|
84
|
+
# deterministic = os.environ.get("CUDNN_DETERMINISTIC", "True") == "True"
|
|
85
|
+
# torch.backends.cudnn.deterministic = deterministic
|
|
86
|
+
|
|
87
|
+
# if benchmark:
|
|
88
|
+
# benchmark = os.environ.get("CUDNN_BENCHMARK", "True") == "True"
|
|
89
|
+
# torch.backends.cudnn.benchmark = benchmark
|