stouputils 1.14.0__py3-none-any.whl → 1.14.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/__init__.pyi +15 -0
- stouputils/_deprecated.pyi +12 -0
- stouputils/all_doctests.pyi +46 -0
- stouputils/applications/__init__.pyi +2 -0
- stouputils/applications/automatic_docs.py +3 -0
- stouputils/applications/automatic_docs.pyi +106 -0
- stouputils/applications/upscaler/__init__.pyi +3 -0
- stouputils/applications/upscaler/config.pyi +18 -0
- stouputils/applications/upscaler/image.pyi +109 -0
- stouputils/applications/upscaler/video.pyi +60 -0
- stouputils/archive.pyi +67 -0
- stouputils/backup.pyi +109 -0
- stouputils/collections.pyi +86 -0
- stouputils/continuous_delivery/__init__.pyi +5 -0
- stouputils/continuous_delivery/cd_utils.pyi +129 -0
- stouputils/continuous_delivery/github.pyi +162 -0
- stouputils/continuous_delivery/pypi.pyi +52 -0
- stouputils/continuous_delivery/pyproject.pyi +67 -0
- stouputils/continuous_delivery/stubs.pyi +39 -0
- stouputils/ctx.pyi +211 -0
- stouputils/data_science/config/get.py +51 -51
- stouputils/data_science/data_processing/image/__init__.py +66 -66
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -79
- stouputils/data_science/data_processing/image/axis_flip.py +58 -58
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -74
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -73
- stouputils/data_science/data_processing/image/blur.py +59 -59
- stouputils/data_science/data_processing/image/brightness.py +54 -54
- stouputils/data_science/data_processing/image/canny.py +110 -110
- stouputils/data_science/data_processing/image/clahe.py +92 -92
- stouputils/data_science/data_processing/image/common.py +30 -30
- stouputils/data_science/data_processing/image/contrast.py +53 -53
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -74
- stouputils/data_science/data_processing/image/denoise.py +378 -378
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -123
- stouputils/data_science/data_processing/image/invert.py +64 -64
- stouputils/data_science/data_processing/image/laplacian.py +60 -60
- stouputils/data_science/data_processing/image/median_blur.py +52 -52
- stouputils/data_science/data_processing/image/noise.py +59 -59
- stouputils/data_science/data_processing/image/normalize.py +65 -65
- stouputils/data_science/data_processing/image/random_erase.py +66 -66
- stouputils/data_science/data_processing/image/resize.py +69 -69
- stouputils/data_science/data_processing/image/rotation.py +80 -80
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -68
- stouputils/data_science/data_processing/image/sharpening.py +55 -55
- stouputils/data_science/data_processing/image/shearing.py +64 -64
- stouputils/data_science/data_processing/image/threshold.py +64 -64
- stouputils/data_science/data_processing/image/translation.py +71 -71
- stouputils/data_science/data_processing/image/zoom.py +83 -83
- stouputils/data_science/data_processing/image_augmentation.py +118 -118
- stouputils/data_science/data_processing/image_preprocess.py +183 -183
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -359
- stouputils/data_science/data_processing/technique.py +481 -481
- stouputils/data_science/dataset/__init__.py +45 -45
- stouputils/data_science/dataset/dataset.py +292 -292
- stouputils/data_science/dataset/dataset_loader.py +135 -135
- stouputils/data_science/dataset/grouping_strategy.py +296 -296
- stouputils/data_science/dataset/image_loader.py +100 -100
- stouputils/data_science/dataset/xy_tuple.py +696 -696
- stouputils/data_science/metric_dictionnary.py +106 -106
- stouputils/data_science/mlflow_utils.py +206 -206
- stouputils/data_science/models/abstract_model.py +149 -149
- stouputils/data_science/models/all.py +85 -85
- stouputils/data_science/models/keras/all.py +38 -38
- stouputils/data_science/models/keras/convnext.py +62 -62
- stouputils/data_science/models/keras/densenet.py +50 -50
- stouputils/data_science/models/keras/efficientnet.py +60 -60
- stouputils/data_science/models/keras/mobilenet.py +56 -56
- stouputils/data_science/models/keras/resnet.py +52 -52
- stouputils/data_science/models/keras/squeezenet.py +233 -233
- stouputils/data_science/models/keras/vgg.py +42 -42
- stouputils/data_science/models/keras/xception.py +38 -38
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -20
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -219
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -148
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -31
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -249
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -66
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -12
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -56
- stouputils/data_science/models/keras_utils/visualizations.py +416 -416
- stouputils/data_science/models/sandbox.py +116 -116
- stouputils/data_science/range_tuple.py +234 -234
- stouputils/data_science/utils.py +285 -285
- stouputils/decorators.pyi +242 -0
- stouputils/image.pyi +172 -0
- stouputils/installer/__init__.py +18 -18
- stouputils/installer/__init__.pyi +5 -0
- stouputils/installer/common.pyi +39 -0
- stouputils/installer/downloader.pyi +24 -0
- stouputils/installer/linux.py +144 -144
- stouputils/installer/linux.pyi +39 -0
- stouputils/installer/main.py +223 -223
- stouputils/installer/main.pyi +57 -0
- stouputils/installer/windows.py +136 -136
- stouputils/installer/windows.pyi +31 -0
- stouputils/io.pyi +213 -0
- stouputils/parallel.py +12 -10
- stouputils/parallel.pyi +211 -0
- stouputils/print.pyi +136 -0
- stouputils/py.typed +1 -1
- stouputils/stouputils/parallel.pyi +4 -4
- stouputils/version_pkg.pyi +15 -0
- {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/METADATA +1 -1
- stouputils-1.14.2.dist-info/RECORD +171 -0
- stouputils-1.14.0.dist-info/RECORD +0 -140
- {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/WHEEL +0 -0
- {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/entry_points.txt +0 -0
|
@@ -1,249 +1,249 @@
|
|
|
1
|
-
|
|
2
|
-
# pyright: reportMissingTypeStubs=false
|
|
3
|
-
|
|
4
|
-
# Imports
|
|
5
|
-
from collections.abc import Callable
|
|
6
|
-
from typing import Any
|
|
7
|
-
|
|
8
|
-
from keras.callbacks import Callback
|
|
9
|
-
from keras.models import Model
|
|
10
|
-
from keras.optimizers import Optimizer
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class ProgressiveUnfreezing(Callback):
|
|
14
|
-
""" Callback inspired by the Learning Rate Finder to progressively unfreeze model layers during training.
|
|
15
|
-
|
|
16
|
-
Warning: This callback is not compatible with model.fit() as it modifies the trainable state of the model.
|
|
17
|
-
Prefer doing your own training loop instead.
|
|
18
|
-
|
|
19
|
-
This callback can operate in two modes:
|
|
20
|
-
1. Start with all layers frozen and incrementally unfreeze them from 0% to 100% (progressive_freeze=False)
|
|
21
|
-
2. Start with all layers unfrozen and incrementally freeze them from 100% to 0% (progressive_freeze=True)
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
def __init__(
|
|
25
|
-
self,
|
|
26
|
-
base_model: Model,
|
|
27
|
-
steps_per_epoch: int,
|
|
28
|
-
epochs: int,
|
|
29
|
-
reset_weights: bool = False,
|
|
30
|
-
reset_optimizer_function: Callable[[], Optimizer] | None = None,
|
|
31
|
-
update_per_epoch: bool = True,
|
|
32
|
-
update_interval: int = 5,
|
|
33
|
-
progressive_freeze: bool = False
|
|
34
|
-
) -> None:
|
|
35
|
-
""" Initialize the progressive unfreezing callback.
|
|
36
|
-
|
|
37
|
-
Args:
|
|
38
|
-
base_model (Model): Base model to unfreeze.
|
|
39
|
-
steps_per_epoch (int): Number of steps per epoch.
|
|
40
|
-
epochs (int): Total number of epochs.
|
|
41
|
-
reset_weights (bool): If True, reset weights after each unfreeze.
|
|
42
|
-
reset_optimizer_function (Callable | None):
|
|
43
|
-
If set, use this function to reset the optimizer every update_interval.
|
|
44
|
-
The function should return a compiled optimizer, e.g. `lambda: model._get_optimizer(AdamW(...))`.
|
|
45
|
-
update_per_epoch (bool): If True, unfreeze per epoch, else per batch.
|
|
46
|
-
update_interval (int): Number of steps between each unfreeze to allow model to stabilize.
|
|
47
|
-
progressive_freeze (bool): If True, start with all layers unfrozen and progressively freeze them.
|
|
48
|
-
"""
|
|
49
|
-
super().__init__()
|
|
50
|
-
self.base_model: Model = base_model
|
|
51
|
-
""" Base model to unfreeze. """
|
|
52
|
-
self.model: Model
|
|
53
|
-
""" Model to apply the progressive unfreezing to. """
|
|
54
|
-
self.steps_per_epoch: int = int(steps_per_epoch)
|
|
55
|
-
""" Number of steps per epoch. """
|
|
56
|
-
self.epochs: int = int(epochs)
|
|
57
|
-
""" Total number of epochs. """
|
|
58
|
-
self.reset_weights: bool = bool(reset_weights)
|
|
59
|
-
""" If True, reset weights after each unfreeze. """
|
|
60
|
-
self.reset_optimizer_function: Callable[[], Optimizer] | None = reset_optimizer_function
|
|
61
|
-
""" If reset_weights is True and this is not None, use this function to get a new optimizer. """
|
|
62
|
-
self.update_per_epoch: bool = bool(update_per_epoch)
|
|
63
|
-
""" If True, unfreeze per epoch, else per batch. """
|
|
64
|
-
self.update_interval: int = max(1, int(update_interval))
|
|
65
|
-
""" Number of steps between each unfreeze to allow model to stabilize. """
|
|
66
|
-
self.progressive_freeze: bool = bool(progressive_freeze)
|
|
67
|
-
""" If True, start with all layers unfrozen and progressively freeze them. """
|
|
68
|
-
|
|
69
|
-
# If updating per epoch, remove to self.epochs the update interval to allow the last step to train with 100% unfreeze
|
|
70
|
-
if self.update_per_epoch:
|
|
71
|
-
self.epochs -= self.update_interval
|
|
72
|
-
|
|
73
|
-
# Calculate total steps considering the update interval
|
|
74
|
-
total_steps_raw: int = self.epochs if self.update_per_epoch else self.steps_per_epoch * self.epochs
|
|
75
|
-
self.total_steps: int = total_steps_raw // self.update_interval
|
|
76
|
-
""" Total number of update steps (considering update_interval). """
|
|
77
|
-
|
|
78
|
-
self.fraction_unfrozen: list[float] = []
|
|
79
|
-
""" Fraction of layers unfrozen. """
|
|
80
|
-
self.losses: list[float] = []
|
|
81
|
-
""" Losses. """
|
|
82
|
-
self._all_layers: list[Any] = []
|
|
83
|
-
""" All layers. """
|
|
84
|
-
self._initial_trainable: list[bool] = []
|
|
85
|
-
""" Initial trainable states. """
|
|
86
|
-
self._initial_weights: list[Any] | None = None
|
|
87
|
-
""" Initial weights of the model. """
|
|
88
|
-
self._last_update_step: int = -1
|
|
89
|
-
""" Last step when layers were unfrozen. """
|
|
90
|
-
self.params: dict[str, Any]
|
|
91
|
-
|
|
92
|
-
def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
|
|
93
|
-
""" Set initial layer trainable states at the start of training and store initial states and weights.
|
|
94
|
-
|
|
95
|
-
Args:
|
|
96
|
-
logs (dict | None): Training logs.
|
|
97
|
-
"""
|
|
98
|
-
# Collect all layers from the model and preserve their original trainable states for potential restoration
|
|
99
|
-
self._all_layers = self.base_model.layers
|
|
100
|
-
self._initial_trainable = [bool(layer.trainable) for layer in self._all_layers]
|
|
101
|
-
|
|
102
|
-
# Store initial weights to reset after each unfreeze
|
|
103
|
-
if self.reset_weights:
|
|
104
|
-
self._initial_weights = self.model.get_weights()
|
|
105
|
-
|
|
106
|
-
# Set initial trainable state based on mode
|
|
107
|
-
for layer in self._all_layers:
|
|
108
|
-
layer.trainable = self.progressive_freeze # If progressive_freeze, start with all layers unfrozen
|
|
109
|
-
|
|
110
|
-
def _update_layers(self, step: int) -> None:
|
|
111
|
-
""" Update layer trainable states based on the current step and mode.
|
|
112
|
-
Reset weights after each update to prevent bias in the results.
|
|
113
|
-
|
|
114
|
-
Args:
|
|
115
|
-
step (int): Current training step.
|
|
116
|
-
"""
|
|
117
|
-
# Calculate the effective step considering the update interval
|
|
118
|
-
effective_step: int = step // self.update_interval
|
|
119
|
-
|
|
120
|
-
# Skip if we haven't reached the next update interval
|
|
121
|
-
if effective_step <= self._last_update_step:
|
|
122
|
-
return
|
|
123
|
-
self._last_update_step = effective_step
|
|
124
|
-
|
|
125
|
-
# Calculate the number of layers to unfreeze based on current effective step
|
|
126
|
-
n_layers: int = len(self._all_layers)
|
|
127
|
-
|
|
128
|
-
if self.progressive_freeze:
|
|
129
|
-
# For progressive freezing, start at 1.0 (all unfrozen) and decrease to 0.0
|
|
130
|
-
fraction: float = max(0.0, 1.0 - (effective_step + 1) / self.total_steps)
|
|
131
|
-
else:
|
|
132
|
-
# For progressive unfreezing, start at 0.0 (all frozen) and increase to 1.0
|
|
133
|
-
fraction: float = min(1.0, (effective_step + 1) / self.total_steps)
|
|
134
|
-
|
|
135
|
-
n_unfreeze: int = int(n_layers * fraction) # Number of layers to keep unfrozen
|
|
136
|
-
self.fraction_unfrozen.append(fraction)
|
|
137
|
-
|
|
138
|
-
# Set trainable state for each layer based on position
|
|
139
|
-
# For both modes, we unfreeze from the top (output layers) to the bottom (input layers)
|
|
140
|
-
for i, layer in enumerate(self._all_layers):
|
|
141
|
-
layer.trainable = i >= (n_layers - n_unfreeze)
|
|
142
|
-
|
|
143
|
-
# Reset weights to initial state to prevent bias and reset optimizer
|
|
144
|
-
if self._initial_weights is not None:
|
|
145
|
-
self.model.set_weights(self._initial_weights) # pyright: ignore [reportUnknownMemberType]
|
|
146
|
-
if self.reset_optimizer_function is not None:
|
|
147
|
-
self.model.optimizer = self.reset_optimizer_function()
|
|
148
|
-
self.model.optimizer.build(self.model.trainable_variables) # pyright: ignore [reportUnknownMemberType]
|
|
149
|
-
|
|
150
|
-
def _track_loss(self, logs: dict[str, Any] | None = None) -> None:
|
|
151
|
-
""" Track the current loss.
|
|
152
|
-
|
|
153
|
-
Args:
|
|
154
|
-
logs (dict | None): Training logs containing loss information.
|
|
155
|
-
"""
|
|
156
|
-
if logs and "loss" in logs:
|
|
157
|
-
self.losses.append(logs["loss"])
|
|
158
|
-
|
|
159
|
-
def on_batch_begin(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
160
|
-
""" Update layer trainable states at the start of each batch if not updating per epoch.
|
|
161
|
-
|
|
162
|
-
Args:
|
|
163
|
-
batch (int): Current batch index.
|
|
164
|
-
logs (dict | None): Training logs.
|
|
165
|
-
"""
|
|
166
|
-
# Skip if we're updating per epoch instead of per batch
|
|
167
|
-
if self.update_per_epoch:
|
|
168
|
-
return
|
|
169
|
-
|
|
170
|
-
# Calculate the current step across all epochs and update layers
|
|
171
|
-
step: int = self.params.get("steps", self.steps_per_epoch) * self.params.get("epoch", 0) + batch
|
|
172
|
-
self._update_layers(step)
|
|
173
|
-
|
|
174
|
-
def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
175
|
-
""" Track loss at the end of each batch if not updating per epoch.
|
|
176
|
-
|
|
177
|
-
Args:
|
|
178
|
-
batch (int): Current batch index.
|
|
179
|
-
logs (dict | None): Training logs.
|
|
180
|
-
"""
|
|
181
|
-
# Skip if we're updating per epoch instead of per batch
|
|
182
|
-
if self.update_per_epoch:
|
|
183
|
-
return
|
|
184
|
-
|
|
185
|
-
# Record the loss if update interval is reached
|
|
186
|
-
if batch % self.update_interval == 0:
|
|
187
|
-
self._track_loss(logs)
|
|
188
|
-
|
|
189
|
-
def on_epoch_begin(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
190
|
-
""" Update layer trainable states at the start of each epoch if updating per epoch.
|
|
191
|
-
|
|
192
|
-
Args:
|
|
193
|
-
epoch (int): Current epoch index.
|
|
194
|
-
logs (dict | None): Training logs.
|
|
195
|
-
"""
|
|
196
|
-
# Skip if we're updating per batch instead of per epoch
|
|
197
|
-
if not self.update_per_epoch:
|
|
198
|
-
return
|
|
199
|
-
|
|
200
|
-
# Update layers based on current epoch
|
|
201
|
-
self._update_layers(epoch)
|
|
202
|
-
|
|
203
|
-
def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
204
|
-
""" Track loss at the end of each epoch if updating per epoch.
|
|
205
|
-
|
|
206
|
-
Args:
|
|
207
|
-
epoch (int): Current epoch index.
|
|
208
|
-
logs (dict | None): Training logs.
|
|
209
|
-
"""
|
|
210
|
-
# Skip if we're updating per batch instead of per epoch
|
|
211
|
-
if not self.update_per_epoch:
|
|
212
|
-
return
|
|
213
|
-
|
|
214
|
-
# Record the loss if update interval is reached
|
|
215
|
-
if epoch % self.update_interval == 0:
|
|
216
|
-
self._track_loss(logs)
|
|
217
|
-
|
|
218
|
-
def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
|
|
219
|
-
""" Restore original trainable states at the end of training.
|
|
220
|
-
|
|
221
|
-
Args:
|
|
222
|
-
logs (dict | None): Training logs.
|
|
223
|
-
"""
|
|
224
|
-
# Restore each layer's original trainable state
|
|
225
|
-
for layer, trainable in zip(self._all_layers, self._initial_trainable, strict=False):
|
|
226
|
-
layer.trainable = trainable
|
|
227
|
-
|
|
228
|
-
def get_results(self, multiply_by_100: bool = True) -> tuple[list[float], list[float]]:
|
|
229
|
-
""" Get the results of the progressive unfreezing from 0% to 100% even if progressive_freeze is True.
|
|
230
|
-
|
|
231
|
-
Args:
|
|
232
|
-
multiply_by_100 (bool): If True, multiply the fractions by 100 to get percentages.
|
|
233
|
-
|
|
234
|
-
Returns:
|
|
235
|
-
tuple[list[float], list[float]]: fractions of layers unfrozen, and losses.
|
|
236
|
-
"""
|
|
237
|
-
fractions: list[float] = self.fraction_unfrozen
|
|
238
|
-
|
|
239
|
-
# Reverse the order if progressive_freeze is True
|
|
240
|
-
if self.progressive_freeze:
|
|
241
|
-
fractions = fractions[::-1]
|
|
242
|
-
|
|
243
|
-
# Multiply by 100 if requested
|
|
244
|
-
if multiply_by_100:
|
|
245
|
-
fractions = [x * 100 for x in fractions]
|
|
246
|
-
|
|
247
|
-
# Return the results
|
|
248
|
-
return fractions, self.losses
|
|
249
|
-
|
|
1
|
+
|
|
2
|
+
# pyright: reportMissingTypeStubs=false
|
|
3
|
+
|
|
4
|
+
# Imports
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from keras.callbacks import Callback
|
|
9
|
+
from keras.models import Model
|
|
10
|
+
from keras.optimizers import Optimizer
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ProgressiveUnfreezing(Callback):
|
|
14
|
+
""" Callback inspired by the Learning Rate Finder to progressively unfreeze model layers during training.
|
|
15
|
+
|
|
16
|
+
Warning: This callback is not compatible with model.fit() as it modifies the trainable state of the model.
|
|
17
|
+
Prefer doing your own training loop instead.
|
|
18
|
+
|
|
19
|
+
This callback can operate in two modes:
|
|
20
|
+
1. Start with all layers frozen and incrementally unfreeze them from 0% to 100% (progressive_freeze=False)
|
|
21
|
+
2. Start with all layers unfrozen and incrementally freeze them from 100% to 0% (progressive_freeze=True)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
base_model: Model,
|
|
27
|
+
steps_per_epoch: int,
|
|
28
|
+
epochs: int,
|
|
29
|
+
reset_weights: bool = False,
|
|
30
|
+
reset_optimizer_function: Callable[[], Optimizer] | None = None,
|
|
31
|
+
update_per_epoch: bool = True,
|
|
32
|
+
update_interval: int = 5,
|
|
33
|
+
progressive_freeze: bool = False
|
|
34
|
+
) -> None:
|
|
35
|
+
""" Initialize the progressive unfreezing callback.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
base_model (Model): Base model to unfreeze.
|
|
39
|
+
steps_per_epoch (int): Number of steps per epoch.
|
|
40
|
+
epochs (int): Total number of epochs.
|
|
41
|
+
reset_weights (bool): If True, reset weights after each unfreeze.
|
|
42
|
+
reset_optimizer_function (Callable | None):
|
|
43
|
+
If set, use this function to reset the optimizer every update_interval.
|
|
44
|
+
The function should return a compiled optimizer, e.g. `lambda: model._get_optimizer(AdamW(...))`.
|
|
45
|
+
update_per_epoch (bool): If True, unfreeze per epoch, else per batch.
|
|
46
|
+
update_interval (int): Number of steps between each unfreeze to allow model to stabilize.
|
|
47
|
+
progressive_freeze (bool): If True, start with all layers unfrozen and progressively freeze them.
|
|
48
|
+
"""
|
|
49
|
+
super().__init__()
|
|
50
|
+
self.base_model: Model = base_model
|
|
51
|
+
""" Base model to unfreeze. """
|
|
52
|
+
self.model: Model
|
|
53
|
+
""" Model to apply the progressive unfreezing to. """
|
|
54
|
+
self.steps_per_epoch: int = int(steps_per_epoch)
|
|
55
|
+
""" Number of steps per epoch. """
|
|
56
|
+
self.epochs: int = int(epochs)
|
|
57
|
+
""" Total number of epochs. """
|
|
58
|
+
self.reset_weights: bool = bool(reset_weights)
|
|
59
|
+
""" If True, reset weights after each unfreeze. """
|
|
60
|
+
self.reset_optimizer_function: Callable[[], Optimizer] | None = reset_optimizer_function
|
|
61
|
+
""" If reset_weights is True and this is not None, use this function to get a new optimizer. """
|
|
62
|
+
self.update_per_epoch: bool = bool(update_per_epoch)
|
|
63
|
+
""" If True, unfreeze per epoch, else per batch. """
|
|
64
|
+
self.update_interval: int = max(1, int(update_interval))
|
|
65
|
+
""" Number of steps between each unfreeze to allow model to stabilize. """
|
|
66
|
+
self.progressive_freeze: bool = bool(progressive_freeze)
|
|
67
|
+
""" If True, start with all layers unfrozen and progressively freeze them. """
|
|
68
|
+
|
|
69
|
+
# If updating per epoch, remove to self.epochs the update interval to allow the last step to train with 100% unfreeze
|
|
70
|
+
if self.update_per_epoch:
|
|
71
|
+
self.epochs -= self.update_interval
|
|
72
|
+
|
|
73
|
+
# Calculate total steps considering the update interval
|
|
74
|
+
total_steps_raw: int = self.epochs if self.update_per_epoch else self.steps_per_epoch * self.epochs
|
|
75
|
+
self.total_steps: int = total_steps_raw // self.update_interval
|
|
76
|
+
""" Total number of update steps (considering update_interval). """
|
|
77
|
+
|
|
78
|
+
self.fraction_unfrozen: list[float] = []
|
|
79
|
+
""" Fraction of layers unfrozen. """
|
|
80
|
+
self.losses: list[float] = []
|
|
81
|
+
""" Losses. """
|
|
82
|
+
self._all_layers: list[Any] = []
|
|
83
|
+
""" All layers. """
|
|
84
|
+
self._initial_trainable: list[bool] = []
|
|
85
|
+
""" Initial trainable states. """
|
|
86
|
+
self._initial_weights: list[Any] | None = None
|
|
87
|
+
""" Initial weights of the model. """
|
|
88
|
+
self._last_update_step: int = -1
|
|
89
|
+
""" Last step when layers were unfrozen. """
|
|
90
|
+
self.params: dict[str, Any]
|
|
91
|
+
|
|
92
|
+
def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
|
|
93
|
+
""" Set initial layer trainable states at the start of training and store initial states and weights.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
logs (dict | None): Training logs.
|
|
97
|
+
"""
|
|
98
|
+
# Collect all layers from the model and preserve their original trainable states for potential restoration
|
|
99
|
+
self._all_layers = self.base_model.layers
|
|
100
|
+
self._initial_trainable = [bool(layer.trainable) for layer in self._all_layers]
|
|
101
|
+
|
|
102
|
+
# Store initial weights to reset after each unfreeze
|
|
103
|
+
if self.reset_weights:
|
|
104
|
+
self._initial_weights = self.model.get_weights()
|
|
105
|
+
|
|
106
|
+
# Set initial trainable state based on mode
|
|
107
|
+
for layer in self._all_layers:
|
|
108
|
+
layer.trainable = self.progressive_freeze # If progressive_freeze, start with all layers unfrozen
|
|
109
|
+
|
|
110
|
+
def _update_layers(self, step: int) -> None:
|
|
111
|
+
""" Update layer trainable states based on the current step and mode.
|
|
112
|
+
Reset weights after each update to prevent bias in the results.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
step (int): Current training step.
|
|
116
|
+
"""
|
|
117
|
+
# Calculate the effective step considering the update interval
|
|
118
|
+
effective_step: int = step // self.update_interval
|
|
119
|
+
|
|
120
|
+
# Skip if we haven't reached the next update interval
|
|
121
|
+
if effective_step <= self._last_update_step:
|
|
122
|
+
return
|
|
123
|
+
self._last_update_step = effective_step
|
|
124
|
+
|
|
125
|
+
# Calculate the number of layers to unfreeze based on current effective step
|
|
126
|
+
n_layers: int = len(self._all_layers)
|
|
127
|
+
|
|
128
|
+
if self.progressive_freeze:
|
|
129
|
+
# For progressive freezing, start at 1.0 (all unfrozen) and decrease to 0.0
|
|
130
|
+
fraction: float = max(0.0, 1.0 - (effective_step + 1) / self.total_steps)
|
|
131
|
+
else:
|
|
132
|
+
# For progressive unfreezing, start at 0.0 (all frozen) and increase to 1.0
|
|
133
|
+
fraction: float = min(1.0, (effective_step + 1) / self.total_steps)
|
|
134
|
+
|
|
135
|
+
n_unfreeze: int = int(n_layers * fraction) # Number of layers to keep unfrozen
|
|
136
|
+
self.fraction_unfrozen.append(fraction)
|
|
137
|
+
|
|
138
|
+
# Set trainable state for each layer based on position
|
|
139
|
+
# For both modes, we unfreeze from the top (output layers) to the bottom (input layers)
|
|
140
|
+
for i, layer in enumerate(self._all_layers):
|
|
141
|
+
layer.trainable = i >= (n_layers - n_unfreeze)
|
|
142
|
+
|
|
143
|
+
# Reset weights to initial state to prevent bias and reset optimizer
|
|
144
|
+
if self._initial_weights is not None:
|
|
145
|
+
self.model.set_weights(self._initial_weights) # pyright: ignore [reportUnknownMemberType]
|
|
146
|
+
if self.reset_optimizer_function is not None:
|
|
147
|
+
self.model.optimizer = self.reset_optimizer_function()
|
|
148
|
+
self.model.optimizer.build(self.model.trainable_variables) # pyright: ignore [reportUnknownMemberType]
|
|
149
|
+
|
|
150
|
+
def _track_loss(self, logs: dict[str, Any] | None = None) -> None:
|
|
151
|
+
""" Track the current loss.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
logs (dict | None): Training logs containing loss information.
|
|
155
|
+
"""
|
|
156
|
+
if logs and "loss" in logs:
|
|
157
|
+
self.losses.append(logs["loss"])
|
|
158
|
+
|
|
159
|
+
def on_batch_begin(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
160
|
+
""" Update layer trainable states at the start of each batch if not updating per epoch.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
batch (int): Current batch index.
|
|
164
|
+
logs (dict | None): Training logs.
|
|
165
|
+
"""
|
|
166
|
+
# Skip if we're updating per epoch instead of per batch
|
|
167
|
+
if self.update_per_epoch:
|
|
168
|
+
return
|
|
169
|
+
|
|
170
|
+
# Calculate the current step across all epochs and update layers
|
|
171
|
+
step: int = self.params.get("steps", self.steps_per_epoch) * self.params.get("epoch", 0) + batch
|
|
172
|
+
self._update_layers(step)
|
|
173
|
+
|
|
174
|
+
def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
175
|
+
""" Track loss at the end of each batch if not updating per epoch.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
batch (int): Current batch index.
|
|
179
|
+
logs (dict | None): Training logs.
|
|
180
|
+
"""
|
|
181
|
+
# Skip if we're updating per epoch instead of per batch
|
|
182
|
+
if self.update_per_epoch:
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
# Record the loss if update interval is reached
|
|
186
|
+
if batch % self.update_interval == 0:
|
|
187
|
+
self._track_loss(logs)
|
|
188
|
+
|
|
189
|
+
def on_epoch_begin(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
190
|
+
""" Update layer trainable states at the start of each epoch if updating per epoch.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
epoch (int): Current epoch index.
|
|
194
|
+
logs (dict | None): Training logs.
|
|
195
|
+
"""
|
|
196
|
+
# Skip if we're updating per batch instead of per epoch
|
|
197
|
+
if not self.update_per_epoch:
|
|
198
|
+
return
|
|
199
|
+
|
|
200
|
+
# Update layers based on current epoch
|
|
201
|
+
self._update_layers(epoch)
|
|
202
|
+
|
|
203
|
+
def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
204
|
+
""" Track loss at the end of each epoch if updating per epoch.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
epoch (int): Current epoch index.
|
|
208
|
+
logs (dict | None): Training logs.
|
|
209
|
+
"""
|
|
210
|
+
# Skip if we're updating per batch instead of per epoch
|
|
211
|
+
if not self.update_per_epoch:
|
|
212
|
+
return
|
|
213
|
+
|
|
214
|
+
# Record the loss if update interval is reached
|
|
215
|
+
if epoch % self.update_interval == 0:
|
|
216
|
+
self._track_loss(logs)
|
|
217
|
+
|
|
218
|
+
def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
|
|
219
|
+
""" Restore original trainable states at the end of training.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
logs (dict | None): Training logs.
|
|
223
|
+
"""
|
|
224
|
+
# Restore each layer's original trainable state
|
|
225
|
+
for layer, trainable in zip(self._all_layers, self._initial_trainable, strict=False):
|
|
226
|
+
layer.trainable = trainable
|
|
227
|
+
|
|
228
|
+
def get_results(self, multiply_by_100: bool = True) -> tuple[list[float], list[float]]:
|
|
229
|
+
""" Get the results of the progressive unfreezing from 0% to 100% even if progressive_freeze is True.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
multiply_by_100 (bool): If True, multiply the fractions by 100 to get percentages.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
tuple[list[float], list[float]]: fractions of layers unfrozen, and losses.
|
|
236
|
+
"""
|
|
237
|
+
fractions: list[float] = self.fraction_unfrozen
|
|
238
|
+
|
|
239
|
+
# Reverse the order if progressive_freeze is True
|
|
240
|
+
if self.progressive_freeze:
|
|
241
|
+
fractions = fractions[::-1]
|
|
242
|
+
|
|
243
|
+
# Multiply by 100 if requested
|
|
244
|
+
if multiply_by_100:
|
|
245
|
+
fractions = [x * 100 for x in fractions]
|
|
246
|
+
|
|
247
|
+
# Return the results
|
|
248
|
+
return fractions, self.losses
|
|
249
|
+
|
|
@@ -1,66 +1,66 @@
|
|
|
1
|
-
|
|
2
|
-
# pyright: reportMissingTypeStubs=false
|
|
3
|
-
|
|
4
|
-
# Imports
|
|
5
|
-
from typing import Any
|
|
6
|
-
|
|
7
|
-
import tensorflow as tf
|
|
8
|
-
from keras.callbacks import Callback
|
|
9
|
-
from keras.models import Model
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class WarmupScheduler(Callback):
|
|
13
|
-
""" Keras Callback for learning rate warmup.
|
|
14
|
-
|
|
15
|
-
Sources:
|
|
16
|
-
- Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour: https://arxiv.org/abs/1706.02677
|
|
17
|
-
- Attention Is All You Need: https://arxiv.org/abs/1706.03762
|
|
18
|
-
|
|
19
|
-
This callback implements a learning rate warmup strategy where the learning rate
|
|
20
|
-
gradually increases from an initial value to a target value over a specified
|
|
21
|
-
number of epochs. This helps stabilize training in the early stages.
|
|
22
|
-
|
|
23
|
-
The learning rate increases linearly from the initial value to the target value
|
|
24
|
-
over the warmup period, and then remains at the target value.
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
def __init__(self, warmup_epochs: int, initial_lr: float, target_lr: float) -> None:
|
|
28
|
-
""" Initialize the warmup scheduler.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
warmup_epochs (int): Number of epochs for warmup.
|
|
32
|
-
initial_lr (float): Starting learning rate for warmup.
|
|
33
|
-
target_lr (float): Target learning rate after warmup.
|
|
34
|
-
"""
|
|
35
|
-
super().__init__()
|
|
36
|
-
self.warmup_epochs: int = warmup_epochs
|
|
37
|
-
""" Number of epochs for warmup. """
|
|
38
|
-
self.initial_lr: float = initial_lr
|
|
39
|
-
""" Starting learning rate for warmup. """
|
|
40
|
-
self.target_lr: float = target_lr
|
|
41
|
-
""" Target learning rate after warmup. """
|
|
42
|
-
self.model: Model
|
|
43
|
-
""" Model to apply the warmup scheduler to. """
|
|
44
|
-
|
|
45
|
-
# Pre-compute learning rates for each epoch to avoid calculations during training
|
|
46
|
-
self.epoch_learning_rates: list[float] = []
|
|
47
|
-
for epoch in range(warmup_epochs + 1):
|
|
48
|
-
if epoch < warmup_epochs:
|
|
49
|
-
lr = initial_lr + (target_lr - initial_lr) * (epoch + 1) / warmup_epochs
|
|
50
|
-
else:
|
|
51
|
-
lr = target_lr
|
|
52
|
-
self.epoch_learning_rates.append(lr)
|
|
53
|
-
|
|
54
|
-
def on_epoch_begin(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
55
|
-
""" Adjust learning rate at the beginning of each epoch during warmup.
|
|
56
|
-
|
|
57
|
-
Args:
|
|
58
|
-
epoch (int): Current epoch index.
|
|
59
|
-
logs (dict | None): Training logs.
|
|
60
|
-
"""
|
|
61
|
-
if self.warmup_epochs <= 0 or epoch > self.warmup_epochs:
|
|
62
|
-
return
|
|
63
|
-
|
|
64
|
-
# Use pre-computed learning rate to avoid calculations during training
|
|
65
|
-
tf.keras.backend.set_value(self.model.optimizer.learning_rate, self.epoch_learning_rates[epoch]) # type: ignore
|
|
66
|
-
|
|
1
|
+
|
|
2
|
+
# pyright: reportMissingTypeStubs=false
|
|
3
|
+
|
|
4
|
+
# Imports
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
from keras.callbacks import Callback
|
|
9
|
+
from keras.models import Model
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class WarmupScheduler(Callback):
|
|
13
|
+
""" Keras Callback for learning rate warmup.
|
|
14
|
+
|
|
15
|
+
Sources:
|
|
16
|
+
- Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour: https://arxiv.org/abs/1706.02677
|
|
17
|
+
- Attention Is All You Need: https://arxiv.org/abs/1706.03762
|
|
18
|
+
|
|
19
|
+
This callback implements a learning rate warmup strategy where the learning rate
|
|
20
|
+
gradually increases from an initial value to a target value over a specified
|
|
21
|
+
number of epochs. This helps stabilize training in the early stages.
|
|
22
|
+
|
|
23
|
+
The learning rate increases linearly from the initial value to the target value
|
|
24
|
+
over the warmup period, and then remains at the target value.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, warmup_epochs: int, initial_lr: float, target_lr: float) -> None:
|
|
28
|
+
""" Initialize the warmup scheduler.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
warmup_epochs (int): Number of epochs for warmup.
|
|
32
|
+
initial_lr (float): Starting learning rate for warmup.
|
|
33
|
+
target_lr (float): Target learning rate after warmup.
|
|
34
|
+
"""
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.warmup_epochs: int = warmup_epochs
|
|
37
|
+
""" Number of epochs for warmup. """
|
|
38
|
+
self.initial_lr: float = initial_lr
|
|
39
|
+
""" Starting learning rate for warmup. """
|
|
40
|
+
self.target_lr: float = target_lr
|
|
41
|
+
""" Target learning rate after warmup. """
|
|
42
|
+
self.model: Model
|
|
43
|
+
""" Model to apply the warmup scheduler to. """
|
|
44
|
+
|
|
45
|
+
# Pre-compute learning rates for each epoch to avoid calculations during training
|
|
46
|
+
self.epoch_learning_rates: list[float] = []
|
|
47
|
+
for epoch in range(warmup_epochs + 1):
|
|
48
|
+
if epoch < warmup_epochs:
|
|
49
|
+
lr = initial_lr + (target_lr - initial_lr) * (epoch + 1) / warmup_epochs
|
|
50
|
+
else:
|
|
51
|
+
lr = target_lr
|
|
52
|
+
self.epoch_learning_rates.append(lr)
|
|
53
|
+
|
|
54
|
+
def on_epoch_begin(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
55
|
+
""" Adjust learning rate at the beginning of each epoch during warmup.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
epoch (int): Current epoch index.
|
|
59
|
+
logs (dict | None): Training logs.
|
|
60
|
+
"""
|
|
61
|
+
if self.warmup_epochs <= 0 or epoch > self.warmup_epochs:
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
# Use pre-computed learning rate to avoid calculations during training
|
|
65
|
+
tf.keras.backend.set_value(self.model.optimizer.learning_rate, self.epoch_learning_rates[epoch]) # type: ignore
|
|
66
|
+
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
""" Custom losses for Keras models.
|
|
2
|
-
|
|
3
|
-
Features:
|
|
4
|
-
|
|
5
|
-
- Next Generation Loss
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
# Imports
|
|
9
|
-
from .next_generation_loss import NextGenerationLoss
|
|
10
|
-
|
|
11
|
-
__all__ = ["NextGenerationLoss"]
|
|
12
|
-
|
|
1
|
+
""" Custom losses for Keras models.
|
|
2
|
+
|
|
3
|
+
Features:
|
|
4
|
+
|
|
5
|
+
- Next Generation Loss
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# Imports
|
|
9
|
+
from .next_generation_loss import NextGenerationLoss
|
|
10
|
+
|
|
11
|
+
__all__ = ["NextGenerationLoss"]
|
|
12
|
+
|