stouputils 1.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. stouputils/__init__.py +40 -0
  2. stouputils/__init__.pyi +14 -0
  3. stouputils/__main__.py +81 -0
  4. stouputils/_deprecated.py +37 -0
  5. stouputils/_deprecated.pyi +12 -0
  6. stouputils/all_doctests.py +160 -0
  7. stouputils/all_doctests.pyi +46 -0
  8. stouputils/applications/__init__.py +22 -0
  9. stouputils/applications/__init__.pyi +2 -0
  10. stouputils/applications/automatic_docs.py +634 -0
  11. stouputils/applications/automatic_docs.pyi +106 -0
  12. stouputils/applications/upscaler/__init__.py +39 -0
  13. stouputils/applications/upscaler/__init__.pyi +3 -0
  14. stouputils/applications/upscaler/config.py +128 -0
  15. stouputils/applications/upscaler/config.pyi +18 -0
  16. stouputils/applications/upscaler/image.py +247 -0
  17. stouputils/applications/upscaler/image.pyi +109 -0
  18. stouputils/applications/upscaler/video.py +287 -0
  19. stouputils/applications/upscaler/video.pyi +60 -0
  20. stouputils/archive.py +344 -0
  21. stouputils/archive.pyi +67 -0
  22. stouputils/backup.py +488 -0
  23. stouputils/backup.pyi +109 -0
  24. stouputils/collections.py +244 -0
  25. stouputils/collections.pyi +86 -0
  26. stouputils/continuous_delivery/__init__.py +27 -0
  27. stouputils/continuous_delivery/__init__.pyi +5 -0
  28. stouputils/continuous_delivery/cd_utils.py +243 -0
  29. stouputils/continuous_delivery/cd_utils.pyi +129 -0
  30. stouputils/continuous_delivery/github.py +522 -0
  31. stouputils/continuous_delivery/github.pyi +162 -0
  32. stouputils/continuous_delivery/pypi.py +91 -0
  33. stouputils/continuous_delivery/pypi.pyi +43 -0
  34. stouputils/continuous_delivery/pyproject.py +147 -0
  35. stouputils/continuous_delivery/pyproject.pyi +67 -0
  36. stouputils/continuous_delivery/stubs.py +86 -0
  37. stouputils/continuous_delivery/stubs.pyi +39 -0
  38. stouputils/ctx.py +408 -0
  39. stouputils/ctx.pyi +211 -0
  40. stouputils/data_science/config/get.py +51 -0
  41. stouputils/data_science/config/set.py +125 -0
  42. stouputils/data_science/data_processing/image/__init__.py +66 -0
  43. stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
  44. stouputils/data_science/data_processing/image/axis_flip.py +58 -0
  45. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
  46. stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
  47. stouputils/data_science/data_processing/image/blur.py +59 -0
  48. stouputils/data_science/data_processing/image/brightness.py +54 -0
  49. stouputils/data_science/data_processing/image/canny.py +110 -0
  50. stouputils/data_science/data_processing/image/clahe.py +92 -0
  51. stouputils/data_science/data_processing/image/common.py +30 -0
  52. stouputils/data_science/data_processing/image/contrast.py +53 -0
  53. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
  54. stouputils/data_science/data_processing/image/denoise.py +378 -0
  55. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
  56. stouputils/data_science/data_processing/image/invert.py +64 -0
  57. stouputils/data_science/data_processing/image/laplacian.py +60 -0
  58. stouputils/data_science/data_processing/image/median_blur.py +52 -0
  59. stouputils/data_science/data_processing/image/noise.py +59 -0
  60. stouputils/data_science/data_processing/image/normalize.py +65 -0
  61. stouputils/data_science/data_processing/image/random_erase.py +66 -0
  62. stouputils/data_science/data_processing/image/resize.py +69 -0
  63. stouputils/data_science/data_processing/image/rotation.py +80 -0
  64. stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
  65. stouputils/data_science/data_processing/image/sharpening.py +55 -0
  66. stouputils/data_science/data_processing/image/shearing.py +64 -0
  67. stouputils/data_science/data_processing/image/threshold.py +64 -0
  68. stouputils/data_science/data_processing/image/translation.py +71 -0
  69. stouputils/data_science/data_processing/image/zoom.py +83 -0
  70. stouputils/data_science/data_processing/image_augmentation.py +118 -0
  71. stouputils/data_science/data_processing/image_preprocess.py +183 -0
  72. stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
  73. stouputils/data_science/data_processing/technique.py +481 -0
  74. stouputils/data_science/dataset/__init__.py +45 -0
  75. stouputils/data_science/dataset/dataset.py +292 -0
  76. stouputils/data_science/dataset/dataset_loader.py +135 -0
  77. stouputils/data_science/dataset/grouping_strategy.py +296 -0
  78. stouputils/data_science/dataset/image_loader.py +100 -0
  79. stouputils/data_science/dataset/xy_tuple.py +696 -0
  80. stouputils/data_science/metric_dictionnary.py +106 -0
  81. stouputils/data_science/metric_utils.py +847 -0
  82. stouputils/data_science/mlflow_utils.py +206 -0
  83. stouputils/data_science/models/abstract_model.py +149 -0
  84. stouputils/data_science/models/all.py +85 -0
  85. stouputils/data_science/models/base_keras.py +765 -0
  86. stouputils/data_science/models/keras/all.py +38 -0
  87. stouputils/data_science/models/keras/convnext.py +62 -0
  88. stouputils/data_science/models/keras/densenet.py +50 -0
  89. stouputils/data_science/models/keras/efficientnet.py +60 -0
  90. stouputils/data_science/models/keras/mobilenet.py +56 -0
  91. stouputils/data_science/models/keras/resnet.py +52 -0
  92. stouputils/data_science/models/keras/squeezenet.py +233 -0
  93. stouputils/data_science/models/keras/vgg.py +42 -0
  94. stouputils/data_science/models/keras/xception.py +38 -0
  95. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
  96. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
  97. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
  98. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
  99. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
  100. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
  101. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
  102. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
  103. stouputils/data_science/models/keras_utils/visualizations.py +416 -0
  104. stouputils/data_science/models/model_interface.py +939 -0
  105. stouputils/data_science/models/sandbox.py +116 -0
  106. stouputils/data_science/range_tuple.py +234 -0
  107. stouputils/data_science/scripts/augment_dataset.py +77 -0
  108. stouputils/data_science/scripts/exhaustive_process.py +133 -0
  109. stouputils/data_science/scripts/preprocess_dataset.py +70 -0
  110. stouputils/data_science/scripts/routine.py +168 -0
  111. stouputils/data_science/utils.py +285 -0
  112. stouputils/decorators.py +595 -0
  113. stouputils/decorators.pyi +242 -0
  114. stouputils/image.py +441 -0
  115. stouputils/image.pyi +172 -0
  116. stouputils/installer/__init__.py +18 -0
  117. stouputils/installer/__init__.pyi +5 -0
  118. stouputils/installer/common.py +67 -0
  119. stouputils/installer/common.pyi +39 -0
  120. stouputils/installer/downloader.py +101 -0
  121. stouputils/installer/downloader.pyi +24 -0
  122. stouputils/installer/linux.py +144 -0
  123. stouputils/installer/linux.pyi +39 -0
  124. stouputils/installer/main.py +223 -0
  125. stouputils/installer/main.pyi +57 -0
  126. stouputils/installer/windows.py +136 -0
  127. stouputils/installer/windows.pyi +31 -0
  128. stouputils/io.py +486 -0
  129. stouputils/io.pyi +213 -0
  130. stouputils/parallel.py +453 -0
  131. stouputils/parallel.pyi +211 -0
  132. stouputils/print.py +527 -0
  133. stouputils/print.pyi +146 -0
  134. stouputils/py.typed +1 -0
  135. stouputils-1.12.1.dist-info/METADATA +179 -0
  136. stouputils-1.12.1.dist-info/RECORD +138 -0
  137. stouputils-1.12.1.dist-info/WHEEL +4 -0
  138. stouputils-1.12.1.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,219 @@
1
+
2
+ # pyright: reportMissingTypeStubs=false
3
+
4
+ # Imports
5
+ from typing import Any
6
+
7
+ import tensorflow as tf
8
+ from keras.callbacks import Callback
9
+ from keras.models import Model
10
+ from tqdm.auto import tqdm
11
+
12
+ from .....print import MAGENTA
13
+ from .....parallel import BAR_FORMAT
14
+
15
+ class ColoredProgressBar(Callback):
16
+ """ Progress bar using tqdm for Keras training.
17
+
18
+ A callback that displays a progress bar using tqdm during model training.
19
+ Shows the training progress across steps with a customized format
20
+ instead of the default Keras one showing multiple lines.
21
+ """
22
+ def __init__(
23
+ self,
24
+ desc: str = "Training",
25
+ track_epochs: bool = False,
26
+ show_lr: bool = False,
27
+ update_frequency: int = 1,
28
+ color: str = MAGENTA
29
+ ) -> None:
30
+ """ Initialize the progress bar callback.
31
+
32
+ Args:
33
+ desc (str): Custom description for the progress bar.
34
+ track_epochs (bool): Whether to track epochs instead of batches.
35
+ show_lr (bool): Whether to show the learning rate.
36
+ update_frequency (int): How often to update the progress bar (every N batches).
37
+ color (str): Color of the progress bar.
38
+ """
39
+ super().__init__()
40
+ self.desc: str = desc
41
+ """ Description of the progress bar. """
42
+ self.track_epochs: bool = track_epochs
43
+ """ Whether to track epochs instead of batches. """
44
+ self.show_lr: bool = show_lr
45
+ """ Whether to show the learning rate. """
46
+ self.latest_val_loss: float = 0.0
47
+ """ Latest validation loss, updated at the end of each epoch. """
48
+ self.latest_lr: float = 0.0
49
+ """ Latest learning rate, updated during batch and epoch processing. """
50
+ self.batch_count: int = 0
51
+ """ Counter to update the progress bar less frequently. """
52
+ self.update_frequency: int = max(1, update_frequency) # Ensure frequency is at least 1
53
+ """ How often to update the progress bar (every N batches). """
54
+ self.color: str = color
55
+ """ Color of the progress bar. """
56
+ self.pbar: tqdm[Any] | None = None
57
+ """ The tqdm progress bar instance. """
58
+ self.epochs: int = 0
59
+ """ Total number of epochs. """
60
+ self.steps: int = 0
61
+ """ Number of steps per epoch. """
62
+ self.total: int = 0
63
+ """ Total number of steps/epochs to track. """
64
+ self.params: dict[str, Any]
65
+ """ Training parameters. """
66
+ self.model: Model
67
+
68
+ def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
69
+ """ Initialize the progress bar at the start of training.
70
+
71
+ Args:
72
+ logs (dict | None): Training logs.
73
+ """
74
+ # Get training parameters
75
+ self.epochs = self.params.get("epochs", 0)
76
+ self.steps = self.params.get("steps", 0)
77
+
78
+ # Determine total units and initial description
79
+ if self.track_epochs:
80
+ desc: str = f"{self.color}{self.desc} (Epochs)"
81
+ self.total = self.epochs
82
+ else:
83
+ desc: str = f"{self.color}{self.desc} (Epoch 1/{self.epochs})"
84
+ self.total = self.epochs * self.steps
85
+
86
+ # Initialize tqdm bar
87
+ self.pbar = tqdm(
88
+ total=self.total,
89
+ desc=desc,
90
+ position=0,
91
+ leave=True,
92
+ bar_format=BAR_FORMAT,
93
+ ascii=False
94
+ )
95
+ # Reset state variables
96
+ self.latest_val_loss = 0.0
97
+ self.latest_lr = 0.0
98
+ self.batch_count = 0
99
+
100
+ def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
101
+ """ Update the progress bar after each batch, based on update frequency.
102
+
103
+ Args:
104
+ batch (int): Current batch number (0-indexed).
105
+ logs (dict | None): Dictionary of logs containing metrics for the batch.
106
+ """
107
+ # Skip updates if tracking epochs, pbar isn't initialized, or steps are unknown
108
+ if self.track_epochs or self.pbar is None or self.steps == 0:
109
+ return
110
+
111
+ self.batch_count += 1
112
+ is_last_batch: bool = (batch + 1) == self.steps
113
+
114
+ # Update only every `update_frequency` batches or on the last batch
115
+ if self.batch_count % self.update_frequency == 0 or is_last_batch:
116
+ increment: int = self.batch_count
117
+ self.batch_count = 0 # Reset counter
118
+
119
+ # Calculate current epoch (1-based) based on the progress bar's state *before* this update
120
+ # Ensure epoch doesn't exceed total epochs in description
121
+ current_epoch: int = min(self.epochs, self.pbar.n // self.steps + 1)
122
+ current_step: int = batch + 1 # Convert to 1-indexed for display
123
+ self.pbar.set_description(
124
+ f"{self.color}{self.desc} (Epoch {current_epoch}/{self.epochs}, Step {current_step}/{self.steps})"
125
+ )
126
+
127
+ # Update learning rate if model and optimizer are available
128
+ if self.model and hasattr(self.model, "optimizer"):
129
+ self.latest_lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
130
+
131
+ # Update postfix with batch loss and the latest known validation loss
132
+ if logs and "loss" in logs:
133
+ loss: float = logs["loss"]
134
+ val_loss_str: str = ""
135
+ if self.latest_val_loss != 0.0:
136
+ if self.latest_val_loss < 1e-3:
137
+ val_loss_str = f", val_loss: {self.latest_val_loss:.2e}"
138
+ else:
139
+ val_loss_str = f", val_loss: {self.latest_val_loss:.5f}"
140
+
141
+ # Format learning rate string
142
+ lr_str: str = ""
143
+ if self.show_lr and self.latest_lr != 0.0:
144
+ if self.latest_lr < 1e-3:
145
+ lr_str = f", lr: {self.latest_lr:.2e}"
146
+ else:
147
+ lr_str = f", lr: {self.latest_lr:.5f}"
148
+
149
+ if loss < 1e-3:
150
+ self.pbar.set_postfix_str(f"loss: {loss:.2e}{val_loss_str}{lr_str}", refresh=False)
151
+ else:
152
+ self.pbar.set_postfix_str(f"loss: {loss:.5f}{val_loss_str}{lr_str}", refresh=False)
153
+
154
+ # Update progress bar position, ensuring not to exceed total
155
+ actual_increment: int = min(increment, self.total - self.pbar.n)
156
+ if actual_increment > 0:
157
+ self.pbar.update(actual_increment)
158
+
159
+ def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
160
+ """ Update metrics and progress bar position at the end of each epoch.
161
+
162
+ Args:
163
+ epoch (int): Current epoch number (0-indexed).
164
+ logs (dict | None): Dictionary of logs containing metrics for the epoch.
165
+ """
166
+ if self.pbar is None:
167
+ return
168
+
169
+ # Update the latest validation loss from epoch logs
170
+ if logs:
171
+ current_val_loss: float = logs.get("val_loss", 0.0)
172
+ if current_val_loss != 0.0:
173
+ self.latest_val_loss = current_val_loss
174
+
175
+ # Update learning rate if model and optimizer are available
176
+ if self.model and hasattr(self.model, "optimizer"):
177
+ self.latest_lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
178
+
179
+ # Update postfix string with final epoch metrics
180
+ loss: float = logs.get("loss", 0.0)
181
+ val_loss_str: str = f", val_loss: {self.latest_val_loss:.5f}" if self.latest_val_loss != 0.0 else ""
182
+
183
+ # Format learning rate string
184
+ lr_str: str = ""
185
+ if self.show_lr and self.latest_lr != 0.0:
186
+ if self.latest_lr < 1e-3:
187
+ lr_str = f", lr: {self.latest_lr:.2e}"
188
+ else:
189
+ lr_str = f", lr: {self.latest_lr:.5f}"
190
+
191
+ if loss != 0.0: # Only update if loss is available
192
+ self.pbar.set_postfix_str(f"loss: {loss:.5f}{val_loss_str}{lr_str}", refresh=True)
193
+
194
+ # Update progress bar position
195
+ if self.track_epochs:
196
+ # Increment by 1 epoch if not already at total
197
+ if self.pbar.n < self.total:
198
+ self.pbar.update(1)
199
+ else:
200
+ # Ensure the progress bar is exactly at the end of the current epoch
201
+ expected_position: int = min(self.total, (epoch + 1) * self.steps)
202
+ increment: int = expected_position - self.pbar.n
203
+ if increment > 0:
204
+ self.pbar.update(increment)
205
+
206
+ def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
207
+ """ Close the progress bar when training is complete.
208
+
209
+ Args:
210
+ logs (dict | None): Training logs.
211
+ """
212
+ if self.pbar is not None:
213
+ # Ensure the bar reaches 100%
214
+ increment: int = self.total - self.pbar.n
215
+ if increment > 0:
216
+ self.pbar.update(increment)
217
+ self.pbar.close()
218
+ self.pbar = None # Reset pbar instance
219
+
@@ -0,0 +1,148 @@
1
+
2
+ # pyright: reportMissingTypeStubs=false
3
+
4
+ # Imports
5
+ from typing import Any
6
+
7
+ import tensorflow as tf
8
+ from keras.callbacks import Callback
9
+ from keras.models import Model
10
+
11
+
12
+ class LearningRateFinder(Callback):
13
+ """ Callback to find optimal learning rate by increasing LR during training.
14
+
15
+ Sources:
16
+ - Inspired by: https://github.com/WittmannF/LRFinder
17
+ - Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 (first description of the method)
18
+
19
+ This callback gradually increases the learning rate from a minimum to a maximum value
20
+ during training, allowing you to identify the optimal learning rate range for your model.
21
+
22
+ It works by:
23
+
24
+ 1. Starting with a very small learning rate
25
+ 2. Exponentially increasing it after each batch or epoch
26
+ 3. Recording the loss at each learning rate
27
+ 4. Restoring the model's initial weights after training
28
+
29
+ The optimal learning rate is typically found where the loss is decreasing most rapidly
30
+ before it starts to diverge.
31
+
32
+ .. image:: https://blog.dataiku.com/hubfs/training%20loss.png
33
+ :alt: Learning rate finder curve example
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ min_lr: float,
39
+ max_lr: float,
40
+ steps_per_epoch: int,
41
+ epochs: int,
42
+ update_per_epoch: bool = False,
43
+ update_interval: int = 5
44
+ ) -> None:
45
+ """ Initialize the learning rate finder.
46
+
47
+ Args:
48
+ min_lr (float): Minimum learning rate
49
+ max_lr (float): Maximum learning rate
50
+ steps_per_epoch (int): Steps per epoch
51
+ epochs (int): Number of epochs
52
+ update_per_epoch (bool): If True, update LR once per epoch instead of every batch.
53
+ update_interval (int): Number of steps between each lr increase, bigger value means more stable loss.
54
+ """
55
+ super().__init__()
56
+ self.min_lr: float = min_lr
57
+ """ Minimum learning rate. """
58
+ self.max_lr: float = max_lr
59
+ """ Maximum learning rate. """
60
+ self.total_updates: int = (epochs if update_per_epoch else steps_per_epoch * epochs) // update_interval
61
+ """ Total number of update steps (considering update_interval). """
62
+ self.update_per_epoch: bool = update_per_epoch
63
+ """ Whether to update learning rate per epoch instead of per batch. """
64
+ self.update_interval: int = max(1, int(update_interval))
65
+ """ Number of steps between each lr increase, bigger value means more stable loss. """
66
+ self.lr_mult: float = (max_lr / min_lr) ** (1 / self.total_updates)
67
+ """ Learning rate multiplier. """
68
+ self.learning_rates: list[float] = []
69
+ """ List of learning rates. """
70
+ self.losses: list[float] = []
71
+ """ List of losses. """
72
+ self.best_lr: float = min_lr
73
+ """ Best learning rate. """
74
+ self.best_loss: float = float("inf")
75
+ """ Best loss. """
76
+ self.model: Model
77
+ """ Model to apply the learning rate finder to. """
78
+ self.initial_weights: list[Any] | None = None
79
+ """ Stores the initial weights of the model. """
80
+
81
+ def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
82
+ """ Set initial learning rate and save initial model weights at the start of training.
83
+
84
+ Args:
85
+ logs (dict | None): Training logs.
86
+ """
87
+ self.initial_weights = self.model.get_weights()
88
+ tf.keras.backend.set_value(self.model.optimizer.learning_rate, self.min_lr) # type: ignore
89
+
90
+ def _update_lr_and_track_metrics(self, logs: dict[str, Any] | None = None) -> None:
91
+ """ Update learning rate and track metrics.
92
+
93
+ Args:
94
+ logs (dict | None): Logs from training
95
+ """
96
+ if logs is None:
97
+ return
98
+
99
+ # Get current learning rate and loss
100
+ current_lr: float = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
101
+ current_loss: float = logs["loss"]
102
+
103
+ # Record values
104
+ self.learning_rates.append(current_lr)
105
+ self.losses.append(current_loss)
106
+
107
+ # Track best values
108
+ if current_loss < self.best_loss:
109
+ self.best_loss = current_loss
110
+ self.best_lr = current_lr
111
+
112
+ # Update learning rate
113
+ new_lr: float = current_lr * self.lr_mult
114
+ tf.keras.backend.set_value(self.model.optimizer.learning_rate, new_lr) # type: ignore
115
+
116
+ def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
117
+ """ Record loss and increase learning rate after each batch if not updating per epoch.
118
+
119
+ Args:
120
+ batch (int): Current batch index.
121
+ logs (dict | None): Training logs.
122
+ """
123
+ if self.update_per_epoch:
124
+ return
125
+ if batch % self.update_interval == 0:
126
+ self._update_lr_and_track_metrics(logs)
127
+
128
+ def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
129
+ """ Record loss and increase learning rate after each epoch if updating per epoch.
130
+
131
+ Args:
132
+ epoch (int): Current epoch index.
133
+ logs (dict | None): Training logs.
134
+ """
135
+ if not self.update_per_epoch:
136
+ return
137
+ if epoch % self.update_interval == 0:
138
+ self._update_lr_and_track_metrics(logs)
139
+
140
+ def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
141
+ """ Restore initial model weights at the end of training.
142
+
143
+ Args:
144
+ logs (dict | None): Training logs.
145
+ """
146
+ if self.initial_weights is not None:
147
+ self.model.set_weights(self.initial_weights) # pyright: ignore [reportUnknownMemberType]
148
+
@@ -0,0 +1,31 @@
1
+
2
+ # pyright: reportMissingTypeStubs=false
3
+ # pyright: reportUnknownMemberType=false
4
+
5
+ # Imports
6
+ from typing import Any
7
+
8
+ from keras.callbacks import ModelCheckpoint
9
+
10
+
11
+ class ModelCheckpointV2(ModelCheckpoint):
12
+ """ Model checkpoint callback but only starts after a given number of epochs.
13
+
14
+ Args:
15
+ epochs_before_start (int): Number of epochs before starting the checkpointing
16
+ """
17
+
18
+ def __init__(self, epochs_before_start: int = 3, *args: Any, **kwargs: Any) -> None:
19
+ super().__init__(*args, **kwargs)
20
+ self.epochs_before_start = epochs_before_start
21
+ self.current_epoch = 0
22
+
23
+ def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
24
+ if self.current_epoch >= self.epochs_before_start:
25
+ super().on_batch_end(batch, logs)
26
+
27
+ def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
28
+ self.current_epoch = epoch
29
+ if epoch >= self.epochs_before_start:
30
+ super().on_epoch_end(epoch, logs)
31
+
@@ -0,0 +1,249 @@
1
+
2
+ # pyright: reportMissingTypeStubs=false
3
+
4
+ # Imports
5
+ from collections.abc import Callable
6
+ from typing import Any
7
+
8
+ from keras.callbacks import Callback
9
+ from keras.models import Model
10
+ from keras.optimizers import Optimizer
11
+
12
+
13
+ class ProgressiveUnfreezing(Callback):
14
+ """ Callback inspired by the Learning Rate Finder to progressively unfreeze model layers during training.
15
+
16
+ Warning: This callback is not compatible with model.fit() as it modifies the trainable state of the model.
17
+ Prefer doing your own training loop instead.
18
+
19
+ This callback can operate in two modes:
20
+ 1. Start with all layers frozen and incrementally unfreeze them from 0% to 100% (progressive_freeze=False)
21
+ 2. Start with all layers unfrozen and incrementally freeze them from 100% to 0% (progressive_freeze=True)
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ base_model: Model,
27
+ steps_per_epoch: int,
28
+ epochs: int,
29
+ reset_weights: bool = False,
30
+ reset_optimizer_function: Callable[[], Optimizer] | None = None,
31
+ update_per_epoch: bool = True,
32
+ update_interval: int = 5,
33
+ progressive_freeze: bool = False
34
+ ) -> None:
35
+ """ Initialize the progressive unfreezing callback.
36
+
37
+ Args:
38
+ base_model (Model): Base model to unfreeze.
39
+ steps_per_epoch (int): Number of steps per epoch.
40
+ epochs (int): Total number of epochs.
41
+ reset_weights (bool): If True, reset weights after each unfreeze.
42
+ reset_optimizer_function (Callable | None):
43
+ If set, use this function to reset the optimizer every update_interval.
44
+ The function should return a compiled optimizer, e.g. `lambda: model._get_optimizer(AdamW(...))`.
45
+ update_per_epoch (bool): If True, unfreeze per epoch, else per batch.
46
+ update_interval (int): Number of steps between each unfreeze to allow model to stabilize.
47
+ progressive_freeze (bool): If True, start with all layers unfrozen and progressively freeze them.
48
+ """
49
+ super().__init__()
50
+ self.base_model: Model = base_model
51
+ """ Base model to unfreeze. """
52
+ self.model: Model
53
+ """ Model to apply the progressive unfreezing to. """
54
+ self.steps_per_epoch: int = int(steps_per_epoch)
55
+ """ Number of steps per epoch. """
56
+ self.epochs: int = int(epochs)
57
+ """ Total number of epochs. """
58
+ self.reset_weights: bool = bool(reset_weights)
59
+ """ If True, reset weights after each unfreeze. """
60
+ self.reset_optimizer_function: Callable[[], Optimizer] | None = reset_optimizer_function
61
+ """ If reset_weights is True and this is not None, use this function to get a new optimizer. """
62
+ self.update_per_epoch: bool = bool(update_per_epoch)
63
+ """ If True, unfreeze per epoch, else per batch. """
64
+ self.update_interval: int = max(1, int(update_interval))
65
+ """ Number of steps between each unfreeze to allow model to stabilize. """
66
+ self.progressive_freeze: bool = bool(progressive_freeze)
67
+ """ If True, start with all layers unfrozen and progressively freeze them. """
68
+
69
+ # If updating per epoch, remove to self.epochs the update interval to allow the last step to train with 100% unfreeze
70
+ if self.update_per_epoch:
71
+ self.epochs -= self.update_interval
72
+
73
+ # Calculate total steps considering the update interval
74
+ total_steps_raw: int = self.epochs if self.update_per_epoch else self.steps_per_epoch * self.epochs
75
+ self.total_steps: int = total_steps_raw // self.update_interval
76
+ """ Total number of update steps (considering update_interval). """
77
+
78
+ self.fraction_unfrozen: list[float] = []
79
+ """ Fraction of layers unfrozen. """
80
+ self.losses: list[float] = []
81
+ """ Losses. """
82
+ self._all_layers: list[Any] = []
83
+ """ All layers. """
84
+ self._initial_trainable: list[bool] = []
85
+ """ Initial trainable states. """
86
+ self._initial_weights: list[Any] | None = None
87
+ """ Initial weights of the model. """
88
+ self._last_update_step: int = -1
89
+ """ Last step when layers were unfrozen. """
90
+ self.params: dict[str, Any]
91
+
92
+ def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
93
+ """ Set initial layer trainable states at the start of training and store initial states and weights.
94
+
95
+ Args:
96
+ logs (dict | None): Training logs.
97
+ """
98
+ # Collect all layers from the model and preserve their original trainable states for potential restoration
99
+ self._all_layers = self.base_model.layers
100
+ self._initial_trainable = [bool(layer.trainable) for layer in self._all_layers]
101
+
102
+ # Store initial weights to reset after each unfreeze
103
+ if self.reset_weights:
104
+ self._initial_weights = self.model.get_weights()
105
+
106
+ # Set initial trainable state based on mode
107
+ for layer in self._all_layers:
108
+ layer.trainable = self.progressive_freeze # If progressive_freeze, start with all layers unfrozen
109
+
110
+ def _update_layers(self, step: int) -> None:
111
+ """ Update layer trainable states based on the current step and mode.
112
+ Reset weights after each update to prevent bias in the results.
113
+
114
+ Args:
115
+ step (int): Current training step.
116
+ """
117
+ # Calculate the effective step considering the update interval
118
+ effective_step: int = step // self.update_interval
119
+
120
+ # Skip if we haven't reached the next update interval
121
+ if effective_step <= self._last_update_step:
122
+ return
123
+ self._last_update_step = effective_step
124
+
125
+ # Calculate the number of layers to unfreeze based on current effective step
126
+ n_layers: int = len(self._all_layers)
127
+
128
+ if self.progressive_freeze:
129
+ # For progressive freezing, start at 1.0 (all unfrozen) and decrease to 0.0
130
+ fraction: float = max(0.0, 1.0 - (effective_step + 1) / self.total_steps)
131
+ else:
132
+ # For progressive unfreezing, start at 0.0 (all frozen) and increase to 1.0
133
+ fraction: float = min(1.0, (effective_step + 1) / self.total_steps)
134
+
135
+ n_unfreeze: int = int(n_layers * fraction) # Number of layers to keep unfrozen
136
+ self.fraction_unfrozen.append(fraction)
137
+
138
+ # Set trainable state for each layer based on position
139
+ # For both modes, we unfreeze from the top (output layers) to the bottom (input layers)
140
+ for i, layer in enumerate(self._all_layers):
141
+ layer.trainable = i >= (n_layers - n_unfreeze)
142
+
143
+ # Reset weights to initial state to prevent bias and reset optimizer
144
+ if self._initial_weights is not None:
145
+ self.model.set_weights(self._initial_weights) # pyright: ignore [reportUnknownMemberType]
146
+ if self.reset_optimizer_function is not None:
147
+ self.model.optimizer = self.reset_optimizer_function()
148
+ self.model.optimizer.build(self.model.trainable_variables) # pyright: ignore [reportUnknownMemberType]
149
+
150
+ def _track_loss(self, logs: dict[str, Any] | None = None) -> None:
151
+ """ Track the current loss.
152
+
153
+ Args:
154
+ logs (dict | None): Training logs containing loss information.
155
+ """
156
+ if logs and "loss" in logs:
157
+ self.losses.append(logs["loss"])
158
+
159
+ def on_batch_begin(self, batch: int, logs: dict[str, Any] | None = None) -> None:
160
+ """ Update layer trainable states at the start of each batch if not updating per epoch.
161
+
162
+ Args:
163
+ batch (int): Current batch index.
164
+ logs (dict | None): Training logs.
165
+ """
166
+ # Skip if we're updating per epoch instead of per batch
167
+ if self.update_per_epoch:
168
+ return
169
+
170
+ # Calculate the current step across all epochs and update layers
171
+ step: int = self.params.get("steps", self.steps_per_epoch) * self.params.get("epoch", 0) + batch
172
+ self._update_layers(step)
173
+
174
+ def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
175
+ """ Track loss at the end of each batch if not updating per epoch.
176
+
177
+ Args:
178
+ batch (int): Current batch index.
179
+ logs (dict | None): Training logs.
180
+ """
181
+ # Skip if we're updating per epoch instead of per batch
182
+ if self.update_per_epoch:
183
+ return
184
+
185
+ # Record the loss if update interval is reached
186
+ if batch % self.update_interval == 0:
187
+ self._track_loss(logs)
188
+
189
+ def on_epoch_begin(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
190
+ """ Update layer trainable states at the start of each epoch if updating per epoch.
191
+
192
+ Args:
193
+ epoch (int): Current epoch index.
194
+ logs (dict | None): Training logs.
195
+ """
196
+ # Skip if we're updating per batch instead of per epoch
197
+ if not self.update_per_epoch:
198
+ return
199
+
200
+ # Update layers based on current epoch
201
+ self._update_layers(epoch)
202
+
203
+ def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
204
+ """ Track loss at the end of each epoch if updating per epoch.
205
+
206
+ Args:
207
+ epoch (int): Current epoch index.
208
+ logs (dict | None): Training logs.
209
+ """
210
+ # Skip if we're updating per batch instead of per epoch
211
+ if not self.update_per_epoch:
212
+ return
213
+
214
+ # Record the loss if update interval is reached
215
+ if epoch % self.update_interval == 0:
216
+ self._track_loss(logs)
217
+
218
+ def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
219
+ """ Restore original trainable states at the end of training.
220
+
221
+ Args:
222
+ logs (dict | None): Training logs.
223
+ """
224
+ # Restore each layer's original trainable state
225
+ for layer, trainable in zip(self._all_layers, self._initial_trainable, strict=False):
226
+ layer.trainable = trainable
227
+
228
+ def get_results(self, multiply_by_100: bool = True) -> tuple[list[float], list[float]]:
229
+ """ Get the results of the progressive unfreezing from 0% to 100% even if progressive_freeze is True.
230
+
231
+ Args:
232
+ multiply_by_100 (bool): If True, multiply the fractions by 100 to get percentages.
233
+
234
+ Returns:
235
+ tuple[list[float], list[float]]: fractions of layers unfrozen, and losses.
236
+ """
237
+ fractions: list[float] = self.fraction_unfrozen
238
+
239
+ # Reverse the order if progressive_freeze is True
240
+ if self.progressive_freeze:
241
+ fractions = fractions[::-1]
242
+
243
+ # Multiply by 100 if requested
244
+ if multiply_by_100:
245
+ fractions = [x * 100 for x in fractions]
246
+
247
+ # Return the results
248
+ return fractions, self.losses
249
+