careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,101 @@
1
+ """Patch transform applying XY random 90 degrees rotations."""
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from careamics.transforms.transform import Transform
8
+
9
+
10
+ class XYRandomRotate90(Transform):
11
+ """Applies random 90 degree rotations to the YX axis.
12
+
13
+ This transform expects C(Z)YX dimensions.
14
+
15
+ Attributes
16
+ ----------
17
+ rng : np.random.Generator
18
+ Random number generator.
19
+ p : float
20
+ Probability of applying the transform.
21
+ seed : Optional[int]
22
+ Random seed.
23
+
24
+ Parameters
25
+ ----------
26
+ p : float
27
+ Probability of applying the transform, by default 0.5.
28
+ seed : Optional[int]
29
+ Random seed, by default None.
30
+ """
31
+
32
+ def __init__(self, p: float = 0.5, seed: Optional[int] = None):
33
+ """Constructor.
34
+
35
+ Parameters
36
+ ----------
37
+ p : float
38
+ Probability of applying the transform, by default 0.5.
39
+ seed : Optional[int]
40
+ Random seed, by default None.
41
+ """
42
+ if p < 0 or p > 1:
43
+ raise ValueError("Probability must be in [0, 1].")
44
+
45
+ # probability to apply the transform
46
+ self.p = p
47
+
48
+ # numpy random generator
49
+ self.rng = np.random.default_rng(seed=seed)
50
+
51
+ def __call__(
52
+ self, patch: np.ndarray, target: Optional[np.ndarray] = None
53
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
54
+ """Apply the transform to the source patch and the target (optional).
55
+
56
+ Parameters
57
+ ----------
58
+ patch : np.ndarray
59
+ Patch, 2D or 3D, shape C(Z)YX.
60
+ target : Optional[np.ndarray], optional
61
+ Target for the patch, by default None.
62
+
63
+ Returns
64
+ -------
65
+ Tuple[np.ndarray, Optional[np.ndarray]]
66
+ Transformed patch and target.
67
+ """
68
+ if self.rng.random() > self.p:
69
+ return patch, target
70
+
71
+ # number of rotations
72
+ n_rot = self.rng.integers(1, 4)
73
+
74
+ axes = (-2, -1)
75
+ patch_transformed = self._apply(patch, n_rot, axes)
76
+ target_transformed = (
77
+ self._apply(target, n_rot, axes) if target is not None else None
78
+ )
79
+
80
+ return patch_transformed, target_transformed
81
+
82
+ def _apply(
83
+ self, patch: np.ndarray, n_rot: int, axes: Tuple[int, int]
84
+ ) -> np.ndarray:
85
+ """Apply the transform to the image.
86
+
87
+ Parameters
88
+ ----------
89
+ patch : np.ndarray
90
+ Image or image patch, 2D or 3D, shape C(Z)YX.
91
+ n_rot : int
92
+ Number of 90 degree rotations.
93
+ axes : Tuple[int, int]
94
+ Axes along which to rotate the patch.
95
+
96
+ Returns
97
+ -------
98
+ np.ndarray
99
+ Transformed patch.
100
+ """
101
+ return np.ascontiguousarray(np.rot90(patch, k=n_rot, axes=axes))
@@ -0,0 +1,19 @@
1
+ """Utils module."""
2
+
3
+ __all__ = [
4
+ "cwd",
5
+ "get_ram_size",
6
+ "check_path_exists",
7
+ "BaseEnum",
8
+ "get_logger",
9
+ "get_careamics_home",
10
+ "autocorrelation",
11
+ ]
12
+
13
+
14
+ from .autocorrelation import autocorrelation
15
+ from .base_enum import BaseEnum
16
+ from .context import cwd, get_careamics_home
17
+ from .logging import get_logger
18
+ from .path_utils import check_path_exists
19
+ from .ram import get_ram_size
@@ -0,0 +1,40 @@
1
+ """Autocorrelation function."""
2
+
3
+ import numpy as np
4
+ from numpy.typing import NDArray
5
+
6
+
7
+ def autocorrelation(image: NDArray) -> NDArray:
8
+ """Compute the autocorrelation of an image.
9
+
10
+ This method is used to explore spatial correlations in images,
11
+ in particular in the noise.
12
+
13
+ The autocorrelation is normalized to the zero-shift value, which is centered in
14
+ the resulting images.
15
+
16
+ Parameters
17
+ ----------
18
+ image : NDArray
19
+ Input image.
20
+
21
+ Returns
22
+ -------
23
+ numpy.ndarray
24
+ Autocorrelation of the input image.
25
+ """
26
+ # normalize image
27
+ image = (image - np.mean(image)) / np.std(image)
28
+
29
+ # compute autocorrelation in fourier space
30
+ image = np.fft.fftn(image)
31
+ image = np.abs(image) ** 2
32
+ image = np.fft.ifftn(image).real
33
+
34
+ # normalize to zero shift value
35
+ image = image / image.flat[0]
36
+
37
+ # shift zero frequency to center
38
+ image = np.fft.fftshift(image)
39
+
40
+ return image
@@ -0,0 +1,60 @@
1
+ """A base class for Enum that allows checking if a value is in the Enum."""
2
+
3
+ from enum import Enum, EnumMeta
4
+ from typing import Any
5
+
6
+
7
+ class _ContainerEnum(EnumMeta):
8
+ """Metaclass for Enum with __contains__ method."""
9
+
10
+ def __contains__(cls, item: Any) -> bool:
11
+ """Check if an item is in the Enum.
12
+
13
+ Parameters
14
+ ----------
15
+ item : Any
16
+ Item to check.
17
+
18
+ Returns
19
+ -------
20
+ bool
21
+ True if the item is in the Enum, False otherwise.
22
+ """
23
+ try:
24
+ cls(item)
25
+ except ValueError:
26
+ return False
27
+ return True
28
+
29
+ @classmethod
30
+ def has_value(cls, value: Any) -> bool:
31
+ """Check if a value is in the Enum.
32
+
33
+ Parameters
34
+ ----------
35
+ value : Any
36
+ Value to check.
37
+
38
+ Returns
39
+ -------
40
+ bool
41
+ True if the value is in the Enum, False otherwise.
42
+ """
43
+ return value in cls._value2member_map_
44
+
45
+
46
+ class BaseEnum(Enum, metaclass=_ContainerEnum):
47
+ """Base Enum class, allowing checking if a value is in the enum.
48
+
49
+ Example
50
+ -------
51
+ >>> from careamics.utils.base_enum import BaseEnum
52
+ >>> # Define a new enum
53
+ >>> class BaseEnumExtension(BaseEnum):
54
+ ... VALUE = "value"
55
+ >>> # Check if value is in the enum
56
+ >>> "value" in BaseEnumExtension
57
+ True
58
+ """
59
+
60
+ pass
@@ -0,0 +1,66 @@
1
+ """
2
+ Context submodule.
3
+
4
+ A convenience function to change the working directory in order to save data.
5
+ """
6
+
7
+ import os
8
+ from contextlib import contextmanager
9
+ from pathlib import Path
10
+ from typing import Iterator, Union
11
+
12
+
13
+ def get_careamics_home() -> Path:
14
+ """Return the CAREamics home directory.
15
+
16
+ CAREamics home directory is a hidden folder in home.
17
+
18
+ Returns
19
+ -------
20
+ Path
21
+ CAREamics home directory path.
22
+ """
23
+ home = Path.home() / ".careamics"
24
+
25
+ if not home.exists():
26
+ home.mkdir(parents=True, exist_ok=True)
27
+
28
+ return home
29
+
30
+
31
+ @contextmanager
32
+ def cwd(path: Union[str, Path]) -> Iterator[None]:
33
+ """
34
+ Change the current working directory to the given path.
35
+
36
+ This method can be used to generate files in a specific directory, once out of the
37
+ context, the working directory is set back to the original one.
38
+
39
+ Parameters
40
+ ----------
41
+ path : Union[str,Path]
42
+ New working directory path.
43
+
44
+ Returns
45
+ -------
46
+ Iterator[None]
47
+ None values.
48
+
49
+ Examples
50
+ --------
51
+ The context is whcnaged within the block and then restored to the original one.
52
+
53
+ >>> with cwd(my_path):
54
+ ... pass # do something
55
+ """
56
+ path = Path(path)
57
+
58
+ if not path.exists():
59
+ path.mkdir(parents=True, exist_ok=True)
60
+
61
+ old_pwd = Path(".").absolute()
62
+ os.chdir(path)
63
+ try:
64
+ yield
65
+ finally:
66
+ os.chdir(old_pwd)
@@ -0,0 +1,322 @@
1
+ """
2
+ Logging submodule.
3
+
4
+ The methods are responsible for the in-console logger.
5
+ """
6
+
7
+ import logging
8
+ import sys
9
+ import time
10
+ from pathlib import Path
11
+ from typing import Any, Dict, Generator, List, Optional, Union
12
+
13
+ LOGGERS: dict = {}
14
+
15
+
16
+ def get_logger(
17
+ name: str,
18
+ log_level: int = logging.INFO,
19
+ log_path: Optional[Union[str, Path]] = None,
20
+ ) -> logging.Logger:
21
+ """
22
+ Create a python logger instance with configured handlers.
23
+
24
+ Parameters
25
+ ----------
26
+ name : str
27
+ Name of the logger.
28
+ log_level : int, optional
29
+ Log level (info, error etc.), by default logging.INFO.
30
+ log_path : Optional[Union[str, Path]], optional
31
+ Path in which to save the log, by default None.
32
+
33
+ Returns
34
+ -------
35
+ logging.Logger
36
+ Logger.
37
+ """
38
+ logger = logging.getLogger(name)
39
+ logger.propagate = False
40
+
41
+ if name in LOGGERS:
42
+ return logger
43
+
44
+ for logger_name in LOGGERS:
45
+ if name.startswith(logger_name):
46
+ return logger
47
+
48
+ logger.propagate = False
49
+
50
+ if log_path:
51
+ handlers = [
52
+ logging.StreamHandler(),
53
+ logging.FileHandler(log_path),
54
+ ]
55
+ else:
56
+ handlers = [logging.StreamHandler()]
57
+
58
+ formatter = logging.Formatter("%(message)s")
59
+
60
+ for handler in handlers:
61
+ handler.setFormatter(formatter) # type: ignore
62
+ handler.setLevel(log_level) # type: ignore
63
+ logger.addHandler(handler) # type: ignore
64
+
65
+ logger.setLevel(log_level)
66
+ LOGGERS[name] = True
67
+
68
+ logger.propagate = False
69
+
70
+ return logger
71
+
72
+
73
+ class ProgressBar:
74
+ """
75
+ Keras style progress bar.
76
+
77
+ Adapted from https://github.com/yueyericardo/pkbar.
78
+
79
+ Parameters
80
+ ----------
81
+ max_value : Optional[int], optional
82
+ Maximum progress bar value, by default None.
83
+ epoch : Optional[int], optional
84
+ Zero-indexed current epoch, by default None.
85
+ num_epochs : Optional[int], optional
86
+ Total number of epochs, by default None.
87
+ stateful_metrics : Optional[List], optional
88
+ Iterable of string names of metrics that should *not* be averaged over time.
89
+ Metrics in this list will be displayed as-is. All others will be averaged by
90
+ the progress bar before display, by default None.
91
+ always_stateful : bool, optional
92
+ Whether to set all metrics to be stateful, by default False.
93
+ mode : str, optional
94
+ Mode, one of "train", "val", or "predict", by default "train".
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ max_value: Optional[int] = None,
100
+ epoch: Optional[int] = None,
101
+ num_epochs: Optional[int] = None,
102
+ stateful_metrics: Optional[List] = None,
103
+ always_stateful: bool = False,
104
+ mode: str = "train",
105
+ ) -> None:
106
+ """
107
+ Constructor.
108
+
109
+ Parameters
110
+ ----------
111
+ max_value : Optional[int], optional
112
+ Maximum progress bar value, by default None.
113
+ epoch : Optional[int], optional
114
+ Zero-indexed current epoch, by default None.
115
+ num_epochs : Optional[int], optional
116
+ Total number of epochs, by default None.
117
+ stateful_metrics : Optional[List], optional
118
+ Iterable of string names of metrics that should *not* be averaged over time.
119
+ Metrics in this list will be displayed as-is. All others will be averaged by
120
+ the progress bar before display, by default None.
121
+ always_stateful : bool, optional
122
+ Whether to set all metrics to be stateful, by default False.
123
+ mode : str, optional
124
+ Mode, one of "train", "val", or "predict", by default "train".
125
+ """
126
+ self.max_value = max_value
127
+ # Width of the progress bar
128
+ self.width = 30
129
+ self.always_stateful = always_stateful
130
+
131
+ if (epoch is not None) and (num_epochs is not None):
132
+ print(f"Epoch: {epoch + 1}/{num_epochs}")
133
+
134
+ if stateful_metrics:
135
+ self.stateful_metrics = set(stateful_metrics)
136
+ else:
137
+ self.stateful_metrics = set()
138
+
139
+ self._dynamic_display = (
140
+ (hasattr(sys.stdout, "isatty") and sys.stdout.isatty())
141
+ or "ipykernel" in sys.modules
142
+ or "posix" in sys.modules
143
+ )
144
+ self._total_width = 0
145
+ self._seen_so_far = 0
146
+ # We use a dict + list to avoid garbage collection
147
+ # issues found in OrderedDict
148
+ self._values: Dict[Any, Any] = {}
149
+ self._values_order: List[Any] = []
150
+ self._start = time.time()
151
+ self._last_update = 0.0
152
+ self.spin = self.spinning_cursor() if self.max_value is None else None
153
+ if mode == "train" and self.max_value is None:
154
+ self.message = "Estimating dataset size"
155
+ elif mode == "val":
156
+ self.message = "Validating"
157
+ elif mode == "predict":
158
+ self.message = "Denoising"
159
+
160
+ def update(
161
+ self, current_step: int, batch_size: int = 1, values: Optional[List] = None
162
+ ) -> None:
163
+ """
164
+ Update the progress bar.
165
+
166
+ Parameters
167
+ ----------
168
+ current_step : int
169
+ Index of the current step.
170
+ batch_size : int, optional
171
+ Batch size, by default 1.
172
+ values : Optional[List], optional
173
+ Updated metrics values, by default None.
174
+ """
175
+ values = values or []
176
+ for k, v in values:
177
+ # if torch tensor, convert it to numpy
178
+ if str(type(v)) == "<class 'torch.Tensor'>":
179
+ v = v.detach().cpu().numpy()
180
+
181
+ if k not in self._values_order:
182
+ self._values_order.append(k)
183
+ if k not in self.stateful_metrics and not self.always_stateful:
184
+ if k not in self._values:
185
+ self._values[k] = [
186
+ v * (current_step - self._seen_so_far),
187
+ current_step - self._seen_so_far,
188
+ ]
189
+ else:
190
+ self._values[k][0] += v * (current_step - self._seen_so_far)
191
+ self._values[k][1] += current_step - self._seen_so_far
192
+ else:
193
+ # Stateful metrics output a numeric value. This representation
194
+ # means "take an average from a single value" but keeps the
195
+ # numeric formatting.
196
+ self._values[k] = [v, 1]
197
+
198
+ self._seen_so_far = current_step
199
+
200
+ now = time.time()
201
+ info = f" - {(now - self._start):.0f}s"
202
+
203
+ prev_total_width = self._total_width
204
+ if self._dynamic_display:
205
+ sys.stdout.write("\b" * prev_total_width)
206
+ sys.stdout.write("\r")
207
+ else:
208
+ sys.stdout.write("\n")
209
+
210
+ if self.max_value is not None:
211
+ bar = f"{current_step}/{self.max_value} ["
212
+ progress = float(current_step) / self.max_value
213
+ progress_width = int(self.width * progress)
214
+ if progress_width > 0:
215
+ bar += "=" * (progress_width - 1)
216
+ if current_step < self.max_value:
217
+ bar += ">"
218
+ else:
219
+ bar += "="
220
+ bar += "." * (self.width - progress_width)
221
+ bar += "]"
222
+ else:
223
+ bar = (
224
+ f"{self.message} {next(self.spin)}, tile " # type: ignore
225
+ f"No. {current_step * batch_size}"
226
+ )
227
+
228
+ self._total_width = len(bar)
229
+ sys.stdout.write(bar)
230
+
231
+ if current_step > 0:
232
+ time_per_unit = (now - self._start) / current_step
233
+ else:
234
+ time_per_unit = 0
235
+
236
+ if time_per_unit >= 1 or time_per_unit == 0:
237
+ info += f" {time_per_unit:.0f}s/step"
238
+ elif time_per_unit >= 1e-3:
239
+ info += f" {time_per_unit * 1e3:.0f}ms/step"
240
+ else:
241
+ info += f" {time_per_unit * 1e6:.0f}us/step"
242
+
243
+ for k in self._values_order:
244
+ info += f" - {k}:"
245
+ if isinstance(self._values[k], list):
246
+ avg = self._values[k][0] / max(1, self._values[k][1])
247
+ if abs(avg) > 1e-3:
248
+ info += f" {avg:.4f}"
249
+ else:
250
+ info += f" {avg:.4e}"
251
+ else:
252
+ info += f" {self._values[k]}s"
253
+
254
+ self._total_width += len(info)
255
+ if prev_total_width > self._total_width:
256
+ info += " " * (prev_total_width - self._total_width)
257
+
258
+ if self.max_value is not None and current_step >= self.max_value:
259
+ info += "\n"
260
+
261
+ sys.stdout.write(info)
262
+ sys.stdout.flush()
263
+
264
+ self._last_update = now
265
+
266
+ def add(self, n: int, values: Optional[List] = None) -> None:
267
+ """
268
+ Update the progress bar by n steps.
269
+
270
+ Parameters
271
+ ----------
272
+ n : int
273
+ Number of steps to increase the progress bar with.
274
+ values : Optional[List], optional
275
+ Updated metrics values, by default None.
276
+ """
277
+ self.update(self._seen_so_far + n, 1, values=values)
278
+
279
+ def spinning_cursor(self) -> Generator:
280
+ """
281
+ Generate a spinning cursor animation.
282
+
283
+ Taken from https://github.com/manrajgrover/py-spinners/tree/master.
284
+
285
+ Returns
286
+ -------
287
+ Generator
288
+ Generator of animation frames.
289
+ """
290
+ while True:
291
+ yield from [
292
+ "▓ ----- ▒",
293
+ "▓ ----- ▒",
294
+ "▓ ----- ▒",
295
+ "▓ ->--- ▒",
296
+ "▓ ->--- ▒",
297
+ "▓ ->--- ▒",
298
+ "▓ -->-- ▒",
299
+ "▓ -->-- ▒",
300
+ "▓ -->-- ▒",
301
+ "▓ --->- ▒",
302
+ "▓ --->- ▒",
303
+ "▓ --->- ▒",
304
+ "▓ ----> ▒",
305
+ "▓ ----> ▒",
306
+ "▓ ----> ▒",
307
+ "▒ ----- ░",
308
+ "▒ ----- ░",
309
+ "▒ ----- ░",
310
+ "▒ ->--- ░",
311
+ "▒ ->--- ░",
312
+ "▒ ->--- ░",
313
+ "▒ -->-- ░",
314
+ "▒ -->-- ░",
315
+ "▒ -->-- ░",
316
+ "▒ --->- ░",
317
+ "▒ --->- ░",
318
+ "▒ --->- ░",
319
+ "▒ ----> ░",
320
+ "▒ ----> ░",
321
+ "▒ ----> ░",
322
+ ]