stouputils 1.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/__init__.py +40 -0
- stouputils/__init__.pyi +14 -0
- stouputils/__main__.py +81 -0
- stouputils/_deprecated.py +37 -0
- stouputils/_deprecated.pyi +12 -0
- stouputils/all_doctests.py +160 -0
- stouputils/all_doctests.pyi +46 -0
- stouputils/applications/__init__.py +22 -0
- stouputils/applications/__init__.pyi +2 -0
- stouputils/applications/automatic_docs.py +634 -0
- stouputils/applications/automatic_docs.pyi +106 -0
- stouputils/applications/upscaler/__init__.py +39 -0
- stouputils/applications/upscaler/__init__.pyi +3 -0
- stouputils/applications/upscaler/config.py +128 -0
- stouputils/applications/upscaler/config.pyi +18 -0
- stouputils/applications/upscaler/image.py +247 -0
- stouputils/applications/upscaler/image.pyi +109 -0
- stouputils/applications/upscaler/video.py +287 -0
- stouputils/applications/upscaler/video.pyi +60 -0
- stouputils/archive.py +344 -0
- stouputils/archive.pyi +67 -0
- stouputils/backup.py +488 -0
- stouputils/backup.pyi +109 -0
- stouputils/collections.py +244 -0
- stouputils/collections.pyi +86 -0
- stouputils/continuous_delivery/__init__.py +27 -0
- stouputils/continuous_delivery/__init__.pyi +5 -0
- stouputils/continuous_delivery/cd_utils.py +243 -0
- stouputils/continuous_delivery/cd_utils.pyi +129 -0
- stouputils/continuous_delivery/github.py +522 -0
- stouputils/continuous_delivery/github.pyi +162 -0
- stouputils/continuous_delivery/pypi.py +91 -0
- stouputils/continuous_delivery/pypi.pyi +43 -0
- stouputils/continuous_delivery/pyproject.py +147 -0
- stouputils/continuous_delivery/pyproject.pyi +67 -0
- stouputils/continuous_delivery/stubs.py +86 -0
- stouputils/continuous_delivery/stubs.pyi +39 -0
- stouputils/ctx.py +408 -0
- stouputils/ctx.pyi +211 -0
- stouputils/data_science/config/get.py +51 -0
- stouputils/data_science/config/set.py +125 -0
- stouputils/data_science/data_processing/image/__init__.py +66 -0
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
- stouputils/data_science/data_processing/image/axis_flip.py +58 -0
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
- stouputils/data_science/data_processing/image/blur.py +59 -0
- stouputils/data_science/data_processing/image/brightness.py +54 -0
- stouputils/data_science/data_processing/image/canny.py +110 -0
- stouputils/data_science/data_processing/image/clahe.py +92 -0
- stouputils/data_science/data_processing/image/common.py +30 -0
- stouputils/data_science/data_processing/image/contrast.py +53 -0
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
- stouputils/data_science/data_processing/image/denoise.py +378 -0
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
- stouputils/data_science/data_processing/image/invert.py +64 -0
- stouputils/data_science/data_processing/image/laplacian.py +60 -0
- stouputils/data_science/data_processing/image/median_blur.py +52 -0
- stouputils/data_science/data_processing/image/noise.py +59 -0
- stouputils/data_science/data_processing/image/normalize.py +65 -0
- stouputils/data_science/data_processing/image/random_erase.py +66 -0
- stouputils/data_science/data_processing/image/resize.py +69 -0
- stouputils/data_science/data_processing/image/rotation.py +80 -0
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
- stouputils/data_science/data_processing/image/sharpening.py +55 -0
- stouputils/data_science/data_processing/image/shearing.py +64 -0
- stouputils/data_science/data_processing/image/threshold.py +64 -0
- stouputils/data_science/data_processing/image/translation.py +71 -0
- stouputils/data_science/data_processing/image/zoom.py +83 -0
- stouputils/data_science/data_processing/image_augmentation.py +118 -0
- stouputils/data_science/data_processing/image_preprocess.py +183 -0
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
- stouputils/data_science/data_processing/technique.py +481 -0
- stouputils/data_science/dataset/__init__.py +45 -0
- stouputils/data_science/dataset/dataset.py +292 -0
- stouputils/data_science/dataset/dataset_loader.py +135 -0
- stouputils/data_science/dataset/grouping_strategy.py +296 -0
- stouputils/data_science/dataset/image_loader.py +100 -0
- stouputils/data_science/dataset/xy_tuple.py +696 -0
- stouputils/data_science/metric_dictionnary.py +106 -0
- stouputils/data_science/metric_utils.py +847 -0
- stouputils/data_science/mlflow_utils.py +206 -0
- stouputils/data_science/models/abstract_model.py +149 -0
- stouputils/data_science/models/all.py +85 -0
- stouputils/data_science/models/base_keras.py +765 -0
- stouputils/data_science/models/keras/all.py +38 -0
- stouputils/data_science/models/keras/convnext.py +62 -0
- stouputils/data_science/models/keras/densenet.py +50 -0
- stouputils/data_science/models/keras/efficientnet.py +60 -0
- stouputils/data_science/models/keras/mobilenet.py +56 -0
- stouputils/data_science/models/keras/resnet.py +52 -0
- stouputils/data_science/models/keras/squeezenet.py +233 -0
- stouputils/data_science/models/keras/vgg.py +42 -0
- stouputils/data_science/models/keras/xception.py +38 -0
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
- stouputils/data_science/models/keras_utils/visualizations.py +416 -0
- stouputils/data_science/models/model_interface.py +939 -0
- stouputils/data_science/models/sandbox.py +116 -0
- stouputils/data_science/range_tuple.py +234 -0
- stouputils/data_science/scripts/augment_dataset.py +77 -0
- stouputils/data_science/scripts/exhaustive_process.py +133 -0
- stouputils/data_science/scripts/preprocess_dataset.py +70 -0
- stouputils/data_science/scripts/routine.py +168 -0
- stouputils/data_science/utils.py +285 -0
- stouputils/decorators.py +595 -0
- stouputils/decorators.pyi +242 -0
- stouputils/image.py +441 -0
- stouputils/image.pyi +172 -0
- stouputils/installer/__init__.py +18 -0
- stouputils/installer/__init__.pyi +5 -0
- stouputils/installer/common.py +67 -0
- stouputils/installer/common.pyi +39 -0
- stouputils/installer/downloader.py +101 -0
- stouputils/installer/downloader.pyi +24 -0
- stouputils/installer/linux.py +144 -0
- stouputils/installer/linux.pyi +39 -0
- stouputils/installer/main.py +223 -0
- stouputils/installer/main.pyi +57 -0
- stouputils/installer/windows.py +136 -0
- stouputils/installer/windows.pyi +31 -0
- stouputils/io.py +486 -0
- stouputils/io.pyi +213 -0
- stouputils/parallel.py +453 -0
- stouputils/parallel.pyi +211 -0
- stouputils/print.py +527 -0
- stouputils/print.pyi +146 -0
- stouputils/py.typed +1 -0
- stouputils-1.12.1.dist-info/METADATA +179 -0
- stouputils-1.12.1.dist-info/RECORD +138 -0
- stouputils-1.12.1.dist-info/WHEEL +4 -0
- stouputils-1.12.1.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
|
|
2
|
+
# pyright: reportMissingTypeStubs=false
|
|
3
|
+
|
|
4
|
+
# Imports
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
from keras.callbacks import Callback
|
|
9
|
+
from keras.models import Model
|
|
10
|
+
from tqdm.auto import tqdm
|
|
11
|
+
|
|
12
|
+
from .....print import MAGENTA
|
|
13
|
+
from .....parallel import BAR_FORMAT
|
|
14
|
+
|
|
15
|
+
class ColoredProgressBar(Callback):
|
|
16
|
+
""" Progress bar using tqdm for Keras training.
|
|
17
|
+
|
|
18
|
+
A callback that displays a progress bar using tqdm during model training.
|
|
19
|
+
Shows the training progress across steps with a customized format
|
|
20
|
+
instead of the default Keras one showing multiple lines.
|
|
21
|
+
"""
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
desc: str = "Training",
|
|
25
|
+
track_epochs: bool = False,
|
|
26
|
+
show_lr: bool = False,
|
|
27
|
+
update_frequency: int = 1,
|
|
28
|
+
color: str = MAGENTA
|
|
29
|
+
) -> None:
|
|
30
|
+
""" Initialize the progress bar callback.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
desc (str): Custom description for the progress bar.
|
|
34
|
+
track_epochs (bool): Whether to track epochs instead of batches.
|
|
35
|
+
show_lr (bool): Whether to show the learning rate.
|
|
36
|
+
update_frequency (int): How often to update the progress bar (every N batches).
|
|
37
|
+
color (str): Color of the progress bar.
|
|
38
|
+
"""
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.desc: str = desc
|
|
41
|
+
""" Description of the progress bar. """
|
|
42
|
+
self.track_epochs: bool = track_epochs
|
|
43
|
+
""" Whether to track epochs instead of batches. """
|
|
44
|
+
self.show_lr: bool = show_lr
|
|
45
|
+
""" Whether to show the learning rate. """
|
|
46
|
+
self.latest_val_loss: float = 0.0
|
|
47
|
+
""" Latest validation loss, updated at the end of each epoch. """
|
|
48
|
+
self.latest_lr: float = 0.0
|
|
49
|
+
""" Latest learning rate, updated during batch and epoch processing. """
|
|
50
|
+
self.batch_count: int = 0
|
|
51
|
+
""" Counter to update the progress bar less frequently. """
|
|
52
|
+
self.update_frequency: int = max(1, update_frequency) # Ensure frequency is at least 1
|
|
53
|
+
""" How often to update the progress bar (every N batches). """
|
|
54
|
+
self.color: str = color
|
|
55
|
+
""" Color of the progress bar. """
|
|
56
|
+
self.pbar: tqdm[Any] | None = None
|
|
57
|
+
""" The tqdm progress bar instance. """
|
|
58
|
+
self.epochs: int = 0
|
|
59
|
+
""" Total number of epochs. """
|
|
60
|
+
self.steps: int = 0
|
|
61
|
+
""" Number of steps per epoch. """
|
|
62
|
+
self.total: int = 0
|
|
63
|
+
""" Total number of steps/epochs to track. """
|
|
64
|
+
self.params: dict[str, Any]
|
|
65
|
+
""" Training parameters. """
|
|
66
|
+
self.model: Model
|
|
67
|
+
|
|
68
|
+
def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
|
|
69
|
+
""" Initialize the progress bar at the start of training.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
logs (dict | None): Training logs.
|
|
73
|
+
"""
|
|
74
|
+
# Get training parameters
|
|
75
|
+
self.epochs = self.params.get("epochs", 0)
|
|
76
|
+
self.steps = self.params.get("steps", 0)
|
|
77
|
+
|
|
78
|
+
# Determine total units and initial description
|
|
79
|
+
if self.track_epochs:
|
|
80
|
+
desc: str = f"{self.color}{self.desc} (Epochs)"
|
|
81
|
+
self.total = self.epochs
|
|
82
|
+
else:
|
|
83
|
+
desc: str = f"{self.color}{self.desc} (Epoch 1/{self.epochs})"
|
|
84
|
+
self.total = self.epochs * self.steps
|
|
85
|
+
|
|
86
|
+
# Initialize tqdm bar
|
|
87
|
+
self.pbar = tqdm(
|
|
88
|
+
total=self.total,
|
|
89
|
+
desc=desc,
|
|
90
|
+
position=0,
|
|
91
|
+
leave=True,
|
|
92
|
+
bar_format=BAR_FORMAT,
|
|
93
|
+
ascii=False
|
|
94
|
+
)
|
|
95
|
+
# Reset state variables
|
|
96
|
+
self.latest_val_loss = 0.0
|
|
97
|
+
self.latest_lr = 0.0
|
|
98
|
+
self.batch_count = 0
|
|
99
|
+
|
|
100
|
+
def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
101
|
+
""" Update the progress bar after each batch, based on update frequency.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
batch (int): Current batch number (0-indexed).
|
|
105
|
+
logs (dict | None): Dictionary of logs containing metrics for the batch.
|
|
106
|
+
"""
|
|
107
|
+
# Skip updates if tracking epochs, pbar isn't initialized, or steps are unknown
|
|
108
|
+
if self.track_epochs or self.pbar is None or self.steps == 0:
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
self.batch_count += 1
|
|
112
|
+
is_last_batch: bool = (batch + 1) == self.steps
|
|
113
|
+
|
|
114
|
+
# Update only every `update_frequency` batches or on the last batch
|
|
115
|
+
if self.batch_count % self.update_frequency == 0 or is_last_batch:
|
|
116
|
+
increment: int = self.batch_count
|
|
117
|
+
self.batch_count = 0 # Reset counter
|
|
118
|
+
|
|
119
|
+
# Calculate current epoch (1-based) based on the progress bar's state *before* this update
|
|
120
|
+
# Ensure epoch doesn't exceed total epochs in description
|
|
121
|
+
current_epoch: int = min(self.epochs, self.pbar.n // self.steps + 1)
|
|
122
|
+
current_step: int = batch + 1 # Convert to 1-indexed for display
|
|
123
|
+
self.pbar.set_description(
|
|
124
|
+
f"{self.color}{self.desc} (Epoch {current_epoch}/{self.epochs}, Step {current_step}/{self.steps})"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Update learning rate if model and optimizer are available
|
|
128
|
+
if self.model and hasattr(self.model, "optimizer"):
|
|
129
|
+
self.latest_lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
|
|
130
|
+
|
|
131
|
+
# Update postfix with batch loss and the latest known validation loss
|
|
132
|
+
if logs and "loss" in logs:
|
|
133
|
+
loss: float = logs["loss"]
|
|
134
|
+
val_loss_str: str = ""
|
|
135
|
+
if self.latest_val_loss != 0.0:
|
|
136
|
+
if self.latest_val_loss < 1e-3:
|
|
137
|
+
val_loss_str = f", val_loss: {self.latest_val_loss:.2e}"
|
|
138
|
+
else:
|
|
139
|
+
val_loss_str = f", val_loss: {self.latest_val_loss:.5f}"
|
|
140
|
+
|
|
141
|
+
# Format learning rate string
|
|
142
|
+
lr_str: str = ""
|
|
143
|
+
if self.show_lr and self.latest_lr != 0.0:
|
|
144
|
+
if self.latest_lr < 1e-3:
|
|
145
|
+
lr_str = f", lr: {self.latest_lr:.2e}"
|
|
146
|
+
else:
|
|
147
|
+
lr_str = f", lr: {self.latest_lr:.5f}"
|
|
148
|
+
|
|
149
|
+
if loss < 1e-3:
|
|
150
|
+
self.pbar.set_postfix_str(f"loss: {loss:.2e}{val_loss_str}{lr_str}", refresh=False)
|
|
151
|
+
else:
|
|
152
|
+
self.pbar.set_postfix_str(f"loss: {loss:.5f}{val_loss_str}{lr_str}", refresh=False)
|
|
153
|
+
|
|
154
|
+
# Update progress bar position, ensuring not to exceed total
|
|
155
|
+
actual_increment: int = min(increment, self.total - self.pbar.n)
|
|
156
|
+
if actual_increment > 0:
|
|
157
|
+
self.pbar.update(actual_increment)
|
|
158
|
+
|
|
159
|
+
def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
160
|
+
""" Update metrics and progress bar position at the end of each epoch.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
epoch (int): Current epoch number (0-indexed).
|
|
164
|
+
logs (dict | None): Dictionary of logs containing metrics for the epoch.
|
|
165
|
+
"""
|
|
166
|
+
if self.pbar is None:
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
# Update the latest validation loss from epoch logs
|
|
170
|
+
if logs:
|
|
171
|
+
current_val_loss: float = logs.get("val_loss", 0.0)
|
|
172
|
+
if current_val_loss != 0.0:
|
|
173
|
+
self.latest_val_loss = current_val_loss
|
|
174
|
+
|
|
175
|
+
# Update learning rate if model and optimizer are available
|
|
176
|
+
if self.model and hasattr(self.model, "optimizer"):
|
|
177
|
+
self.latest_lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
|
|
178
|
+
|
|
179
|
+
# Update postfix string with final epoch metrics
|
|
180
|
+
loss: float = logs.get("loss", 0.0)
|
|
181
|
+
val_loss_str: str = f", val_loss: {self.latest_val_loss:.5f}" if self.latest_val_loss != 0.0 else ""
|
|
182
|
+
|
|
183
|
+
# Format learning rate string
|
|
184
|
+
lr_str: str = ""
|
|
185
|
+
if self.show_lr and self.latest_lr != 0.0:
|
|
186
|
+
if self.latest_lr < 1e-3:
|
|
187
|
+
lr_str = f", lr: {self.latest_lr:.2e}"
|
|
188
|
+
else:
|
|
189
|
+
lr_str = f", lr: {self.latest_lr:.5f}"
|
|
190
|
+
|
|
191
|
+
if loss != 0.0: # Only update if loss is available
|
|
192
|
+
self.pbar.set_postfix_str(f"loss: {loss:.5f}{val_loss_str}{lr_str}", refresh=True)
|
|
193
|
+
|
|
194
|
+
# Update progress bar position
|
|
195
|
+
if self.track_epochs:
|
|
196
|
+
# Increment by 1 epoch if not already at total
|
|
197
|
+
if self.pbar.n < self.total:
|
|
198
|
+
self.pbar.update(1)
|
|
199
|
+
else:
|
|
200
|
+
# Ensure the progress bar is exactly at the end of the current epoch
|
|
201
|
+
expected_position: int = min(self.total, (epoch + 1) * self.steps)
|
|
202
|
+
increment: int = expected_position - self.pbar.n
|
|
203
|
+
if increment > 0:
|
|
204
|
+
self.pbar.update(increment)
|
|
205
|
+
|
|
206
|
+
def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
|
|
207
|
+
""" Close the progress bar when training is complete.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
logs (dict | None): Training logs.
|
|
211
|
+
"""
|
|
212
|
+
if self.pbar is not None:
|
|
213
|
+
# Ensure the bar reaches 100%
|
|
214
|
+
increment: int = self.total - self.pbar.n
|
|
215
|
+
if increment > 0:
|
|
216
|
+
self.pbar.update(increment)
|
|
217
|
+
self.pbar.close()
|
|
218
|
+
self.pbar = None # Reset pbar instance
|
|
219
|
+
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
|
|
2
|
+
# pyright: reportMissingTypeStubs=false
|
|
3
|
+
|
|
4
|
+
# Imports
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
from keras.callbacks import Callback
|
|
9
|
+
from keras.models import Model
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LearningRateFinder(Callback):
|
|
13
|
+
""" Callback to find optimal learning rate by increasing LR during training.
|
|
14
|
+
|
|
15
|
+
Sources:
|
|
16
|
+
- Inspired by: https://github.com/WittmannF/LRFinder
|
|
17
|
+
- Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 (first description of the method)
|
|
18
|
+
|
|
19
|
+
This callback gradually increases the learning rate from a minimum to a maximum value
|
|
20
|
+
during training, allowing you to identify the optimal learning rate range for your model.
|
|
21
|
+
|
|
22
|
+
It works by:
|
|
23
|
+
|
|
24
|
+
1. Starting with a very small learning rate
|
|
25
|
+
2. Exponentially increasing it after each batch or epoch
|
|
26
|
+
3. Recording the loss at each learning rate
|
|
27
|
+
4. Restoring the model's initial weights after training
|
|
28
|
+
|
|
29
|
+
The optimal learning rate is typically found where the loss is decreasing most rapidly
|
|
30
|
+
before it starts to diverge.
|
|
31
|
+
|
|
32
|
+
.. image:: https://blog.dataiku.com/hubfs/training%20loss.png
|
|
33
|
+
:alt: Learning rate finder curve example
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
min_lr: float,
|
|
39
|
+
max_lr: float,
|
|
40
|
+
steps_per_epoch: int,
|
|
41
|
+
epochs: int,
|
|
42
|
+
update_per_epoch: bool = False,
|
|
43
|
+
update_interval: int = 5
|
|
44
|
+
) -> None:
|
|
45
|
+
""" Initialize the learning rate finder.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
min_lr (float): Minimum learning rate
|
|
49
|
+
max_lr (float): Maximum learning rate
|
|
50
|
+
steps_per_epoch (int): Steps per epoch
|
|
51
|
+
epochs (int): Number of epochs
|
|
52
|
+
update_per_epoch (bool): If True, update LR once per epoch instead of every batch.
|
|
53
|
+
update_interval (int): Number of steps between each lr increase, bigger value means more stable loss.
|
|
54
|
+
"""
|
|
55
|
+
super().__init__()
|
|
56
|
+
self.min_lr: float = min_lr
|
|
57
|
+
""" Minimum learning rate. """
|
|
58
|
+
self.max_lr: float = max_lr
|
|
59
|
+
""" Maximum learning rate. """
|
|
60
|
+
self.total_updates: int = (epochs if update_per_epoch else steps_per_epoch * epochs) // update_interval
|
|
61
|
+
""" Total number of update steps (considering update_interval). """
|
|
62
|
+
self.update_per_epoch: bool = update_per_epoch
|
|
63
|
+
""" Whether to update learning rate per epoch instead of per batch. """
|
|
64
|
+
self.update_interval: int = max(1, int(update_interval))
|
|
65
|
+
""" Number of steps between each lr increase, bigger value means more stable loss. """
|
|
66
|
+
self.lr_mult: float = (max_lr / min_lr) ** (1 / self.total_updates)
|
|
67
|
+
""" Learning rate multiplier. """
|
|
68
|
+
self.learning_rates: list[float] = []
|
|
69
|
+
""" List of learning rates. """
|
|
70
|
+
self.losses: list[float] = []
|
|
71
|
+
""" List of losses. """
|
|
72
|
+
self.best_lr: float = min_lr
|
|
73
|
+
""" Best learning rate. """
|
|
74
|
+
self.best_loss: float = float("inf")
|
|
75
|
+
""" Best loss. """
|
|
76
|
+
self.model: Model
|
|
77
|
+
""" Model to apply the learning rate finder to. """
|
|
78
|
+
self.initial_weights: list[Any] | None = None
|
|
79
|
+
""" Stores the initial weights of the model. """
|
|
80
|
+
|
|
81
|
+
def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
|
|
82
|
+
""" Set initial learning rate and save initial model weights at the start of training.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
logs (dict | None): Training logs.
|
|
86
|
+
"""
|
|
87
|
+
self.initial_weights = self.model.get_weights()
|
|
88
|
+
tf.keras.backend.set_value(self.model.optimizer.learning_rate, self.min_lr) # type: ignore
|
|
89
|
+
|
|
90
|
+
def _update_lr_and_track_metrics(self, logs: dict[str, Any] | None = None) -> None:
|
|
91
|
+
""" Update learning rate and track metrics.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
logs (dict | None): Logs from training
|
|
95
|
+
"""
|
|
96
|
+
if logs is None:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
# Get current learning rate and loss
|
|
100
|
+
current_lr: float = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
|
|
101
|
+
current_loss: float = logs["loss"]
|
|
102
|
+
|
|
103
|
+
# Record values
|
|
104
|
+
self.learning_rates.append(current_lr)
|
|
105
|
+
self.losses.append(current_loss)
|
|
106
|
+
|
|
107
|
+
# Track best values
|
|
108
|
+
if current_loss < self.best_loss:
|
|
109
|
+
self.best_loss = current_loss
|
|
110
|
+
self.best_lr = current_lr
|
|
111
|
+
|
|
112
|
+
# Update learning rate
|
|
113
|
+
new_lr: float = current_lr * self.lr_mult
|
|
114
|
+
tf.keras.backend.set_value(self.model.optimizer.learning_rate, new_lr) # type: ignore
|
|
115
|
+
|
|
116
|
+
def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
117
|
+
""" Record loss and increase learning rate after each batch if not updating per epoch.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
batch (int): Current batch index.
|
|
121
|
+
logs (dict | None): Training logs.
|
|
122
|
+
"""
|
|
123
|
+
if self.update_per_epoch:
|
|
124
|
+
return
|
|
125
|
+
if batch % self.update_interval == 0:
|
|
126
|
+
self._update_lr_and_track_metrics(logs)
|
|
127
|
+
|
|
128
|
+
def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
129
|
+
""" Record loss and increase learning rate after each epoch if updating per epoch.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
epoch (int): Current epoch index.
|
|
133
|
+
logs (dict | None): Training logs.
|
|
134
|
+
"""
|
|
135
|
+
if not self.update_per_epoch:
|
|
136
|
+
return
|
|
137
|
+
if epoch % self.update_interval == 0:
|
|
138
|
+
self._update_lr_and_track_metrics(logs)
|
|
139
|
+
|
|
140
|
+
def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
|
|
141
|
+
""" Restore initial model weights at the end of training.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
logs (dict | None): Training logs.
|
|
145
|
+
"""
|
|
146
|
+
if self.initial_weights is not None:
|
|
147
|
+
self.model.set_weights(self.initial_weights) # pyright: ignore [reportUnknownMemberType]
|
|
148
|
+
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
|
|
2
|
+
# pyright: reportMissingTypeStubs=false
|
|
3
|
+
# pyright: reportUnknownMemberType=false
|
|
4
|
+
|
|
5
|
+
# Imports
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from keras.callbacks import ModelCheckpoint
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ModelCheckpointV2(ModelCheckpoint):
|
|
12
|
+
""" Model checkpoint callback but only starts after a given number of epochs.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
epochs_before_start (int): Number of epochs before starting the checkpointing
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, epochs_before_start: int = 3, *args: Any, **kwargs: Any) -> None:
|
|
19
|
+
super().__init__(*args, **kwargs)
|
|
20
|
+
self.epochs_before_start = epochs_before_start
|
|
21
|
+
self.current_epoch = 0
|
|
22
|
+
|
|
23
|
+
def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
24
|
+
if self.current_epoch >= self.epochs_before_start:
|
|
25
|
+
super().on_batch_end(batch, logs)
|
|
26
|
+
|
|
27
|
+
def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
28
|
+
self.current_epoch = epoch
|
|
29
|
+
if epoch >= self.epochs_before_start:
|
|
30
|
+
super().on_epoch_end(epoch, logs)
|
|
31
|
+
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
|
|
2
|
+
# pyright: reportMissingTypeStubs=false
|
|
3
|
+
|
|
4
|
+
# Imports
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from keras.callbacks import Callback
|
|
9
|
+
from keras.models import Model
|
|
10
|
+
from keras.optimizers import Optimizer
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ProgressiveUnfreezing(Callback):
|
|
14
|
+
""" Callback inspired by the Learning Rate Finder to progressively unfreeze model layers during training.
|
|
15
|
+
|
|
16
|
+
Warning: This callback is not compatible with model.fit() as it modifies the trainable state of the model.
|
|
17
|
+
Prefer doing your own training loop instead.
|
|
18
|
+
|
|
19
|
+
This callback can operate in two modes:
|
|
20
|
+
1. Start with all layers frozen and incrementally unfreeze them from 0% to 100% (progressive_freeze=False)
|
|
21
|
+
2. Start with all layers unfrozen and incrementally freeze them from 100% to 0% (progressive_freeze=True)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
base_model: Model,
|
|
27
|
+
steps_per_epoch: int,
|
|
28
|
+
epochs: int,
|
|
29
|
+
reset_weights: bool = False,
|
|
30
|
+
reset_optimizer_function: Callable[[], Optimizer] | None = None,
|
|
31
|
+
update_per_epoch: bool = True,
|
|
32
|
+
update_interval: int = 5,
|
|
33
|
+
progressive_freeze: bool = False
|
|
34
|
+
) -> None:
|
|
35
|
+
""" Initialize the progressive unfreezing callback.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
base_model (Model): Base model to unfreeze.
|
|
39
|
+
steps_per_epoch (int): Number of steps per epoch.
|
|
40
|
+
epochs (int): Total number of epochs.
|
|
41
|
+
reset_weights (bool): If True, reset weights after each unfreeze.
|
|
42
|
+
reset_optimizer_function (Callable | None):
|
|
43
|
+
If set, use this function to reset the optimizer every update_interval.
|
|
44
|
+
The function should return a compiled optimizer, e.g. `lambda: model._get_optimizer(AdamW(...))`.
|
|
45
|
+
update_per_epoch (bool): If True, unfreeze per epoch, else per batch.
|
|
46
|
+
update_interval (int): Number of steps between each unfreeze to allow model to stabilize.
|
|
47
|
+
progressive_freeze (bool): If True, start with all layers unfrozen and progressively freeze them.
|
|
48
|
+
"""
|
|
49
|
+
super().__init__()
|
|
50
|
+
self.base_model: Model = base_model
|
|
51
|
+
""" Base model to unfreeze. """
|
|
52
|
+
self.model: Model
|
|
53
|
+
""" Model to apply the progressive unfreezing to. """
|
|
54
|
+
self.steps_per_epoch: int = int(steps_per_epoch)
|
|
55
|
+
""" Number of steps per epoch. """
|
|
56
|
+
self.epochs: int = int(epochs)
|
|
57
|
+
""" Total number of epochs. """
|
|
58
|
+
self.reset_weights: bool = bool(reset_weights)
|
|
59
|
+
""" If True, reset weights after each unfreeze. """
|
|
60
|
+
self.reset_optimizer_function: Callable[[], Optimizer] | None = reset_optimizer_function
|
|
61
|
+
""" If reset_weights is True and this is not None, use this function to get a new optimizer. """
|
|
62
|
+
self.update_per_epoch: bool = bool(update_per_epoch)
|
|
63
|
+
""" If True, unfreeze per epoch, else per batch. """
|
|
64
|
+
self.update_interval: int = max(1, int(update_interval))
|
|
65
|
+
""" Number of steps between each unfreeze to allow model to stabilize. """
|
|
66
|
+
self.progressive_freeze: bool = bool(progressive_freeze)
|
|
67
|
+
""" If True, start with all layers unfrozen and progressively freeze them. """
|
|
68
|
+
|
|
69
|
+
# If updating per epoch, remove to self.epochs the update interval to allow the last step to train with 100% unfreeze
|
|
70
|
+
if self.update_per_epoch:
|
|
71
|
+
self.epochs -= self.update_interval
|
|
72
|
+
|
|
73
|
+
# Calculate total steps considering the update interval
|
|
74
|
+
total_steps_raw: int = self.epochs if self.update_per_epoch else self.steps_per_epoch * self.epochs
|
|
75
|
+
self.total_steps: int = total_steps_raw // self.update_interval
|
|
76
|
+
""" Total number of update steps (considering update_interval). """
|
|
77
|
+
|
|
78
|
+
self.fraction_unfrozen: list[float] = []
|
|
79
|
+
""" Fraction of layers unfrozen. """
|
|
80
|
+
self.losses: list[float] = []
|
|
81
|
+
""" Losses. """
|
|
82
|
+
self._all_layers: list[Any] = []
|
|
83
|
+
""" All layers. """
|
|
84
|
+
self._initial_trainable: list[bool] = []
|
|
85
|
+
""" Initial trainable states. """
|
|
86
|
+
self._initial_weights: list[Any] | None = None
|
|
87
|
+
""" Initial weights of the model. """
|
|
88
|
+
self._last_update_step: int = -1
|
|
89
|
+
""" Last step when layers were unfrozen. """
|
|
90
|
+
self.params: dict[str, Any]
|
|
91
|
+
|
|
92
|
+
def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
|
|
93
|
+
""" Set initial layer trainable states at the start of training and store initial states and weights.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
logs (dict | None): Training logs.
|
|
97
|
+
"""
|
|
98
|
+
# Collect all layers from the model and preserve their original trainable states for potential restoration
|
|
99
|
+
self._all_layers = self.base_model.layers
|
|
100
|
+
self._initial_trainable = [bool(layer.trainable) for layer in self._all_layers]
|
|
101
|
+
|
|
102
|
+
# Store initial weights to reset after each unfreeze
|
|
103
|
+
if self.reset_weights:
|
|
104
|
+
self._initial_weights = self.model.get_weights()
|
|
105
|
+
|
|
106
|
+
# Set initial trainable state based on mode
|
|
107
|
+
for layer in self._all_layers:
|
|
108
|
+
layer.trainable = self.progressive_freeze # If progressive_freeze, start with all layers unfrozen
|
|
109
|
+
|
|
110
|
+
def _update_layers(self, step: int) -> None:
|
|
111
|
+
""" Update layer trainable states based on the current step and mode.
|
|
112
|
+
Reset weights after each update to prevent bias in the results.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
step (int): Current training step.
|
|
116
|
+
"""
|
|
117
|
+
# Calculate the effective step considering the update interval
|
|
118
|
+
effective_step: int = step // self.update_interval
|
|
119
|
+
|
|
120
|
+
# Skip if we haven't reached the next update interval
|
|
121
|
+
if effective_step <= self._last_update_step:
|
|
122
|
+
return
|
|
123
|
+
self._last_update_step = effective_step
|
|
124
|
+
|
|
125
|
+
# Calculate the number of layers to unfreeze based on current effective step
|
|
126
|
+
n_layers: int = len(self._all_layers)
|
|
127
|
+
|
|
128
|
+
if self.progressive_freeze:
|
|
129
|
+
# For progressive freezing, start at 1.0 (all unfrozen) and decrease to 0.0
|
|
130
|
+
fraction: float = max(0.0, 1.0 - (effective_step + 1) / self.total_steps)
|
|
131
|
+
else:
|
|
132
|
+
# For progressive unfreezing, start at 0.0 (all frozen) and increase to 1.0
|
|
133
|
+
fraction: float = min(1.0, (effective_step + 1) / self.total_steps)
|
|
134
|
+
|
|
135
|
+
n_unfreeze: int = int(n_layers * fraction) # Number of layers to keep unfrozen
|
|
136
|
+
self.fraction_unfrozen.append(fraction)
|
|
137
|
+
|
|
138
|
+
# Set trainable state for each layer based on position
|
|
139
|
+
# For both modes, we unfreeze from the top (output layers) to the bottom (input layers)
|
|
140
|
+
for i, layer in enumerate(self._all_layers):
|
|
141
|
+
layer.trainable = i >= (n_layers - n_unfreeze)
|
|
142
|
+
|
|
143
|
+
# Reset weights to initial state to prevent bias and reset optimizer
|
|
144
|
+
if self._initial_weights is not None:
|
|
145
|
+
self.model.set_weights(self._initial_weights) # pyright: ignore [reportUnknownMemberType]
|
|
146
|
+
if self.reset_optimizer_function is not None:
|
|
147
|
+
self.model.optimizer = self.reset_optimizer_function()
|
|
148
|
+
self.model.optimizer.build(self.model.trainable_variables) # pyright: ignore [reportUnknownMemberType]
|
|
149
|
+
|
|
150
|
+
def _track_loss(self, logs: dict[str, Any] | None = None) -> None:
|
|
151
|
+
""" Track the current loss.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
logs (dict | None): Training logs containing loss information.
|
|
155
|
+
"""
|
|
156
|
+
if logs and "loss" in logs:
|
|
157
|
+
self.losses.append(logs["loss"])
|
|
158
|
+
|
|
159
|
+
def on_batch_begin(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
160
|
+
""" Update layer trainable states at the start of each batch if not updating per epoch.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
batch (int): Current batch index.
|
|
164
|
+
logs (dict | None): Training logs.
|
|
165
|
+
"""
|
|
166
|
+
# Skip if we're updating per epoch instead of per batch
|
|
167
|
+
if self.update_per_epoch:
|
|
168
|
+
return
|
|
169
|
+
|
|
170
|
+
# Calculate the current step across all epochs and update layers
|
|
171
|
+
step: int = self.params.get("steps", self.steps_per_epoch) * self.params.get("epoch", 0) + batch
|
|
172
|
+
self._update_layers(step)
|
|
173
|
+
|
|
174
|
+
def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
175
|
+
""" Track loss at the end of each batch if not updating per epoch.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
batch (int): Current batch index.
|
|
179
|
+
logs (dict | None): Training logs.
|
|
180
|
+
"""
|
|
181
|
+
# Skip if we're updating per epoch instead of per batch
|
|
182
|
+
if self.update_per_epoch:
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
# Record the loss if update interval is reached
|
|
186
|
+
if batch % self.update_interval == 0:
|
|
187
|
+
self._track_loss(logs)
|
|
188
|
+
|
|
189
|
+
def on_epoch_begin(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
190
|
+
""" Update layer trainable states at the start of each epoch if updating per epoch.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
epoch (int): Current epoch index.
|
|
194
|
+
logs (dict | None): Training logs.
|
|
195
|
+
"""
|
|
196
|
+
# Skip if we're updating per batch instead of per epoch
|
|
197
|
+
if not self.update_per_epoch:
|
|
198
|
+
return
|
|
199
|
+
|
|
200
|
+
# Update layers based on current epoch
|
|
201
|
+
self._update_layers(epoch)
|
|
202
|
+
|
|
203
|
+
def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
204
|
+
""" Track loss at the end of each epoch if updating per epoch.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
epoch (int): Current epoch index.
|
|
208
|
+
logs (dict | None): Training logs.
|
|
209
|
+
"""
|
|
210
|
+
# Skip if we're updating per batch instead of per epoch
|
|
211
|
+
if not self.update_per_epoch:
|
|
212
|
+
return
|
|
213
|
+
|
|
214
|
+
# Record the loss if update interval is reached
|
|
215
|
+
if epoch % self.update_interval == 0:
|
|
216
|
+
self._track_loss(logs)
|
|
217
|
+
|
|
218
|
+
def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
|
|
219
|
+
""" Restore original trainable states at the end of training.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
logs (dict | None): Training logs.
|
|
223
|
+
"""
|
|
224
|
+
# Restore each layer's original trainable state
|
|
225
|
+
for layer, trainable in zip(self._all_layers, self._initial_trainable, strict=False):
|
|
226
|
+
layer.trainable = trainable
|
|
227
|
+
|
|
228
|
+
def get_results(self, multiply_by_100: bool = True) -> tuple[list[float], list[float]]:
|
|
229
|
+
""" Get the results of the progressive unfreezing from 0% to 100% even if progressive_freeze is True.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
multiply_by_100 (bool): If True, multiply the fractions by 100 to get percentages.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
tuple[list[float], list[float]]: fractions of layers unfrozen, and losses.
|
|
236
|
+
"""
|
|
237
|
+
fractions: list[float] = self.fraction_unfrozen
|
|
238
|
+
|
|
239
|
+
# Reverse the order if progressive_freeze is True
|
|
240
|
+
if self.progressive_freeze:
|
|
241
|
+
fractions = fractions[::-1]
|
|
242
|
+
|
|
243
|
+
# Multiply by 100 if requested
|
|
244
|
+
if multiply_by_100:
|
|
245
|
+
fractions = [x * 100 for x in fractions]
|
|
246
|
+
|
|
247
|
+
# Return the results
|
|
248
|
+
return fractions, self.losses
|
|
249
|
+
|