stouputils 1.14.0__py3-none-any.whl → 1.14.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. stouputils/__init__.pyi +15 -0
  2. stouputils/_deprecated.pyi +12 -0
  3. stouputils/all_doctests.pyi +46 -0
  4. stouputils/applications/__init__.pyi +2 -0
  5. stouputils/applications/automatic_docs.py +3 -0
  6. stouputils/applications/automatic_docs.pyi +106 -0
  7. stouputils/applications/upscaler/__init__.pyi +3 -0
  8. stouputils/applications/upscaler/config.pyi +18 -0
  9. stouputils/applications/upscaler/image.pyi +109 -0
  10. stouputils/applications/upscaler/video.pyi +60 -0
  11. stouputils/archive.pyi +67 -0
  12. stouputils/backup.pyi +109 -0
  13. stouputils/collections.pyi +86 -0
  14. stouputils/continuous_delivery/__init__.pyi +5 -0
  15. stouputils/continuous_delivery/cd_utils.pyi +129 -0
  16. stouputils/continuous_delivery/github.pyi +162 -0
  17. stouputils/continuous_delivery/pypi.pyi +52 -0
  18. stouputils/continuous_delivery/pyproject.pyi +67 -0
  19. stouputils/continuous_delivery/stubs.pyi +39 -0
  20. stouputils/ctx.pyi +211 -0
  21. stouputils/data_science/config/get.py +51 -51
  22. stouputils/data_science/data_processing/image/__init__.py +66 -66
  23. stouputils/data_science/data_processing/image/auto_contrast.py +79 -79
  24. stouputils/data_science/data_processing/image/axis_flip.py +58 -58
  25. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -74
  26. stouputils/data_science/data_processing/image/binary_threshold.py +73 -73
  27. stouputils/data_science/data_processing/image/blur.py +59 -59
  28. stouputils/data_science/data_processing/image/brightness.py +54 -54
  29. stouputils/data_science/data_processing/image/canny.py +110 -110
  30. stouputils/data_science/data_processing/image/clahe.py +92 -92
  31. stouputils/data_science/data_processing/image/common.py +30 -30
  32. stouputils/data_science/data_processing/image/contrast.py +53 -53
  33. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -74
  34. stouputils/data_science/data_processing/image/denoise.py +378 -378
  35. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -123
  36. stouputils/data_science/data_processing/image/invert.py +64 -64
  37. stouputils/data_science/data_processing/image/laplacian.py +60 -60
  38. stouputils/data_science/data_processing/image/median_blur.py +52 -52
  39. stouputils/data_science/data_processing/image/noise.py +59 -59
  40. stouputils/data_science/data_processing/image/normalize.py +65 -65
  41. stouputils/data_science/data_processing/image/random_erase.py +66 -66
  42. stouputils/data_science/data_processing/image/resize.py +69 -69
  43. stouputils/data_science/data_processing/image/rotation.py +80 -80
  44. stouputils/data_science/data_processing/image/salt_pepper.py +68 -68
  45. stouputils/data_science/data_processing/image/sharpening.py +55 -55
  46. stouputils/data_science/data_processing/image/shearing.py +64 -64
  47. stouputils/data_science/data_processing/image/threshold.py +64 -64
  48. stouputils/data_science/data_processing/image/translation.py +71 -71
  49. stouputils/data_science/data_processing/image/zoom.py +83 -83
  50. stouputils/data_science/data_processing/image_augmentation.py +118 -118
  51. stouputils/data_science/data_processing/image_preprocess.py +183 -183
  52. stouputils/data_science/data_processing/prosthesis_detection.py +359 -359
  53. stouputils/data_science/data_processing/technique.py +481 -481
  54. stouputils/data_science/dataset/__init__.py +45 -45
  55. stouputils/data_science/dataset/dataset.py +292 -292
  56. stouputils/data_science/dataset/dataset_loader.py +135 -135
  57. stouputils/data_science/dataset/grouping_strategy.py +296 -296
  58. stouputils/data_science/dataset/image_loader.py +100 -100
  59. stouputils/data_science/dataset/xy_tuple.py +696 -696
  60. stouputils/data_science/metric_dictionnary.py +106 -106
  61. stouputils/data_science/mlflow_utils.py +206 -206
  62. stouputils/data_science/models/abstract_model.py +149 -149
  63. stouputils/data_science/models/all.py +85 -85
  64. stouputils/data_science/models/keras/all.py +38 -38
  65. stouputils/data_science/models/keras/convnext.py +62 -62
  66. stouputils/data_science/models/keras/densenet.py +50 -50
  67. stouputils/data_science/models/keras/efficientnet.py +60 -60
  68. stouputils/data_science/models/keras/mobilenet.py +56 -56
  69. stouputils/data_science/models/keras/resnet.py +52 -52
  70. stouputils/data_science/models/keras/squeezenet.py +233 -233
  71. stouputils/data_science/models/keras/vgg.py +42 -42
  72. stouputils/data_science/models/keras/xception.py +38 -38
  73. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -20
  74. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -219
  75. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -148
  76. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -31
  77. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -249
  78. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -66
  79. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -12
  80. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -56
  81. stouputils/data_science/models/keras_utils/visualizations.py +416 -416
  82. stouputils/data_science/models/sandbox.py +116 -116
  83. stouputils/data_science/range_tuple.py +234 -234
  84. stouputils/data_science/utils.py +285 -285
  85. stouputils/decorators.pyi +242 -0
  86. stouputils/image.pyi +172 -0
  87. stouputils/installer/__init__.py +18 -18
  88. stouputils/installer/__init__.pyi +5 -0
  89. stouputils/installer/common.pyi +39 -0
  90. stouputils/installer/downloader.pyi +24 -0
  91. stouputils/installer/linux.py +144 -144
  92. stouputils/installer/linux.pyi +39 -0
  93. stouputils/installer/main.py +223 -223
  94. stouputils/installer/main.pyi +57 -0
  95. stouputils/installer/windows.py +136 -136
  96. stouputils/installer/windows.pyi +31 -0
  97. stouputils/io.pyi +213 -0
  98. stouputils/parallel.py +12 -10
  99. stouputils/parallel.pyi +211 -0
  100. stouputils/print.pyi +136 -0
  101. stouputils/py.typed +1 -1
  102. stouputils/stouputils/parallel.pyi +4 -4
  103. stouputils/version_pkg.pyi +15 -0
  104. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/METADATA +1 -1
  105. stouputils-1.14.2.dist-info/RECORD +171 -0
  106. stouputils-1.14.0.dist-info/RECORD +0 -140
  107. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/WHEEL +0 -0
  108. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/entry_points.txt +0 -0
@@ -1,20 +1,20 @@
1
- """ Custom callbacks for Keras models.
2
-
3
- Features:
4
-
5
- - Learning rate finder callback for finding the optimal learning rate
6
- - Warmup scheduler callback for warmup training
7
- - Progressive unfreezing callback for unfreezing layers during training (incompatible with model.fit(), need a custom training loop)
8
- - Tqdm progress bar callback for better training visualization
9
- - Model checkpoint callback that only starts checkpointing after a given number of epochs
10
- """
11
-
12
- # Imports
13
- from .colored_progress_bar import ColoredProgressBar
14
- from .learning_rate_finder import LearningRateFinder
15
- from .model_checkpoint_v2 import ModelCheckpointV2
16
- from .progressive_unfreezing import ProgressiveUnfreezing
17
- from .warmup_scheduler import WarmupScheduler
18
-
19
- __all__ = ["ColoredProgressBar", "LearningRateFinder", "ModelCheckpointV2", "ProgressiveUnfreezing", "WarmupScheduler"]
20
-
1
+ """ Custom callbacks for Keras models.
2
+
3
+ Features:
4
+
5
+ - Learning rate finder callback for finding the optimal learning rate
6
+ - Warmup scheduler callback for warmup training
7
+ - Progressive unfreezing callback for unfreezing layers during training (incompatible with model.fit(), need a custom training loop)
8
+ - Tqdm progress bar callback for better training visualization
9
+ - Model checkpoint callback that only starts checkpointing after a given number of epochs
10
+ """
11
+
12
+ # Imports
13
+ from .colored_progress_bar import ColoredProgressBar
14
+ from .learning_rate_finder import LearningRateFinder
15
+ from .model_checkpoint_v2 import ModelCheckpointV2
16
+ from .progressive_unfreezing import ProgressiveUnfreezing
17
+ from .warmup_scheduler import WarmupScheduler
18
+
19
+ __all__ = ["ColoredProgressBar", "LearningRateFinder", "ModelCheckpointV2", "ProgressiveUnfreezing", "WarmupScheduler"]
20
+
@@ -1,219 +1,219 @@
1
-
2
- # pyright: reportMissingTypeStubs=false
3
-
4
- # Imports
5
- from typing import Any
6
-
7
- import tensorflow as tf
8
- from keras.callbacks import Callback
9
- from keras.models import Model
10
- from tqdm.auto import tqdm
11
-
12
- from .....print import MAGENTA
13
- from .....parallel import BAR_FORMAT
14
-
15
- class ColoredProgressBar(Callback):
16
- """ Progress bar using tqdm for Keras training.
17
-
18
- A callback that displays a progress bar using tqdm during model training.
19
- Shows the training progress across steps with a customized format
20
- instead of the default Keras one showing multiple lines.
21
- """
22
- def __init__(
23
- self,
24
- desc: str = "Training",
25
- track_epochs: bool = False,
26
- show_lr: bool = False,
27
- update_frequency: int = 1,
28
- color: str = MAGENTA
29
- ) -> None:
30
- """ Initialize the progress bar callback.
31
-
32
- Args:
33
- desc (str): Custom description for the progress bar.
34
- track_epochs (bool): Whether to track epochs instead of batches.
35
- show_lr (bool): Whether to show the learning rate.
36
- update_frequency (int): How often to update the progress bar (every N batches).
37
- color (str): Color of the progress bar.
38
- """
39
- super().__init__()
40
- self.desc: str = desc
41
- """ Description of the progress bar. """
42
- self.track_epochs: bool = track_epochs
43
- """ Whether to track epochs instead of batches. """
44
- self.show_lr: bool = show_lr
45
- """ Whether to show the learning rate. """
46
- self.latest_val_loss: float = 0.0
47
- """ Latest validation loss, updated at the end of each epoch. """
48
- self.latest_lr: float = 0.0
49
- """ Latest learning rate, updated during batch and epoch processing. """
50
- self.batch_count: int = 0
51
- """ Counter to update the progress bar less frequently. """
52
- self.update_frequency: int = max(1, update_frequency) # Ensure frequency is at least 1
53
- """ How often to update the progress bar (every N batches). """
54
- self.color: str = color
55
- """ Color of the progress bar. """
56
- self.pbar: tqdm[Any] | None = None
57
- """ The tqdm progress bar instance. """
58
- self.epochs: int = 0
59
- """ Total number of epochs. """
60
- self.steps: int = 0
61
- """ Number of steps per epoch. """
62
- self.total: int = 0
63
- """ Total number of steps/epochs to track. """
64
- self.params: dict[str, Any]
65
- """ Training parameters. """
66
- self.model: Model
67
-
68
- def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
69
- """ Initialize the progress bar at the start of training.
70
-
71
- Args:
72
- logs (dict | None): Training logs.
73
- """
74
- # Get training parameters
75
- self.epochs = self.params.get("epochs", 0)
76
- self.steps = self.params.get("steps", 0)
77
-
78
- # Determine total units and initial description
79
- if self.track_epochs:
80
- desc: str = f"{self.color}{self.desc} (Epochs)"
81
- self.total = self.epochs
82
- else:
83
- desc: str = f"{self.color}{self.desc} (Epoch 1/{self.epochs})"
84
- self.total = self.epochs * self.steps
85
-
86
- # Initialize tqdm bar
87
- self.pbar = tqdm(
88
- total=self.total,
89
- desc=desc,
90
- position=0,
91
- leave=True,
92
- bar_format=BAR_FORMAT,
93
- ascii=False
94
- )
95
- # Reset state variables
96
- self.latest_val_loss = 0.0
97
- self.latest_lr = 0.0
98
- self.batch_count = 0
99
-
100
- def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
101
- """ Update the progress bar after each batch, based on update frequency.
102
-
103
- Args:
104
- batch (int): Current batch number (0-indexed).
105
- logs (dict | None): Dictionary of logs containing metrics for the batch.
106
- """
107
- # Skip updates if tracking epochs, pbar isn't initialized, or steps are unknown
108
- if self.track_epochs or self.pbar is None or self.steps == 0:
109
- return
110
-
111
- self.batch_count += 1
112
- is_last_batch: bool = (batch + 1) == self.steps
113
-
114
- # Update only every `update_frequency` batches or on the last batch
115
- if self.batch_count % self.update_frequency == 0 or is_last_batch:
116
- increment: int = self.batch_count
117
- self.batch_count = 0 # Reset counter
118
-
119
- # Calculate current epoch (1-based) based on the progress bar's state *before* this update
120
- # Ensure epoch doesn't exceed total epochs in description
121
- current_epoch: int = min(self.epochs, self.pbar.n // self.steps + 1)
122
- current_step: int = batch + 1 # Convert to 1-indexed for display
123
- self.pbar.set_description(
124
- f"{self.color}{self.desc} (Epoch {current_epoch}/{self.epochs}, Step {current_step}/{self.steps})"
125
- )
126
-
127
- # Update learning rate if model and optimizer are available
128
- if self.model and hasattr(self.model, "optimizer"):
129
- self.latest_lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
130
-
131
- # Update postfix with batch loss and the latest known validation loss
132
- if logs and "loss" in logs:
133
- loss: float = logs["loss"]
134
- val_loss_str: str = ""
135
- if self.latest_val_loss != 0.0:
136
- if self.latest_val_loss < 1e-3:
137
- val_loss_str = f", val_loss: {self.latest_val_loss:.2e}"
138
- else:
139
- val_loss_str = f", val_loss: {self.latest_val_loss:.5f}"
140
-
141
- # Format learning rate string
142
- lr_str: str = ""
143
- if self.show_lr and self.latest_lr != 0.0:
144
- if self.latest_lr < 1e-3:
145
- lr_str = f", lr: {self.latest_lr:.2e}"
146
- else:
147
- lr_str = f", lr: {self.latest_lr:.5f}"
148
-
149
- if loss < 1e-3:
150
- self.pbar.set_postfix_str(f"loss: {loss:.2e}{val_loss_str}{lr_str}", refresh=False)
151
- else:
152
- self.pbar.set_postfix_str(f"loss: {loss:.5f}{val_loss_str}{lr_str}", refresh=False)
153
-
154
- # Update progress bar position, ensuring not to exceed total
155
- actual_increment: int = min(increment, self.total - self.pbar.n)
156
- if actual_increment > 0:
157
- self.pbar.update(actual_increment)
158
-
159
- def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
160
- """ Update metrics and progress bar position at the end of each epoch.
161
-
162
- Args:
163
- epoch (int): Current epoch number (0-indexed).
164
- logs (dict | None): Dictionary of logs containing metrics for the epoch.
165
- """
166
- if self.pbar is None:
167
- return
168
-
169
- # Update the latest validation loss from epoch logs
170
- if logs:
171
- current_val_loss: float = logs.get("val_loss", 0.0)
172
- if current_val_loss != 0.0:
173
- self.latest_val_loss = current_val_loss
174
-
175
- # Update learning rate if model and optimizer are available
176
- if self.model and hasattr(self.model, "optimizer"):
177
- self.latest_lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
178
-
179
- # Update postfix string with final epoch metrics
180
- loss: float = logs.get("loss", 0.0)
181
- val_loss_str: str = f", val_loss: {self.latest_val_loss:.5f}" if self.latest_val_loss != 0.0 else ""
182
-
183
- # Format learning rate string
184
- lr_str: str = ""
185
- if self.show_lr and self.latest_lr != 0.0:
186
- if self.latest_lr < 1e-3:
187
- lr_str = f", lr: {self.latest_lr:.2e}"
188
- else:
189
- lr_str = f", lr: {self.latest_lr:.5f}"
190
-
191
- if loss != 0.0: # Only update if loss is available
192
- self.pbar.set_postfix_str(f"loss: {loss:.5f}{val_loss_str}{lr_str}", refresh=True)
193
-
194
- # Update progress bar position
195
- if self.track_epochs:
196
- # Increment by 1 epoch if not already at total
197
- if self.pbar.n < self.total:
198
- self.pbar.update(1)
199
- else:
200
- # Ensure the progress bar is exactly at the end of the current epoch
201
- expected_position: int = min(self.total, (epoch + 1) * self.steps)
202
- increment: int = expected_position - self.pbar.n
203
- if increment > 0:
204
- self.pbar.update(increment)
205
-
206
- def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
207
- """ Close the progress bar when training is complete.
208
-
209
- Args:
210
- logs (dict | None): Training logs.
211
- """
212
- if self.pbar is not None:
213
- # Ensure the bar reaches 100%
214
- increment: int = self.total - self.pbar.n
215
- if increment > 0:
216
- self.pbar.update(increment)
217
- self.pbar.close()
218
- self.pbar = None # Reset pbar instance
219
-
1
+
2
+ # pyright: reportMissingTypeStubs=false
3
+
4
+ # Imports
5
+ from typing import Any
6
+
7
+ import tensorflow as tf
8
+ from keras.callbacks import Callback
9
+ from keras.models import Model
10
+ from tqdm.auto import tqdm
11
+
12
+ from .....print import MAGENTA
13
+ from .....parallel import BAR_FORMAT
14
+
15
+ class ColoredProgressBar(Callback):
16
+ """ Progress bar using tqdm for Keras training.
17
+
18
+ A callback that displays a progress bar using tqdm during model training.
19
+ Shows the training progress across steps with a customized format
20
+ instead of the default Keras one showing multiple lines.
21
+ """
22
+ def __init__(
23
+ self,
24
+ desc: str = "Training",
25
+ track_epochs: bool = False,
26
+ show_lr: bool = False,
27
+ update_frequency: int = 1,
28
+ color: str = MAGENTA
29
+ ) -> None:
30
+ """ Initialize the progress bar callback.
31
+
32
+ Args:
33
+ desc (str): Custom description for the progress bar.
34
+ track_epochs (bool): Whether to track epochs instead of batches.
35
+ show_lr (bool): Whether to show the learning rate.
36
+ update_frequency (int): How often to update the progress bar (every N batches).
37
+ color (str): Color of the progress bar.
38
+ """
39
+ super().__init__()
40
+ self.desc: str = desc
41
+ """ Description of the progress bar. """
42
+ self.track_epochs: bool = track_epochs
43
+ """ Whether to track epochs instead of batches. """
44
+ self.show_lr: bool = show_lr
45
+ """ Whether to show the learning rate. """
46
+ self.latest_val_loss: float = 0.0
47
+ """ Latest validation loss, updated at the end of each epoch. """
48
+ self.latest_lr: float = 0.0
49
+ """ Latest learning rate, updated during batch and epoch processing. """
50
+ self.batch_count: int = 0
51
+ """ Counter to update the progress bar less frequently. """
52
+ self.update_frequency: int = max(1, update_frequency) # Ensure frequency is at least 1
53
+ """ How often to update the progress bar (every N batches). """
54
+ self.color: str = color
55
+ """ Color of the progress bar. """
56
+ self.pbar: tqdm[Any] | None = None
57
+ """ The tqdm progress bar instance. """
58
+ self.epochs: int = 0
59
+ """ Total number of epochs. """
60
+ self.steps: int = 0
61
+ """ Number of steps per epoch. """
62
+ self.total: int = 0
63
+ """ Total number of steps/epochs to track. """
64
+ self.params: dict[str, Any]
65
+ """ Training parameters. """
66
+ self.model: Model
67
+
68
+ def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
69
+ """ Initialize the progress bar at the start of training.
70
+
71
+ Args:
72
+ logs (dict | None): Training logs.
73
+ """
74
+ # Get training parameters
75
+ self.epochs = self.params.get("epochs", 0)
76
+ self.steps = self.params.get("steps", 0)
77
+
78
+ # Determine total units and initial description
79
+ if self.track_epochs:
80
+ desc: str = f"{self.color}{self.desc} (Epochs)"
81
+ self.total = self.epochs
82
+ else:
83
+ desc: str = f"{self.color}{self.desc} (Epoch 1/{self.epochs})"
84
+ self.total = self.epochs * self.steps
85
+
86
+ # Initialize tqdm bar
87
+ self.pbar = tqdm(
88
+ total=self.total,
89
+ desc=desc,
90
+ position=0,
91
+ leave=True,
92
+ bar_format=BAR_FORMAT,
93
+ ascii=False
94
+ )
95
+ # Reset state variables
96
+ self.latest_val_loss = 0.0
97
+ self.latest_lr = 0.0
98
+ self.batch_count = 0
99
+
100
+ def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
101
+ """ Update the progress bar after each batch, based on update frequency.
102
+
103
+ Args:
104
+ batch (int): Current batch number (0-indexed).
105
+ logs (dict | None): Dictionary of logs containing metrics for the batch.
106
+ """
107
+ # Skip updates if tracking epochs, pbar isn't initialized, or steps are unknown
108
+ if self.track_epochs or self.pbar is None or self.steps == 0:
109
+ return
110
+
111
+ self.batch_count += 1
112
+ is_last_batch: bool = (batch + 1) == self.steps
113
+
114
+ # Update only every `update_frequency` batches or on the last batch
115
+ if self.batch_count % self.update_frequency == 0 or is_last_batch:
116
+ increment: int = self.batch_count
117
+ self.batch_count = 0 # Reset counter
118
+
119
+ # Calculate current epoch (1-based) based on the progress bar's state *before* this update
120
+ # Ensure epoch doesn't exceed total epochs in description
121
+ current_epoch: int = min(self.epochs, self.pbar.n // self.steps + 1)
122
+ current_step: int = batch + 1 # Convert to 1-indexed for display
123
+ self.pbar.set_description(
124
+ f"{self.color}{self.desc} (Epoch {current_epoch}/{self.epochs}, Step {current_step}/{self.steps})"
125
+ )
126
+
127
+ # Update learning rate if model and optimizer are available
128
+ if self.model and hasattr(self.model, "optimizer"):
129
+ self.latest_lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
130
+
131
+ # Update postfix with batch loss and the latest known validation loss
132
+ if logs and "loss" in logs:
133
+ loss: float = logs["loss"]
134
+ val_loss_str: str = ""
135
+ if self.latest_val_loss != 0.0:
136
+ if self.latest_val_loss < 1e-3:
137
+ val_loss_str = f", val_loss: {self.latest_val_loss:.2e}"
138
+ else:
139
+ val_loss_str = f", val_loss: {self.latest_val_loss:.5f}"
140
+
141
+ # Format learning rate string
142
+ lr_str: str = ""
143
+ if self.show_lr and self.latest_lr != 0.0:
144
+ if self.latest_lr < 1e-3:
145
+ lr_str = f", lr: {self.latest_lr:.2e}"
146
+ else:
147
+ lr_str = f", lr: {self.latest_lr:.5f}"
148
+
149
+ if loss < 1e-3:
150
+ self.pbar.set_postfix_str(f"loss: {loss:.2e}{val_loss_str}{lr_str}", refresh=False)
151
+ else:
152
+ self.pbar.set_postfix_str(f"loss: {loss:.5f}{val_loss_str}{lr_str}", refresh=False)
153
+
154
+ # Update progress bar position, ensuring not to exceed total
155
+ actual_increment: int = min(increment, self.total - self.pbar.n)
156
+ if actual_increment > 0:
157
+ self.pbar.update(actual_increment)
158
+
159
+ def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
160
+ """ Update metrics and progress bar position at the end of each epoch.
161
+
162
+ Args:
163
+ epoch (int): Current epoch number (0-indexed).
164
+ logs (dict | None): Dictionary of logs containing metrics for the epoch.
165
+ """
166
+ if self.pbar is None:
167
+ return
168
+
169
+ # Update the latest validation loss from epoch logs
170
+ if logs:
171
+ current_val_loss: float = logs.get("val_loss", 0.0)
172
+ if current_val_loss != 0.0:
173
+ self.latest_val_loss = current_val_loss
174
+
175
+ # Update learning rate if model and optimizer are available
176
+ if self.model and hasattr(self.model, "optimizer"):
177
+ self.latest_lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
178
+
179
+ # Update postfix string with final epoch metrics
180
+ loss: float = logs.get("loss", 0.0)
181
+ val_loss_str: str = f", val_loss: {self.latest_val_loss:.5f}" if self.latest_val_loss != 0.0 else ""
182
+
183
+ # Format learning rate string
184
+ lr_str: str = ""
185
+ if self.show_lr and self.latest_lr != 0.0:
186
+ if self.latest_lr < 1e-3:
187
+ lr_str = f", lr: {self.latest_lr:.2e}"
188
+ else:
189
+ lr_str = f", lr: {self.latest_lr:.5f}"
190
+
191
+ if loss != 0.0: # Only update if loss is available
192
+ self.pbar.set_postfix_str(f"loss: {loss:.5f}{val_loss_str}{lr_str}", refresh=True)
193
+
194
+ # Update progress bar position
195
+ if self.track_epochs:
196
+ # Increment by 1 epoch if not already at total
197
+ if self.pbar.n < self.total:
198
+ self.pbar.update(1)
199
+ else:
200
+ # Ensure the progress bar is exactly at the end of the current epoch
201
+ expected_position: int = min(self.total, (epoch + 1) * self.steps)
202
+ increment: int = expected_position - self.pbar.n
203
+ if increment > 0:
204
+ self.pbar.update(increment)
205
+
206
+ def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
207
+ """ Close the progress bar when training is complete.
208
+
209
+ Args:
210
+ logs (dict | None): Training logs.
211
+ """
212
+ if self.pbar is not None:
213
+ # Ensure the bar reaches 100%
214
+ increment: int = self.total - self.pbar.n
215
+ if increment > 0:
216
+ self.pbar.update(increment)
217
+ self.pbar.close()
218
+ self.pbar = None # Reset pbar instance
219
+