careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""Patch transform applying XY random 90 degrees rotations."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from careamics.transforms.transform import Transform
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class XYRandomRotate90(Transform):
|
|
11
|
+
"""Applies random 90 degree rotations to the YX axis.
|
|
12
|
+
|
|
13
|
+
This transform expects C(Z)YX dimensions.
|
|
14
|
+
|
|
15
|
+
Attributes
|
|
16
|
+
----------
|
|
17
|
+
rng : np.random.Generator
|
|
18
|
+
Random number generator.
|
|
19
|
+
p : float
|
|
20
|
+
Probability of applying the transform.
|
|
21
|
+
seed : Optional[int]
|
|
22
|
+
Random seed.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
p : float
|
|
27
|
+
Probability of applying the transform, by default 0.5.
|
|
28
|
+
seed : Optional[int]
|
|
29
|
+
Random seed, by default None.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, p: float = 0.5, seed: Optional[int] = None):
|
|
33
|
+
"""Constructor.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
p : float
|
|
38
|
+
Probability of applying the transform, by default 0.5.
|
|
39
|
+
seed : Optional[int]
|
|
40
|
+
Random seed, by default None.
|
|
41
|
+
"""
|
|
42
|
+
if p < 0 or p > 1:
|
|
43
|
+
raise ValueError("Probability must be in [0, 1].")
|
|
44
|
+
|
|
45
|
+
# probability to apply the transform
|
|
46
|
+
self.p = p
|
|
47
|
+
|
|
48
|
+
# numpy random generator
|
|
49
|
+
self.rng = np.random.default_rng(seed=seed)
|
|
50
|
+
|
|
51
|
+
def __call__(
|
|
52
|
+
self, patch: np.ndarray, target: Optional[np.ndarray] = None
|
|
53
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
54
|
+
"""Apply the transform to the source patch and the target (optional).
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
patch : np.ndarray
|
|
59
|
+
Patch, 2D or 3D, shape C(Z)YX.
|
|
60
|
+
target : Optional[np.ndarray], optional
|
|
61
|
+
Target for the patch, by default None.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
66
|
+
Transformed patch and target.
|
|
67
|
+
"""
|
|
68
|
+
if self.rng.random() > self.p:
|
|
69
|
+
return patch, target
|
|
70
|
+
|
|
71
|
+
# number of rotations
|
|
72
|
+
n_rot = self.rng.integers(1, 4)
|
|
73
|
+
|
|
74
|
+
axes = (-2, -1)
|
|
75
|
+
patch_transformed = self._apply(patch, n_rot, axes)
|
|
76
|
+
target_transformed = (
|
|
77
|
+
self._apply(target, n_rot, axes) if target is not None else None
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
return patch_transformed, target_transformed
|
|
81
|
+
|
|
82
|
+
def _apply(
|
|
83
|
+
self, patch: np.ndarray, n_rot: int, axes: Tuple[int, int]
|
|
84
|
+
) -> np.ndarray:
|
|
85
|
+
"""Apply the transform to the image.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
patch : np.ndarray
|
|
90
|
+
Image or image patch, 2D or 3D, shape C(Z)YX.
|
|
91
|
+
n_rot : int
|
|
92
|
+
Number of 90 degree rotations.
|
|
93
|
+
axes : Tuple[int, int]
|
|
94
|
+
Axes along which to rotate the patch.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
np.ndarray
|
|
99
|
+
Transformed patch.
|
|
100
|
+
"""
|
|
101
|
+
return np.ascontiguousarray(np.rot90(patch, k=n_rot, axes=axes))
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Utils module."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"cwd",
|
|
5
|
+
"get_ram_size",
|
|
6
|
+
"check_path_exists",
|
|
7
|
+
"BaseEnum",
|
|
8
|
+
"get_logger",
|
|
9
|
+
"get_careamics_home",
|
|
10
|
+
"autocorrelation",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
from .autocorrelation import autocorrelation
|
|
15
|
+
from .base_enum import BaseEnum
|
|
16
|
+
from .context import cwd, get_careamics_home
|
|
17
|
+
from .logging import get_logger
|
|
18
|
+
from .path_utils import check_path_exists
|
|
19
|
+
from .ram import get_ram_size
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Autocorrelation function."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def autocorrelation(image: NDArray) -> NDArray:
|
|
8
|
+
"""Compute the autocorrelation of an image.
|
|
9
|
+
|
|
10
|
+
This method is used to explore spatial correlations in images,
|
|
11
|
+
in particular in the noise.
|
|
12
|
+
|
|
13
|
+
The autocorrelation is normalized to the zero-shift value, which is centered in
|
|
14
|
+
the resulting images.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
image : NDArray
|
|
19
|
+
Input image.
|
|
20
|
+
|
|
21
|
+
Returns
|
|
22
|
+
-------
|
|
23
|
+
numpy.ndarray
|
|
24
|
+
Autocorrelation of the input image.
|
|
25
|
+
"""
|
|
26
|
+
# normalize image
|
|
27
|
+
image = (image - np.mean(image)) / np.std(image)
|
|
28
|
+
|
|
29
|
+
# compute autocorrelation in fourier space
|
|
30
|
+
image = np.fft.fftn(image)
|
|
31
|
+
image = np.abs(image) ** 2
|
|
32
|
+
image = np.fft.ifftn(image).real
|
|
33
|
+
|
|
34
|
+
# normalize to zero shift value
|
|
35
|
+
image = image / image.flat[0]
|
|
36
|
+
|
|
37
|
+
# shift zero frequency to center
|
|
38
|
+
image = np.fft.fftshift(image)
|
|
39
|
+
|
|
40
|
+
return image
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""A base class for Enum that allows checking if a value is in the Enum."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum, EnumMeta
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class _ContainerEnum(EnumMeta):
|
|
8
|
+
"""Metaclass for Enum with __contains__ method."""
|
|
9
|
+
|
|
10
|
+
def __contains__(cls, item: Any) -> bool:
|
|
11
|
+
"""Check if an item is in the Enum.
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
item : Any
|
|
16
|
+
Item to check.
|
|
17
|
+
|
|
18
|
+
Returns
|
|
19
|
+
-------
|
|
20
|
+
bool
|
|
21
|
+
True if the item is in the Enum, False otherwise.
|
|
22
|
+
"""
|
|
23
|
+
try:
|
|
24
|
+
cls(item)
|
|
25
|
+
except ValueError:
|
|
26
|
+
return False
|
|
27
|
+
return True
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def has_value(cls, value: Any) -> bool:
|
|
31
|
+
"""Check if a value is in the Enum.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
value : Any
|
|
36
|
+
Value to check.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
bool
|
|
41
|
+
True if the value is in the Enum, False otherwise.
|
|
42
|
+
"""
|
|
43
|
+
return value in cls._value2member_map_
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class BaseEnum(Enum, metaclass=_ContainerEnum):
|
|
47
|
+
"""Base Enum class, allowing checking if a value is in the enum.
|
|
48
|
+
|
|
49
|
+
Example
|
|
50
|
+
-------
|
|
51
|
+
>>> from careamics.utils.base_enum import BaseEnum
|
|
52
|
+
>>> # Define a new enum
|
|
53
|
+
>>> class BaseEnumExtension(BaseEnum):
|
|
54
|
+
... VALUE = "value"
|
|
55
|
+
>>> # Check if value is in the enum
|
|
56
|
+
>>> "value" in BaseEnumExtension
|
|
57
|
+
True
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
pass
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Context submodule.
|
|
3
|
+
|
|
4
|
+
A convenience function to change the working directory in order to save data.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from contextlib import contextmanager
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Iterator, Union
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_careamics_home() -> Path:
|
|
14
|
+
"""Return the CAREamics home directory.
|
|
15
|
+
|
|
16
|
+
CAREamics home directory is a hidden folder in home.
|
|
17
|
+
|
|
18
|
+
Returns
|
|
19
|
+
-------
|
|
20
|
+
Path
|
|
21
|
+
CAREamics home directory path.
|
|
22
|
+
"""
|
|
23
|
+
home = Path.home() / ".careamics"
|
|
24
|
+
|
|
25
|
+
if not home.exists():
|
|
26
|
+
home.mkdir(parents=True, exist_ok=True)
|
|
27
|
+
|
|
28
|
+
return home
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@contextmanager
|
|
32
|
+
def cwd(path: Union[str, Path]) -> Iterator[None]:
|
|
33
|
+
"""
|
|
34
|
+
Change the current working directory to the given path.
|
|
35
|
+
|
|
36
|
+
This method can be used to generate files in a specific directory, once out of the
|
|
37
|
+
context, the working directory is set back to the original one.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
path : Union[str,Path]
|
|
42
|
+
New working directory path.
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
Iterator[None]
|
|
47
|
+
None values.
|
|
48
|
+
|
|
49
|
+
Examples
|
|
50
|
+
--------
|
|
51
|
+
The context is whcnaged within the block and then restored to the original one.
|
|
52
|
+
|
|
53
|
+
>>> with cwd(my_path):
|
|
54
|
+
... pass # do something
|
|
55
|
+
"""
|
|
56
|
+
path = Path(path)
|
|
57
|
+
|
|
58
|
+
if not path.exists():
|
|
59
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
60
|
+
|
|
61
|
+
old_pwd = Path(".").absolute()
|
|
62
|
+
os.chdir(path)
|
|
63
|
+
try:
|
|
64
|
+
yield
|
|
65
|
+
finally:
|
|
66
|
+
os.chdir(old_pwd)
|
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Logging submodule.
|
|
3
|
+
|
|
4
|
+
The methods are responsible for the in-console logger.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import sys
|
|
9
|
+
import time
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any, Dict, Generator, List, Optional, Union
|
|
12
|
+
|
|
13
|
+
LOGGERS: dict = {}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_logger(
|
|
17
|
+
name: str,
|
|
18
|
+
log_level: int = logging.INFO,
|
|
19
|
+
log_path: Optional[Union[str, Path]] = None,
|
|
20
|
+
) -> logging.Logger:
|
|
21
|
+
"""
|
|
22
|
+
Create a python logger instance with configured handlers.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
name : str
|
|
27
|
+
Name of the logger.
|
|
28
|
+
log_level : int, optional
|
|
29
|
+
Log level (info, error etc.), by default logging.INFO.
|
|
30
|
+
log_path : Optional[Union[str, Path]], optional
|
|
31
|
+
Path in which to save the log, by default None.
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
logging.Logger
|
|
36
|
+
Logger.
|
|
37
|
+
"""
|
|
38
|
+
logger = logging.getLogger(name)
|
|
39
|
+
logger.propagate = False
|
|
40
|
+
|
|
41
|
+
if name in LOGGERS:
|
|
42
|
+
return logger
|
|
43
|
+
|
|
44
|
+
for logger_name in LOGGERS:
|
|
45
|
+
if name.startswith(logger_name):
|
|
46
|
+
return logger
|
|
47
|
+
|
|
48
|
+
logger.propagate = False
|
|
49
|
+
|
|
50
|
+
if log_path:
|
|
51
|
+
handlers = [
|
|
52
|
+
logging.StreamHandler(),
|
|
53
|
+
logging.FileHandler(log_path),
|
|
54
|
+
]
|
|
55
|
+
else:
|
|
56
|
+
handlers = [logging.StreamHandler()]
|
|
57
|
+
|
|
58
|
+
formatter = logging.Formatter("%(message)s")
|
|
59
|
+
|
|
60
|
+
for handler in handlers:
|
|
61
|
+
handler.setFormatter(formatter) # type: ignore
|
|
62
|
+
handler.setLevel(log_level) # type: ignore
|
|
63
|
+
logger.addHandler(handler) # type: ignore
|
|
64
|
+
|
|
65
|
+
logger.setLevel(log_level)
|
|
66
|
+
LOGGERS[name] = True
|
|
67
|
+
|
|
68
|
+
logger.propagate = False
|
|
69
|
+
|
|
70
|
+
return logger
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class ProgressBar:
|
|
74
|
+
"""
|
|
75
|
+
Keras style progress bar.
|
|
76
|
+
|
|
77
|
+
Adapted from https://github.com/yueyericardo/pkbar.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
max_value : Optional[int], optional
|
|
82
|
+
Maximum progress bar value, by default None.
|
|
83
|
+
epoch : Optional[int], optional
|
|
84
|
+
Zero-indexed current epoch, by default None.
|
|
85
|
+
num_epochs : Optional[int], optional
|
|
86
|
+
Total number of epochs, by default None.
|
|
87
|
+
stateful_metrics : Optional[List], optional
|
|
88
|
+
Iterable of string names of metrics that should *not* be averaged over time.
|
|
89
|
+
Metrics in this list will be displayed as-is. All others will be averaged by
|
|
90
|
+
the progress bar before display, by default None.
|
|
91
|
+
always_stateful : bool, optional
|
|
92
|
+
Whether to set all metrics to be stateful, by default False.
|
|
93
|
+
mode : str, optional
|
|
94
|
+
Mode, one of "train", "val", or "predict", by default "train".
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
max_value: Optional[int] = None,
|
|
100
|
+
epoch: Optional[int] = None,
|
|
101
|
+
num_epochs: Optional[int] = None,
|
|
102
|
+
stateful_metrics: Optional[List] = None,
|
|
103
|
+
always_stateful: bool = False,
|
|
104
|
+
mode: str = "train",
|
|
105
|
+
) -> None:
|
|
106
|
+
"""
|
|
107
|
+
Constructor.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
max_value : Optional[int], optional
|
|
112
|
+
Maximum progress bar value, by default None.
|
|
113
|
+
epoch : Optional[int], optional
|
|
114
|
+
Zero-indexed current epoch, by default None.
|
|
115
|
+
num_epochs : Optional[int], optional
|
|
116
|
+
Total number of epochs, by default None.
|
|
117
|
+
stateful_metrics : Optional[List], optional
|
|
118
|
+
Iterable of string names of metrics that should *not* be averaged over time.
|
|
119
|
+
Metrics in this list will be displayed as-is. All others will be averaged by
|
|
120
|
+
the progress bar before display, by default None.
|
|
121
|
+
always_stateful : bool, optional
|
|
122
|
+
Whether to set all metrics to be stateful, by default False.
|
|
123
|
+
mode : str, optional
|
|
124
|
+
Mode, one of "train", "val", or "predict", by default "train".
|
|
125
|
+
"""
|
|
126
|
+
self.max_value = max_value
|
|
127
|
+
# Width of the progress bar
|
|
128
|
+
self.width = 30
|
|
129
|
+
self.always_stateful = always_stateful
|
|
130
|
+
|
|
131
|
+
if (epoch is not None) and (num_epochs is not None):
|
|
132
|
+
print(f"Epoch: {epoch + 1}/{num_epochs}")
|
|
133
|
+
|
|
134
|
+
if stateful_metrics:
|
|
135
|
+
self.stateful_metrics = set(stateful_metrics)
|
|
136
|
+
else:
|
|
137
|
+
self.stateful_metrics = set()
|
|
138
|
+
|
|
139
|
+
self._dynamic_display = (
|
|
140
|
+
(hasattr(sys.stdout, "isatty") and sys.stdout.isatty())
|
|
141
|
+
or "ipykernel" in sys.modules
|
|
142
|
+
or "posix" in sys.modules
|
|
143
|
+
)
|
|
144
|
+
self._total_width = 0
|
|
145
|
+
self._seen_so_far = 0
|
|
146
|
+
# We use a dict + list to avoid garbage collection
|
|
147
|
+
# issues found in OrderedDict
|
|
148
|
+
self._values: Dict[Any, Any] = {}
|
|
149
|
+
self._values_order: List[Any] = []
|
|
150
|
+
self._start = time.time()
|
|
151
|
+
self._last_update = 0.0
|
|
152
|
+
self.spin = self.spinning_cursor() if self.max_value is None else None
|
|
153
|
+
if mode == "train" and self.max_value is None:
|
|
154
|
+
self.message = "Estimating dataset size"
|
|
155
|
+
elif mode == "val":
|
|
156
|
+
self.message = "Validating"
|
|
157
|
+
elif mode == "predict":
|
|
158
|
+
self.message = "Denoising"
|
|
159
|
+
|
|
160
|
+
def update(
|
|
161
|
+
self, current_step: int, batch_size: int = 1, values: Optional[List] = None
|
|
162
|
+
) -> None:
|
|
163
|
+
"""
|
|
164
|
+
Update the progress bar.
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
current_step : int
|
|
169
|
+
Index of the current step.
|
|
170
|
+
batch_size : int, optional
|
|
171
|
+
Batch size, by default 1.
|
|
172
|
+
values : Optional[List], optional
|
|
173
|
+
Updated metrics values, by default None.
|
|
174
|
+
"""
|
|
175
|
+
values = values or []
|
|
176
|
+
for k, v in values:
|
|
177
|
+
# if torch tensor, convert it to numpy
|
|
178
|
+
if str(type(v)) == "<class 'torch.Tensor'>":
|
|
179
|
+
v = v.detach().cpu().numpy()
|
|
180
|
+
|
|
181
|
+
if k not in self._values_order:
|
|
182
|
+
self._values_order.append(k)
|
|
183
|
+
if k not in self.stateful_metrics and not self.always_stateful:
|
|
184
|
+
if k not in self._values:
|
|
185
|
+
self._values[k] = [
|
|
186
|
+
v * (current_step - self._seen_so_far),
|
|
187
|
+
current_step - self._seen_so_far,
|
|
188
|
+
]
|
|
189
|
+
else:
|
|
190
|
+
self._values[k][0] += v * (current_step - self._seen_so_far)
|
|
191
|
+
self._values[k][1] += current_step - self._seen_so_far
|
|
192
|
+
else:
|
|
193
|
+
# Stateful metrics output a numeric value. This representation
|
|
194
|
+
# means "take an average from a single value" but keeps the
|
|
195
|
+
# numeric formatting.
|
|
196
|
+
self._values[k] = [v, 1]
|
|
197
|
+
|
|
198
|
+
self._seen_so_far = current_step
|
|
199
|
+
|
|
200
|
+
now = time.time()
|
|
201
|
+
info = f" - {(now - self._start):.0f}s"
|
|
202
|
+
|
|
203
|
+
prev_total_width = self._total_width
|
|
204
|
+
if self._dynamic_display:
|
|
205
|
+
sys.stdout.write("\b" * prev_total_width)
|
|
206
|
+
sys.stdout.write("\r")
|
|
207
|
+
else:
|
|
208
|
+
sys.stdout.write("\n")
|
|
209
|
+
|
|
210
|
+
if self.max_value is not None:
|
|
211
|
+
bar = f"{current_step}/{self.max_value} ["
|
|
212
|
+
progress = float(current_step) / self.max_value
|
|
213
|
+
progress_width = int(self.width * progress)
|
|
214
|
+
if progress_width > 0:
|
|
215
|
+
bar += "=" * (progress_width - 1)
|
|
216
|
+
if current_step < self.max_value:
|
|
217
|
+
bar += ">"
|
|
218
|
+
else:
|
|
219
|
+
bar += "="
|
|
220
|
+
bar += "." * (self.width - progress_width)
|
|
221
|
+
bar += "]"
|
|
222
|
+
else:
|
|
223
|
+
bar = (
|
|
224
|
+
f"{self.message} {next(self.spin)}, tile " # type: ignore
|
|
225
|
+
f"No. {current_step * batch_size}"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
self._total_width = len(bar)
|
|
229
|
+
sys.stdout.write(bar)
|
|
230
|
+
|
|
231
|
+
if current_step > 0:
|
|
232
|
+
time_per_unit = (now - self._start) / current_step
|
|
233
|
+
else:
|
|
234
|
+
time_per_unit = 0
|
|
235
|
+
|
|
236
|
+
if time_per_unit >= 1 or time_per_unit == 0:
|
|
237
|
+
info += f" {time_per_unit:.0f}s/step"
|
|
238
|
+
elif time_per_unit >= 1e-3:
|
|
239
|
+
info += f" {time_per_unit * 1e3:.0f}ms/step"
|
|
240
|
+
else:
|
|
241
|
+
info += f" {time_per_unit * 1e6:.0f}us/step"
|
|
242
|
+
|
|
243
|
+
for k in self._values_order:
|
|
244
|
+
info += f" - {k}:"
|
|
245
|
+
if isinstance(self._values[k], list):
|
|
246
|
+
avg = self._values[k][0] / max(1, self._values[k][1])
|
|
247
|
+
if abs(avg) > 1e-3:
|
|
248
|
+
info += f" {avg:.4f}"
|
|
249
|
+
else:
|
|
250
|
+
info += f" {avg:.4e}"
|
|
251
|
+
else:
|
|
252
|
+
info += f" {self._values[k]}s"
|
|
253
|
+
|
|
254
|
+
self._total_width += len(info)
|
|
255
|
+
if prev_total_width > self._total_width:
|
|
256
|
+
info += " " * (prev_total_width - self._total_width)
|
|
257
|
+
|
|
258
|
+
if self.max_value is not None and current_step >= self.max_value:
|
|
259
|
+
info += "\n"
|
|
260
|
+
|
|
261
|
+
sys.stdout.write(info)
|
|
262
|
+
sys.stdout.flush()
|
|
263
|
+
|
|
264
|
+
self._last_update = now
|
|
265
|
+
|
|
266
|
+
def add(self, n: int, values: Optional[List] = None) -> None:
|
|
267
|
+
"""
|
|
268
|
+
Update the progress bar by n steps.
|
|
269
|
+
|
|
270
|
+
Parameters
|
|
271
|
+
----------
|
|
272
|
+
n : int
|
|
273
|
+
Number of steps to increase the progress bar with.
|
|
274
|
+
values : Optional[List], optional
|
|
275
|
+
Updated metrics values, by default None.
|
|
276
|
+
"""
|
|
277
|
+
self.update(self._seen_so_far + n, 1, values=values)
|
|
278
|
+
|
|
279
|
+
def spinning_cursor(self) -> Generator:
|
|
280
|
+
"""
|
|
281
|
+
Generate a spinning cursor animation.
|
|
282
|
+
|
|
283
|
+
Taken from https://github.com/manrajgrover/py-spinners/tree/master.
|
|
284
|
+
|
|
285
|
+
Returns
|
|
286
|
+
-------
|
|
287
|
+
Generator
|
|
288
|
+
Generator of animation frames.
|
|
289
|
+
"""
|
|
290
|
+
while True:
|
|
291
|
+
yield from [
|
|
292
|
+
"▓ ----- ▒",
|
|
293
|
+
"▓ ----- ▒",
|
|
294
|
+
"▓ ----- ▒",
|
|
295
|
+
"▓ ->--- ▒",
|
|
296
|
+
"▓ ->--- ▒",
|
|
297
|
+
"▓ ->--- ▒",
|
|
298
|
+
"▓ -->-- ▒",
|
|
299
|
+
"▓ -->-- ▒",
|
|
300
|
+
"▓ -->-- ▒",
|
|
301
|
+
"▓ --->- ▒",
|
|
302
|
+
"▓ --->- ▒",
|
|
303
|
+
"▓ --->- ▒",
|
|
304
|
+
"▓ ----> ▒",
|
|
305
|
+
"▓ ----> ▒",
|
|
306
|
+
"▓ ----> ▒",
|
|
307
|
+
"▒ ----- ░",
|
|
308
|
+
"▒ ----- ░",
|
|
309
|
+
"▒ ----- ░",
|
|
310
|
+
"▒ ->--- ░",
|
|
311
|
+
"▒ ->--- ░",
|
|
312
|
+
"▒ ->--- ░",
|
|
313
|
+
"▒ -->-- ░",
|
|
314
|
+
"▒ -->-- ░",
|
|
315
|
+
"▒ -->-- ░",
|
|
316
|
+
"▒ --->- ░",
|
|
317
|
+
"▒ --->- ░",
|
|
318
|
+
"▒ --->- ░",
|
|
319
|
+
"▒ ----> ░",
|
|
320
|
+
"▒ ----> ░",
|
|
321
|
+
"▒ ----> ░",
|
|
322
|
+
]
|