dragon-ml-toolbox 19.11.0__py3-none-any.whl → 19.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  import numpy as np
2
2
  import torch
3
+ from collections import deque
3
4
  from tqdm.auto import tqdm
4
5
  from typing import Union, Literal, Optional
5
6
  from pathlib import Path
@@ -16,9 +17,11 @@ _LOGGER = get_logger("Callbacks")
16
17
  __all__ = [
17
18
  "History",
18
19
  "TqdmProgressBar",
19
- "DragonEarlyStopping",
20
+ "DragonPatienceEarlyStopping",
21
+ "DragonPrecheltEarlyStopping",
20
22
  "DragonModelCheckpoint",
21
- "DragonLRScheduler"
23
+ "DragonScheduler",
24
+ "DragonReduceLROnPlateau"
22
25
  ]
23
26
 
24
27
 
@@ -112,67 +115,89 @@ class TqdmProgressBar(_Callback):
112
115
  self.epoch_bar.close() # type: ignore
113
116
 
114
117
 
115
- class DragonEarlyStopping(_Callback):
118
+ class _DragonEarlyStopping(_Callback):
116
119
  """
117
- Stop training when a monitored metric has stopped improving.
120
+ Base class for Early Stopping strategies.
121
+ Ensures type compatibility and shared logging logic.
118
122
  """
119
- def __init__(self, monitor: str=PyTorchLogKeys.VAL_LOSS, min_delta: float=0.0, patience: int=5, mode: Literal['auto', 'min', 'max']='auto', verbose: int=1):
120
- """
123
+ def __init__(self,
124
+ monitor: str,
125
+ verbose: int = 1):
126
+ super().__init__()
127
+ self.monitor = monitor
128
+ self.verbose = verbose
129
+ self.stopped_epoch = 0
130
+
131
+ def _stop_training(self, epoch: int, reason: str):
132
+ """Helper to trigger the stop."""
133
+ self.stopped_epoch = epoch
134
+ self.trainer.stop_training = True # type: ignore
135
+ if self.verbose > 0:
136
+ _LOGGER.info(f"Epoch {epoch}: Early stopping triggered. Reason: {reason}")
137
+
138
+
139
+ class DragonPatienceEarlyStopping(_DragonEarlyStopping):
140
+ """
141
+ Standard early stopping: Tracks minimum validation loss (or other metric) with a patience counter.
142
+ """
143
+ def __init__(self,
144
+ monitor: Literal["Training Loss", "Validation Loss"] = "Validation Loss",
145
+ min_delta: float = 0.0,
146
+ patience: int = 10,
147
+ mode: Literal['min', 'max'] = 'min',
148
+ verbose: int = 1):
149
+ """
121
150
  Args:
122
- monitor (str): Quantity to be monitored. Defaults to 'val_loss'.
123
- min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
151
+ monitor (str): Metric to monitor.
152
+ min_delta (float): Minimum change to qualify as an improvement.
124
153
  patience (int): Number of epochs with no improvement after which training will be stopped.
125
- mode (str): One of {'auto', 'min', 'max'}. In 'min' mode, training will stop when the quantity
126
- monitored has stopped decreasing; in 'max' mode it will stop when the quantity
127
- monitored has stopped increasing; in 'auto' mode, the direction is automatically
128
- inferred from the name of the monitored quantity.
154
+ mode (str): One of {'min', 'max'}. In 'min' mode, training will stop when the quantity monitored has stopped decreasing; in 'max' mode it will stop when the quantity monitored has stopped increasing.
129
155
  verbose (int): Verbosity mode.
130
156
  """
131
- super().__init__()
132
- self.monitor = monitor
157
+ # standardize monitor key
158
+ if monitor == "Training Loss":
159
+ std_monitor = PyTorchLogKeys.TRAIN_LOSS
160
+ elif monitor == "Validation Loss":
161
+ std_monitor = PyTorchLogKeys.VAL_LOSS
162
+ else:
163
+ _LOGGER.error(f"Unknown monitor key: {monitor}.")
164
+ raise ValueError()
165
+
166
+ super().__init__(std_monitor, verbose)
133
167
  self.patience = patience
134
168
  self.min_delta = min_delta
135
169
  self.wait = 0
136
- self.stopped_epoch = 0
137
- self.verbose = verbose
170
+ self.mode = mode
138
171
 
139
- if mode not in ['auto', 'min', 'max']:
140
- _LOGGER.error(f"EarlyStopping mode {mode} is unknown, choose one of ('auto', 'min', 'max')")
172
+ if mode not in ['min', 'max']:
173
+ _LOGGER.error(f"EarlyStopping mode {mode} is unknown, choose one of ('min', 'max')")
141
174
  raise ValueError()
142
- self.mode = mode
143
175
 
144
- # Determine the comparison operator based on the mode
176
+ # Determine the comparison operator
145
177
  if self.mode == 'min':
146
178
  self.monitor_op = np.less
147
179
  elif self.mode == 'max':
148
180
  self.monitor_op = np.greater
149
- else: # auto mode
150
- if 'acc' in self.monitor.lower():
151
- self.monitor_op = np.greater
152
- else: # Default to min mode for loss or other metrics
153
- self.monitor_op = np.less
181
+ else:
182
+ # raise error for unknown mode
183
+ _LOGGER.error(f"EarlyStopping mode {mode} is unknown, choose one of ('min', 'max')")
184
+ raise ValueError()
154
185
 
155
186
  self.best = np.inf if self.monitor_op == np.less else -np.inf
156
187
 
157
188
  def on_train_begin(self, logs=None):
158
- # Reset state at the beginning of training
159
189
  self.wait = 0
160
- self.stopped_epoch = 0
161
190
  self.best = np.inf if self.monitor_op == np.less else -np.inf
162
-
191
+
163
192
  def on_epoch_end(self, epoch, logs=None):
164
193
  current = logs.get(self.monitor) # type: ignore
165
194
  if current is None:
166
195
  return
167
196
 
168
- # Determine the comparison threshold based on the mode
197
+ # Check improvement
169
198
  if self.monitor_op == np.less:
170
- # For 'min' mode, we need to be smaller than 'best' by at least 'min_delta'
171
- # Correct check: current < self.best - self.min_delta
172
199
  is_improvement = self.monitor_op(current, self.best - self.min_delta)
173
200
  else:
174
- # For 'max' mode, we need to be greater than 'best' by at least 'min_delta'
175
- # Correct check: current > self.best + self.min_delta
176
201
  is_improvement = self.monitor_op(current, self.best + self.min_delta)
177
202
 
178
203
  if is_improvement:
@@ -183,137 +208,224 @@ class DragonEarlyStopping(_Callback):
183
208
  else:
184
209
  self.wait += 1
185
210
  if self.wait >= self.patience:
186
- self.stopped_epoch = epoch
187
- self.trainer.stop_training = True # type: ignore
188
- if self.verbose > 0:
189
- _LOGGER.info(f"Epoch {epoch+1}: early stopping after {self.wait} epochs with no improvement.")
211
+ self._stop_training(epoch, f"No improvement in {self.monitor} for {self.wait} epochs.")
212
+
213
+
214
+ class DragonPrecheltEarlyStopping(_DragonEarlyStopping):
215
+ """
216
+ Implements Prechelt's 'Progress-Modified GL' criterion.
217
+ Tracks the ratio between Generalization Loss (overfitting) and Training Progress.
218
+
219
+ References:
220
+ Prechelt, L. (1998). Early Stopping - But When?
221
+ """
222
+ def __init__(self,
223
+ alpha: float = 0.75,
224
+ k: int = 5,
225
+ verbose: int = 1):
226
+ """
227
+ This early stopping strategy monitors both validation loss and training loss to determine the optimal stopping point.
228
+
229
+ Args:
230
+ alpha (float): The threshold for the stopping criterion.
231
+ k (int): The window size for calculating training progress.
232
+ verbose (int): Verbosity mode.
233
+
234
+ NOTE:
235
+
236
+ - **The Strip Size (k)**:
237
+ - `5`: The empirical "gold standard." It is long enough to smooth out batch noise but short enough to react to convergence plateaus quickly.
238
+ - `10` to `20`: Use if the training curve is very jagged (e.g., noisy data, small batch sizes, high dropout, or Reinforcement Learning). A larger k value prevents premature stopping due to random volatility.
239
+ - **The threshold (alpha)**:
240
+ - `< 0.5`: Aggressive. Stops training very early.
241
+ - `0.75` to `0.80`: Prechelt found this range to be the most robust across different datasets. It typically yields the best trade-off between generalization and training cost.
242
+ - `1.0` to `1.2`: Useful for complex tasks (like Transformers) where training progress might dip temporarily before recovering. It risks slightly more overfitting but ensures potential is exhausted.
243
+ """
244
+ super().__init__(PyTorchLogKeys.VAL_LOSS, verbose)
245
+ self.train_monitor = PyTorchLogKeys.TRAIN_LOSS
246
+ self.alpha = alpha
247
+ self.k = k
248
+
249
+ self.best_val_loss = np.inf
250
+ self.train_strip = deque(maxlen=k)
251
+
252
+ def on_train_begin(self, logs=None):
253
+ self.best_val_loss = np.inf
254
+ self.train_strip.clear()
255
+
256
+ def on_epoch_end(self, epoch, logs=None):
257
+ val_loss = logs.get(self.monitor) # type: ignore
258
+ train_loss = logs.get(self.train_monitor) # type: ignore
259
+
260
+ if val_loss is None or train_loss is None:
261
+ return
262
+
263
+ # 1. Update Best Validation Loss
264
+ if val_loss < self.best_val_loss:
265
+ self.best_val_loss = val_loss
266
+
267
+ # 2. Update Training Strip
268
+ self.train_strip.append(train_loss)
269
+
270
+ # 3. Calculate Generalization Loss (GL)
271
+ # GL(t) = 100 * (E_val / E_opt - 1)
272
+ # Low GL is good. High GL means we are drifting away from best val score (overfitting).
273
+ gl = 100 * ((val_loss / self.best_val_loss) - 1)
274
+
275
+ # 4. Calculate Progress (Pk)
276
+ # Pk(t) = 1000 * (Sum(strip) / (k * min(strip)) - 1)
277
+ # High Pk is good (training loss is still dropping fast). Low Pk means training has stalled.
278
+ if len(self.train_strip) < self.k:
279
+ # Not enough data for progress yet
280
+ return
281
+
282
+ strip_sum = sum(self.train_strip)
283
+ strip_min = min(self.train_strip)
284
+
285
+ # Avoid division by zero
286
+ if strip_min == 0:
287
+ pk = 0.1 # Arbitrary small number
288
+ else:
289
+ pk = 1000 * ((strip_sum / (self.k * strip_min)) - 1)
290
+
291
+ # 5. The Quotient Criterion
292
+ # Stop if GL / Pk > alpha
293
+ # Intuition: Stop if Overfitting is high AND Progress is low.
294
+
295
+ # Avoid division by zero
296
+ if pk == 0:
297
+ pk = 1e-6
298
+
299
+ quotient = gl / pk
300
+
301
+ if self.verbose > 1:
302
+ _LOGGER.info(f"Epoch {epoch}: GL={gl:.3f} | Pk={pk:.3f} | Quotient={quotient:.3f} (Threshold={self.alpha})")
303
+
304
+ if quotient > self.alpha:
305
+ self._stop_training(epoch, f"Prechelt Criterion triggered. Generalization/Progress quotient ({quotient:.3f}) > alpha ({self.alpha}).")
190
306
 
191
307
 
192
308
  class DragonModelCheckpoint(_Callback):
193
309
  """
194
310
  Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
195
311
  """
196
- def __init__(self, save_dir: Union[str,Path], monitor: str = PyTorchLogKeys.VAL_LOSS,
197
- save_best_only: bool = True, mode: Literal['auto', 'min', 'max']= 'auto', verbose: int = 0):
312
+ def __init__(self,
313
+ save_dir: Union[str, Path],
314
+ monitor: Literal["Training Loss", "Validation Loss", "both"] = "Validation Loss",
315
+ save_three_best: bool = True,
316
+ mode: Literal['min', 'max'] = 'min',
317
+ verbose: int = 0):
198
318
  """
199
- - If `save_best_only` is True, it saves the single best model, deleting the previous best.
200
- - If `save_best_only` is False, it keeps the 3 most recent checkpoints, deleting the oldest ones automatically.
201
-
202
319
  Args:
203
320
  save_dir (str): Directory where checkpoint files will be saved.
204
- monitor (str): Metric to monitor.
205
- save_best_only (bool): If true, save only the best model.
206
- mode (str): One of {'auto', 'min', 'max'}.
321
+ monitor (str): Metric to monitor. If "both", the sum of training loss and validation loss is used.
322
+ save_three_best (bool):
323
+ - If True, keeps the top 3 best checkpoints found during training (based on metric).
324
+ - If False, keeps the 3 most recent checkpoints (rolling window).
325
+ mode (str): One of {'min', 'max'}.
207
326
  verbose (int): Verbosity mode.
208
327
  """
209
-
210
328
  super().__init__()
211
329
  self.save_dir = make_fullpath(save_dir, make=True, enforce="directory")
212
- if not self.save_dir.is_dir():
213
- _LOGGER.error(f"{save_dir} is not a valid directory.")
214
- raise IOError()
215
330
 
216
- self.monitor = monitor
217
- self.save_best_only = save_best_only
331
+ # Standardize monitor key
332
+ if monitor == "Training Loss":
333
+ std_monitor = PyTorchLogKeys.TRAIN_LOSS
334
+ elif monitor == "Validation Loss":
335
+ std_monitor = PyTorchLogKeys.VAL_LOSS
336
+ elif monitor == "both":
337
+ std_monitor = "both"
338
+ else:
339
+ _LOGGER.error(f"Unknown monitor key: {monitor}.")
340
+ raise ValueError()
341
+
342
+ self.monitor = std_monitor
343
+ self.save_three_best = save_three_best
218
344
  self.verbose = verbose
219
345
  self._latest_checkpoint_path = None
220
346
  self._checkpoint_name = PyTorchCheckpointKeys.CHECKPOINT_NAME
221
347
 
222
- # State variables to be managed during training
223
- self.saved_checkpoints = []
224
- self.last_best_filepath = None
348
+ # State variables
349
+ # stored as list of dicts: [{'path': Path, 'score': float, 'epoch': int}]
350
+ self.best_checkpoints = []
351
+ # For rolling check (save_three_best=False)
352
+ self.recent_checkpoints = []
225
353
 
226
- if mode not in ['auto', 'min', 'max']:
227
- _LOGGER.error(f"ModelCheckpoint mode {mode} is unknown.")
354
+ if mode not in ['min', 'max']:
355
+ _LOGGER.error(f"ModelCheckpoint mode {mode} is unknown. Use 'min' or 'max'.")
228
356
  raise ValueError()
229
357
  self.mode = mode
230
358
 
359
+ # Determine comparison operator
231
360
  if self.mode == 'min':
232
361
  self.monitor_op = np.less
233
- elif self.mode == 'max':
234
- self.monitor_op = np.greater
362
+ self.best = np.inf
235
363
  else:
236
- self.monitor_op = np.less if 'loss' in self.monitor else np.greater
237
-
238
- self.best = np.inf if self.monitor_op == np.less else -np.inf
364
+ self.monitor_op = np.greater
365
+ self.best = -np.inf
239
366
 
240
367
  def on_train_begin(self, logs=None):
241
- """Reset state when training starts."""
242
- self.best = np.inf if self.monitor_op == np.less else -np.inf
243
- self.saved_checkpoints = []
244
- self.last_best_filepath = None
368
+ """Reset file tracking state when training starts.
369
+ NOTE: Do nOT reset self.best here if it differs from the default. This allows the Trainer to restore 'best' from a checkpoint before calling train()."""
370
+ self.best_checkpoints = []
371
+ self.recent_checkpoints = []
372
+
373
+ # Check if self.best is at default initialization value
374
+ is_default_min = (self.mode == 'min' and self.best == np.inf)
375
+ is_default_max = (self.mode == 'max' and self.best == -np.inf)
376
+
377
+ # If it is NOT default, it means it was restored.
378
+ if not (is_default_min or is_default_max):
379
+ _LOGGER.debug(f"Resuming with best score: {self.best:.4f}")
380
+
381
+ def _get_metric_value(self, logs):
382
+ """Extracts or calculates the metric value based on configuration."""
383
+ if self.monitor == "both":
384
+ t_loss = logs.get(PyTorchLogKeys.TRAIN_LOSS)
385
+ v_loss = logs.get(PyTorchLogKeys.VAL_LOSS)
386
+ if t_loss is None or v_loss is None:
387
+ return None
388
+ return t_loss + v_loss
389
+ else:
390
+ return logs.get(self.monitor)
245
391
 
246
392
  def on_epoch_end(self, epoch, logs=None):
247
393
  logs = logs or {}
394
+ current_score = self._get_metric_value(logs)
248
395
 
249
- if self.save_best_only:
250
- self._save_best_model(epoch, logs)
251
- else:
252
- self._save_rolling_checkpoints(epoch, logs)
253
-
254
- def _save_best_model(self, epoch, logs):
255
- """Saves a single best model and deletes the previous one."""
256
- current = logs.get(self.monitor)
257
- if current is None:
396
+ if current_score is None:
397
+ if self.verbose > 0:
398
+ _LOGGER.warning(f"Epoch {epoch}: Metric '{self.monitor}' not found in logs. Skipping checkpoint.")
258
399
  return
259
-
260
- if self.monitor_op(current, self.best):
261
- old_best_str = f"{self.best:.4f}" if self.best not in [np.inf, -np.inf] else "inf"
262
-
263
- # Create a descriptive filename
264
- self.save_dir.mkdir(parents=True, exist_ok=True)
265
- current_string = str(round(current, ndigits=2)).replace('.', '_')
266
- filename = f"epoch{epoch}_{self._checkpoint_name}-{current_string}.pth"
267
- new_filepath = self.save_dir / filename
268
-
400
+
401
+ # 1. Update global best score (for logging/metadata)
402
+ if self.monitor_op(current_score, self.best):
269
403
  if self.verbose > 0:
270
- _LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current:.4f}, saving model to {new_filepath}")
271
-
272
- # Update best score *before* saving
273
- self.best = current
404
+ # Only log explicit "improvement" if we are beating the historical best
405
+ old_best_str = f"{self.best:.4f}" if not np.isinf(self.best) else "inf"
406
+ _LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current_score:.4f}")
407
+ self.best = current_score
274
408
 
275
- # Create a comprehensive checkpoint dictionary
276
- checkpoint_data = {
277
- PyTorchCheckpointKeys.EPOCH: epoch,
278
- PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
279
- PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
280
- PyTorchCheckpointKeys.BEST_SCORE: self.best,
281
- PyTorchCheckpointKeys.HISTORY: self.trainer.history, # type: ignore
282
- }
283
-
284
- # Check for scheduler
285
- if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
286
- checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
287
-
288
- # Save the new best model
289
- torch.save(checkpoint_data, new_filepath)
290
- self._latest_checkpoint_path = new_filepath
291
-
292
- # Delete the old best model file
293
- if self.last_best_filepath and self.last_best_filepath.exists():
294
- self.last_best_filepath.unlink()
295
-
296
- # Update state
297
- self.last_best_filepath = new_filepath
409
+ if self.save_three_best:
410
+ self._save_top_k_checkpoints(epoch, current_score)
411
+ else:
412
+ self._save_rolling_checkpoints(epoch, current_score)
298
413
 
299
- def _save_rolling_checkpoints(self, epoch, logs):
300
- """Saves the latest model and keeps only the most recent ones."""
301
- current = logs.get(self.monitor)
302
-
414
+ def _save_checkpoint_file(self, epoch, current_score):
415
+ """Helper to physically save the file."""
303
416
  self.save_dir.mkdir(parents=True, exist_ok=True)
304
- current_string = str(round(current, ndigits=2)).replace('.', '_')
305
- filename = f"epoch{epoch}_{self._checkpoint_name}-{current_string}.pth"
306
- filepath = self.save_dir / filename
307
417
 
308
- if self.verbose > 0:
309
- _LOGGER.info(f'Epoch {epoch}: saving model to {filepath}')
418
+ # Create filename
419
+ score_str = f"{current_score:.4f}".replace('.', '_')
420
+ filename = f"epoch{epoch}_{self._checkpoint_name}-{score_str}.pth"
421
+ filepath = self.save_dir / filename
310
422
 
311
- # Create a comprehensive checkpoint dictionary
423
+ # Create checkpoint dict
312
424
  checkpoint_data = {
313
425
  PyTorchCheckpointKeys.EPOCH: epoch,
314
426
  PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
315
427
  PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
316
- PyTorchCheckpointKeys.BEST_SCORE: self.best, # Save the current best score
428
+ PyTorchCheckpointKeys.BEST_SCORE: current_score,
317
429
  PyTorchCheckpointKeys.HISTORY: self.trainer.history, # type: ignore
318
430
  }
319
431
 
@@ -321,91 +433,269 @@ class DragonModelCheckpoint(_Callback):
321
433
  checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
322
434
 
323
435
  torch.save(checkpoint_data, filepath)
324
-
325
436
  self._latest_checkpoint_path = filepath
437
+
438
+ return filepath
439
+
440
+ def _save_top_k_checkpoints(self, epoch, current_score):
441
+ """Logic for maintaining the top 3 best checkpoints."""
442
+
443
+ def sort_key(item): return item['score']
444
+
445
+ # Determine sort direction so that Index 0 is BEST and Index -1 is WORST
446
+ # Min mode (lower is better): Ascending (reverse=False) -> [0.1, 0.5, 0.9] (0.1 is best)
447
+ # Max mode (higher is better): Descending (reverse=True) -> [0.9, 0.5, 0.1] (0.9 is best)
448
+ is_reverse = (self.mode == 'max')
326
449
 
327
- self.saved_checkpoints.append(filepath)
450
+ should_save = False
451
+
452
+ if len(self.best_checkpoints) < 3:
453
+ should_save = True
454
+ else:
455
+ # Sort current list to identify the worst (last item)
456
+ self.best_checkpoints.sort(key=sort_key, reverse=is_reverse)
457
+ worst_entry = self.best_checkpoints[-1]
458
+
459
+ # Check if current is better than the worst in the list
460
+ # min mode: current < worst['score']
461
+ # max mode: current > worst['score']
462
+ if self.monitor_op(current_score, worst_entry['score']):
463
+ should_save = True
464
+
465
+ if should_save:
466
+ filepath = self._save_checkpoint_file(epoch, current_score)
467
+
468
+ if self.verbose > 0:
469
+ _LOGGER.info(f"Epoch {epoch}: {self.monitor} ({current_score:.4f}) is in top 3. Saving to {filepath.name}")
328
470
 
329
- # If we have more than n checkpoints, remove the oldest one
330
- if len(self.saved_checkpoints) > 3:
331
- file_to_delete = self.saved_checkpoints.pop(0)
471
+ self.best_checkpoints.append({'path': filepath, 'score': current_score, 'epoch': epoch})
472
+
473
+ # Prune if > 3
474
+ if len(self.best_checkpoints) > 3:
475
+ # Re-sort to ensure worst is at the end
476
+ self.best_checkpoints.sort(key=sort_key, reverse=is_reverse)
477
+
478
+ # Evict the last one (Worst)
479
+ entry_to_delete = self.best_checkpoints.pop(-1)
480
+
481
+ if entry_to_delete['path'].exists():
482
+ if self.verbose > 0:
483
+ _LOGGER.info(f" -> Deleting checkpoint outside top 3: {entry_to_delete['path'].name}")
484
+ entry_to_delete['path'].unlink()
485
+
486
+ def _save_rolling_checkpoints(self, epoch, current_score):
487
+ """Saves the latest model and keeps only the 3 most recent ones."""
488
+ filepath = self._save_checkpoint_file(epoch, current_score)
489
+
490
+ if self.verbose > 0:
491
+ _LOGGER.info(f'Epoch {epoch}: saving rolling model to {filepath.name}')
492
+
493
+ self.recent_checkpoints.append(filepath)
494
+
495
+ # If we have more than 3 checkpoints, remove the oldest one
496
+ if len(self.recent_checkpoints) > 3:
497
+ file_to_delete = self.recent_checkpoints.pop(0)
332
498
  if file_to_delete.exists():
333
499
  if self.verbose > 0:
334
- _LOGGER.info(f" -> Deleting old checkpoint: {file_to_delete.name}")
500
+ _LOGGER.info(f" -> Deleting old rolling checkpoint: {file_to_delete.name}")
335
501
  file_to_delete.unlink()
336
502
 
337
503
  @property
338
504
  def best_checkpoint_path(self):
339
- if self._latest_checkpoint_path:
505
+ # If tracking top 3, return the absolute best among them
506
+ if self.save_three_best and self.best_checkpoints:
507
+ def sort_key(item): return item['score']
508
+ is_reverse = (self.mode == 'max')
509
+ # Sort Best -> Worst
510
+ sorted_bests = sorted(self.best_checkpoints, key=sort_key, reverse=is_reverse)
511
+ # Index 0 is always the best based on the logic above
512
+ return sorted_bests[0]['path']
513
+
514
+ elif self._latest_checkpoint_path:
340
515
  return self._latest_checkpoint_path
341
516
  else:
342
517
  _LOGGER.error("No checkpoint paths saved.")
343
518
  raise ValueError()
344
519
 
345
520
 
346
- class DragonLRScheduler(_Callback):
521
+ class _DragonLRScheduler(_Callback):
347
522
  """
348
- Callback to manage a PyTorch learning rate scheduler.
523
+ Base class for Dragon LR Schedulers.
524
+ Handles common logic like logging and attaching to the trainer.
349
525
  """
350
- def __init__(self, scheduler, monitor: Optional[str] = PyTorchLogKeys.VAL_LOSS):
351
- """
352
- This callback automatically calls the scheduler's `step()` method at the
353
- end of each epoch. It also logs a message when the learning rate changes.
526
+ def __init__(self):
527
+ super().__init__()
528
+ self.scheduler = None
529
+ self.previous_lr = None
530
+
531
+ def set_trainer(self, trainer):
532
+ """Associates the callback with the trainer."""
533
+ super().set_trainer(trainer)
534
+ # Note: Subclasses must ensure self.scheduler is set before or during this call
535
+ # if they want to register it immediately.
536
+ if self.scheduler:
537
+ self.trainer.scheduler = self.scheduler # type: ignore
538
+
539
+ def on_train_begin(self, logs=None):
540
+ """Store the initial learning rate."""
541
+ if not self.trainer.optimizer: # type: ignore
542
+ _LOGGER.warning("No optimizer found in trainer. LRScheduler cannot track learning rate.")
543
+ return
544
+ self.previous_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
545
+
546
+ def _check_and_log_lr(self, epoch, logs, verbose: bool):
547
+ """Helper to log LR changes and update history."""
548
+ if not self.trainer.optimizer: # type: ignore
549
+ return
354
550
 
551
+ current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
552
+
553
+ # Log change
554
+ if self.previous_lr is not None and current_lr != self.previous_lr:
555
+ if verbose:
556
+ print(f" > Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
557
+ self.previous_lr = current_lr
558
+
559
+ # Log to dictionary
560
+ logs[PyTorchLogKeys.LEARNING_RATE] = current_lr
561
+
562
+ # Log to history
563
+ if hasattr(self.trainer, 'history'):
564
+ self.trainer.history.setdefault(PyTorchLogKeys.LEARNING_RATE, []).append(current_lr) # type: ignore
565
+
566
+
567
+ class DragonScheduler(_DragonLRScheduler):
568
+ """
569
+ Callback for standard PyTorch Learning Rate Schedulers.
570
+
571
+ Compatible with: StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR, etc.
572
+
573
+ NOT Compatible with: ReduceLROnPlateau (Use `DragonReduceLROnPlateau` instead).
574
+ """
575
+ def __init__(self, scheduler, verbose: bool=True):
576
+ """
355
577
  Args:
356
- scheduler: An initialized PyTorch learning rate scheduler.
357
- monitor (str): The metric to monitor for schedulers that require it, like `ReduceLROnPlateau`. Should match a key in the logs (e.g., 'val_loss').
578
+ scheduler: An initialized PyTorch learning rate scheduler instance.
579
+ verbose (bool): If True, logs learning rate changes to console.
358
580
  """
359
581
  super().__init__()
582
+ if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
583
+ raise ValueError(
584
+ "DragonLRScheduler does not support 'ReduceLROnPlateau'. "
585
+ "Please use the `DragonReduceLROnPlateau` callback instead."
586
+ )
360
587
  self.scheduler = scheduler
361
- self.monitor = monitor
362
- self.previous_lr = None
588
+ self.verbose = verbose
363
589
 
364
590
  def set_trainer(self, trainer):
365
- """This is called by the Trainer to associate itself with the callback."""
366
591
  super().set_trainer(trainer)
367
- # Register the scheduler with the trainer so it can be added to the checkpoint
592
+ # Explicitly register the scheduler again to be safe
368
593
  self.trainer.scheduler = self.scheduler # type: ignore
369
-
370
- def on_train_begin(self, logs=None):
371
- """Store the initial learning rate."""
372
- self.previous_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
594
+ if self.verbose:
595
+ _LOGGER.info(f"Registered LR Scheduler: {self.scheduler.__class__.__name__}")
373
596
 
374
597
  def on_epoch_end(self, epoch, logs=None):
375
- """Step the scheduler and log any change in learning rate."""
376
598
  logs = logs or {}
377
599
 
378
- # For schedulers that need a metric (e.g., val_loss)
379
- if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
380
- if self.monitor is None:
381
- _LOGGER.error("LRScheduler needs a `monitor` metric for ReduceLROnPlateau.")
382
- raise ValueError()
383
-
384
- metric_val = logs.get(self.monitor) # type: ignore
385
- if metric_val is not None:
386
- self.scheduler.step(metric_val)
387
- else:
388
- _LOGGER.warning(f"LRScheduler could not find metric '{self.monitor}' in logs.")
600
+ # Standard step (no metrics needed)
601
+ self.scheduler.step()
389
602
 
390
- # For all other schedulers
603
+ self._check_and_log_lr(epoch, logs, self.verbose)
604
+
605
+
606
+ class DragonReduceLROnPlateau(_DragonLRScheduler):
607
+ """
608
+ Specific callback for `torch.optim.lr_scheduler.ReduceLROnPlateau`. Reduces learning rate when a monitored metric has stopped improving.
609
+
610
+ This wrapper initializes the scheduler internally using the Trainer's optimizer, simplifying the setup process.
611
+ """
612
+ def __init__(self,
613
+ monitor: Literal["Training Loss", "Validation Loss"] = "Validation Loss",
614
+ mode: Literal['min', 'max'] = 'min',
615
+ factor: float = 0.1,
616
+ patience: int = 5,
617
+ threshold: float = 1e-4,
618
+ threshold_mode: Literal['rel', 'abs'] = 'rel',
619
+ cooldown: int = 0,
620
+ min_lr: float = 0,
621
+ eps: float = 1e-8,
622
+ verbose: bool = True):
623
+ """
624
+ Args:
625
+ monitor ("Training Loss", "Validation Loss"): Metric to monitor.
626
+ mode ('min', 'max'): One of 'min', 'max'.
627
+ factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor.
628
+ patience (int): Number of epochs with no improvement after which learning rate will be reduced.
629
+ threshold (float): Threshold for measuring the new optimum.
630
+ threshold_mode ('rel', 'abs'): One of 'rel', 'abs'.
631
+ cooldown (int): Number of epochs to wait before resuming normal operation after lr has been reduced.
632
+ min_lr (float or list): A scalar or a list of scalars.
633
+ eps (float): Minimal decay applied to lr.
634
+ verbose (bool): If True, logs learning rate changes to console.
635
+ """
636
+ super().__init__()
637
+
638
+ # Standardize monitor key
639
+ if monitor == "Training Loss":
640
+ std_monitor = PyTorchLogKeys.TRAIN_LOSS
641
+ elif monitor == "Validation Loss":
642
+ std_monitor = PyTorchLogKeys.VAL_LOSS
391
643
  else:
392
- self.scheduler.step()
644
+ _LOGGER.error(f"Unknown monitor key: {monitor}.")
645
+ raise ValueError()
646
+
647
+ self.monitor = std_monitor
648
+ self.verbose = verbose
649
+
650
+ # Config storage for delayed initialization
651
+ self.config = {
652
+ 'mode': mode,
653
+ 'factor': factor,
654
+ 'patience': patience,
655
+ 'threshold': threshold,
656
+ 'threshold_mode': threshold_mode,
657
+ 'cooldown': cooldown,
658
+ 'min_lr': min_lr,
659
+ 'eps': eps,
660
+ }
661
+
662
+ def set_trainer(self, trainer):
663
+ """
664
+ Initializes the ReduceLROnPlateau scheduler using the trainer's optimizer and registers it.
665
+ """
666
+ super().set_trainer(trainer)
667
+
668
+ if not hasattr(self.trainer, 'optimizer'):
669
+ _LOGGER.error("Trainer has no optimizer. Cannot initialize ReduceLROnPlateau.")
670
+ raise ValueError()
393
671
 
394
- # Get the current learning rate
395
- current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
672
+ # Initialize the actual scheduler with the optimizer
673
+ if self.verbose:
674
+ _LOGGER.info(f"Initializing ReduceLROnPlateau monitoring '{self.monitor}'")
675
+
676
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
677
+ optimizer=self.trainer.optimizer, # type: ignore
678
+ **self.config
679
+ )
680
+
681
+ # Register with trainer for checkpointing
682
+ self.trainer.scheduler = self.scheduler # type: ignore
396
683
 
397
- # Log the change if the LR was updated
398
- if current_lr != self.previous_lr:
399
- _LOGGER.info(f"Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
400
- self.previous_lr = current_lr
684
+ def on_epoch_end(self, epoch, logs=None):
685
+ logs = logs or {}
401
686
 
402
- # --- Add LR to logs and history ---
403
- # Add to the logs dict for any subsequent callbacks
404
- logs[PyTorchLogKeys.LEARNING_RATE] = current_lr
687
+ metric_val = logs.get(self.monitor)
405
688
 
406
- # Also add directly to the trainer's history dict
407
- if hasattr(self.trainer, 'history'):
408
- self.trainer.history.setdefault(PyTorchLogKeys.LEARNING_RATE, []).append(current_lr) # type: ignore
689
+ if metric_val is None:
690
+ _LOGGER.warning(f"DragonReduceLROnPlateau could not find metric '{self.monitor}' in logs. Scheduler step skipped.")
691
+ # Still log LR to keep history consistent
692
+ self._check_and_log_lr(epoch, logs, self.verbose)
693
+ return
694
+
695
+ # Step with metric
696
+ self.scheduler.step(metric_val)
697
+
698
+ self._check_and_log_lr(epoch, logs, self.verbose)
409
699
 
410
700
 
411
701
  def info():