stouputils 1.14.0__py3-none-any.whl → 1.14.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. stouputils/__init__.pyi +15 -0
  2. stouputils/_deprecated.pyi +12 -0
  3. stouputils/all_doctests.pyi +46 -0
  4. stouputils/applications/__init__.pyi +2 -0
  5. stouputils/applications/automatic_docs.py +3 -0
  6. stouputils/applications/automatic_docs.pyi +106 -0
  7. stouputils/applications/upscaler/__init__.pyi +3 -0
  8. stouputils/applications/upscaler/config.pyi +18 -0
  9. stouputils/applications/upscaler/image.pyi +109 -0
  10. stouputils/applications/upscaler/video.pyi +60 -0
  11. stouputils/archive.pyi +67 -0
  12. stouputils/backup.pyi +109 -0
  13. stouputils/collections.pyi +86 -0
  14. stouputils/continuous_delivery/__init__.pyi +5 -0
  15. stouputils/continuous_delivery/cd_utils.pyi +129 -0
  16. stouputils/continuous_delivery/github.pyi +162 -0
  17. stouputils/continuous_delivery/pypi.pyi +52 -0
  18. stouputils/continuous_delivery/pyproject.pyi +67 -0
  19. stouputils/continuous_delivery/stubs.pyi +39 -0
  20. stouputils/ctx.pyi +211 -0
  21. stouputils/data_science/config/get.py +51 -51
  22. stouputils/data_science/data_processing/image/__init__.py +66 -66
  23. stouputils/data_science/data_processing/image/auto_contrast.py +79 -79
  24. stouputils/data_science/data_processing/image/axis_flip.py +58 -58
  25. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -74
  26. stouputils/data_science/data_processing/image/binary_threshold.py +73 -73
  27. stouputils/data_science/data_processing/image/blur.py +59 -59
  28. stouputils/data_science/data_processing/image/brightness.py +54 -54
  29. stouputils/data_science/data_processing/image/canny.py +110 -110
  30. stouputils/data_science/data_processing/image/clahe.py +92 -92
  31. stouputils/data_science/data_processing/image/common.py +30 -30
  32. stouputils/data_science/data_processing/image/contrast.py +53 -53
  33. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -74
  34. stouputils/data_science/data_processing/image/denoise.py +378 -378
  35. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -123
  36. stouputils/data_science/data_processing/image/invert.py +64 -64
  37. stouputils/data_science/data_processing/image/laplacian.py +60 -60
  38. stouputils/data_science/data_processing/image/median_blur.py +52 -52
  39. stouputils/data_science/data_processing/image/noise.py +59 -59
  40. stouputils/data_science/data_processing/image/normalize.py +65 -65
  41. stouputils/data_science/data_processing/image/random_erase.py +66 -66
  42. stouputils/data_science/data_processing/image/resize.py +69 -69
  43. stouputils/data_science/data_processing/image/rotation.py +80 -80
  44. stouputils/data_science/data_processing/image/salt_pepper.py +68 -68
  45. stouputils/data_science/data_processing/image/sharpening.py +55 -55
  46. stouputils/data_science/data_processing/image/shearing.py +64 -64
  47. stouputils/data_science/data_processing/image/threshold.py +64 -64
  48. stouputils/data_science/data_processing/image/translation.py +71 -71
  49. stouputils/data_science/data_processing/image/zoom.py +83 -83
  50. stouputils/data_science/data_processing/image_augmentation.py +118 -118
  51. stouputils/data_science/data_processing/image_preprocess.py +183 -183
  52. stouputils/data_science/data_processing/prosthesis_detection.py +359 -359
  53. stouputils/data_science/data_processing/technique.py +481 -481
  54. stouputils/data_science/dataset/__init__.py +45 -45
  55. stouputils/data_science/dataset/dataset.py +292 -292
  56. stouputils/data_science/dataset/dataset_loader.py +135 -135
  57. stouputils/data_science/dataset/grouping_strategy.py +296 -296
  58. stouputils/data_science/dataset/image_loader.py +100 -100
  59. stouputils/data_science/dataset/xy_tuple.py +696 -696
  60. stouputils/data_science/metric_dictionnary.py +106 -106
  61. stouputils/data_science/mlflow_utils.py +206 -206
  62. stouputils/data_science/models/abstract_model.py +149 -149
  63. stouputils/data_science/models/all.py +85 -85
  64. stouputils/data_science/models/keras/all.py +38 -38
  65. stouputils/data_science/models/keras/convnext.py +62 -62
  66. stouputils/data_science/models/keras/densenet.py +50 -50
  67. stouputils/data_science/models/keras/efficientnet.py +60 -60
  68. stouputils/data_science/models/keras/mobilenet.py +56 -56
  69. stouputils/data_science/models/keras/resnet.py +52 -52
  70. stouputils/data_science/models/keras/squeezenet.py +233 -233
  71. stouputils/data_science/models/keras/vgg.py +42 -42
  72. stouputils/data_science/models/keras/xception.py +38 -38
  73. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -20
  74. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -219
  75. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -148
  76. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -31
  77. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -249
  78. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -66
  79. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -12
  80. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -56
  81. stouputils/data_science/models/keras_utils/visualizations.py +416 -416
  82. stouputils/data_science/models/sandbox.py +116 -116
  83. stouputils/data_science/range_tuple.py +234 -234
  84. stouputils/data_science/utils.py +285 -285
  85. stouputils/decorators.pyi +242 -0
  86. stouputils/image.pyi +172 -0
  87. stouputils/installer/__init__.py +18 -18
  88. stouputils/installer/__init__.pyi +5 -0
  89. stouputils/installer/common.pyi +39 -0
  90. stouputils/installer/downloader.pyi +24 -0
  91. stouputils/installer/linux.py +144 -144
  92. stouputils/installer/linux.pyi +39 -0
  93. stouputils/installer/main.py +223 -223
  94. stouputils/installer/main.pyi +57 -0
  95. stouputils/installer/windows.py +136 -136
  96. stouputils/installer/windows.pyi +31 -0
  97. stouputils/io.pyi +213 -0
  98. stouputils/parallel.py +12 -10
  99. stouputils/parallel.pyi +211 -0
  100. stouputils/print.pyi +136 -0
  101. stouputils/py.typed +1 -1
  102. stouputils/stouputils/parallel.pyi +4 -4
  103. stouputils/version_pkg.pyi +15 -0
  104. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/METADATA +1 -1
  105. stouputils-1.14.2.dist-info/RECORD +171 -0
  106. stouputils-1.14.0.dist-info/RECORD +0 -140
  107. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/WHEEL +0 -0
  108. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/entry_points.txt +0 -0
@@ -1,249 +1,249 @@
1
-
2
- # pyright: reportMissingTypeStubs=false
3
-
4
- # Imports
5
- from collections.abc import Callable
6
- from typing import Any
7
-
8
- from keras.callbacks import Callback
9
- from keras.models import Model
10
- from keras.optimizers import Optimizer
11
-
12
-
13
- class ProgressiveUnfreezing(Callback):
14
- """ Callback inspired by the Learning Rate Finder to progressively unfreeze model layers during training.
15
-
16
- Warning: This callback is not compatible with model.fit() as it modifies the trainable state of the model.
17
- Prefer doing your own training loop instead.
18
-
19
- This callback can operate in two modes:
20
- 1. Start with all layers frozen and incrementally unfreeze them from 0% to 100% (progressive_freeze=False)
21
- 2. Start with all layers unfrozen and incrementally freeze them from 100% to 0% (progressive_freeze=True)
22
- """
23
-
24
- def __init__(
25
- self,
26
- base_model: Model,
27
- steps_per_epoch: int,
28
- epochs: int,
29
- reset_weights: bool = False,
30
- reset_optimizer_function: Callable[[], Optimizer] | None = None,
31
- update_per_epoch: bool = True,
32
- update_interval: int = 5,
33
- progressive_freeze: bool = False
34
- ) -> None:
35
- """ Initialize the progressive unfreezing callback.
36
-
37
- Args:
38
- base_model (Model): Base model to unfreeze.
39
- steps_per_epoch (int): Number of steps per epoch.
40
- epochs (int): Total number of epochs.
41
- reset_weights (bool): If True, reset weights after each unfreeze.
42
- reset_optimizer_function (Callable | None):
43
- If set, use this function to reset the optimizer every update_interval.
44
- The function should return a compiled optimizer, e.g. `lambda: model._get_optimizer(AdamW(...))`.
45
- update_per_epoch (bool): If True, unfreeze per epoch, else per batch.
46
- update_interval (int): Number of steps between each unfreeze to allow model to stabilize.
47
- progressive_freeze (bool): If True, start with all layers unfrozen and progressively freeze them.
48
- """
49
- super().__init__()
50
- self.base_model: Model = base_model
51
- """ Base model to unfreeze. """
52
- self.model: Model
53
- """ Model to apply the progressive unfreezing to. """
54
- self.steps_per_epoch: int = int(steps_per_epoch)
55
- """ Number of steps per epoch. """
56
- self.epochs: int = int(epochs)
57
- """ Total number of epochs. """
58
- self.reset_weights: bool = bool(reset_weights)
59
- """ If True, reset weights after each unfreeze. """
60
- self.reset_optimizer_function: Callable[[], Optimizer] | None = reset_optimizer_function
61
- """ If reset_weights is True and this is not None, use this function to get a new optimizer. """
62
- self.update_per_epoch: bool = bool(update_per_epoch)
63
- """ If True, unfreeze per epoch, else per batch. """
64
- self.update_interval: int = max(1, int(update_interval))
65
- """ Number of steps between each unfreeze to allow model to stabilize. """
66
- self.progressive_freeze: bool = bool(progressive_freeze)
67
- """ If True, start with all layers unfrozen and progressively freeze them. """
68
-
69
- # If updating per epoch, remove to self.epochs the update interval to allow the last step to train with 100% unfreeze
70
- if self.update_per_epoch:
71
- self.epochs -= self.update_interval
72
-
73
- # Calculate total steps considering the update interval
74
- total_steps_raw: int = self.epochs if self.update_per_epoch else self.steps_per_epoch * self.epochs
75
- self.total_steps: int = total_steps_raw // self.update_interval
76
- """ Total number of update steps (considering update_interval). """
77
-
78
- self.fraction_unfrozen: list[float] = []
79
- """ Fraction of layers unfrozen. """
80
- self.losses: list[float] = []
81
- """ Losses. """
82
- self._all_layers: list[Any] = []
83
- """ All layers. """
84
- self._initial_trainable: list[bool] = []
85
- """ Initial trainable states. """
86
- self._initial_weights: list[Any] | None = None
87
- """ Initial weights of the model. """
88
- self._last_update_step: int = -1
89
- """ Last step when layers were unfrozen. """
90
- self.params: dict[str, Any]
91
-
92
- def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
93
- """ Set initial layer trainable states at the start of training and store initial states and weights.
94
-
95
- Args:
96
- logs (dict | None): Training logs.
97
- """
98
- # Collect all layers from the model and preserve their original trainable states for potential restoration
99
- self._all_layers = self.base_model.layers
100
- self._initial_trainable = [bool(layer.trainable) for layer in self._all_layers]
101
-
102
- # Store initial weights to reset after each unfreeze
103
- if self.reset_weights:
104
- self._initial_weights = self.model.get_weights()
105
-
106
- # Set initial trainable state based on mode
107
- for layer in self._all_layers:
108
- layer.trainable = self.progressive_freeze # If progressive_freeze, start with all layers unfrozen
109
-
110
- def _update_layers(self, step: int) -> None:
111
- """ Update layer trainable states based on the current step and mode.
112
- Reset weights after each update to prevent bias in the results.
113
-
114
- Args:
115
- step (int): Current training step.
116
- """
117
- # Calculate the effective step considering the update interval
118
- effective_step: int = step // self.update_interval
119
-
120
- # Skip if we haven't reached the next update interval
121
- if effective_step <= self._last_update_step:
122
- return
123
- self._last_update_step = effective_step
124
-
125
- # Calculate the number of layers to unfreeze based on current effective step
126
- n_layers: int = len(self._all_layers)
127
-
128
- if self.progressive_freeze:
129
- # For progressive freezing, start at 1.0 (all unfrozen) and decrease to 0.0
130
- fraction: float = max(0.0, 1.0 - (effective_step + 1) / self.total_steps)
131
- else:
132
- # For progressive unfreezing, start at 0.0 (all frozen) and increase to 1.0
133
- fraction: float = min(1.0, (effective_step + 1) / self.total_steps)
134
-
135
- n_unfreeze: int = int(n_layers * fraction) # Number of layers to keep unfrozen
136
- self.fraction_unfrozen.append(fraction)
137
-
138
- # Set trainable state for each layer based on position
139
- # For both modes, we unfreeze from the top (output layers) to the bottom (input layers)
140
- for i, layer in enumerate(self._all_layers):
141
- layer.trainable = i >= (n_layers - n_unfreeze)
142
-
143
- # Reset weights to initial state to prevent bias and reset optimizer
144
- if self._initial_weights is not None:
145
- self.model.set_weights(self._initial_weights) # pyright: ignore [reportUnknownMemberType]
146
- if self.reset_optimizer_function is not None:
147
- self.model.optimizer = self.reset_optimizer_function()
148
- self.model.optimizer.build(self.model.trainable_variables) # pyright: ignore [reportUnknownMemberType]
149
-
150
- def _track_loss(self, logs: dict[str, Any] | None = None) -> None:
151
- """ Track the current loss.
152
-
153
- Args:
154
- logs (dict | None): Training logs containing loss information.
155
- """
156
- if logs and "loss" in logs:
157
- self.losses.append(logs["loss"])
158
-
159
- def on_batch_begin(self, batch: int, logs: dict[str, Any] | None = None) -> None:
160
- """ Update layer trainable states at the start of each batch if not updating per epoch.
161
-
162
- Args:
163
- batch (int): Current batch index.
164
- logs (dict | None): Training logs.
165
- """
166
- # Skip if we're updating per epoch instead of per batch
167
- if self.update_per_epoch:
168
- return
169
-
170
- # Calculate the current step across all epochs and update layers
171
- step: int = self.params.get("steps", self.steps_per_epoch) * self.params.get("epoch", 0) + batch
172
- self._update_layers(step)
173
-
174
- def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
175
- """ Track loss at the end of each batch if not updating per epoch.
176
-
177
- Args:
178
- batch (int): Current batch index.
179
- logs (dict | None): Training logs.
180
- """
181
- # Skip if we're updating per epoch instead of per batch
182
- if self.update_per_epoch:
183
- return
184
-
185
- # Record the loss if update interval is reached
186
- if batch % self.update_interval == 0:
187
- self._track_loss(logs)
188
-
189
- def on_epoch_begin(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
190
- """ Update layer trainable states at the start of each epoch if updating per epoch.
191
-
192
- Args:
193
- epoch (int): Current epoch index.
194
- logs (dict | None): Training logs.
195
- """
196
- # Skip if we're updating per batch instead of per epoch
197
- if not self.update_per_epoch:
198
- return
199
-
200
- # Update layers based on current epoch
201
- self._update_layers(epoch)
202
-
203
- def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
204
- """ Track loss at the end of each epoch if updating per epoch.
205
-
206
- Args:
207
- epoch (int): Current epoch index.
208
- logs (dict | None): Training logs.
209
- """
210
- # Skip if we're updating per batch instead of per epoch
211
- if not self.update_per_epoch:
212
- return
213
-
214
- # Record the loss if update interval is reached
215
- if epoch % self.update_interval == 0:
216
- self._track_loss(logs)
217
-
218
- def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
219
- """ Restore original trainable states at the end of training.
220
-
221
- Args:
222
- logs (dict | None): Training logs.
223
- """
224
- # Restore each layer's original trainable state
225
- for layer, trainable in zip(self._all_layers, self._initial_trainable, strict=False):
226
- layer.trainable = trainable
227
-
228
- def get_results(self, multiply_by_100: bool = True) -> tuple[list[float], list[float]]:
229
- """ Get the results of the progressive unfreezing from 0% to 100% even if progressive_freeze is True.
230
-
231
- Args:
232
- multiply_by_100 (bool): If True, multiply the fractions by 100 to get percentages.
233
-
234
- Returns:
235
- tuple[list[float], list[float]]: fractions of layers unfrozen, and losses.
236
- """
237
- fractions: list[float] = self.fraction_unfrozen
238
-
239
- # Reverse the order if progressive_freeze is True
240
- if self.progressive_freeze:
241
- fractions = fractions[::-1]
242
-
243
- # Multiply by 100 if requested
244
- if multiply_by_100:
245
- fractions = [x * 100 for x in fractions]
246
-
247
- # Return the results
248
- return fractions, self.losses
249
-
1
+
2
+ # pyright: reportMissingTypeStubs=false
3
+
4
+ # Imports
5
+ from collections.abc import Callable
6
+ from typing import Any
7
+
8
+ from keras.callbacks import Callback
9
+ from keras.models import Model
10
+ from keras.optimizers import Optimizer
11
+
12
+
13
+ class ProgressiveUnfreezing(Callback):
14
+ """ Callback inspired by the Learning Rate Finder to progressively unfreeze model layers during training.
15
+
16
+ Warning: This callback is not compatible with model.fit() as it modifies the trainable state of the model.
17
+ Prefer doing your own training loop instead.
18
+
19
+ This callback can operate in two modes:
20
+ 1. Start with all layers frozen and incrementally unfreeze them from 0% to 100% (progressive_freeze=False)
21
+ 2. Start with all layers unfrozen and incrementally freeze them from 100% to 0% (progressive_freeze=True)
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ base_model: Model,
27
+ steps_per_epoch: int,
28
+ epochs: int,
29
+ reset_weights: bool = False,
30
+ reset_optimizer_function: Callable[[], Optimizer] | None = None,
31
+ update_per_epoch: bool = True,
32
+ update_interval: int = 5,
33
+ progressive_freeze: bool = False
34
+ ) -> None:
35
+ """ Initialize the progressive unfreezing callback.
36
+
37
+ Args:
38
+ base_model (Model): Base model to unfreeze.
39
+ steps_per_epoch (int): Number of steps per epoch.
40
+ epochs (int): Total number of epochs.
41
+ reset_weights (bool): If True, reset weights after each unfreeze.
42
+ reset_optimizer_function (Callable | None):
43
+ If set, use this function to reset the optimizer every update_interval.
44
+ The function should return a compiled optimizer, e.g. `lambda: model._get_optimizer(AdamW(...))`.
45
+ update_per_epoch (bool): If True, unfreeze per epoch, else per batch.
46
+ update_interval (int): Number of steps between each unfreeze to allow model to stabilize.
47
+ progressive_freeze (bool): If True, start with all layers unfrozen and progressively freeze them.
48
+ """
49
+ super().__init__()
50
+ self.base_model: Model = base_model
51
+ """ Base model to unfreeze. """
52
+ self.model: Model
53
+ """ Model to apply the progressive unfreezing to. """
54
+ self.steps_per_epoch: int = int(steps_per_epoch)
55
+ """ Number of steps per epoch. """
56
+ self.epochs: int = int(epochs)
57
+ """ Total number of epochs. """
58
+ self.reset_weights: bool = bool(reset_weights)
59
+ """ If True, reset weights after each unfreeze. """
60
+ self.reset_optimizer_function: Callable[[], Optimizer] | None = reset_optimizer_function
61
+ """ If reset_weights is True and this is not None, use this function to get a new optimizer. """
62
+ self.update_per_epoch: bool = bool(update_per_epoch)
63
+ """ If True, unfreeze per epoch, else per batch. """
64
+ self.update_interval: int = max(1, int(update_interval))
65
+ """ Number of steps between each unfreeze to allow model to stabilize. """
66
+ self.progressive_freeze: bool = bool(progressive_freeze)
67
+ """ If True, start with all layers unfrozen and progressively freeze them. """
68
+
69
+ # If updating per epoch, remove to self.epochs the update interval to allow the last step to train with 100% unfreeze
70
+ if self.update_per_epoch:
71
+ self.epochs -= self.update_interval
72
+
73
+ # Calculate total steps considering the update interval
74
+ total_steps_raw: int = self.epochs if self.update_per_epoch else self.steps_per_epoch * self.epochs
75
+ self.total_steps: int = total_steps_raw // self.update_interval
76
+ """ Total number of update steps (considering update_interval). """
77
+
78
+ self.fraction_unfrozen: list[float] = []
79
+ """ Fraction of layers unfrozen. """
80
+ self.losses: list[float] = []
81
+ """ Losses. """
82
+ self._all_layers: list[Any] = []
83
+ """ All layers. """
84
+ self._initial_trainable: list[bool] = []
85
+ """ Initial trainable states. """
86
+ self._initial_weights: list[Any] | None = None
87
+ """ Initial weights of the model. """
88
+ self._last_update_step: int = -1
89
+ """ Last step when layers were unfrozen. """
90
+ self.params: dict[str, Any]
91
+
92
+ def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
93
+ """ Set initial layer trainable states at the start of training and store initial states and weights.
94
+
95
+ Args:
96
+ logs (dict | None): Training logs.
97
+ """
98
+ # Collect all layers from the model and preserve their original trainable states for potential restoration
99
+ self._all_layers = self.base_model.layers
100
+ self._initial_trainable = [bool(layer.trainable) for layer in self._all_layers]
101
+
102
+ # Store initial weights to reset after each unfreeze
103
+ if self.reset_weights:
104
+ self._initial_weights = self.model.get_weights()
105
+
106
+ # Set initial trainable state based on mode
107
+ for layer in self._all_layers:
108
+ layer.trainable = self.progressive_freeze # If progressive_freeze, start with all layers unfrozen
109
+
110
+ def _update_layers(self, step: int) -> None:
111
+ """ Update layer trainable states based on the current step and mode.
112
+ Reset weights after each update to prevent bias in the results.
113
+
114
+ Args:
115
+ step (int): Current training step.
116
+ """
117
+ # Calculate the effective step considering the update interval
118
+ effective_step: int = step // self.update_interval
119
+
120
+ # Skip if we haven't reached the next update interval
121
+ if effective_step <= self._last_update_step:
122
+ return
123
+ self._last_update_step = effective_step
124
+
125
+ # Calculate the number of layers to unfreeze based on current effective step
126
+ n_layers: int = len(self._all_layers)
127
+
128
+ if self.progressive_freeze:
129
+ # For progressive freezing, start at 1.0 (all unfrozen) and decrease to 0.0
130
+ fraction: float = max(0.0, 1.0 - (effective_step + 1) / self.total_steps)
131
+ else:
132
+ # For progressive unfreezing, start at 0.0 (all frozen) and increase to 1.0
133
+ fraction: float = min(1.0, (effective_step + 1) / self.total_steps)
134
+
135
+ n_unfreeze: int = int(n_layers * fraction) # Number of layers to keep unfrozen
136
+ self.fraction_unfrozen.append(fraction)
137
+
138
+ # Set trainable state for each layer based on position
139
+ # For both modes, we unfreeze from the top (output layers) to the bottom (input layers)
140
+ for i, layer in enumerate(self._all_layers):
141
+ layer.trainable = i >= (n_layers - n_unfreeze)
142
+
143
+ # Reset weights to initial state to prevent bias and reset optimizer
144
+ if self._initial_weights is not None:
145
+ self.model.set_weights(self._initial_weights) # pyright: ignore [reportUnknownMemberType]
146
+ if self.reset_optimizer_function is not None:
147
+ self.model.optimizer = self.reset_optimizer_function()
148
+ self.model.optimizer.build(self.model.trainable_variables) # pyright: ignore [reportUnknownMemberType]
149
+
150
+ def _track_loss(self, logs: dict[str, Any] | None = None) -> None:
151
+ """ Track the current loss.
152
+
153
+ Args:
154
+ logs (dict | None): Training logs containing loss information.
155
+ """
156
+ if logs and "loss" in logs:
157
+ self.losses.append(logs["loss"])
158
+
159
+ def on_batch_begin(self, batch: int, logs: dict[str, Any] | None = None) -> None:
160
+ """ Update layer trainable states at the start of each batch if not updating per epoch.
161
+
162
+ Args:
163
+ batch (int): Current batch index.
164
+ logs (dict | None): Training logs.
165
+ """
166
+ # Skip if we're updating per epoch instead of per batch
167
+ if self.update_per_epoch:
168
+ return
169
+
170
+ # Calculate the current step across all epochs and update layers
171
+ step: int = self.params.get("steps", self.steps_per_epoch) * self.params.get("epoch", 0) + batch
172
+ self._update_layers(step)
173
+
174
+ def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
175
+ """ Track loss at the end of each batch if not updating per epoch.
176
+
177
+ Args:
178
+ batch (int): Current batch index.
179
+ logs (dict | None): Training logs.
180
+ """
181
+ # Skip if we're updating per epoch instead of per batch
182
+ if self.update_per_epoch:
183
+ return
184
+
185
+ # Record the loss if update interval is reached
186
+ if batch % self.update_interval == 0:
187
+ self._track_loss(logs)
188
+
189
+ def on_epoch_begin(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
190
+ """ Update layer trainable states at the start of each epoch if updating per epoch.
191
+
192
+ Args:
193
+ epoch (int): Current epoch index.
194
+ logs (dict | None): Training logs.
195
+ """
196
+ # Skip if we're updating per batch instead of per epoch
197
+ if not self.update_per_epoch:
198
+ return
199
+
200
+ # Update layers based on current epoch
201
+ self._update_layers(epoch)
202
+
203
+ def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
204
+ """ Track loss at the end of each epoch if updating per epoch.
205
+
206
+ Args:
207
+ epoch (int): Current epoch index.
208
+ logs (dict | None): Training logs.
209
+ """
210
+ # Skip if we're updating per batch instead of per epoch
211
+ if not self.update_per_epoch:
212
+ return
213
+
214
+ # Record the loss if update interval is reached
215
+ if epoch % self.update_interval == 0:
216
+ self._track_loss(logs)
217
+
218
+ def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
219
+ """ Restore original trainable states at the end of training.
220
+
221
+ Args:
222
+ logs (dict | None): Training logs.
223
+ """
224
+ # Restore each layer's original trainable state
225
+ for layer, trainable in zip(self._all_layers, self._initial_trainable, strict=False):
226
+ layer.trainable = trainable
227
+
228
+ def get_results(self, multiply_by_100: bool = True) -> tuple[list[float], list[float]]:
229
+ """ Get the results of the progressive unfreezing from 0% to 100% even if progressive_freeze is True.
230
+
231
+ Args:
232
+ multiply_by_100 (bool): If True, multiply the fractions by 100 to get percentages.
233
+
234
+ Returns:
235
+ tuple[list[float], list[float]]: fractions of layers unfrozen, and losses.
236
+ """
237
+ fractions: list[float] = self.fraction_unfrozen
238
+
239
+ # Reverse the order if progressive_freeze is True
240
+ if self.progressive_freeze:
241
+ fractions = fractions[::-1]
242
+
243
+ # Multiply by 100 if requested
244
+ if multiply_by_100:
245
+ fractions = [x * 100 for x in fractions]
246
+
247
+ # Return the results
248
+ return fractions, self.losses
249
+
@@ -1,66 +1,66 @@
1
-
2
- # pyright: reportMissingTypeStubs=false
3
-
4
- # Imports
5
- from typing import Any
6
-
7
- import tensorflow as tf
8
- from keras.callbacks import Callback
9
- from keras.models import Model
10
-
11
-
12
- class WarmupScheduler(Callback):
13
- """ Keras Callback for learning rate warmup.
14
-
15
- Sources:
16
- - Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour: https://arxiv.org/abs/1706.02677
17
- - Attention Is All You Need: https://arxiv.org/abs/1706.03762
18
-
19
- This callback implements a learning rate warmup strategy where the learning rate
20
- gradually increases from an initial value to a target value over a specified
21
- number of epochs. This helps stabilize training in the early stages.
22
-
23
- The learning rate increases linearly from the initial value to the target value
24
- over the warmup period, and then remains at the target value.
25
- """
26
-
27
- def __init__(self, warmup_epochs: int, initial_lr: float, target_lr: float) -> None:
28
- """ Initialize the warmup scheduler.
29
-
30
- Args:
31
- warmup_epochs (int): Number of epochs for warmup.
32
- initial_lr (float): Starting learning rate for warmup.
33
- target_lr (float): Target learning rate after warmup.
34
- """
35
- super().__init__()
36
- self.warmup_epochs: int = warmup_epochs
37
- """ Number of epochs for warmup. """
38
- self.initial_lr: float = initial_lr
39
- """ Starting learning rate for warmup. """
40
- self.target_lr: float = target_lr
41
- """ Target learning rate after warmup. """
42
- self.model: Model
43
- """ Model to apply the warmup scheduler to. """
44
-
45
- # Pre-compute learning rates for each epoch to avoid calculations during training
46
- self.epoch_learning_rates: list[float] = []
47
- for epoch in range(warmup_epochs + 1):
48
- if epoch < warmup_epochs:
49
- lr = initial_lr + (target_lr - initial_lr) * (epoch + 1) / warmup_epochs
50
- else:
51
- lr = target_lr
52
- self.epoch_learning_rates.append(lr)
53
-
54
- def on_epoch_begin(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
55
- """ Adjust learning rate at the beginning of each epoch during warmup.
56
-
57
- Args:
58
- epoch (int): Current epoch index.
59
- logs (dict | None): Training logs.
60
- """
61
- if self.warmup_epochs <= 0 or epoch > self.warmup_epochs:
62
- return
63
-
64
- # Use pre-computed learning rate to avoid calculations during training
65
- tf.keras.backend.set_value(self.model.optimizer.learning_rate, self.epoch_learning_rates[epoch]) # type: ignore
66
-
1
+
2
+ # pyright: reportMissingTypeStubs=false
3
+
4
+ # Imports
5
+ from typing import Any
6
+
7
+ import tensorflow as tf
8
+ from keras.callbacks import Callback
9
+ from keras.models import Model
10
+
11
+
12
+ class WarmupScheduler(Callback):
13
+ """ Keras Callback for learning rate warmup.
14
+
15
+ Sources:
16
+ - Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour: https://arxiv.org/abs/1706.02677
17
+ - Attention Is All You Need: https://arxiv.org/abs/1706.03762
18
+
19
+ This callback implements a learning rate warmup strategy where the learning rate
20
+ gradually increases from an initial value to a target value over a specified
21
+ number of epochs. This helps stabilize training in the early stages.
22
+
23
+ The learning rate increases linearly from the initial value to the target value
24
+ over the warmup period, and then remains at the target value.
25
+ """
26
+
27
+ def __init__(self, warmup_epochs: int, initial_lr: float, target_lr: float) -> None:
28
+ """ Initialize the warmup scheduler.
29
+
30
+ Args:
31
+ warmup_epochs (int): Number of epochs for warmup.
32
+ initial_lr (float): Starting learning rate for warmup.
33
+ target_lr (float): Target learning rate after warmup.
34
+ """
35
+ super().__init__()
36
+ self.warmup_epochs: int = warmup_epochs
37
+ """ Number of epochs for warmup. """
38
+ self.initial_lr: float = initial_lr
39
+ """ Starting learning rate for warmup. """
40
+ self.target_lr: float = target_lr
41
+ """ Target learning rate after warmup. """
42
+ self.model: Model
43
+ """ Model to apply the warmup scheduler to. """
44
+
45
+ # Pre-compute learning rates for each epoch to avoid calculations during training
46
+ self.epoch_learning_rates: list[float] = []
47
+ for epoch in range(warmup_epochs + 1):
48
+ if epoch < warmup_epochs:
49
+ lr = initial_lr + (target_lr - initial_lr) * (epoch + 1) / warmup_epochs
50
+ else:
51
+ lr = target_lr
52
+ self.epoch_learning_rates.append(lr)
53
+
54
+ def on_epoch_begin(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
55
+ """ Adjust learning rate at the beginning of each epoch during warmup.
56
+
57
+ Args:
58
+ epoch (int): Current epoch index.
59
+ logs (dict | None): Training logs.
60
+ """
61
+ if self.warmup_epochs <= 0 or epoch > self.warmup_epochs:
62
+ return
63
+
64
+ # Use pre-computed learning rate to avoid calculations during training
65
+ tf.keras.backend.set_value(self.model.optimizer.learning_rate, self.epoch_learning_rates[epoch]) # type: ignore
66
+
@@ -1,12 +1,12 @@
1
- """ Custom losses for Keras models.
2
-
3
- Features:
4
-
5
- - Next Generation Loss
6
- """
7
-
8
- # Imports
9
- from .next_generation_loss import NextGenerationLoss
10
-
11
- __all__ = ["NextGenerationLoss"]
12
-
1
+ """ Custom losses for Keras models.
2
+
3
+ Features:
4
+
5
+ - Next Generation Loss
6
+ """
7
+
8
+ # Imports
9
+ from .next_generation_loss import NextGenerationLoss
10
+
11
+ __all__ = ["NextGenerationLoss"]
12
+