dragon-ml-toolbox 12.10.0__tar.gz → 14.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {dragon_ml_toolbox-12.10.0/dragon_ml_toolbox.egg-info → dragon_ml_toolbox-14.1.0}/PKG-INFO +2 -1
  2. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0/dragon_ml_toolbox.egg-info}/PKG-INFO +2 -1
  3. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/dragon_ml_toolbox.egg-info/SOURCES.txt +8 -1
  4. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/dragon_ml_toolbox.egg-info/requires.txt +1 -0
  5. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/MICE_imputation.py +207 -5
  6. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/ML_callbacks.py +70 -37
  7. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/ML_datasetmaster.py +221 -266
  8. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/ML_evaluation.py +107 -49
  9. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/ML_evaluation_multi.py +106 -32
  10. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/ML_inference.py +14 -5
  11. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/ML_models.py +137 -57
  12. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/ML_optimization.py +49 -36
  13. dragon_ml_toolbox-14.1.0/ml_tools/ML_trainer.py +1078 -0
  14. dragon_ml_toolbox-14.1.0/ml_tools/ML_utilities.py +528 -0
  15. dragon_ml_toolbox-14.1.0/ml_tools/ML_vision_datasetmaster.py +1315 -0
  16. dragon_ml_toolbox-14.1.0/ml_tools/ML_vision_evaluation.py +260 -0
  17. dragon_ml_toolbox-14.1.0/ml_tools/ML_vision_inference.py +428 -0
  18. dragon_ml_toolbox-14.1.0/ml_tools/ML_vision_models.py +627 -0
  19. dragon_ml_toolbox-14.1.0/ml_tools/ML_vision_transformers.py +58 -0
  20. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/PSO_optimization.py +5 -1
  21. dragon_ml_toolbox-14.1.0/ml_tools/_ML_pytorch_tabular.py +543 -0
  22. dragon_ml_toolbox-14.1.0/ml_tools/_ML_vision_recipe.py +88 -0
  23. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/__init__.py +1 -0
  24. dragon_ml_toolbox-14.1.0/ml_tools/_schema.py +96 -0
  25. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/custom_logger.py +38 -15
  26. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/data_exploration.py +576 -138
  27. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/keys.py +51 -1
  28. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/math_utilities.py +1 -1
  29. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/optimization_tools.py +65 -86
  30. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/serde.py +83 -30
  31. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/utilities.py +192 -3
  32. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/pyproject.toml +12 -2
  33. dragon_ml_toolbox-12.10.0/ml_tools/ML_simple_optimization.py +0 -413
  34. dragon_ml_toolbox-12.10.0/ml_tools/ML_trainer.py +0 -537
  35. dragon_ml_toolbox-12.10.0/ml_tools/ML_utilities.py +0 -230
  36. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/LICENSE +0 -0
  37. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/LICENSE-THIRD-PARTY.md +0 -0
  38. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/README.md +0 -0
  39. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/dragon_ml_toolbox.egg-info/dependency_links.txt +0 -0
  40. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/dragon_ml_toolbox.egg-info/top_level.txt +0 -0
  41. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/ETL_cleaning.py +0 -0
  42. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/ETL_engineering.py +0 -0
  43. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/GUI_tools.py +0 -0
  44. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/ML_scaler.py +0 -0
  45. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/RNN_forecast.py +0 -0
  46. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/SQL.py +0 -0
  47. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/VIF_factor.py +0 -0
  48. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/_logger.py +0 -0
  49. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/_script_info.py +0 -0
  50. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/constants.py +0 -0
  51. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/ensemble_evaluation.py +0 -0
  52. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/ensemble_inference.py +0 -0
  53. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/ensemble_learning.py +0 -0
  54. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/handle_excel.py +0 -0
  55. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/ml_tools/path_manager.py +0 -0
  56. {dragon_ml_toolbox-12.10.0 → dragon_ml_toolbox-14.1.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 12.10.0
3
+ Version: 14.1.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -34,6 +34,7 @@ Requires-Dist: Pillow; extra == "ml"
34
34
  Requires-Dist: evotorch; extra == "ml"
35
35
  Requires-Dist: pyarrow; extra == "ml"
36
36
  Requires-Dist: colorlog; extra == "ml"
37
+ Requires-Dist: torchmetrics; extra == "ml"
37
38
  Provides-Extra: mice
38
39
  Requires-Dist: numpy<2.0; extra == "mice"
39
40
  Requires-Dist: pandas; extra == "mice"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 12.10.0
3
+ Version: 14.1.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -34,6 +34,7 @@ Requires-Dist: Pillow; extra == "ml"
34
34
  Requires-Dist: evotorch; extra == "ml"
35
35
  Requires-Dist: pyarrow; extra == "ml"
36
36
  Requires-Dist: colorlog; extra == "ml"
37
+ Requires-Dist: torchmetrics; extra == "ml"
37
38
  Provides-Extra: mice
38
39
  Requires-Dist: numpy<2.0; extra == "mice"
39
40
  Requires-Dist: pandas; extra == "mice"
@@ -19,15 +19,22 @@ ml_tools/ML_inference.py
19
19
  ml_tools/ML_models.py
20
20
  ml_tools/ML_optimization.py
21
21
  ml_tools/ML_scaler.py
22
- ml_tools/ML_simple_optimization.py
23
22
  ml_tools/ML_trainer.py
24
23
  ml_tools/ML_utilities.py
24
+ ml_tools/ML_vision_datasetmaster.py
25
+ ml_tools/ML_vision_evaluation.py
26
+ ml_tools/ML_vision_inference.py
27
+ ml_tools/ML_vision_models.py
28
+ ml_tools/ML_vision_transformers.py
25
29
  ml_tools/PSO_optimization.py
26
30
  ml_tools/RNN_forecast.py
27
31
  ml_tools/SQL.py
28
32
  ml_tools/VIF_factor.py
33
+ ml_tools/_ML_pytorch_tabular.py
34
+ ml_tools/_ML_vision_recipe.py
29
35
  ml_tools/__init__.py
30
36
  ml_tools/_logger.py
37
+ ml_tools/_schema.py
31
38
  ml_tools/_script_info.py
32
39
  ml_tools/constants.py
33
40
  ml_tools/custom_logger.py
@@ -21,6 +21,7 @@ Pillow
21
21
  evotorch
22
22
  pyarrow
23
23
  colorlog
24
+ torchmetrics
24
25
 
25
26
  [excel]
26
27
  pandas
@@ -7,19 +7,20 @@ from plotnine import ggplot, labs, theme, element_blank # type: ignore
7
7
  from typing import Optional, Union
8
8
 
9
9
  from .utilities import load_dataframe, merge_dataframes, save_dataframe_filename
10
- from .math_utilities import threshold_binary_values
10
+ from .math_utilities import threshold_binary_values, discretize_categorical_values
11
11
  from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
12
12
  from ._logger import _LOGGER
13
13
  from ._script_info import _script_info
14
+ from ._schema import FeatureSchema
14
15
 
15
16
 
16
17
  __all__ = [
18
+ "MiceImputer",
17
19
  "apply_mice",
18
20
  "save_imputed_datasets",
19
- "get_na_column_names",
20
21
  "get_convergence_diagnostic",
21
22
  "get_imputed_distributions",
22
- "run_mice_pipeline"
23
+ "run_mice_pipeline",
23
24
  ]
24
25
 
25
26
 
@@ -79,7 +80,7 @@ def save_imputed_datasets(save_dir: Union[str, Path], imputed_datasets: list, df
79
80
 
80
81
 
81
82
  #Get names of features that had missing values before imputation
82
- def get_na_column_names(df: pd.DataFrame):
83
+ def _get_na_column_names(df: pd.DataFrame):
83
84
  return [col for col in df.columns if df[col].isna().any()]
84
85
 
85
86
 
@@ -264,7 +265,7 @@ def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str]
264
265
 
265
266
  save_imputed_datasets(save_dir=save_datasets_path, imputed_datasets=imputed_datasets, df_targets=df_targets, imputed_dataset_names=imputed_dataset_names)
266
267
 
267
- imputed_column_names = get_na_column_names(df=df)
268
+ imputed_column_names = _get_na_column_names(df=df)
268
269
 
269
270
  get_convergence_diagnostic(kernel=kernel, imputed_dataset_names=imputed_dataset_names, column_names=imputed_column_names, root_dir=save_metrics_path)
270
271
 
@@ -278,5 +279,206 @@ def _skip_targets(df: pd.DataFrame, target_cols: list[str]):
278
279
  return df_feats, df_targets
279
280
 
280
281
 
282
+ # modern implementation
283
+ class MiceImputer:
284
+ """
285
+ A modern MICE imputation pipeline that uses a FeatureSchema
286
+ to correctly discretize categorical features after imputation.
287
+ """
288
+ def __init__(self,
289
+ schema: FeatureSchema,
290
+ iterations: int=20,
291
+ resulting_datasets: int = 1,
292
+ random_state: int = 101):
293
+
294
+ self.schema = schema
295
+ self.random_state = random_state
296
+ self.iterations = iterations
297
+ self.resulting_datasets = resulting_datasets
298
+
299
+ # --- Store schema info ---
300
+
301
+ # 1. Categorical info
302
+ if not self.schema.categorical_index_map:
303
+ _LOGGER.warning("FeatureSchema has no 'categorical_index_map'. No discretization will be applied.")
304
+ self.cat_info = {}
305
+ else:
306
+ self.cat_info = self.schema.categorical_index_map
307
+
308
+ # 2. Ordered feature names (critical for index mapping)
309
+ self.ordered_features = list(self.schema.feature_names)
310
+
311
+ # 3. Names of categorical features
312
+ self.categorical_features = list(self.schema.categorical_feature_names)
313
+
314
+ _LOGGER.info(f"MiceImputer initialized. Found {len(self.cat_info)} categorical features to discretize.")
315
+
316
+ def _post_process(self, imputed_df: pd.DataFrame) -> pd.DataFrame:
317
+ """
318
+ Applies schema-based discretization to a completed dataframe.
319
+
320
+ This method works around the behavior of `discretize_categorical_values`
321
+ (which returns a full int32 array) by:
322
+ 1. Calling it on the full, ordered feature array.
323
+ 2. Extracting *only* the valid discretized categorical columns.
324
+ 3. Updating the original float dataframe with these integer values.
325
+ """
326
+ # If no categorical features are defined, return the df as-is.
327
+ if not self.cat_info:
328
+ return imputed_df
329
+
330
+ try:
331
+ # 1. Ensure DataFrame columns match the schema order
332
+ # This is critical for the index-based categorical_info
333
+ df_ordered: pd.DataFrame = imputed_df[self.ordered_features] # type: ignore
334
+
335
+ # 2. Convert to NumPy array
336
+ array_ordered = df_ordered.to_numpy()
337
+
338
+ # 3. Apply discretization utility (which returns a full int32 array)
339
+ # This array has *correct* categorical values but *truncated* continuous values.
340
+ discretized_array_int32 = discretize_categorical_values(
341
+ array_ordered,
342
+ self.cat_info,
343
+ start_at_zero=True # Assuming 0-based indexing
344
+ )
345
+
346
+ # 4. Create a new DF from the int32 array, keeping the categorical columns.
347
+ df_discretized_cats = pd.DataFrame(
348
+ discretized_array_int32,
349
+ columns=self.ordered_features,
350
+ index=df_ordered.index # <-- Critical: align index
351
+ )[self.categorical_features] # <-- Select only cat features
352
+
353
+ # 5. "Rejoin": Start with a fresh copy of the *original* imputed DF (which has correct continuous floats).
354
+ final_df = df_ordered.copy()
355
+
356
+ # 6. Use .update() to "paste" the integer categorical values
357
+ # over the old float categorical values. Continuous floats are unaffected.
358
+ final_df.update(df_discretized_cats)
359
+
360
+ return final_df
361
+
362
+ except Exception as e:
363
+ _LOGGER.error(f"Failed during post-processing discretization:\n\tInput DF shape: {imputed_df.shape}\n\tSchema features: {len(self.ordered_features)}\n\tCategorical info keys: {list(self.cat_info.keys())}\n{e}")
364
+ raise
365
+
366
+ def _run_mice(self,
367
+ df: pd.DataFrame,
368
+ df_name: str) -> tuple[mf.ImputationKernel, list[pd.DataFrame], list[str]]:
369
+ """
370
+ Runs the MICE kernel and applies schema-based post-processing.
371
+
372
+ Parameters:
373
+ df (pd.DataFrame): The input dataframe *with NaNs*. Should only contain feature columns.
374
+ df_name (str): The base name for the dataset.
375
+
376
+ Returns:
377
+ tuple[mf.ImputationKernel, list[pd.DataFrame], list[str]]:
378
+ - The trained MICE kernel
379
+ - A list of imputed and processed DataFrames
380
+ - A list of names for the new DataFrames
381
+ """
382
+ # Ensure input df only contains features from the schema and is in the correct order.
383
+ try:
384
+ df_feats = df[self.ordered_features]
385
+ except KeyError as e:
386
+ _LOGGER.error(f"Input DataFrame is missing required schema columns: {e}")
387
+ raise
388
+
389
+ # 1. Initialize kernel
390
+ kernel = mf.ImputationKernel(
391
+ data=df_feats,
392
+ num_datasets=self.resulting_datasets,
393
+ random_state=self.random_state
394
+ )
395
+
396
+ _LOGGER.info("➡️ Schema-based MICE imputation running...")
397
+
398
+ # 2. Perform MICE
399
+ kernel.mice(self.iterations)
400
+
401
+ # 3. Retrieve, process, and collect datasets
402
+ imputed_datasets = []
403
+ for i in range(self.resulting_datasets):
404
+ # complete_data returns a pd.DataFrame
405
+ completed_df = kernel.complete_data(dataset=i)
406
+
407
+ # Apply our new discretization and ordering
408
+ processed_df = self._post_process(completed_df)
409
+ imputed_datasets.append(processed_df)
410
+
411
+ if not imputed_datasets:
412
+ _LOGGER.error("No imputed datasets were generated.")
413
+ raise ValueError()
414
+
415
+ # 4. Generate names
416
+ if self.resulting_datasets == 1:
417
+ imputed_dataset_names = [f"{df_name}_MICE"]
418
+ else:
419
+ imputed_dataset_names = [f"{df_name}_MICE_{i+1}" for i in range(self.resulting_datasets)]
420
+
421
+ # 5. Validate indexes
422
+ for imputed_df, subname in zip(imputed_datasets, imputed_dataset_names):
423
+ assert imputed_df.shape[0] == df.shape[0], f"❌ Row count mismatch in dataset {subname}"
424
+ assert all(imputed_df.index == df.index), f"❌ Index mismatch in dataset {subname}"
425
+
426
+ _LOGGER.info("Schema-based MICE imputation complete.")
427
+
428
+ return kernel, imputed_datasets, imputed_dataset_names
429
+
430
+ def run_pipeline(self,
431
+ df_path_or_dir: Union[str,Path],
432
+ save_datasets_dir: Union[str,Path],
433
+ save_metrics_dir: Union[str,Path],
434
+ ):
435
+ """
436
+ Runs the complete MICE imputation pipeline.
437
+
438
+ This method automates the entire workflow:
439
+ 1. Loads data from a CSV file path or a directory with CSV files.
440
+ 2. Separates features and targets based on the `FeatureSchema`.
441
+ 3. Runs the MICE algorithm on the feature set.
442
+ 4. Applies schema-based post-processing to discretize categorical features.
443
+ 5. Saves the final, processed, and imputed dataset(s) (re-joined with targets) to `save_datasets_dir`.
444
+ 6. Generates and saves convergence and distribution plots for all imputed columns to `save_metrics_dir`.
445
+
446
+ Parameters
447
+ ----------
448
+ df_path_or_dir :[str,Path]
449
+ Path to a single CSV file or a directory containing multiple CSV files to impute.
450
+ save_datasets_dir : [str,Path]
451
+ Directory where the final imputed and processed dataset(s) will be saved as CSVs.
452
+ save_metrics_dir : [str,Path]
453
+ Directory where convergence and distribution plots will be saved.
454
+ """
455
+ # Check paths
456
+ save_datasets_path = make_fullpath(save_datasets_dir, make=True)
457
+ save_metrics_path = make_fullpath(save_metrics_dir, make=True)
458
+
459
+ input_path = make_fullpath(df_path_or_dir)
460
+ if input_path.is_file():
461
+ all_file_paths = [input_path]
462
+ else:
463
+ all_file_paths = list(list_csv_paths(input_path).values())
464
+
465
+ for df_path in all_file_paths:
466
+
467
+ df, df_name = load_dataframe(df_path=df_path, kind="pandas")
468
+
469
+ df_features: pd.DataFrame = df[self.schema.feature_names] # type: ignore
470
+ df_targets = df.drop(columns=self.schema.feature_names)
471
+
472
+ imputed_column_names = _get_na_column_names(df=df_features)
473
+
474
+ kernel, imputed_datasets, imputed_dataset_names = self._run_mice(df=df_features, df_name=df_name)
475
+
476
+ save_imputed_datasets(save_dir=save_datasets_path, imputed_datasets=imputed_datasets, df_targets=df_targets, imputed_dataset_names=imputed_dataset_names)
477
+
478
+ get_convergence_diagnostic(kernel=kernel, imputed_dataset_names=imputed_dataset_names, column_names=imputed_column_names, root_dir=save_metrics_path)
479
+
480
+ get_imputed_distributions(kernel=kernel, df_name=df_name, root_dir=save_metrics_path, column_names=imputed_column_names)
481
+
482
+
281
483
  def info():
282
484
  _script_info(__all__)
@@ -5,7 +5,7 @@ from typing import Union, Literal, Optional
5
5
  from pathlib import Path
6
6
 
7
7
  from .path_manager import make_fullpath, sanitize_filename
8
- from .keys import PyTorchLogKeys
8
+ from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
9
9
  from ._logger import _LOGGER
10
10
  from ._script_info import _script_info
11
11
 
@@ -113,18 +113,19 @@ class TqdmProgressBar(Callback):
113
113
  class EarlyStopping(Callback):
114
114
  """
115
115
  Stop training when a monitored metric has stopped improving.
116
-
117
- Args:
118
- monitor (str): Quantity to be monitored. Defaults to 'val_loss'.
119
- min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
120
- patience (int): Number of epochs with no improvement after which training will be stopped.
121
- mode (str): One of {'auto', 'min', 'max'}. In 'min' mode, training will stop when the quantity
122
- monitored has stopped decreasing; in 'max' mode it will stop when the quantity
123
- monitored has stopped increasing; in 'auto' mode, the direction is automatically
124
- inferred from the name of the monitored quantity.
125
- verbose (int): Verbosity mode.
126
116
  """
127
117
  def __init__(self, monitor: str=PyTorchLogKeys.VAL_LOSS, min_delta: float=0.0, patience: int=5, mode: Literal['auto', 'min', 'max']='auto', verbose: int=1):
118
+ """
119
+ Args:
120
+ monitor (str): Quantity to be monitored. Defaults to 'val_loss'.
121
+ min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
122
+ patience (int): Number of epochs with no improvement after which training will be stopped.
123
+ mode (str): One of {'auto', 'min', 'max'}. In 'min' mode, training will stop when the quantity
124
+ monitored has stopped decreasing; in 'max' mode it will stop when the quantity
125
+ monitored has stopped increasing; in 'auto' mode, the direction is automatically
126
+ inferred from the name of the monitored quantity.
127
+ verbose (int): Verbosity mode.
128
+ """
128
129
  super().__init__()
129
130
  self.monitor = monitor
130
131
  self.patience = patience
@@ -188,22 +189,23 @@ class EarlyStopping(Callback):
188
189
 
189
190
  class ModelCheckpoint(Callback):
190
191
  """
191
- Saves the model to a directory with automated filename generation and rotation. The filename includes the epoch and score.
192
-
193
- - If `save_best_only` is True, it saves the single best model, deleting the
194
- previous best.
195
- - If `save_best_only` is False, it keeps the 3 most recent checkpoints,
196
- deleting the oldest ones automatically.
197
-
198
- Args:
199
- save_dir (str): Directory where checkpoint files will be saved.
200
- monitor (str): Metric to monitor for `save_best_only=True`.
201
- save_best_only (bool): If true, save only the best model.
202
- mode (str): One of {'auto', 'min', 'max'}.
203
- verbose (int): Verbosity mode.
192
+ Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
204
193
  """
205
194
  def __init__(self, save_dir: Union[str,Path], checkpoint_name: Optional[str]=None, monitor: str = PyTorchLogKeys.VAL_LOSS,
206
195
  save_best_only: bool = True, mode: Literal['auto', 'min', 'max']= 'auto', verbose: int = 0):
196
+ """
197
+ - If `save_best_only` is True, it saves the single best model, deleting the previous best.
198
+ - If `save_best_only` is False, it keeps the 3 most recent checkpoints, deleting the oldest ones automatically.
199
+
200
+ Args:
201
+ save_dir (str): Directory where checkpoint files will be saved.
202
+ checkpoint_name (str| None): If None, the filename will include the epoch and score.
203
+ monitor (str): Metric to monitor.
204
+ save_best_only (bool): If true, save only the best model.
205
+ mode (str): One of {'auto', 'min', 'max'}.
206
+ verbose (int): Verbosity mode.
207
+ """
208
+
207
209
  super().__init__()
208
210
  self.save_dir = make_fullpath(save_dir, make=True, enforce="directory")
209
211
  if not self.save_dir.is_dir():
@@ -268,15 +270,29 @@ class ModelCheckpoint(Callback):
268
270
  if self.verbose > 0:
269
271
  _LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current:.4f}, saving model to {new_filepath}")
270
272
 
273
+ # Update best score *before* saving
274
+ self.best = current
275
+
276
+ # Create a comprehensive checkpoint dictionary
277
+ checkpoint_data = {
278
+ PyTorchCheckpointKeys.EPOCH: epoch,
279
+ PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
280
+ PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
281
+ PyTorchCheckpointKeys.BEST_SCORE: self.best,
282
+ }
283
+
284
+ # Check for scheduler
285
+ if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
286
+ checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
287
+
271
288
  # Save the new best model
272
- torch.save(self.trainer.model.state_dict(), new_filepath) # type: ignore
289
+ torch.save(checkpoint_data, new_filepath)
273
290
 
274
291
  # Delete the old best model file
275
292
  if self.last_best_filepath and self.last_best_filepath.exists():
276
293
  self.last_best_filepath.unlink()
277
294
 
278
295
  # Update state
279
- self.best = current
280
296
  self.last_best_filepath = new_filepath
281
297
 
282
298
  def _save_rolling_checkpoints(self, epoch, logs):
@@ -290,7 +306,19 @@ class ModelCheckpoint(Callback):
290
306
 
291
307
  if self.verbose > 0:
292
308
  _LOGGER.info(f'Epoch {epoch}: saving model to {filepath}')
293
- torch.save(self.trainer.model.state_dict(), filepath) # type: ignore
309
+
310
+ # Create a comprehensive checkpoint dictionary
311
+ checkpoint_data = {
312
+ PyTorchCheckpointKeys.EPOCH: epoch,
313
+ PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
314
+ PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
315
+ PyTorchCheckpointKeys.BEST_SCORE: self.best, # Save the current best score
316
+ }
317
+
318
+ if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
319
+ checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
320
+
321
+ torch.save(checkpoint_data, filepath)
294
322
 
295
323
  self.saved_checkpoints.append(filepath)
296
324
 
@@ -306,21 +334,26 @@ class ModelCheckpoint(Callback):
306
334
  class LRScheduler(Callback):
307
335
  """
308
336
  Callback to manage a PyTorch learning rate scheduler.
309
-
310
- This callback automatically calls the scheduler's `step()` method at the
311
- end of each epoch. It also logs a message when the learning rate changes.
312
-
313
- Args:
314
- scheduler: An initialized PyTorch learning rate scheduler.
315
- monitor (str, optional): The metric to monitor for schedulers that
316
- require it, like `ReduceLROnPlateau`.
317
- Should match a key in the logs (e.g., 'val_loss').
318
337
  """
319
- def __init__(self, scheduler, monitor: Optional[str] = None):
338
+ def __init__(self, scheduler, monitor: Optional[str] = PyTorchLogKeys.VAL_LOSS):
339
+ """
340
+ This callback automatically calls the scheduler's `step()` method at the
341
+ end of each epoch. It also logs a message when the learning rate changes.
342
+
343
+ Args:
344
+ scheduler: An initialized PyTorch learning rate scheduler.
345
+ monitor (str): The metric to monitor for schedulers that require it, like `ReduceLROnPlateau`. Should match a key in the logs (e.g., 'val_loss').
346
+ """
320
347
  super().__init__()
321
348
  self.scheduler = scheduler
322
349
  self.monitor = monitor
323
350
  self.previous_lr = None
351
+
352
+ def set_trainer(self, trainer):
353
+ """This is called by the Trainer to associate itself with the callback."""
354
+ super().set_trainer(trainer)
355
+ # Register the scheduler with the trainer so it can be added to the checkpoint
356
+ self.trainer.scheduler = self.scheduler # type: ignore
324
357
 
325
358
  def on_train_begin(self, logs=None):
326
359
  """Store the initial learning rate."""