stouputils 1.14.0__py3-none-any.whl → 1.14.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/__init__.pyi +15 -0
- stouputils/_deprecated.pyi +12 -0
- stouputils/all_doctests.pyi +46 -0
- stouputils/applications/__init__.pyi +2 -0
- stouputils/applications/automatic_docs.py +3 -0
- stouputils/applications/automatic_docs.pyi +106 -0
- stouputils/applications/upscaler/__init__.pyi +3 -0
- stouputils/applications/upscaler/config.pyi +18 -0
- stouputils/applications/upscaler/image.pyi +109 -0
- stouputils/applications/upscaler/video.pyi +60 -0
- stouputils/archive.pyi +67 -0
- stouputils/backup.pyi +109 -0
- stouputils/collections.pyi +86 -0
- stouputils/continuous_delivery/__init__.pyi +5 -0
- stouputils/continuous_delivery/cd_utils.pyi +129 -0
- stouputils/continuous_delivery/github.pyi +162 -0
- stouputils/continuous_delivery/pypi.pyi +52 -0
- stouputils/continuous_delivery/pyproject.pyi +67 -0
- stouputils/continuous_delivery/stubs.pyi +39 -0
- stouputils/ctx.pyi +211 -0
- stouputils/data_science/config/get.py +51 -51
- stouputils/data_science/data_processing/image/__init__.py +66 -66
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -79
- stouputils/data_science/data_processing/image/axis_flip.py +58 -58
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -74
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -73
- stouputils/data_science/data_processing/image/blur.py +59 -59
- stouputils/data_science/data_processing/image/brightness.py +54 -54
- stouputils/data_science/data_processing/image/canny.py +110 -110
- stouputils/data_science/data_processing/image/clahe.py +92 -92
- stouputils/data_science/data_processing/image/common.py +30 -30
- stouputils/data_science/data_processing/image/contrast.py +53 -53
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -74
- stouputils/data_science/data_processing/image/denoise.py +378 -378
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -123
- stouputils/data_science/data_processing/image/invert.py +64 -64
- stouputils/data_science/data_processing/image/laplacian.py +60 -60
- stouputils/data_science/data_processing/image/median_blur.py +52 -52
- stouputils/data_science/data_processing/image/noise.py +59 -59
- stouputils/data_science/data_processing/image/normalize.py +65 -65
- stouputils/data_science/data_processing/image/random_erase.py +66 -66
- stouputils/data_science/data_processing/image/resize.py +69 -69
- stouputils/data_science/data_processing/image/rotation.py +80 -80
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -68
- stouputils/data_science/data_processing/image/sharpening.py +55 -55
- stouputils/data_science/data_processing/image/shearing.py +64 -64
- stouputils/data_science/data_processing/image/threshold.py +64 -64
- stouputils/data_science/data_processing/image/translation.py +71 -71
- stouputils/data_science/data_processing/image/zoom.py +83 -83
- stouputils/data_science/data_processing/image_augmentation.py +118 -118
- stouputils/data_science/data_processing/image_preprocess.py +183 -183
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -359
- stouputils/data_science/data_processing/technique.py +481 -481
- stouputils/data_science/dataset/__init__.py +45 -45
- stouputils/data_science/dataset/dataset.py +292 -292
- stouputils/data_science/dataset/dataset_loader.py +135 -135
- stouputils/data_science/dataset/grouping_strategy.py +296 -296
- stouputils/data_science/dataset/image_loader.py +100 -100
- stouputils/data_science/dataset/xy_tuple.py +696 -696
- stouputils/data_science/metric_dictionnary.py +106 -106
- stouputils/data_science/mlflow_utils.py +206 -206
- stouputils/data_science/models/abstract_model.py +149 -149
- stouputils/data_science/models/all.py +85 -85
- stouputils/data_science/models/keras/all.py +38 -38
- stouputils/data_science/models/keras/convnext.py +62 -62
- stouputils/data_science/models/keras/densenet.py +50 -50
- stouputils/data_science/models/keras/efficientnet.py +60 -60
- stouputils/data_science/models/keras/mobilenet.py +56 -56
- stouputils/data_science/models/keras/resnet.py +52 -52
- stouputils/data_science/models/keras/squeezenet.py +233 -233
- stouputils/data_science/models/keras/vgg.py +42 -42
- stouputils/data_science/models/keras/xception.py +38 -38
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -20
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -219
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -148
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -31
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -249
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -66
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -12
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -56
- stouputils/data_science/models/keras_utils/visualizations.py +416 -416
- stouputils/data_science/models/sandbox.py +116 -116
- stouputils/data_science/range_tuple.py +234 -234
- stouputils/data_science/utils.py +285 -285
- stouputils/decorators.pyi +242 -0
- stouputils/image.pyi +172 -0
- stouputils/installer/__init__.py +18 -18
- stouputils/installer/__init__.pyi +5 -0
- stouputils/installer/common.pyi +39 -0
- stouputils/installer/downloader.pyi +24 -0
- stouputils/installer/linux.py +144 -144
- stouputils/installer/linux.pyi +39 -0
- stouputils/installer/main.py +223 -223
- stouputils/installer/main.pyi +57 -0
- stouputils/installer/windows.py +136 -136
- stouputils/installer/windows.pyi +31 -0
- stouputils/io.pyi +213 -0
- stouputils/parallel.py +12 -10
- stouputils/parallel.pyi +211 -0
- stouputils/print.pyi +136 -0
- stouputils/py.typed +1 -1
- stouputils/stouputils/parallel.pyi +4 -4
- stouputils/version_pkg.pyi +15 -0
- {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/METADATA +1 -1
- stouputils-1.14.2.dist-info/RECORD +171 -0
- stouputils-1.14.0.dist-info/RECORD +0 -140
- {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/WHEEL +0 -0
- {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/entry_points.txt +0 -0
|
@@ -1,20 +1,20 @@
|
|
|
1
|
-
""" Custom callbacks for Keras models.
|
|
2
|
-
|
|
3
|
-
Features:
|
|
4
|
-
|
|
5
|
-
- Learning rate finder callback for finding the optimal learning rate
|
|
6
|
-
- Warmup scheduler callback for warmup training
|
|
7
|
-
- Progressive unfreezing callback for unfreezing layers during training (incompatible with model.fit(), need a custom training loop)
|
|
8
|
-
- Tqdm progress bar callback for better training visualization
|
|
9
|
-
- Model checkpoint callback that only starts checkpointing after a given number of epochs
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
|
-
# Imports
|
|
13
|
-
from .colored_progress_bar import ColoredProgressBar
|
|
14
|
-
from .learning_rate_finder import LearningRateFinder
|
|
15
|
-
from .model_checkpoint_v2 import ModelCheckpointV2
|
|
16
|
-
from .progressive_unfreezing import ProgressiveUnfreezing
|
|
17
|
-
from .warmup_scheduler import WarmupScheduler
|
|
18
|
-
|
|
19
|
-
__all__ = ["ColoredProgressBar", "LearningRateFinder", "ModelCheckpointV2", "ProgressiveUnfreezing", "WarmupScheduler"]
|
|
20
|
-
|
|
1
|
+
""" Custom callbacks for Keras models.
|
|
2
|
+
|
|
3
|
+
Features:
|
|
4
|
+
|
|
5
|
+
- Learning rate finder callback for finding the optimal learning rate
|
|
6
|
+
- Warmup scheduler callback for warmup training
|
|
7
|
+
- Progressive unfreezing callback for unfreezing layers during training (incompatible with model.fit(), need a custom training loop)
|
|
8
|
+
- Tqdm progress bar callback for better training visualization
|
|
9
|
+
- Model checkpoint callback that only starts checkpointing after a given number of epochs
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
# Imports
|
|
13
|
+
from .colored_progress_bar import ColoredProgressBar
|
|
14
|
+
from .learning_rate_finder import LearningRateFinder
|
|
15
|
+
from .model_checkpoint_v2 import ModelCheckpointV2
|
|
16
|
+
from .progressive_unfreezing import ProgressiveUnfreezing
|
|
17
|
+
from .warmup_scheduler import WarmupScheduler
|
|
18
|
+
|
|
19
|
+
__all__ = ["ColoredProgressBar", "LearningRateFinder", "ModelCheckpointV2", "ProgressiveUnfreezing", "WarmupScheduler"]
|
|
20
|
+
|
|
@@ -1,219 +1,219 @@
|
|
|
1
|
-
|
|
2
|
-
# pyright: reportMissingTypeStubs=false
|
|
3
|
-
|
|
4
|
-
# Imports
|
|
5
|
-
from typing import Any
|
|
6
|
-
|
|
7
|
-
import tensorflow as tf
|
|
8
|
-
from keras.callbacks import Callback
|
|
9
|
-
from keras.models import Model
|
|
10
|
-
from tqdm.auto import tqdm
|
|
11
|
-
|
|
12
|
-
from .....print import MAGENTA
|
|
13
|
-
from .....parallel import BAR_FORMAT
|
|
14
|
-
|
|
15
|
-
class ColoredProgressBar(Callback):
|
|
16
|
-
""" Progress bar using tqdm for Keras training.
|
|
17
|
-
|
|
18
|
-
A callback that displays a progress bar using tqdm during model training.
|
|
19
|
-
Shows the training progress across steps with a customized format
|
|
20
|
-
instead of the default Keras one showing multiple lines.
|
|
21
|
-
"""
|
|
22
|
-
def __init__(
|
|
23
|
-
self,
|
|
24
|
-
desc: str = "Training",
|
|
25
|
-
track_epochs: bool = False,
|
|
26
|
-
show_lr: bool = False,
|
|
27
|
-
update_frequency: int = 1,
|
|
28
|
-
color: str = MAGENTA
|
|
29
|
-
) -> None:
|
|
30
|
-
""" Initialize the progress bar callback.
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
desc (str): Custom description for the progress bar.
|
|
34
|
-
track_epochs (bool): Whether to track epochs instead of batches.
|
|
35
|
-
show_lr (bool): Whether to show the learning rate.
|
|
36
|
-
update_frequency (int): How often to update the progress bar (every N batches).
|
|
37
|
-
color (str): Color of the progress bar.
|
|
38
|
-
"""
|
|
39
|
-
super().__init__()
|
|
40
|
-
self.desc: str = desc
|
|
41
|
-
""" Description of the progress bar. """
|
|
42
|
-
self.track_epochs: bool = track_epochs
|
|
43
|
-
""" Whether to track epochs instead of batches. """
|
|
44
|
-
self.show_lr: bool = show_lr
|
|
45
|
-
""" Whether to show the learning rate. """
|
|
46
|
-
self.latest_val_loss: float = 0.0
|
|
47
|
-
""" Latest validation loss, updated at the end of each epoch. """
|
|
48
|
-
self.latest_lr: float = 0.0
|
|
49
|
-
""" Latest learning rate, updated during batch and epoch processing. """
|
|
50
|
-
self.batch_count: int = 0
|
|
51
|
-
""" Counter to update the progress bar less frequently. """
|
|
52
|
-
self.update_frequency: int = max(1, update_frequency) # Ensure frequency is at least 1
|
|
53
|
-
""" How often to update the progress bar (every N batches). """
|
|
54
|
-
self.color: str = color
|
|
55
|
-
""" Color of the progress bar. """
|
|
56
|
-
self.pbar: tqdm[Any] | None = None
|
|
57
|
-
""" The tqdm progress bar instance. """
|
|
58
|
-
self.epochs: int = 0
|
|
59
|
-
""" Total number of epochs. """
|
|
60
|
-
self.steps: int = 0
|
|
61
|
-
""" Number of steps per epoch. """
|
|
62
|
-
self.total: int = 0
|
|
63
|
-
""" Total number of steps/epochs to track. """
|
|
64
|
-
self.params: dict[str, Any]
|
|
65
|
-
""" Training parameters. """
|
|
66
|
-
self.model: Model
|
|
67
|
-
|
|
68
|
-
def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
|
|
69
|
-
""" Initialize the progress bar at the start of training.
|
|
70
|
-
|
|
71
|
-
Args:
|
|
72
|
-
logs (dict | None): Training logs.
|
|
73
|
-
"""
|
|
74
|
-
# Get training parameters
|
|
75
|
-
self.epochs = self.params.get("epochs", 0)
|
|
76
|
-
self.steps = self.params.get("steps", 0)
|
|
77
|
-
|
|
78
|
-
# Determine total units and initial description
|
|
79
|
-
if self.track_epochs:
|
|
80
|
-
desc: str = f"{self.color}{self.desc} (Epochs)"
|
|
81
|
-
self.total = self.epochs
|
|
82
|
-
else:
|
|
83
|
-
desc: str = f"{self.color}{self.desc} (Epoch 1/{self.epochs})"
|
|
84
|
-
self.total = self.epochs * self.steps
|
|
85
|
-
|
|
86
|
-
# Initialize tqdm bar
|
|
87
|
-
self.pbar = tqdm(
|
|
88
|
-
total=self.total,
|
|
89
|
-
desc=desc,
|
|
90
|
-
position=0,
|
|
91
|
-
leave=True,
|
|
92
|
-
bar_format=BAR_FORMAT,
|
|
93
|
-
ascii=False
|
|
94
|
-
)
|
|
95
|
-
# Reset state variables
|
|
96
|
-
self.latest_val_loss = 0.0
|
|
97
|
-
self.latest_lr = 0.0
|
|
98
|
-
self.batch_count = 0
|
|
99
|
-
|
|
100
|
-
def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
101
|
-
""" Update the progress bar after each batch, based on update frequency.
|
|
102
|
-
|
|
103
|
-
Args:
|
|
104
|
-
batch (int): Current batch number (0-indexed).
|
|
105
|
-
logs (dict | None): Dictionary of logs containing metrics for the batch.
|
|
106
|
-
"""
|
|
107
|
-
# Skip updates if tracking epochs, pbar isn't initialized, or steps are unknown
|
|
108
|
-
if self.track_epochs or self.pbar is None or self.steps == 0:
|
|
109
|
-
return
|
|
110
|
-
|
|
111
|
-
self.batch_count += 1
|
|
112
|
-
is_last_batch: bool = (batch + 1) == self.steps
|
|
113
|
-
|
|
114
|
-
# Update only every `update_frequency` batches or on the last batch
|
|
115
|
-
if self.batch_count % self.update_frequency == 0 or is_last_batch:
|
|
116
|
-
increment: int = self.batch_count
|
|
117
|
-
self.batch_count = 0 # Reset counter
|
|
118
|
-
|
|
119
|
-
# Calculate current epoch (1-based) based on the progress bar's state *before* this update
|
|
120
|
-
# Ensure epoch doesn't exceed total epochs in description
|
|
121
|
-
current_epoch: int = min(self.epochs, self.pbar.n // self.steps + 1)
|
|
122
|
-
current_step: int = batch + 1 # Convert to 1-indexed for display
|
|
123
|
-
self.pbar.set_description(
|
|
124
|
-
f"{self.color}{self.desc} (Epoch {current_epoch}/{self.epochs}, Step {current_step}/{self.steps})"
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
# Update learning rate if model and optimizer are available
|
|
128
|
-
if self.model and hasattr(self.model, "optimizer"):
|
|
129
|
-
self.latest_lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
|
|
130
|
-
|
|
131
|
-
# Update postfix with batch loss and the latest known validation loss
|
|
132
|
-
if logs and "loss" in logs:
|
|
133
|
-
loss: float = logs["loss"]
|
|
134
|
-
val_loss_str: str = ""
|
|
135
|
-
if self.latest_val_loss != 0.0:
|
|
136
|
-
if self.latest_val_loss < 1e-3:
|
|
137
|
-
val_loss_str = f", val_loss: {self.latest_val_loss:.2e}"
|
|
138
|
-
else:
|
|
139
|
-
val_loss_str = f", val_loss: {self.latest_val_loss:.5f}"
|
|
140
|
-
|
|
141
|
-
# Format learning rate string
|
|
142
|
-
lr_str: str = ""
|
|
143
|
-
if self.show_lr and self.latest_lr != 0.0:
|
|
144
|
-
if self.latest_lr < 1e-3:
|
|
145
|
-
lr_str = f", lr: {self.latest_lr:.2e}"
|
|
146
|
-
else:
|
|
147
|
-
lr_str = f", lr: {self.latest_lr:.5f}"
|
|
148
|
-
|
|
149
|
-
if loss < 1e-3:
|
|
150
|
-
self.pbar.set_postfix_str(f"loss: {loss:.2e}{val_loss_str}{lr_str}", refresh=False)
|
|
151
|
-
else:
|
|
152
|
-
self.pbar.set_postfix_str(f"loss: {loss:.5f}{val_loss_str}{lr_str}", refresh=False)
|
|
153
|
-
|
|
154
|
-
# Update progress bar position, ensuring not to exceed total
|
|
155
|
-
actual_increment: int = min(increment, self.total - self.pbar.n)
|
|
156
|
-
if actual_increment > 0:
|
|
157
|
-
self.pbar.update(actual_increment)
|
|
158
|
-
|
|
159
|
-
def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
160
|
-
""" Update metrics and progress bar position at the end of each epoch.
|
|
161
|
-
|
|
162
|
-
Args:
|
|
163
|
-
epoch (int): Current epoch number (0-indexed).
|
|
164
|
-
logs (dict | None): Dictionary of logs containing metrics for the epoch.
|
|
165
|
-
"""
|
|
166
|
-
if self.pbar is None:
|
|
167
|
-
return
|
|
168
|
-
|
|
169
|
-
# Update the latest validation loss from epoch logs
|
|
170
|
-
if logs:
|
|
171
|
-
current_val_loss: float = logs.get("val_loss", 0.0)
|
|
172
|
-
if current_val_loss != 0.0:
|
|
173
|
-
self.latest_val_loss = current_val_loss
|
|
174
|
-
|
|
175
|
-
# Update learning rate if model and optimizer are available
|
|
176
|
-
if self.model and hasattr(self.model, "optimizer"):
|
|
177
|
-
self.latest_lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
|
|
178
|
-
|
|
179
|
-
# Update postfix string with final epoch metrics
|
|
180
|
-
loss: float = logs.get("loss", 0.0)
|
|
181
|
-
val_loss_str: str = f", val_loss: {self.latest_val_loss:.5f}" if self.latest_val_loss != 0.0 else ""
|
|
182
|
-
|
|
183
|
-
# Format learning rate string
|
|
184
|
-
lr_str: str = ""
|
|
185
|
-
if self.show_lr and self.latest_lr != 0.0:
|
|
186
|
-
if self.latest_lr < 1e-3:
|
|
187
|
-
lr_str = f", lr: {self.latest_lr:.2e}"
|
|
188
|
-
else:
|
|
189
|
-
lr_str = f", lr: {self.latest_lr:.5f}"
|
|
190
|
-
|
|
191
|
-
if loss != 0.0: # Only update if loss is available
|
|
192
|
-
self.pbar.set_postfix_str(f"loss: {loss:.5f}{val_loss_str}{lr_str}", refresh=True)
|
|
193
|
-
|
|
194
|
-
# Update progress bar position
|
|
195
|
-
if self.track_epochs:
|
|
196
|
-
# Increment by 1 epoch if not already at total
|
|
197
|
-
if self.pbar.n < self.total:
|
|
198
|
-
self.pbar.update(1)
|
|
199
|
-
else:
|
|
200
|
-
# Ensure the progress bar is exactly at the end of the current epoch
|
|
201
|
-
expected_position: int = min(self.total, (epoch + 1) * self.steps)
|
|
202
|
-
increment: int = expected_position - self.pbar.n
|
|
203
|
-
if increment > 0:
|
|
204
|
-
self.pbar.update(increment)
|
|
205
|
-
|
|
206
|
-
def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
|
|
207
|
-
""" Close the progress bar when training is complete.
|
|
208
|
-
|
|
209
|
-
Args:
|
|
210
|
-
logs (dict | None): Training logs.
|
|
211
|
-
"""
|
|
212
|
-
if self.pbar is not None:
|
|
213
|
-
# Ensure the bar reaches 100%
|
|
214
|
-
increment: int = self.total - self.pbar.n
|
|
215
|
-
if increment > 0:
|
|
216
|
-
self.pbar.update(increment)
|
|
217
|
-
self.pbar.close()
|
|
218
|
-
self.pbar = None # Reset pbar instance
|
|
219
|
-
|
|
1
|
+
|
|
2
|
+
# pyright: reportMissingTypeStubs=false
|
|
3
|
+
|
|
4
|
+
# Imports
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
from keras.callbacks import Callback
|
|
9
|
+
from keras.models import Model
|
|
10
|
+
from tqdm.auto import tqdm
|
|
11
|
+
|
|
12
|
+
from .....print import MAGENTA
|
|
13
|
+
from .....parallel import BAR_FORMAT
|
|
14
|
+
|
|
15
|
+
class ColoredProgressBar(Callback):
|
|
16
|
+
""" Progress bar using tqdm for Keras training.
|
|
17
|
+
|
|
18
|
+
A callback that displays a progress bar using tqdm during model training.
|
|
19
|
+
Shows the training progress across steps with a customized format
|
|
20
|
+
instead of the default Keras one showing multiple lines.
|
|
21
|
+
"""
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
desc: str = "Training",
|
|
25
|
+
track_epochs: bool = False,
|
|
26
|
+
show_lr: bool = False,
|
|
27
|
+
update_frequency: int = 1,
|
|
28
|
+
color: str = MAGENTA
|
|
29
|
+
) -> None:
|
|
30
|
+
""" Initialize the progress bar callback.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
desc (str): Custom description for the progress bar.
|
|
34
|
+
track_epochs (bool): Whether to track epochs instead of batches.
|
|
35
|
+
show_lr (bool): Whether to show the learning rate.
|
|
36
|
+
update_frequency (int): How often to update the progress bar (every N batches).
|
|
37
|
+
color (str): Color of the progress bar.
|
|
38
|
+
"""
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.desc: str = desc
|
|
41
|
+
""" Description of the progress bar. """
|
|
42
|
+
self.track_epochs: bool = track_epochs
|
|
43
|
+
""" Whether to track epochs instead of batches. """
|
|
44
|
+
self.show_lr: bool = show_lr
|
|
45
|
+
""" Whether to show the learning rate. """
|
|
46
|
+
self.latest_val_loss: float = 0.0
|
|
47
|
+
""" Latest validation loss, updated at the end of each epoch. """
|
|
48
|
+
self.latest_lr: float = 0.0
|
|
49
|
+
""" Latest learning rate, updated during batch and epoch processing. """
|
|
50
|
+
self.batch_count: int = 0
|
|
51
|
+
""" Counter to update the progress bar less frequently. """
|
|
52
|
+
self.update_frequency: int = max(1, update_frequency) # Ensure frequency is at least 1
|
|
53
|
+
""" How often to update the progress bar (every N batches). """
|
|
54
|
+
self.color: str = color
|
|
55
|
+
""" Color of the progress bar. """
|
|
56
|
+
self.pbar: tqdm[Any] | None = None
|
|
57
|
+
""" The tqdm progress bar instance. """
|
|
58
|
+
self.epochs: int = 0
|
|
59
|
+
""" Total number of epochs. """
|
|
60
|
+
self.steps: int = 0
|
|
61
|
+
""" Number of steps per epoch. """
|
|
62
|
+
self.total: int = 0
|
|
63
|
+
""" Total number of steps/epochs to track. """
|
|
64
|
+
self.params: dict[str, Any]
|
|
65
|
+
""" Training parameters. """
|
|
66
|
+
self.model: Model
|
|
67
|
+
|
|
68
|
+
def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
|
|
69
|
+
""" Initialize the progress bar at the start of training.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
logs (dict | None): Training logs.
|
|
73
|
+
"""
|
|
74
|
+
# Get training parameters
|
|
75
|
+
self.epochs = self.params.get("epochs", 0)
|
|
76
|
+
self.steps = self.params.get("steps", 0)
|
|
77
|
+
|
|
78
|
+
# Determine total units and initial description
|
|
79
|
+
if self.track_epochs:
|
|
80
|
+
desc: str = f"{self.color}{self.desc} (Epochs)"
|
|
81
|
+
self.total = self.epochs
|
|
82
|
+
else:
|
|
83
|
+
desc: str = f"{self.color}{self.desc} (Epoch 1/{self.epochs})"
|
|
84
|
+
self.total = self.epochs * self.steps
|
|
85
|
+
|
|
86
|
+
# Initialize tqdm bar
|
|
87
|
+
self.pbar = tqdm(
|
|
88
|
+
total=self.total,
|
|
89
|
+
desc=desc,
|
|
90
|
+
position=0,
|
|
91
|
+
leave=True,
|
|
92
|
+
bar_format=BAR_FORMAT,
|
|
93
|
+
ascii=False
|
|
94
|
+
)
|
|
95
|
+
# Reset state variables
|
|
96
|
+
self.latest_val_loss = 0.0
|
|
97
|
+
self.latest_lr = 0.0
|
|
98
|
+
self.batch_count = 0
|
|
99
|
+
|
|
100
|
+
def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
101
|
+
""" Update the progress bar after each batch, based on update frequency.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
batch (int): Current batch number (0-indexed).
|
|
105
|
+
logs (dict | None): Dictionary of logs containing metrics for the batch.
|
|
106
|
+
"""
|
|
107
|
+
# Skip updates if tracking epochs, pbar isn't initialized, or steps are unknown
|
|
108
|
+
if self.track_epochs or self.pbar is None or self.steps == 0:
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
self.batch_count += 1
|
|
112
|
+
is_last_batch: bool = (batch + 1) == self.steps
|
|
113
|
+
|
|
114
|
+
# Update only every `update_frequency` batches or on the last batch
|
|
115
|
+
if self.batch_count % self.update_frequency == 0 or is_last_batch:
|
|
116
|
+
increment: int = self.batch_count
|
|
117
|
+
self.batch_count = 0 # Reset counter
|
|
118
|
+
|
|
119
|
+
# Calculate current epoch (1-based) based on the progress bar's state *before* this update
|
|
120
|
+
# Ensure epoch doesn't exceed total epochs in description
|
|
121
|
+
current_epoch: int = min(self.epochs, self.pbar.n // self.steps + 1)
|
|
122
|
+
current_step: int = batch + 1 # Convert to 1-indexed for display
|
|
123
|
+
self.pbar.set_description(
|
|
124
|
+
f"{self.color}{self.desc} (Epoch {current_epoch}/{self.epochs}, Step {current_step}/{self.steps})"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Update learning rate if model and optimizer are available
|
|
128
|
+
if self.model and hasattr(self.model, "optimizer"):
|
|
129
|
+
self.latest_lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
|
|
130
|
+
|
|
131
|
+
# Update postfix with batch loss and the latest known validation loss
|
|
132
|
+
if logs and "loss" in logs:
|
|
133
|
+
loss: float = logs["loss"]
|
|
134
|
+
val_loss_str: str = ""
|
|
135
|
+
if self.latest_val_loss != 0.0:
|
|
136
|
+
if self.latest_val_loss < 1e-3:
|
|
137
|
+
val_loss_str = f", val_loss: {self.latest_val_loss:.2e}"
|
|
138
|
+
else:
|
|
139
|
+
val_loss_str = f", val_loss: {self.latest_val_loss:.5f}"
|
|
140
|
+
|
|
141
|
+
# Format learning rate string
|
|
142
|
+
lr_str: str = ""
|
|
143
|
+
if self.show_lr and self.latest_lr != 0.0:
|
|
144
|
+
if self.latest_lr < 1e-3:
|
|
145
|
+
lr_str = f", lr: {self.latest_lr:.2e}"
|
|
146
|
+
else:
|
|
147
|
+
lr_str = f", lr: {self.latest_lr:.5f}"
|
|
148
|
+
|
|
149
|
+
if loss < 1e-3:
|
|
150
|
+
self.pbar.set_postfix_str(f"loss: {loss:.2e}{val_loss_str}{lr_str}", refresh=False)
|
|
151
|
+
else:
|
|
152
|
+
self.pbar.set_postfix_str(f"loss: {loss:.5f}{val_loss_str}{lr_str}", refresh=False)
|
|
153
|
+
|
|
154
|
+
# Update progress bar position, ensuring not to exceed total
|
|
155
|
+
actual_increment: int = min(increment, self.total - self.pbar.n)
|
|
156
|
+
if actual_increment > 0:
|
|
157
|
+
self.pbar.update(actual_increment)
|
|
158
|
+
|
|
159
|
+
def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
160
|
+
""" Update metrics and progress bar position at the end of each epoch.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
epoch (int): Current epoch number (0-indexed).
|
|
164
|
+
logs (dict | None): Dictionary of logs containing metrics for the epoch.
|
|
165
|
+
"""
|
|
166
|
+
if self.pbar is None:
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
# Update the latest validation loss from epoch logs
|
|
170
|
+
if logs:
|
|
171
|
+
current_val_loss: float = logs.get("val_loss", 0.0)
|
|
172
|
+
if current_val_loss != 0.0:
|
|
173
|
+
self.latest_val_loss = current_val_loss
|
|
174
|
+
|
|
175
|
+
# Update learning rate if model and optimizer are available
|
|
176
|
+
if self.model and hasattr(self.model, "optimizer"):
|
|
177
|
+
self.latest_lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
|
|
178
|
+
|
|
179
|
+
# Update postfix string with final epoch metrics
|
|
180
|
+
loss: float = logs.get("loss", 0.0)
|
|
181
|
+
val_loss_str: str = f", val_loss: {self.latest_val_loss:.5f}" if self.latest_val_loss != 0.0 else ""
|
|
182
|
+
|
|
183
|
+
# Format learning rate string
|
|
184
|
+
lr_str: str = ""
|
|
185
|
+
if self.show_lr and self.latest_lr != 0.0:
|
|
186
|
+
if self.latest_lr < 1e-3:
|
|
187
|
+
lr_str = f", lr: {self.latest_lr:.2e}"
|
|
188
|
+
else:
|
|
189
|
+
lr_str = f", lr: {self.latest_lr:.5f}"
|
|
190
|
+
|
|
191
|
+
if loss != 0.0: # Only update if loss is available
|
|
192
|
+
self.pbar.set_postfix_str(f"loss: {loss:.5f}{val_loss_str}{lr_str}", refresh=True)
|
|
193
|
+
|
|
194
|
+
# Update progress bar position
|
|
195
|
+
if self.track_epochs:
|
|
196
|
+
# Increment by 1 epoch if not already at total
|
|
197
|
+
if self.pbar.n < self.total:
|
|
198
|
+
self.pbar.update(1)
|
|
199
|
+
else:
|
|
200
|
+
# Ensure the progress bar is exactly at the end of the current epoch
|
|
201
|
+
expected_position: int = min(self.total, (epoch + 1) * self.steps)
|
|
202
|
+
increment: int = expected_position - self.pbar.n
|
|
203
|
+
if increment > 0:
|
|
204
|
+
self.pbar.update(increment)
|
|
205
|
+
|
|
206
|
+
def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
|
|
207
|
+
""" Close the progress bar when training is complete.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
logs (dict | None): Training logs.
|
|
211
|
+
"""
|
|
212
|
+
if self.pbar is not None:
|
|
213
|
+
# Ensure the bar reaches 100%
|
|
214
|
+
increment: int = self.total - self.pbar.n
|
|
215
|
+
if increment > 0:
|
|
216
|
+
self.pbar.update(increment)
|
|
217
|
+
self.pbar.close()
|
|
218
|
+
self.pbar = None # Reset pbar instance
|
|
219
|
+
|