dragon-ml-toolbox 14.7.0__py3-none-any.whl → 16.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +9 -5
  2. dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
  3. ml_tools/ETL_cleaning.py +20 -20
  4. ml_tools/ETL_engineering.py +23 -25
  5. ml_tools/GUI_tools.py +20 -20
  6. ml_tools/MICE_imputation.py +3 -3
  7. ml_tools/ML_callbacks.py +43 -26
  8. ml_tools/ML_configuration.py +704 -24
  9. ml_tools/ML_datasetmaster.py +235 -280
  10. ml_tools/ML_evaluation.py +144 -39
  11. ml_tools/ML_evaluation_multi.py +103 -35
  12. ml_tools/ML_inference.py +290 -208
  13. ml_tools/ML_models.py +13 -102
  14. ml_tools/ML_models_advanced.py +1 -1
  15. ml_tools/ML_optimization.py +12 -12
  16. ml_tools/ML_scaler.py +11 -11
  17. ml_tools/ML_sequence_datasetmaster.py +341 -0
  18. ml_tools/ML_sequence_evaluation.py +219 -0
  19. ml_tools/ML_sequence_inference.py +391 -0
  20. ml_tools/ML_sequence_models.py +139 -0
  21. ml_tools/ML_trainer.py +1342 -386
  22. ml_tools/ML_utilities.py +1 -1
  23. ml_tools/ML_vision_datasetmaster.py +120 -72
  24. ml_tools/ML_vision_evaluation.py +30 -6
  25. ml_tools/ML_vision_inference.py +129 -152
  26. ml_tools/ML_vision_models.py +1 -1
  27. ml_tools/ML_vision_transformers.py +121 -40
  28. ml_tools/PSO_optimization.py +6 -6
  29. ml_tools/SQL.py +4 -4
  30. ml_tools/{keys.py → _keys.py} +45 -0
  31. ml_tools/_schema.py +1 -1
  32. ml_tools/ensemble_evaluation.py +1 -1
  33. ml_tools/ensemble_inference.py +7 -33
  34. ml_tools/ensemble_learning.py +1 -1
  35. ml_tools/optimization_tools.py +2 -2
  36. ml_tools/path_manager.py +5 -5
  37. ml_tools/utilities.py +1 -2
  38. dragon_ml_toolbox-14.7.0.dist-info/RECORD +0 -49
  39. ml_tools/RNN_forecast.py +0 -56
  40. ml_tools/_ML_vision_recipe.py +0 -88
  41. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
  42. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
  43. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  44. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_callbacks.py CHANGED
@@ -4,23 +4,22 @@ from tqdm.auto import tqdm
4
4
  from typing import Union, Literal, Optional
5
5
  from pathlib import Path
6
6
 
7
- from .path_manager import make_fullpath, sanitize_filename
8
- from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
7
+ from .path_manager import make_fullpath
8
+ from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys
9
9
  from ._logger import _LOGGER
10
10
  from ._script_info import _script_info
11
11
 
12
12
 
13
13
  __all__ = [
14
- "Callback",
15
14
  "History",
16
15
  "TqdmProgressBar",
17
- "EarlyStopping",
18
- "ModelCheckpoint",
19
- "LRScheduler"
16
+ "DragonEarlyStopping",
17
+ "DragonModelCheckpoint",
18
+ "DragonLRScheduler"
20
19
  ]
21
20
 
22
21
 
23
- class Callback:
22
+ class _Callback:
24
23
  """
25
24
  Abstract base class used to build new callbacks.
26
25
 
@@ -60,7 +59,7 @@ class Callback:
60
59
  pass
61
60
 
62
61
 
63
- class History(Callback):
62
+ class History(_Callback):
64
63
  """
65
64
  Callback that records events into a `history` dictionary.
66
65
 
@@ -79,7 +78,7 @@ class History(Callback):
79
78
  self.trainer.history.setdefault(k, []).append(v) # type: ignore
80
79
 
81
80
 
82
- class TqdmProgressBar(Callback):
81
+ class TqdmProgressBar(_Callback):
83
82
  """Callback that provides a tqdm progress bar for training."""
84
83
  def __init__(self):
85
84
  self.epoch_bar = None
@@ -110,7 +109,7 @@ class TqdmProgressBar(Callback):
110
109
  self.epoch_bar.close() # type: ignore
111
110
 
112
111
 
113
- class EarlyStopping(Callback):
112
+ class DragonEarlyStopping(_Callback):
114
113
  """
115
114
  Stop training when a monitored metric has stopped improving.
116
115
  """
@@ -187,11 +186,11 @@ class EarlyStopping(Callback):
187
186
  _LOGGER.info(f"Epoch {epoch+1}: early stopping after {self.wait} epochs with no improvement.")
188
187
 
189
188
 
190
- class ModelCheckpoint(Callback):
189
+ class DragonModelCheckpoint(_Callback):
191
190
  """
192
191
  Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
193
192
  """
194
- def __init__(self, save_dir: Union[str,Path], checkpoint_name: Optional[str]=None, monitor: str = PyTorchLogKeys.VAL_LOSS,
193
+ def __init__(self, save_dir: Union[str,Path], monitor: str = PyTorchLogKeys.VAL_LOSS,
195
194
  save_best_only: bool = True, mode: Literal['auto', 'min', 'max']= 'auto', verbose: int = 0):
196
195
  """
197
196
  - If `save_best_only` is True, it saves the single best model, deleting the previous best.
@@ -199,7 +198,6 @@ class ModelCheckpoint(Callback):
199
198
 
200
199
  Args:
201
200
  save_dir (str): Directory where checkpoint files will be saved.
202
- checkpoint_name (str| None): If None, the filename will include the epoch and score.
203
201
  monitor (str): Metric to monitor.
204
202
  save_best_only (bool): If true, save only the best model.
205
203
  mode (str): One of {'auto', 'min', 'max'}.
@@ -215,9 +213,8 @@ class ModelCheckpoint(Callback):
215
213
  self.monitor = monitor
216
214
  self.save_best_only = save_best_only
217
215
  self.verbose = verbose
218
- if checkpoint_name:
219
- checkpoint_name = sanitize_filename(checkpoint_name)
220
- self.checkpoint_name = checkpoint_name
216
+ self._latest_checkpoint_path = None
217
+ self._checkpoint_name = PyTorchCheckpointKeys.CHECKPOINT_NAME
221
218
 
222
219
  # State variables to be managed during training
223
220
  self.saved_checkpoints = []
@@ -261,10 +258,7 @@ class ModelCheckpoint(Callback):
261
258
  old_best_str = f"{self.best:.4f}" if self.best not in [np.inf, -np.inf] else "inf"
262
259
 
263
260
  # Create a descriptive filename
264
- if self.checkpoint_name is None:
265
- filename = f"epoch_{epoch}-{self.monitor}_{current:.4f}.pth"
266
- else:
267
- filename = f"epoch{epoch}_{self.checkpoint_name}.pth"
261
+ filename = f"epoch{epoch}_{self._checkpoint_name}_{current:.4f}.pth"
268
262
  new_filepath = self.save_dir / filename
269
263
 
270
264
  if self.verbose > 0:
@@ -279,6 +273,7 @@ class ModelCheckpoint(Callback):
279
273
  PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
280
274
  PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
281
275
  PyTorchCheckpointKeys.BEST_SCORE: self.best,
276
+ PyTorchCheckpointKeys.HISTORY: self.trainer.history, # type: ignore
282
277
  }
283
278
 
284
279
  # Check for scheduler
@@ -287,6 +282,7 @@ class ModelCheckpoint(Callback):
287
282
 
288
283
  # Save the new best model
289
284
  torch.save(checkpoint_data, new_filepath)
285
+ self._latest_checkpoint_path = new_filepath
290
286
 
291
287
  # Delete the old best model file
292
288
  if self.last_best_filepath and self.last_best_filepath.exists():
@@ -298,10 +294,8 @@ class ModelCheckpoint(Callback):
298
294
  def _save_rolling_checkpoints(self, epoch, logs):
299
295
  """Saves the latest model and keeps only the most recent ones."""
300
296
  current = logs.get(self.monitor)
301
- if self.checkpoint_name is None:
302
- filename = f"epoch_{epoch}-{self.monitor}_{current:.4f}.pth"
303
- else:
304
- filename = f"epoch{epoch}_{self.checkpoint_name}.pth"
297
+
298
+ filename = f"epoch{epoch}_{self._checkpoint_name}_{current:.4f}.pth"
305
299
  filepath = self.save_dir / filename
306
300
 
307
301
  if self.verbose > 0:
@@ -313,12 +307,15 @@ class ModelCheckpoint(Callback):
313
307
  PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
314
308
  PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
315
309
  PyTorchCheckpointKeys.BEST_SCORE: self.best, # Save the current best score
310
+ PyTorchCheckpointKeys.HISTORY: self.trainer.history, # type: ignore
316
311
  }
317
312
 
318
313
  if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
319
314
  checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
320
315
 
321
316
  torch.save(checkpoint_data, filepath)
317
+
318
+ self._latest_checkpoint_path = filepath
322
319
 
323
320
  self.saved_checkpoints.append(filepath)
324
321
 
@@ -330,8 +327,16 @@ class ModelCheckpoint(Callback):
330
327
  _LOGGER.info(f" -> Deleting old checkpoint: {file_to_delete.name}")
331
328
  file_to_delete.unlink()
332
329
 
330
+ @property
331
+ def best_checkpoint_path(self):
332
+ if self._latest_checkpoint_path:
333
+ return self._latest_checkpoint_path
334
+ else:
335
+ _LOGGER.error("No checkpoint paths saved.")
336
+ raise ValueError()
337
+
333
338
 
334
- class LRScheduler(Callback):
339
+ class DragonLRScheduler(_Callback):
335
340
  """
336
341
  Callback to manage a PyTorch learning rate scheduler.
337
342
  """
@@ -361,6 +366,8 @@ class LRScheduler(Callback):
361
366
 
362
367
  def on_epoch_end(self, epoch, logs=None):
363
368
  """Step the scheduler and log any change in learning rate."""
369
+ logs = logs or {}
370
+
364
371
  # For schedulers that need a metric (e.g., val_loss)
365
372
  if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
366
373
  if self.monitor is None:
@@ -376,12 +383,22 @@ class LRScheduler(Callback):
376
383
  # For all other schedulers
377
384
  else:
378
385
  self.scheduler.step()
386
+
387
+ # Get the current learning rate
388
+ current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
379
389
 
380
390
  # Log the change if the LR was updated
381
- current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
382
391
  if current_lr != self.previous_lr:
383
392
  _LOGGER.info(f"Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
384
393
  self.previous_lr = current_lr
394
+
395
+ # --- Add LR to logs and history ---
396
+ # Add to the logs dict for any subsequent callbacks
397
+ logs[PyTorchLogKeys.LEARNING_RATE] = current_lr
398
+
399
+ # Also add directly to the trainer's history dict
400
+ if hasattr(self.trainer, 'history'):
401
+ self.trainer.history.setdefault(PyTorchLogKeys.LEARNING_RATE, []).append(current_lr) # type: ignore
385
402
 
386
403
 
387
404
  def info():