dragon-ml-toolbox 19.10.0__py3-none-any.whl → 19.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.10.0.dist-info → dragon_ml_toolbox-19.12.0.dist-info}/METADATA +1 -1
- {dragon_ml_toolbox-19.10.0.dist-info → dragon_ml_toolbox-19.12.0.dist-info}/RECORD +19 -19
- ml_tools/ML_callbacks.py +8 -4
- ml_tools/_core/_MICE_imputation.py +2 -2
- ml_tools/_core/_ML_callbacks.py +461 -171
- ml_tools/_core/_ML_trainer.py +50 -50
- ml_tools/_core/_ML_utilities.py +153 -50
- ml_tools/_core/_PSO_optimization.py +1 -1
- ml_tools/_core/_ensemble_inference.py +1 -1
- ml_tools/_core/_keys.py +32 -1
- ml_tools/_core/_optimization_tools.py +1 -1
- ml_tools/_core/_path_manager.py +149 -27
- ml_tools/_core/_utilities.py +6 -2
- ml_tools/keys.py +2 -0
- ml_tools/path_manager.py +5 -1
- {dragon_ml_toolbox-19.10.0.dist-info → dragon_ml_toolbox-19.12.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.10.0.dist-info → dragon_ml_toolbox-19.12.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.10.0.dist-info → dragon_ml_toolbox-19.12.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.10.0.dist-info → dragon_ml_toolbox-19.12.0.dist-info}/top_level.txt +0 -0
ml_tools/_core/_ML_callbacks.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import torch
|
|
3
|
+
from collections import deque
|
|
3
4
|
from tqdm.auto import tqdm
|
|
4
5
|
from typing import Union, Literal, Optional
|
|
5
6
|
from pathlib import Path
|
|
@@ -16,9 +17,11 @@ _LOGGER = get_logger("Callbacks")
|
|
|
16
17
|
__all__ = [
|
|
17
18
|
"History",
|
|
18
19
|
"TqdmProgressBar",
|
|
19
|
-
"
|
|
20
|
+
"DragonPatienceEarlyStopping",
|
|
21
|
+
"DragonPrecheltEarlyStopping",
|
|
20
22
|
"DragonModelCheckpoint",
|
|
21
|
-
"
|
|
23
|
+
"DragonScheduler",
|
|
24
|
+
"DragonReduceLROnPlateau"
|
|
22
25
|
]
|
|
23
26
|
|
|
24
27
|
|
|
@@ -112,67 +115,89 @@ class TqdmProgressBar(_Callback):
|
|
|
112
115
|
self.epoch_bar.close() # type: ignore
|
|
113
116
|
|
|
114
117
|
|
|
115
|
-
class
|
|
118
|
+
class _DragonEarlyStopping(_Callback):
|
|
116
119
|
"""
|
|
117
|
-
|
|
120
|
+
Base class for Early Stopping strategies.
|
|
121
|
+
Ensures type compatibility and shared logging logic.
|
|
118
122
|
"""
|
|
119
|
-
def __init__(self,
|
|
120
|
-
|
|
123
|
+
def __init__(self,
|
|
124
|
+
monitor: str,
|
|
125
|
+
verbose: int = 1):
|
|
126
|
+
super().__init__()
|
|
127
|
+
self.monitor = monitor
|
|
128
|
+
self.verbose = verbose
|
|
129
|
+
self.stopped_epoch = 0
|
|
130
|
+
|
|
131
|
+
def _stop_training(self, epoch: int, reason: str):
|
|
132
|
+
"""Helper to trigger the stop."""
|
|
133
|
+
self.stopped_epoch = epoch
|
|
134
|
+
self.trainer.stop_training = True # type: ignore
|
|
135
|
+
if self.verbose > 0:
|
|
136
|
+
_LOGGER.info(f"Epoch {epoch}: Early stopping triggered. Reason: {reason}")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class DragonPatienceEarlyStopping(_DragonEarlyStopping):
|
|
140
|
+
"""
|
|
141
|
+
Standard early stopping: Tracks minimum validation loss (or other metric) with a patience counter.
|
|
142
|
+
"""
|
|
143
|
+
def __init__(self,
|
|
144
|
+
monitor: Literal["Training Loss", "Validation Loss"] = "Validation Loss",
|
|
145
|
+
min_delta: float = 0.0,
|
|
146
|
+
patience: int = 10,
|
|
147
|
+
mode: Literal['min', 'max'] = 'min',
|
|
148
|
+
verbose: int = 1):
|
|
149
|
+
"""
|
|
121
150
|
Args:
|
|
122
|
-
monitor (str):
|
|
123
|
-
min_delta (float): Minimum change
|
|
151
|
+
monitor (str): Metric to monitor.
|
|
152
|
+
min_delta (float): Minimum change to qualify as an improvement.
|
|
124
153
|
patience (int): Number of epochs with no improvement after which training will be stopped.
|
|
125
|
-
mode (str): One of {'
|
|
126
|
-
monitored has stopped decreasing; in 'max' mode it will stop when the quantity
|
|
127
|
-
monitored has stopped increasing; in 'auto' mode, the direction is automatically
|
|
128
|
-
inferred from the name of the monitored quantity.
|
|
154
|
+
mode (str): One of {'min', 'max'}. In 'min' mode, training will stop when the quantity monitored has stopped decreasing; in 'max' mode it will stop when the quantity monitored has stopped increasing.
|
|
129
155
|
verbose (int): Verbosity mode.
|
|
130
156
|
"""
|
|
131
|
-
|
|
132
|
-
|
|
157
|
+
# standardize monitor key
|
|
158
|
+
if monitor == "Training Loss":
|
|
159
|
+
std_monitor = PyTorchLogKeys.TRAIN_LOSS
|
|
160
|
+
elif monitor == "Validation Loss":
|
|
161
|
+
std_monitor = PyTorchLogKeys.VAL_LOSS
|
|
162
|
+
else:
|
|
163
|
+
_LOGGER.error(f"Unknown monitor key: {monitor}.")
|
|
164
|
+
raise ValueError()
|
|
165
|
+
|
|
166
|
+
super().__init__(std_monitor, verbose)
|
|
133
167
|
self.patience = patience
|
|
134
168
|
self.min_delta = min_delta
|
|
135
169
|
self.wait = 0
|
|
136
|
-
self.
|
|
137
|
-
self.verbose = verbose
|
|
170
|
+
self.mode = mode
|
|
138
171
|
|
|
139
|
-
if mode not in ['
|
|
140
|
-
_LOGGER.error(f"EarlyStopping mode {mode} is unknown, choose one of ('
|
|
172
|
+
if mode not in ['min', 'max']:
|
|
173
|
+
_LOGGER.error(f"EarlyStopping mode {mode} is unknown, choose one of ('min', 'max')")
|
|
141
174
|
raise ValueError()
|
|
142
|
-
self.mode = mode
|
|
143
175
|
|
|
144
|
-
# Determine the comparison operator
|
|
176
|
+
# Determine the comparison operator
|
|
145
177
|
if self.mode == 'min':
|
|
146
178
|
self.monitor_op = np.less
|
|
147
179
|
elif self.mode == 'max':
|
|
148
180
|
self.monitor_op = np.greater
|
|
149
|
-
else:
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
self.monitor_op = np.less
|
|
181
|
+
else:
|
|
182
|
+
# raise error for unknown mode
|
|
183
|
+
_LOGGER.error(f"EarlyStopping mode {mode} is unknown, choose one of ('min', 'max')")
|
|
184
|
+
raise ValueError()
|
|
154
185
|
|
|
155
186
|
self.best = np.inf if self.monitor_op == np.less else -np.inf
|
|
156
187
|
|
|
157
188
|
def on_train_begin(self, logs=None):
|
|
158
|
-
# Reset state at the beginning of training
|
|
159
189
|
self.wait = 0
|
|
160
|
-
self.stopped_epoch = 0
|
|
161
190
|
self.best = np.inf if self.monitor_op == np.less else -np.inf
|
|
162
|
-
|
|
191
|
+
|
|
163
192
|
def on_epoch_end(self, epoch, logs=None):
|
|
164
193
|
current = logs.get(self.monitor) # type: ignore
|
|
165
194
|
if current is None:
|
|
166
195
|
return
|
|
167
196
|
|
|
168
|
-
#
|
|
197
|
+
# Check improvement
|
|
169
198
|
if self.monitor_op == np.less:
|
|
170
|
-
# For 'min' mode, we need to be smaller than 'best' by at least 'min_delta'
|
|
171
|
-
# Correct check: current < self.best - self.min_delta
|
|
172
199
|
is_improvement = self.monitor_op(current, self.best - self.min_delta)
|
|
173
200
|
else:
|
|
174
|
-
# For 'max' mode, we need to be greater than 'best' by at least 'min_delta'
|
|
175
|
-
# Correct check: current > self.best + self.min_delta
|
|
176
201
|
is_improvement = self.monitor_op(current, self.best + self.min_delta)
|
|
177
202
|
|
|
178
203
|
if is_improvement:
|
|
@@ -183,137 +208,224 @@ class DragonEarlyStopping(_Callback):
|
|
|
183
208
|
else:
|
|
184
209
|
self.wait += 1
|
|
185
210
|
if self.wait >= self.patience:
|
|
186
|
-
self.
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
211
|
+
self._stop_training(epoch, f"No improvement in {self.monitor} for {self.wait} epochs.")
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class DragonPrecheltEarlyStopping(_DragonEarlyStopping):
|
|
215
|
+
"""
|
|
216
|
+
Implements Prechelt's 'Progress-Modified GL' criterion.
|
|
217
|
+
Tracks the ratio between Generalization Loss (overfitting) and Training Progress.
|
|
218
|
+
|
|
219
|
+
References:
|
|
220
|
+
Prechelt, L. (1998). Early Stopping - But When?
|
|
221
|
+
"""
|
|
222
|
+
def __init__(self,
|
|
223
|
+
alpha: float = 0.75,
|
|
224
|
+
k: int = 5,
|
|
225
|
+
verbose: int = 1):
|
|
226
|
+
"""
|
|
227
|
+
This early stopping strategy monitors both validation loss and training loss to determine the optimal stopping point.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
alpha (float): The threshold for the stopping criterion.
|
|
231
|
+
k (int): The window size for calculating training progress.
|
|
232
|
+
verbose (int): Verbosity mode.
|
|
233
|
+
|
|
234
|
+
NOTE:
|
|
235
|
+
|
|
236
|
+
- **The Strip Size (k)**:
|
|
237
|
+
- `5`: The empirical "gold standard." It is long enough to smooth out batch noise but short enough to react to convergence plateaus quickly.
|
|
238
|
+
- `10` to `20`: Use if the training curve is very jagged (e.g., noisy data, small batch sizes, high dropout, or Reinforcement Learning). A larger k value prevents premature stopping due to random volatility.
|
|
239
|
+
- **The threshold (alpha)**:
|
|
240
|
+
- `< 0.5`: Aggressive. Stops training very early.
|
|
241
|
+
- `0.75` to `0.80`: Prechelt found this range to be the most robust across different datasets. It typically yields the best trade-off between generalization and training cost.
|
|
242
|
+
- `1.0` to `1.2`: Useful for complex tasks (like Transformers) where training progress might dip temporarily before recovering. It risks slightly more overfitting but ensures potential is exhausted.
|
|
243
|
+
"""
|
|
244
|
+
super().__init__(PyTorchLogKeys.VAL_LOSS, verbose)
|
|
245
|
+
self.train_monitor = PyTorchLogKeys.TRAIN_LOSS
|
|
246
|
+
self.alpha = alpha
|
|
247
|
+
self.k = k
|
|
248
|
+
|
|
249
|
+
self.best_val_loss = np.inf
|
|
250
|
+
self.train_strip = deque(maxlen=k)
|
|
251
|
+
|
|
252
|
+
def on_train_begin(self, logs=None):
|
|
253
|
+
self.best_val_loss = np.inf
|
|
254
|
+
self.train_strip.clear()
|
|
255
|
+
|
|
256
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
257
|
+
val_loss = logs.get(self.monitor) # type: ignore
|
|
258
|
+
train_loss = logs.get(self.train_monitor) # type: ignore
|
|
259
|
+
|
|
260
|
+
if val_loss is None or train_loss is None:
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
# 1. Update Best Validation Loss
|
|
264
|
+
if val_loss < self.best_val_loss:
|
|
265
|
+
self.best_val_loss = val_loss
|
|
266
|
+
|
|
267
|
+
# 2. Update Training Strip
|
|
268
|
+
self.train_strip.append(train_loss)
|
|
269
|
+
|
|
270
|
+
# 3. Calculate Generalization Loss (GL)
|
|
271
|
+
# GL(t) = 100 * (E_val / E_opt - 1)
|
|
272
|
+
# Low GL is good. High GL means we are drifting away from best val score (overfitting).
|
|
273
|
+
gl = 100 * ((val_loss / self.best_val_loss) - 1)
|
|
274
|
+
|
|
275
|
+
# 4. Calculate Progress (Pk)
|
|
276
|
+
# Pk(t) = 1000 * (Sum(strip) / (k * min(strip)) - 1)
|
|
277
|
+
# High Pk is good (training loss is still dropping fast). Low Pk means training has stalled.
|
|
278
|
+
if len(self.train_strip) < self.k:
|
|
279
|
+
# Not enough data for progress yet
|
|
280
|
+
return
|
|
281
|
+
|
|
282
|
+
strip_sum = sum(self.train_strip)
|
|
283
|
+
strip_min = min(self.train_strip)
|
|
284
|
+
|
|
285
|
+
# Avoid division by zero
|
|
286
|
+
if strip_min == 0:
|
|
287
|
+
pk = 0.1 # Arbitrary small number
|
|
288
|
+
else:
|
|
289
|
+
pk = 1000 * ((strip_sum / (self.k * strip_min)) - 1)
|
|
290
|
+
|
|
291
|
+
# 5. The Quotient Criterion
|
|
292
|
+
# Stop if GL / Pk > alpha
|
|
293
|
+
# Intuition: Stop if Overfitting is high AND Progress is low.
|
|
294
|
+
|
|
295
|
+
# Avoid division by zero
|
|
296
|
+
if pk == 0:
|
|
297
|
+
pk = 1e-6
|
|
298
|
+
|
|
299
|
+
quotient = gl / pk
|
|
300
|
+
|
|
301
|
+
if self.verbose > 1:
|
|
302
|
+
_LOGGER.info(f"Epoch {epoch}: GL={gl:.3f} | Pk={pk:.3f} | Quotient={quotient:.3f} (Threshold={self.alpha})")
|
|
303
|
+
|
|
304
|
+
if quotient > self.alpha:
|
|
305
|
+
self._stop_training(epoch, f"Prechelt Criterion triggered. Generalization/Progress quotient ({quotient:.3f}) > alpha ({self.alpha}).")
|
|
190
306
|
|
|
191
307
|
|
|
192
308
|
class DragonModelCheckpoint(_Callback):
|
|
193
309
|
"""
|
|
194
310
|
Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
|
|
195
311
|
"""
|
|
196
|
-
def __init__(self,
|
|
197
|
-
|
|
312
|
+
def __init__(self,
|
|
313
|
+
save_dir: Union[str, Path],
|
|
314
|
+
monitor: Literal["Training Loss", "Validation Loss", "both"] = "Validation Loss",
|
|
315
|
+
save_three_best: bool = True,
|
|
316
|
+
mode: Literal['min', 'max'] = 'min',
|
|
317
|
+
verbose: int = 0):
|
|
198
318
|
"""
|
|
199
|
-
- If `save_best_only` is True, it saves the single best model, deleting the previous best.
|
|
200
|
-
- If `save_best_only` is False, it keeps the 3 most recent checkpoints, deleting the oldest ones automatically.
|
|
201
|
-
|
|
202
319
|
Args:
|
|
203
320
|
save_dir (str): Directory where checkpoint files will be saved.
|
|
204
|
-
monitor (str): Metric to monitor.
|
|
205
|
-
|
|
206
|
-
|
|
321
|
+
monitor (str): Metric to monitor. If "both", the sum of training loss and validation loss is used.
|
|
322
|
+
save_three_best (bool):
|
|
323
|
+
- If True, keeps the top 3 best checkpoints found during training (based on metric).
|
|
324
|
+
- If False, keeps the 3 most recent checkpoints (rolling window).
|
|
325
|
+
mode (str): One of {'min', 'max'}.
|
|
207
326
|
verbose (int): Verbosity mode.
|
|
208
327
|
"""
|
|
209
|
-
|
|
210
328
|
super().__init__()
|
|
211
329
|
self.save_dir = make_fullpath(save_dir, make=True, enforce="directory")
|
|
212
|
-
if not self.save_dir.is_dir():
|
|
213
|
-
_LOGGER.error(f"{save_dir} is not a valid directory.")
|
|
214
|
-
raise IOError()
|
|
215
330
|
|
|
216
|
-
|
|
217
|
-
|
|
331
|
+
# Standardize monitor key
|
|
332
|
+
if monitor == "Training Loss":
|
|
333
|
+
std_monitor = PyTorchLogKeys.TRAIN_LOSS
|
|
334
|
+
elif monitor == "Validation Loss":
|
|
335
|
+
std_monitor = PyTorchLogKeys.VAL_LOSS
|
|
336
|
+
elif monitor == "both":
|
|
337
|
+
std_monitor = "both"
|
|
338
|
+
else:
|
|
339
|
+
_LOGGER.error(f"Unknown monitor key: {monitor}.")
|
|
340
|
+
raise ValueError()
|
|
341
|
+
|
|
342
|
+
self.monitor = std_monitor
|
|
343
|
+
self.save_three_best = save_three_best
|
|
218
344
|
self.verbose = verbose
|
|
219
345
|
self._latest_checkpoint_path = None
|
|
220
346
|
self._checkpoint_name = PyTorchCheckpointKeys.CHECKPOINT_NAME
|
|
221
347
|
|
|
222
|
-
# State variables
|
|
223
|
-
|
|
224
|
-
self.
|
|
348
|
+
# State variables
|
|
349
|
+
# stored as list of dicts: [{'path': Path, 'score': float, 'epoch': int}]
|
|
350
|
+
self.best_checkpoints = []
|
|
351
|
+
# For rolling check (save_three_best=False)
|
|
352
|
+
self.recent_checkpoints = []
|
|
225
353
|
|
|
226
|
-
if mode not in ['
|
|
227
|
-
_LOGGER.error(f"ModelCheckpoint mode {mode} is unknown.")
|
|
354
|
+
if mode not in ['min', 'max']:
|
|
355
|
+
_LOGGER.error(f"ModelCheckpoint mode {mode} is unknown. Use 'min' or 'max'.")
|
|
228
356
|
raise ValueError()
|
|
229
357
|
self.mode = mode
|
|
230
358
|
|
|
359
|
+
# Determine comparison operator
|
|
231
360
|
if self.mode == 'min':
|
|
232
361
|
self.monitor_op = np.less
|
|
233
|
-
|
|
234
|
-
self.monitor_op = np.greater
|
|
362
|
+
self.best = np.inf
|
|
235
363
|
else:
|
|
236
|
-
self.monitor_op = np.
|
|
237
|
-
|
|
238
|
-
self.best = np.inf if self.monitor_op == np.less else -np.inf
|
|
364
|
+
self.monitor_op = np.greater
|
|
365
|
+
self.best = -np.inf
|
|
239
366
|
|
|
240
367
|
def on_train_begin(self, logs=None):
|
|
241
|
-
"""Reset state when training starts.
|
|
242
|
-
self.best
|
|
243
|
-
self.
|
|
244
|
-
self.
|
|
368
|
+
"""Reset file tracking state when training starts.
|
|
369
|
+
NOTE: Do nOT reset self.best here if it differs from the default. This allows the Trainer to restore 'best' from a checkpoint before calling train()."""
|
|
370
|
+
self.best_checkpoints = []
|
|
371
|
+
self.recent_checkpoints = []
|
|
372
|
+
|
|
373
|
+
# Check if self.best is at default initialization value
|
|
374
|
+
is_default_min = (self.mode == 'min' and self.best == np.inf)
|
|
375
|
+
is_default_max = (self.mode == 'max' and self.best == -np.inf)
|
|
376
|
+
|
|
377
|
+
# If it is NOT default, it means it was restored.
|
|
378
|
+
if not (is_default_min or is_default_max):
|
|
379
|
+
_LOGGER.debug(f"Resuming with best score: {self.best:.4f}")
|
|
380
|
+
|
|
381
|
+
def _get_metric_value(self, logs):
|
|
382
|
+
"""Extracts or calculates the metric value based on configuration."""
|
|
383
|
+
if self.monitor == "both":
|
|
384
|
+
t_loss = logs.get(PyTorchLogKeys.TRAIN_LOSS)
|
|
385
|
+
v_loss = logs.get(PyTorchLogKeys.VAL_LOSS)
|
|
386
|
+
if t_loss is None or v_loss is None:
|
|
387
|
+
return None
|
|
388
|
+
return t_loss + v_loss
|
|
389
|
+
else:
|
|
390
|
+
return logs.get(self.monitor)
|
|
245
391
|
|
|
246
392
|
def on_epoch_end(self, epoch, logs=None):
|
|
247
393
|
logs = logs or {}
|
|
394
|
+
current_score = self._get_metric_value(logs)
|
|
248
395
|
|
|
249
|
-
if
|
|
250
|
-
self.
|
|
251
|
-
|
|
252
|
-
self._save_rolling_checkpoints(epoch, logs)
|
|
253
|
-
|
|
254
|
-
def _save_best_model(self, epoch, logs):
|
|
255
|
-
"""Saves a single best model and deletes the previous one."""
|
|
256
|
-
current = logs.get(self.monitor)
|
|
257
|
-
if current is None:
|
|
396
|
+
if current_score is None:
|
|
397
|
+
if self.verbose > 0:
|
|
398
|
+
_LOGGER.warning(f"Epoch {epoch}: Metric '{self.monitor}' not found in logs. Skipping checkpoint.")
|
|
258
399
|
return
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
# Create a descriptive filename
|
|
264
|
-
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
265
|
-
current_string = str(round(current, ndigits=2)).replace('.', '_')
|
|
266
|
-
filename = f"epoch{epoch}_{self._checkpoint_name}-{current_string}.pth"
|
|
267
|
-
new_filepath = self.save_dir / filename
|
|
268
|
-
|
|
400
|
+
|
|
401
|
+
# 1. Update global best score (for logging/metadata)
|
|
402
|
+
if self.monitor_op(current_score, self.best):
|
|
269
403
|
if self.verbose > 0:
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
self.best =
|
|
404
|
+
# Only log explicit "improvement" if we are beating the historical best
|
|
405
|
+
old_best_str = f"{self.best:.4f}" if not np.isinf(self.best) else "inf"
|
|
406
|
+
_LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current_score:.4f}")
|
|
407
|
+
self.best = current_score
|
|
274
408
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
|
|
280
|
-
PyTorchCheckpointKeys.BEST_SCORE: self.best,
|
|
281
|
-
PyTorchCheckpointKeys.HISTORY: self.trainer.history, # type: ignore
|
|
282
|
-
}
|
|
283
|
-
|
|
284
|
-
# Check for scheduler
|
|
285
|
-
if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
|
|
286
|
-
checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
|
|
287
|
-
|
|
288
|
-
# Save the new best model
|
|
289
|
-
torch.save(checkpoint_data, new_filepath)
|
|
290
|
-
self._latest_checkpoint_path = new_filepath
|
|
291
|
-
|
|
292
|
-
# Delete the old best model file
|
|
293
|
-
if self.last_best_filepath and self.last_best_filepath.exists():
|
|
294
|
-
self.last_best_filepath.unlink()
|
|
295
|
-
|
|
296
|
-
# Update state
|
|
297
|
-
self.last_best_filepath = new_filepath
|
|
409
|
+
if self.save_three_best:
|
|
410
|
+
self._save_top_k_checkpoints(epoch, current_score)
|
|
411
|
+
else:
|
|
412
|
+
self._save_rolling_checkpoints(epoch, current_score)
|
|
298
413
|
|
|
299
|
-
def
|
|
300
|
-
"""
|
|
301
|
-
current = logs.get(self.monitor)
|
|
302
|
-
|
|
414
|
+
def _save_checkpoint_file(self, epoch, current_score):
|
|
415
|
+
"""Helper to physically save the file."""
|
|
303
416
|
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
304
|
-
current_string = str(round(current, ndigits=2)).replace('.', '_')
|
|
305
|
-
filename = f"epoch{epoch}_{self._checkpoint_name}-{current_string}.pth"
|
|
306
|
-
filepath = self.save_dir / filename
|
|
307
417
|
|
|
308
|
-
|
|
309
|
-
|
|
418
|
+
# Create filename
|
|
419
|
+
score_str = f"{current_score:.4f}".replace('.', '_')
|
|
420
|
+
filename = f"epoch{epoch}_{self._checkpoint_name}-{score_str}.pth"
|
|
421
|
+
filepath = self.save_dir / filename
|
|
310
422
|
|
|
311
|
-
# Create
|
|
423
|
+
# Create checkpoint dict
|
|
312
424
|
checkpoint_data = {
|
|
313
425
|
PyTorchCheckpointKeys.EPOCH: epoch,
|
|
314
426
|
PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
|
|
315
427
|
PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
|
|
316
|
-
PyTorchCheckpointKeys.BEST_SCORE:
|
|
428
|
+
PyTorchCheckpointKeys.BEST_SCORE: current_score,
|
|
317
429
|
PyTorchCheckpointKeys.HISTORY: self.trainer.history, # type: ignore
|
|
318
430
|
}
|
|
319
431
|
|
|
@@ -321,91 +433,269 @@ class DragonModelCheckpoint(_Callback):
|
|
|
321
433
|
checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
|
|
322
434
|
|
|
323
435
|
torch.save(checkpoint_data, filepath)
|
|
324
|
-
|
|
325
436
|
self._latest_checkpoint_path = filepath
|
|
437
|
+
|
|
438
|
+
return filepath
|
|
439
|
+
|
|
440
|
+
def _save_top_k_checkpoints(self, epoch, current_score):
|
|
441
|
+
"""Logic for maintaining the top 3 best checkpoints."""
|
|
442
|
+
|
|
443
|
+
def sort_key(item): return item['score']
|
|
444
|
+
|
|
445
|
+
# Determine sort direction so that Index 0 is BEST and Index -1 is WORST
|
|
446
|
+
# Min mode (lower is better): Ascending (reverse=False) -> [0.1, 0.5, 0.9] (0.1 is best)
|
|
447
|
+
# Max mode (higher is better): Descending (reverse=True) -> [0.9, 0.5, 0.1] (0.9 is best)
|
|
448
|
+
is_reverse = (self.mode == 'max')
|
|
326
449
|
|
|
327
|
-
|
|
450
|
+
should_save = False
|
|
451
|
+
|
|
452
|
+
if len(self.best_checkpoints) < 3:
|
|
453
|
+
should_save = True
|
|
454
|
+
else:
|
|
455
|
+
# Sort current list to identify the worst (last item)
|
|
456
|
+
self.best_checkpoints.sort(key=sort_key, reverse=is_reverse)
|
|
457
|
+
worst_entry = self.best_checkpoints[-1]
|
|
458
|
+
|
|
459
|
+
# Check if current is better than the worst in the list
|
|
460
|
+
# min mode: current < worst['score']
|
|
461
|
+
# max mode: current > worst['score']
|
|
462
|
+
if self.monitor_op(current_score, worst_entry['score']):
|
|
463
|
+
should_save = True
|
|
464
|
+
|
|
465
|
+
if should_save:
|
|
466
|
+
filepath = self._save_checkpoint_file(epoch, current_score)
|
|
467
|
+
|
|
468
|
+
if self.verbose > 0:
|
|
469
|
+
_LOGGER.info(f"Epoch {epoch}: {self.monitor} ({current_score:.4f}) is in top 3. Saving to {filepath.name}")
|
|
328
470
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
471
|
+
self.best_checkpoints.append({'path': filepath, 'score': current_score, 'epoch': epoch})
|
|
472
|
+
|
|
473
|
+
# Prune if > 3
|
|
474
|
+
if len(self.best_checkpoints) > 3:
|
|
475
|
+
# Re-sort to ensure worst is at the end
|
|
476
|
+
self.best_checkpoints.sort(key=sort_key, reverse=is_reverse)
|
|
477
|
+
|
|
478
|
+
# Evict the last one (Worst)
|
|
479
|
+
entry_to_delete = self.best_checkpoints.pop(-1)
|
|
480
|
+
|
|
481
|
+
if entry_to_delete['path'].exists():
|
|
482
|
+
if self.verbose > 0:
|
|
483
|
+
_LOGGER.info(f" -> Deleting checkpoint outside top 3: {entry_to_delete['path'].name}")
|
|
484
|
+
entry_to_delete['path'].unlink()
|
|
485
|
+
|
|
486
|
+
def _save_rolling_checkpoints(self, epoch, current_score):
|
|
487
|
+
"""Saves the latest model and keeps only the 3 most recent ones."""
|
|
488
|
+
filepath = self._save_checkpoint_file(epoch, current_score)
|
|
489
|
+
|
|
490
|
+
if self.verbose > 0:
|
|
491
|
+
_LOGGER.info(f'Epoch {epoch}: saving rolling model to {filepath.name}')
|
|
492
|
+
|
|
493
|
+
self.recent_checkpoints.append(filepath)
|
|
494
|
+
|
|
495
|
+
# If we have more than 3 checkpoints, remove the oldest one
|
|
496
|
+
if len(self.recent_checkpoints) > 3:
|
|
497
|
+
file_to_delete = self.recent_checkpoints.pop(0)
|
|
332
498
|
if file_to_delete.exists():
|
|
333
499
|
if self.verbose > 0:
|
|
334
|
-
_LOGGER.info(f" -> Deleting old checkpoint: {file_to_delete.name}")
|
|
500
|
+
_LOGGER.info(f" -> Deleting old rolling checkpoint: {file_to_delete.name}")
|
|
335
501
|
file_to_delete.unlink()
|
|
336
502
|
|
|
337
503
|
@property
|
|
338
504
|
def best_checkpoint_path(self):
|
|
339
|
-
|
|
505
|
+
# If tracking top 3, return the absolute best among them
|
|
506
|
+
if self.save_three_best and self.best_checkpoints:
|
|
507
|
+
def sort_key(item): return item['score']
|
|
508
|
+
is_reverse = (self.mode == 'max')
|
|
509
|
+
# Sort Best -> Worst
|
|
510
|
+
sorted_bests = sorted(self.best_checkpoints, key=sort_key, reverse=is_reverse)
|
|
511
|
+
# Index 0 is always the best based on the logic above
|
|
512
|
+
return sorted_bests[0]['path']
|
|
513
|
+
|
|
514
|
+
elif self._latest_checkpoint_path:
|
|
340
515
|
return self._latest_checkpoint_path
|
|
341
516
|
else:
|
|
342
517
|
_LOGGER.error("No checkpoint paths saved.")
|
|
343
518
|
raise ValueError()
|
|
344
519
|
|
|
345
520
|
|
|
346
|
-
class
|
|
521
|
+
class _DragonLRScheduler(_Callback):
|
|
347
522
|
"""
|
|
348
|
-
|
|
523
|
+
Base class for Dragon LR Schedulers.
|
|
524
|
+
Handles common logic like logging and attaching to the trainer.
|
|
349
525
|
"""
|
|
350
|
-
def __init__(self
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
526
|
+
def __init__(self):
|
|
527
|
+
super().__init__()
|
|
528
|
+
self.scheduler = None
|
|
529
|
+
self.previous_lr = None
|
|
530
|
+
|
|
531
|
+
def set_trainer(self, trainer):
|
|
532
|
+
"""Associates the callback with the trainer."""
|
|
533
|
+
super().set_trainer(trainer)
|
|
534
|
+
# Note: Subclasses must ensure self.scheduler is set before or during this call
|
|
535
|
+
# if they want to register it immediately.
|
|
536
|
+
if self.scheduler:
|
|
537
|
+
self.trainer.scheduler = self.scheduler # type: ignore
|
|
538
|
+
|
|
539
|
+
def on_train_begin(self, logs=None):
|
|
540
|
+
"""Store the initial learning rate."""
|
|
541
|
+
if not self.trainer.optimizer: # type: ignore
|
|
542
|
+
_LOGGER.warning("No optimizer found in trainer. LRScheduler cannot track learning rate.")
|
|
543
|
+
return
|
|
544
|
+
self.previous_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
545
|
+
|
|
546
|
+
def _check_and_log_lr(self, epoch, logs, verbose: bool):
|
|
547
|
+
"""Helper to log LR changes and update history."""
|
|
548
|
+
if not self.trainer.optimizer: # type: ignore
|
|
549
|
+
return
|
|
354
550
|
|
|
551
|
+
current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
552
|
+
|
|
553
|
+
# Log change
|
|
554
|
+
if self.previous_lr is not None and current_lr != self.previous_lr:
|
|
555
|
+
if verbose:
|
|
556
|
+
print(f" > Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
|
|
557
|
+
self.previous_lr = current_lr
|
|
558
|
+
|
|
559
|
+
# Log to dictionary
|
|
560
|
+
logs[PyTorchLogKeys.LEARNING_RATE] = current_lr
|
|
561
|
+
|
|
562
|
+
# Log to history
|
|
563
|
+
if hasattr(self.trainer, 'history'):
|
|
564
|
+
self.trainer.history.setdefault(PyTorchLogKeys.LEARNING_RATE, []).append(current_lr) # type: ignore
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
class DragonScheduler(_DragonLRScheduler):
|
|
568
|
+
"""
|
|
569
|
+
Callback for standard PyTorch Learning Rate Schedulers.
|
|
570
|
+
|
|
571
|
+
Compatible with: StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR, etc.
|
|
572
|
+
|
|
573
|
+
NOT Compatible with: ReduceLROnPlateau (Use `DragonReduceLROnPlateau` instead).
|
|
574
|
+
"""
|
|
575
|
+
def __init__(self, scheduler, verbose: bool=True):
|
|
576
|
+
"""
|
|
355
577
|
Args:
|
|
356
|
-
scheduler: An initialized PyTorch learning rate scheduler.
|
|
357
|
-
|
|
578
|
+
scheduler: An initialized PyTorch learning rate scheduler instance.
|
|
579
|
+
verbose (bool): If True, logs learning rate changes to console.
|
|
358
580
|
"""
|
|
359
581
|
super().__init__()
|
|
582
|
+
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
583
|
+
raise ValueError(
|
|
584
|
+
"DragonLRScheduler does not support 'ReduceLROnPlateau'. "
|
|
585
|
+
"Please use the `DragonReduceLROnPlateau` callback instead."
|
|
586
|
+
)
|
|
360
587
|
self.scheduler = scheduler
|
|
361
|
-
self.
|
|
362
|
-
self.previous_lr = None
|
|
588
|
+
self.verbose = verbose
|
|
363
589
|
|
|
364
590
|
def set_trainer(self, trainer):
|
|
365
|
-
"""This is called by the Trainer to associate itself with the callback."""
|
|
366
591
|
super().set_trainer(trainer)
|
|
367
|
-
#
|
|
592
|
+
# Explicitly register the scheduler again to be safe
|
|
368
593
|
self.trainer.scheduler = self.scheduler # type: ignore
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
"""Store the initial learning rate."""
|
|
372
|
-
self.previous_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
594
|
+
if self.verbose:
|
|
595
|
+
_LOGGER.info(f"Registered LR Scheduler: {self.scheduler.__class__.__name__}")
|
|
373
596
|
|
|
374
597
|
def on_epoch_end(self, epoch, logs=None):
|
|
375
|
-
"""Step the scheduler and log any change in learning rate."""
|
|
376
598
|
logs = logs or {}
|
|
377
599
|
|
|
378
|
-
#
|
|
379
|
-
|
|
380
|
-
if self.monitor is None:
|
|
381
|
-
_LOGGER.error("LRScheduler needs a `monitor` metric for ReduceLROnPlateau.")
|
|
382
|
-
raise ValueError()
|
|
383
|
-
|
|
384
|
-
metric_val = logs.get(self.monitor) # type: ignore
|
|
385
|
-
if metric_val is not None:
|
|
386
|
-
self.scheduler.step(metric_val)
|
|
387
|
-
else:
|
|
388
|
-
_LOGGER.warning(f"LRScheduler could not find metric '{self.monitor}' in logs.")
|
|
600
|
+
# Standard step (no metrics needed)
|
|
601
|
+
self.scheduler.step()
|
|
389
602
|
|
|
390
|
-
|
|
603
|
+
self._check_and_log_lr(epoch, logs, self.verbose)
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
class DragonReduceLROnPlateau(_DragonLRScheduler):
|
|
607
|
+
"""
|
|
608
|
+
Specific callback for `torch.optim.lr_scheduler.ReduceLROnPlateau`. Reduces learning rate when a monitored metric has stopped improving.
|
|
609
|
+
|
|
610
|
+
This wrapper initializes the scheduler internally using the Trainer's optimizer, simplifying the setup process.
|
|
611
|
+
"""
|
|
612
|
+
def __init__(self,
|
|
613
|
+
monitor: Literal["Training Loss", "Validation Loss"] = "Validation Loss",
|
|
614
|
+
mode: Literal['min', 'max'] = 'min',
|
|
615
|
+
factor: float = 0.1,
|
|
616
|
+
patience: int = 5,
|
|
617
|
+
threshold: float = 1e-4,
|
|
618
|
+
threshold_mode: Literal['rel', 'abs'] = 'rel',
|
|
619
|
+
cooldown: int = 0,
|
|
620
|
+
min_lr: float = 0,
|
|
621
|
+
eps: float = 1e-8,
|
|
622
|
+
verbose: bool = True):
|
|
623
|
+
"""
|
|
624
|
+
Args:
|
|
625
|
+
monitor ("Training Loss", "Validation Loss"): Metric to monitor.
|
|
626
|
+
mode ('min', 'max'): One of 'min', 'max'.
|
|
627
|
+
factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor.
|
|
628
|
+
patience (int): Number of epochs with no improvement after which learning rate will be reduced.
|
|
629
|
+
threshold (float): Threshold for measuring the new optimum.
|
|
630
|
+
threshold_mode ('rel', 'abs'): One of 'rel', 'abs'.
|
|
631
|
+
cooldown (int): Number of epochs to wait before resuming normal operation after lr has been reduced.
|
|
632
|
+
min_lr (float or list): A scalar or a list of scalars.
|
|
633
|
+
eps (float): Minimal decay applied to lr.
|
|
634
|
+
verbose (bool): If True, logs learning rate changes to console.
|
|
635
|
+
"""
|
|
636
|
+
super().__init__()
|
|
637
|
+
|
|
638
|
+
# Standardize monitor key
|
|
639
|
+
if monitor == "Training Loss":
|
|
640
|
+
std_monitor = PyTorchLogKeys.TRAIN_LOSS
|
|
641
|
+
elif monitor == "Validation Loss":
|
|
642
|
+
std_monitor = PyTorchLogKeys.VAL_LOSS
|
|
391
643
|
else:
|
|
392
|
-
|
|
644
|
+
_LOGGER.error(f"Unknown monitor key: {monitor}.")
|
|
645
|
+
raise ValueError()
|
|
646
|
+
|
|
647
|
+
self.monitor = std_monitor
|
|
648
|
+
self.verbose = verbose
|
|
649
|
+
|
|
650
|
+
# Config storage for delayed initialization
|
|
651
|
+
self.config = {
|
|
652
|
+
'mode': mode,
|
|
653
|
+
'factor': factor,
|
|
654
|
+
'patience': patience,
|
|
655
|
+
'threshold': threshold,
|
|
656
|
+
'threshold_mode': threshold_mode,
|
|
657
|
+
'cooldown': cooldown,
|
|
658
|
+
'min_lr': min_lr,
|
|
659
|
+
'eps': eps,
|
|
660
|
+
}
|
|
661
|
+
|
|
662
|
+
def set_trainer(self, trainer):
|
|
663
|
+
"""
|
|
664
|
+
Initializes the ReduceLROnPlateau scheduler using the trainer's optimizer and registers it.
|
|
665
|
+
"""
|
|
666
|
+
super().set_trainer(trainer)
|
|
667
|
+
|
|
668
|
+
if not hasattr(self.trainer, 'optimizer'):
|
|
669
|
+
_LOGGER.error("Trainer has no optimizer. Cannot initialize ReduceLROnPlateau.")
|
|
670
|
+
raise ValueError()
|
|
393
671
|
|
|
394
|
-
#
|
|
395
|
-
|
|
672
|
+
# Initialize the actual scheduler with the optimizer
|
|
673
|
+
if self.verbose:
|
|
674
|
+
_LOGGER.info(f"Initializing ReduceLROnPlateau monitoring '{self.monitor}'")
|
|
675
|
+
|
|
676
|
+
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
677
|
+
optimizer=self.trainer.optimizer, # type: ignore
|
|
678
|
+
**self.config
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
# Register with trainer for checkpointing
|
|
682
|
+
self.trainer.scheduler = self.scheduler # type: ignore
|
|
396
683
|
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
_LOGGER.info(f"Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
|
|
400
|
-
self.previous_lr = current_lr
|
|
684
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
685
|
+
logs = logs or {}
|
|
401
686
|
|
|
402
|
-
|
|
403
|
-
# Add to the logs dict for any subsequent callbacks
|
|
404
|
-
logs[PyTorchLogKeys.LEARNING_RATE] = current_lr
|
|
687
|
+
metric_val = logs.get(self.monitor)
|
|
405
688
|
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
689
|
+
if metric_val is None:
|
|
690
|
+
_LOGGER.warning(f"DragonReduceLROnPlateau could not find metric '{self.monitor}' in logs. Scheduler step skipped.")
|
|
691
|
+
# Still log LR to keep history consistent
|
|
692
|
+
self._check_and_log_lr(epoch, logs, self.verbose)
|
|
693
|
+
return
|
|
694
|
+
|
|
695
|
+
# Step with metric
|
|
696
|
+
self.scheduler.step(metric_val)
|
|
697
|
+
|
|
698
|
+
self._check_and_log_lr(epoch, logs, self.verbose)
|
|
409
699
|
|
|
410
700
|
|
|
411
701
|
def info():
|