dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
- dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +788 -0
- ml_tools/ML_datasetmaster.py +303 -448
- ml_tools/ML_evaluation.py +351 -93
- ml_tools/ML_evaluation_multi.py +139 -42
- ml_tools/ML_inference.py +290 -209
- ml_tools/ML_models.py +33 -106
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +219 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1604 -179
- ml_tools/ML_utilities.py +351 -4
- ml_tools/ML_vision_datasetmaster.py +1540 -0
- ml_tools/ML_vision_evaluation.py +284 -0
- ml_tools/ML_vision_inference.py +405 -0
- ml_tools/ML_vision_models.py +641 -0
- ml_tools/ML_vision_transformers.py +284 -0
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/_keys.py +171 -0
- ml_tools/_schema.py +1 -1
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +502 -93
- ml_tools/ensemble_evaluation.py +54 -11
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/math_utilities.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/serde.py +2 -2
- ml_tools/utilities.py +192 -4
- dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/keys.py +0 -87
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
ml_tools/MICE_imputation.py
CHANGED
|
@@ -7,19 +7,20 @@ from plotnine import ggplot, labs, theme, element_blank # type: ignore
|
|
|
7
7
|
from typing import Optional, Union
|
|
8
8
|
|
|
9
9
|
from .utilities import load_dataframe, merge_dataframes, save_dataframe_filename
|
|
10
|
-
from .math_utilities import threshold_binary_values
|
|
10
|
+
from .math_utilities import threshold_binary_values, discretize_categorical_values
|
|
11
11
|
from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
|
|
12
12
|
from ._logger import _LOGGER
|
|
13
13
|
from ._script_info import _script_info
|
|
14
|
+
from ._schema import FeatureSchema
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
__all__ = [
|
|
18
|
+
"DragonMICE",
|
|
17
19
|
"apply_mice",
|
|
18
20
|
"save_imputed_datasets",
|
|
19
|
-
"get_na_column_names",
|
|
20
21
|
"get_convergence_diagnostic",
|
|
21
22
|
"get_imputed_distributions",
|
|
22
|
-
"run_mice_pipeline"
|
|
23
|
+
"run_mice_pipeline",
|
|
23
24
|
]
|
|
24
25
|
|
|
25
26
|
|
|
@@ -79,7 +80,7 @@ def save_imputed_datasets(save_dir: Union[str, Path], imputed_datasets: list, df
|
|
|
79
80
|
|
|
80
81
|
|
|
81
82
|
#Get names of features that had missing values before imputation
|
|
82
|
-
def
|
|
83
|
+
def _get_na_column_names(df: pd.DataFrame):
|
|
83
84
|
return [col for col in df.columns if df[col].isna().any()]
|
|
84
85
|
|
|
85
86
|
|
|
@@ -264,7 +265,7 @@ def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str]
|
|
|
264
265
|
|
|
265
266
|
save_imputed_datasets(save_dir=save_datasets_path, imputed_datasets=imputed_datasets, df_targets=df_targets, imputed_dataset_names=imputed_dataset_names)
|
|
266
267
|
|
|
267
|
-
imputed_column_names =
|
|
268
|
+
imputed_column_names = _get_na_column_names(df=df)
|
|
268
269
|
|
|
269
270
|
get_convergence_diagnostic(kernel=kernel, imputed_dataset_names=imputed_dataset_names, column_names=imputed_column_names, root_dir=save_metrics_path)
|
|
270
271
|
|
|
@@ -278,5 +279,206 @@ def _skip_targets(df: pd.DataFrame, target_cols: list[str]):
|
|
|
278
279
|
return df_feats, df_targets
|
|
279
280
|
|
|
280
281
|
|
|
282
|
+
# modern implementation
|
|
283
|
+
class DragonMICE:
|
|
284
|
+
"""
|
|
285
|
+
A modern MICE imputation pipeline that uses a FeatureSchema
|
|
286
|
+
to correctly discretize categorical features after imputation.
|
|
287
|
+
"""
|
|
288
|
+
def __init__(self,
|
|
289
|
+
schema: FeatureSchema,
|
|
290
|
+
iterations: int=20,
|
|
291
|
+
resulting_datasets: int = 1,
|
|
292
|
+
random_state: int = 101):
|
|
293
|
+
|
|
294
|
+
self.schema = schema
|
|
295
|
+
self.random_state = random_state
|
|
296
|
+
self.iterations = iterations
|
|
297
|
+
self.resulting_datasets = resulting_datasets
|
|
298
|
+
|
|
299
|
+
# --- Store schema info ---
|
|
300
|
+
|
|
301
|
+
# 1. Categorical info
|
|
302
|
+
if not self.schema.categorical_index_map:
|
|
303
|
+
_LOGGER.warning("FeatureSchema has no 'categorical_index_map'. No discretization will be applied.")
|
|
304
|
+
self.cat_info = {}
|
|
305
|
+
else:
|
|
306
|
+
self.cat_info = self.schema.categorical_index_map
|
|
307
|
+
|
|
308
|
+
# 2. Ordered feature names (critical for index mapping)
|
|
309
|
+
self.ordered_features = list(self.schema.feature_names)
|
|
310
|
+
|
|
311
|
+
# 3. Names of categorical features
|
|
312
|
+
self.categorical_features = list(self.schema.categorical_feature_names)
|
|
313
|
+
|
|
314
|
+
_LOGGER.info(f"DragonMICE initialized. Found {len(self.cat_info)} categorical features to discretize.")
|
|
315
|
+
|
|
316
|
+
def _post_process(self, imputed_df: pd.DataFrame) -> pd.DataFrame:
|
|
317
|
+
"""
|
|
318
|
+
Applies schema-based discretization to a completed dataframe.
|
|
319
|
+
|
|
320
|
+
This method works around the behavior of `discretize_categorical_values`
|
|
321
|
+
(which returns a full int32 array) by:
|
|
322
|
+
1. Calling it on the full, ordered feature array.
|
|
323
|
+
2. Extracting *only* the valid discretized categorical columns.
|
|
324
|
+
3. Updating the original float dataframe with these integer values.
|
|
325
|
+
"""
|
|
326
|
+
# If no categorical features are defined, return the df as-is.
|
|
327
|
+
if not self.cat_info:
|
|
328
|
+
return imputed_df
|
|
329
|
+
|
|
330
|
+
try:
|
|
331
|
+
# 1. Ensure DataFrame columns match the schema order
|
|
332
|
+
# This is critical for the index-based categorical_info
|
|
333
|
+
df_ordered: pd.DataFrame = imputed_df[self.ordered_features] # type: ignore
|
|
334
|
+
|
|
335
|
+
# 2. Convert to NumPy array
|
|
336
|
+
array_ordered = df_ordered.to_numpy()
|
|
337
|
+
|
|
338
|
+
# 3. Apply discretization utility (which returns a full int32 array)
|
|
339
|
+
# This array has *correct* categorical values but *truncated* continuous values.
|
|
340
|
+
discretized_array_int32 = discretize_categorical_values(
|
|
341
|
+
array_ordered,
|
|
342
|
+
self.cat_info,
|
|
343
|
+
start_at_zero=True # Assuming 0-based indexing
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# 4. Create a new DF from the int32 array, keeping the categorical columns.
|
|
347
|
+
df_discretized_cats = pd.DataFrame(
|
|
348
|
+
discretized_array_int32,
|
|
349
|
+
columns=self.ordered_features,
|
|
350
|
+
index=df_ordered.index # <-- Critical: align index
|
|
351
|
+
)[self.categorical_features] # <-- Select only cat features
|
|
352
|
+
|
|
353
|
+
# 5. "Rejoin": Start with a fresh copy of the *original* imputed DF (which has correct continuous floats).
|
|
354
|
+
final_df = df_ordered.copy()
|
|
355
|
+
|
|
356
|
+
# 6. Use .update() to "paste" the integer categorical values
|
|
357
|
+
# over the old float categorical values. Continuous floats are unaffected.
|
|
358
|
+
final_df.update(df_discretized_cats)
|
|
359
|
+
|
|
360
|
+
return final_df
|
|
361
|
+
|
|
362
|
+
except Exception as e:
|
|
363
|
+
_LOGGER.error(f"Failed during post-processing discretization:\n\tInput DF shape: {imputed_df.shape}\n\tSchema features: {len(self.ordered_features)}\n\tCategorical info keys: {list(self.cat_info.keys())}\n{e}")
|
|
364
|
+
raise
|
|
365
|
+
|
|
366
|
+
def _run_mice(self,
|
|
367
|
+
df: pd.DataFrame,
|
|
368
|
+
df_name: str) -> tuple[mf.ImputationKernel, list[pd.DataFrame], list[str]]:
|
|
369
|
+
"""
|
|
370
|
+
Runs the MICE kernel and applies schema-based post-processing.
|
|
371
|
+
|
|
372
|
+
Parameters:
|
|
373
|
+
df (pd.DataFrame): The input dataframe *with NaNs*. Should only contain feature columns.
|
|
374
|
+
df_name (str): The base name for the dataset.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
tuple[mf.ImputationKernel, list[pd.DataFrame], list[str]]:
|
|
378
|
+
- The trained MICE kernel
|
|
379
|
+
- A list of imputed and processed DataFrames
|
|
380
|
+
- A list of names for the new DataFrames
|
|
381
|
+
"""
|
|
382
|
+
# Ensure input df only contains features from the schema and is in the correct order.
|
|
383
|
+
try:
|
|
384
|
+
df_feats = df[self.ordered_features]
|
|
385
|
+
except KeyError as e:
|
|
386
|
+
_LOGGER.error(f"Input DataFrame is missing required schema columns: {e}")
|
|
387
|
+
raise
|
|
388
|
+
|
|
389
|
+
# 1. Initialize kernel
|
|
390
|
+
kernel = mf.ImputationKernel(
|
|
391
|
+
data=df_feats,
|
|
392
|
+
num_datasets=self.resulting_datasets,
|
|
393
|
+
random_state=self.random_state
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
_LOGGER.info("➡️ Schema-based MICE imputation running...")
|
|
397
|
+
|
|
398
|
+
# 2. Perform MICE
|
|
399
|
+
kernel.mice(self.iterations)
|
|
400
|
+
|
|
401
|
+
# 3. Retrieve, process, and collect datasets
|
|
402
|
+
imputed_datasets = []
|
|
403
|
+
for i in range(self.resulting_datasets):
|
|
404
|
+
# complete_data returns a pd.DataFrame
|
|
405
|
+
completed_df = kernel.complete_data(dataset=i)
|
|
406
|
+
|
|
407
|
+
# Apply our new discretization and ordering
|
|
408
|
+
processed_df = self._post_process(completed_df)
|
|
409
|
+
imputed_datasets.append(processed_df)
|
|
410
|
+
|
|
411
|
+
if not imputed_datasets:
|
|
412
|
+
_LOGGER.error("No imputed datasets were generated.")
|
|
413
|
+
raise ValueError()
|
|
414
|
+
|
|
415
|
+
# 4. Generate names
|
|
416
|
+
if self.resulting_datasets == 1:
|
|
417
|
+
imputed_dataset_names = [f"{df_name}_MICE"]
|
|
418
|
+
else:
|
|
419
|
+
imputed_dataset_names = [f"{df_name}_MICE_{i+1}" for i in range(self.resulting_datasets)]
|
|
420
|
+
|
|
421
|
+
# 5. Validate indexes
|
|
422
|
+
for imputed_df, subname in zip(imputed_datasets, imputed_dataset_names):
|
|
423
|
+
assert imputed_df.shape[0] == df.shape[0], f"❌ Row count mismatch in dataset {subname}"
|
|
424
|
+
assert all(imputed_df.index == df.index), f"❌ Index mismatch in dataset {subname}"
|
|
425
|
+
|
|
426
|
+
_LOGGER.info("Schema-based MICE imputation complete.")
|
|
427
|
+
|
|
428
|
+
return kernel, imputed_datasets, imputed_dataset_names
|
|
429
|
+
|
|
430
|
+
def run_pipeline(self,
|
|
431
|
+
df_path_or_dir: Union[str,Path],
|
|
432
|
+
save_datasets_dir: Union[str,Path],
|
|
433
|
+
save_metrics_dir: Union[str,Path],
|
|
434
|
+
):
|
|
435
|
+
"""
|
|
436
|
+
Runs the complete MICE imputation pipeline.
|
|
437
|
+
|
|
438
|
+
This method automates the entire workflow:
|
|
439
|
+
1. Loads data from a CSV file path or a directory with CSV files.
|
|
440
|
+
2. Separates features and targets based on the `FeatureSchema`.
|
|
441
|
+
3. Runs the MICE algorithm on the feature set.
|
|
442
|
+
4. Applies schema-based post-processing to discretize categorical features.
|
|
443
|
+
5. Saves the final, processed, and imputed dataset(s) (re-joined with targets) to `save_datasets_dir`.
|
|
444
|
+
6. Generates and saves convergence and distribution plots for all imputed columns to `save_metrics_dir`.
|
|
445
|
+
|
|
446
|
+
Parameters
|
|
447
|
+
----------
|
|
448
|
+
df_path_or_dir :[str,Path]
|
|
449
|
+
Path to a single CSV file or a directory containing multiple CSV files to impute.
|
|
450
|
+
save_datasets_dir : [str,Path]
|
|
451
|
+
Directory where the final imputed and processed dataset(s) will be saved as CSVs.
|
|
452
|
+
save_metrics_dir : [str,Path]
|
|
453
|
+
Directory where convergence and distribution plots will be saved.
|
|
454
|
+
"""
|
|
455
|
+
# Check paths
|
|
456
|
+
save_datasets_path = make_fullpath(save_datasets_dir, make=True)
|
|
457
|
+
save_metrics_path = make_fullpath(save_metrics_dir, make=True)
|
|
458
|
+
|
|
459
|
+
input_path = make_fullpath(df_path_or_dir)
|
|
460
|
+
if input_path.is_file():
|
|
461
|
+
all_file_paths = [input_path]
|
|
462
|
+
else:
|
|
463
|
+
all_file_paths = list(list_csv_paths(input_path).values())
|
|
464
|
+
|
|
465
|
+
for df_path in all_file_paths:
|
|
466
|
+
|
|
467
|
+
df, df_name = load_dataframe(df_path=df_path, kind="pandas")
|
|
468
|
+
|
|
469
|
+
df_features: pd.DataFrame = df[self.schema.feature_names] # type: ignore
|
|
470
|
+
df_targets = df.drop(columns=self.schema.feature_names)
|
|
471
|
+
|
|
472
|
+
imputed_column_names = _get_na_column_names(df=df_features)
|
|
473
|
+
|
|
474
|
+
kernel, imputed_datasets, imputed_dataset_names = self._run_mice(df=df_features, df_name=df_name)
|
|
475
|
+
|
|
476
|
+
save_imputed_datasets(save_dir=save_datasets_path, imputed_datasets=imputed_datasets, df_targets=df_targets, imputed_dataset_names=imputed_dataset_names)
|
|
477
|
+
|
|
478
|
+
get_convergence_diagnostic(kernel=kernel, imputed_dataset_names=imputed_dataset_names, column_names=imputed_column_names, root_dir=save_metrics_path)
|
|
479
|
+
|
|
480
|
+
get_imputed_distributions(kernel=kernel, df_name=df_name, root_dir=save_metrics_path, column_names=imputed_column_names)
|
|
481
|
+
|
|
482
|
+
|
|
281
483
|
def info():
|
|
282
484
|
_script_info(__all__)
|
ml_tools/ML_callbacks.py
CHANGED
|
@@ -4,23 +4,22 @@ from tqdm.auto import tqdm
|
|
|
4
4
|
from typing import Union, Literal, Optional
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
|
|
7
|
-
from .path_manager import make_fullpath
|
|
8
|
-
from .
|
|
7
|
+
from .path_manager import make_fullpath
|
|
8
|
+
from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys
|
|
9
9
|
from ._logger import _LOGGER
|
|
10
10
|
from ._script_info import _script_info
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
__all__ = [
|
|
14
|
-
"Callback",
|
|
15
14
|
"History",
|
|
16
15
|
"TqdmProgressBar",
|
|
17
|
-
"
|
|
18
|
-
"
|
|
19
|
-
"
|
|
16
|
+
"DragonEarlyStopping",
|
|
17
|
+
"DragonModelCheckpoint",
|
|
18
|
+
"DragonLRScheduler"
|
|
20
19
|
]
|
|
21
20
|
|
|
22
21
|
|
|
23
|
-
class
|
|
22
|
+
class _Callback:
|
|
24
23
|
"""
|
|
25
24
|
Abstract base class used to build new callbacks.
|
|
26
25
|
|
|
@@ -60,7 +59,7 @@ class Callback:
|
|
|
60
59
|
pass
|
|
61
60
|
|
|
62
61
|
|
|
63
|
-
class History(
|
|
62
|
+
class History(_Callback):
|
|
64
63
|
"""
|
|
65
64
|
Callback that records events into a `history` dictionary.
|
|
66
65
|
|
|
@@ -79,7 +78,7 @@ class History(Callback):
|
|
|
79
78
|
self.trainer.history.setdefault(k, []).append(v) # type: ignore
|
|
80
79
|
|
|
81
80
|
|
|
82
|
-
class TqdmProgressBar(
|
|
81
|
+
class TqdmProgressBar(_Callback):
|
|
83
82
|
"""Callback that provides a tqdm progress bar for training."""
|
|
84
83
|
def __init__(self):
|
|
85
84
|
self.epoch_bar = None
|
|
@@ -110,7 +109,7 @@ class TqdmProgressBar(Callback):
|
|
|
110
109
|
self.epoch_bar.close() # type: ignore
|
|
111
110
|
|
|
112
111
|
|
|
113
|
-
class
|
|
112
|
+
class DragonEarlyStopping(_Callback):
|
|
114
113
|
"""
|
|
115
114
|
Stop training when a monitored metric has stopped improving.
|
|
116
115
|
"""
|
|
@@ -187,11 +186,11 @@ class EarlyStopping(Callback):
|
|
|
187
186
|
_LOGGER.info(f"Epoch {epoch+1}: early stopping after {self.wait} epochs with no improvement.")
|
|
188
187
|
|
|
189
188
|
|
|
190
|
-
class
|
|
189
|
+
class DragonModelCheckpoint(_Callback):
|
|
191
190
|
"""
|
|
192
191
|
Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
|
|
193
192
|
"""
|
|
194
|
-
def __init__(self, save_dir: Union[str,Path],
|
|
193
|
+
def __init__(self, save_dir: Union[str,Path], monitor: str = PyTorchLogKeys.VAL_LOSS,
|
|
195
194
|
save_best_only: bool = True, mode: Literal['auto', 'min', 'max']= 'auto', verbose: int = 0):
|
|
196
195
|
"""
|
|
197
196
|
- If `save_best_only` is True, it saves the single best model, deleting the previous best.
|
|
@@ -199,7 +198,6 @@ class ModelCheckpoint(Callback):
|
|
|
199
198
|
|
|
200
199
|
Args:
|
|
201
200
|
save_dir (str): Directory where checkpoint files will be saved.
|
|
202
|
-
checkpoint_name (str| None): If None, the filename will include the epoch and score.
|
|
203
201
|
monitor (str): Metric to monitor.
|
|
204
202
|
save_best_only (bool): If true, save only the best model.
|
|
205
203
|
mode (str): One of {'auto', 'min', 'max'}.
|
|
@@ -215,9 +213,8 @@ class ModelCheckpoint(Callback):
|
|
|
215
213
|
self.monitor = monitor
|
|
216
214
|
self.save_best_only = save_best_only
|
|
217
215
|
self.verbose = verbose
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
self.checkpoint_name = checkpoint_name
|
|
216
|
+
self._latest_checkpoint_path = None
|
|
217
|
+
self._checkpoint_name = PyTorchCheckpointKeys.CHECKPOINT_NAME
|
|
221
218
|
|
|
222
219
|
# State variables to be managed during training
|
|
223
220
|
self.saved_checkpoints = []
|
|
@@ -261,10 +258,7 @@ class ModelCheckpoint(Callback):
|
|
|
261
258
|
old_best_str = f"{self.best:.4f}" if self.best not in [np.inf, -np.inf] else "inf"
|
|
262
259
|
|
|
263
260
|
# Create a descriptive filename
|
|
264
|
-
|
|
265
|
-
filename = f"epoch_{epoch}-{self.monitor}_{current:.4f}.pth"
|
|
266
|
-
else:
|
|
267
|
-
filename = f"epoch{epoch}_{self.checkpoint_name}.pth"
|
|
261
|
+
filename = f"epoch{epoch}_{self._checkpoint_name}_{current:.4f}.pth"
|
|
268
262
|
new_filepath = self.save_dir / filename
|
|
269
263
|
|
|
270
264
|
if self.verbose > 0:
|
|
@@ -279,6 +273,7 @@ class ModelCheckpoint(Callback):
|
|
|
279
273
|
PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
|
|
280
274
|
PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
|
|
281
275
|
PyTorchCheckpointKeys.BEST_SCORE: self.best,
|
|
276
|
+
PyTorchCheckpointKeys.HISTORY: self.trainer.history, # type: ignore
|
|
282
277
|
}
|
|
283
278
|
|
|
284
279
|
# Check for scheduler
|
|
@@ -287,6 +282,7 @@ class ModelCheckpoint(Callback):
|
|
|
287
282
|
|
|
288
283
|
# Save the new best model
|
|
289
284
|
torch.save(checkpoint_data, new_filepath)
|
|
285
|
+
self._latest_checkpoint_path = new_filepath
|
|
290
286
|
|
|
291
287
|
# Delete the old best model file
|
|
292
288
|
if self.last_best_filepath and self.last_best_filepath.exists():
|
|
@@ -298,10 +294,8 @@ class ModelCheckpoint(Callback):
|
|
|
298
294
|
def _save_rolling_checkpoints(self, epoch, logs):
|
|
299
295
|
"""Saves the latest model and keeps only the most recent ones."""
|
|
300
296
|
current = logs.get(self.monitor)
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
else:
|
|
304
|
-
filename = f"epoch{epoch}_{self.checkpoint_name}.pth"
|
|
297
|
+
|
|
298
|
+
filename = f"epoch{epoch}_{self._checkpoint_name}_{current:.4f}.pth"
|
|
305
299
|
filepath = self.save_dir / filename
|
|
306
300
|
|
|
307
301
|
if self.verbose > 0:
|
|
@@ -313,12 +307,15 @@ class ModelCheckpoint(Callback):
|
|
|
313
307
|
PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
|
|
314
308
|
PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
|
|
315
309
|
PyTorchCheckpointKeys.BEST_SCORE: self.best, # Save the current best score
|
|
310
|
+
PyTorchCheckpointKeys.HISTORY: self.trainer.history, # type: ignore
|
|
316
311
|
}
|
|
317
312
|
|
|
318
313
|
if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
|
|
319
314
|
checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
|
|
320
315
|
|
|
321
316
|
torch.save(checkpoint_data, filepath)
|
|
317
|
+
|
|
318
|
+
self._latest_checkpoint_path = filepath
|
|
322
319
|
|
|
323
320
|
self.saved_checkpoints.append(filepath)
|
|
324
321
|
|
|
@@ -330,8 +327,16 @@ class ModelCheckpoint(Callback):
|
|
|
330
327
|
_LOGGER.info(f" -> Deleting old checkpoint: {file_to_delete.name}")
|
|
331
328
|
file_to_delete.unlink()
|
|
332
329
|
|
|
330
|
+
@property
|
|
331
|
+
def best_checkpoint_path(self):
|
|
332
|
+
if self._latest_checkpoint_path:
|
|
333
|
+
return self._latest_checkpoint_path
|
|
334
|
+
else:
|
|
335
|
+
_LOGGER.error("No checkpoint paths saved.")
|
|
336
|
+
raise ValueError()
|
|
337
|
+
|
|
333
338
|
|
|
334
|
-
class
|
|
339
|
+
class DragonLRScheduler(_Callback):
|
|
335
340
|
"""
|
|
336
341
|
Callback to manage a PyTorch learning rate scheduler.
|
|
337
342
|
"""
|
|
@@ -361,6 +366,8 @@ class LRScheduler(Callback):
|
|
|
361
366
|
|
|
362
367
|
def on_epoch_end(self, epoch, logs=None):
|
|
363
368
|
"""Step the scheduler and log any change in learning rate."""
|
|
369
|
+
logs = logs or {}
|
|
370
|
+
|
|
364
371
|
# For schedulers that need a metric (e.g., val_loss)
|
|
365
372
|
if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
366
373
|
if self.monitor is None:
|
|
@@ -376,12 +383,22 @@ class LRScheduler(Callback):
|
|
|
376
383
|
# For all other schedulers
|
|
377
384
|
else:
|
|
378
385
|
self.scheduler.step()
|
|
386
|
+
|
|
387
|
+
# Get the current learning rate
|
|
388
|
+
current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
379
389
|
|
|
380
390
|
# Log the change if the LR was updated
|
|
381
|
-
current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
382
391
|
if current_lr != self.previous_lr:
|
|
383
392
|
_LOGGER.info(f"Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
|
|
384
393
|
self.previous_lr = current_lr
|
|
394
|
+
|
|
395
|
+
# --- Add LR to logs and history ---
|
|
396
|
+
# Add to the logs dict for any subsequent callbacks
|
|
397
|
+
logs[PyTorchLogKeys.LEARNING_RATE] = current_lr
|
|
398
|
+
|
|
399
|
+
# Also add directly to the trainer's history dict
|
|
400
|
+
if hasattr(self.trainer, 'history'):
|
|
401
|
+
self.trainer.history.setdefault(PyTorchLogKeys.LEARNING_RATE, []).append(current_lr) # type: ignore
|
|
385
402
|
|
|
386
403
|
|
|
387
404
|
def info():
|