dragon-ml-toolbox 10.1.1__py3-none-any.whl → 14.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/METADATA +38 -63
- dragon_ml_toolbox-14.2.0.dist-info/RECORD +48 -0
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/licenses/LICENSE +1 -1
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +11 -0
- ml_tools/ETL_cleaning.py +175 -59
- ml_tools/ETL_engineering.py +506 -70
- ml_tools/GUI_tools.py +2 -1
- ml_tools/MICE_imputation.py +212 -7
- ml_tools/ML_callbacks.py +73 -40
- ml_tools/ML_datasetmaster.py +267 -284
- ml_tools/ML_evaluation.py +119 -58
- ml_tools/ML_evaluation_multi.py +107 -32
- ml_tools/ML_inference.py +15 -5
- ml_tools/ML_models.py +234 -170
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +321 -97
- ml_tools/ML_scaler.py +10 -5
- ml_tools/ML_trainer.py +585 -40
- ml_tools/ML_utilities.py +528 -0
- ml_tools/ML_vision_datasetmaster.py +1315 -0
- ml_tools/ML_vision_evaluation.py +260 -0
- ml_tools/ML_vision_inference.py +428 -0
- ml_tools/ML_vision_models.py +627 -0
- ml_tools/ML_vision_transformers.py +58 -0
- ml_tools/PSO_optimization.py +10 -7
- ml_tools/RNN_forecast.py +2 -0
- ml_tools/SQL.py +22 -9
- ml_tools/VIF_factor.py +4 -3
- ml_tools/_ML_vision_recipe.py +88 -0
- ml_tools/__init__.py +1 -0
- ml_tools/_logger.py +0 -2
- ml_tools/_schema.py +96 -0
- ml_tools/constants.py +79 -0
- ml_tools/custom_logger.py +164 -16
- ml_tools/data_exploration.py +1092 -109
- ml_tools/ensemble_evaluation.py +48 -1
- ml_tools/ensemble_inference.py +6 -7
- ml_tools/ensemble_learning.py +4 -3
- ml_tools/handle_excel.py +1 -0
- ml_tools/keys.py +80 -0
- ml_tools/math_utilities.py +259 -0
- ml_tools/optimization_tools.py +198 -24
- ml_tools/path_manager.py +144 -45
- ml_tools/serde.py +192 -0
- ml_tools/utilities.py +287 -227
- dragon_ml_toolbox-10.1.1.dist-info/RECORD +0 -36
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/top_level.txt +0 -0
ml_tools/GUI_tools.py
CHANGED
|
@@ -4,8 +4,9 @@ import traceback
|
|
|
4
4
|
import FreeSimpleGUI as sg
|
|
5
5
|
from functools import wraps
|
|
6
6
|
from typing import Any, Dict, Tuple, List, Literal, Union, Optional, Callable
|
|
7
|
-
from ._script_info import _script_info
|
|
8
7
|
import numpy as np
|
|
8
|
+
|
|
9
|
+
from ._script_info import _script_info
|
|
9
10
|
from ._logger import _LOGGER
|
|
10
11
|
from .keys import _OneHotOtherPlaceholder
|
|
11
12
|
|
ml_tools/MICE_imputation.py
CHANGED
|
@@ -3,20 +3,24 @@ import miceforest as mf
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
5
|
import numpy as np
|
|
6
|
-
from .utilities import load_dataframe, merge_dataframes, save_dataframe, threshold_binary_values
|
|
7
|
-
from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
|
|
8
6
|
from plotnine import ggplot, labs, theme, element_blank # type: ignore
|
|
9
7
|
from typing import Optional, Union
|
|
8
|
+
|
|
9
|
+
from .utilities import load_dataframe, merge_dataframes, save_dataframe_filename
|
|
10
|
+
from .math_utilities import threshold_binary_values, discretize_categorical_values
|
|
11
|
+
from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
|
|
10
12
|
from ._logger import _LOGGER
|
|
11
13
|
from ._script_info import _script_info
|
|
14
|
+
from ._schema import FeatureSchema
|
|
15
|
+
|
|
12
16
|
|
|
13
17
|
__all__ = [
|
|
18
|
+
"MiceImputer",
|
|
14
19
|
"apply_mice",
|
|
15
20
|
"save_imputed_datasets",
|
|
16
|
-
"get_na_column_names",
|
|
17
21
|
"get_convergence_diagnostic",
|
|
18
22
|
"get_imputed_distributions",
|
|
19
|
-
"run_mice_pipeline"
|
|
23
|
+
"run_mice_pipeline",
|
|
20
24
|
]
|
|
21
25
|
|
|
22
26
|
|
|
@@ -72,11 +76,11 @@ def apply_mice(df: pd.DataFrame, df_name: str, binary_columns: Optional[list[str
|
|
|
72
76
|
def save_imputed_datasets(save_dir: Union[str, Path], imputed_datasets: list, df_targets: pd.DataFrame, imputed_dataset_names: list[str]):
|
|
73
77
|
for imputed_df, subname in zip(imputed_datasets, imputed_dataset_names):
|
|
74
78
|
merged_df = merge_dataframes(imputed_df, df_targets, direction="horizontal", verbose=False)
|
|
75
|
-
|
|
79
|
+
save_dataframe_filename(df=merged_df, save_dir=save_dir, filename=subname)
|
|
76
80
|
|
|
77
81
|
|
|
78
82
|
#Get names of features that had missing values before imputation
|
|
79
|
-
def
|
|
83
|
+
def _get_na_column_names(df: pd.DataFrame):
|
|
80
84
|
return [col for col in df.columns if df[col].isna().any()]
|
|
81
85
|
|
|
82
86
|
|
|
@@ -261,7 +265,7 @@ def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str]
|
|
|
261
265
|
|
|
262
266
|
save_imputed_datasets(save_dir=save_datasets_path, imputed_datasets=imputed_datasets, df_targets=df_targets, imputed_dataset_names=imputed_dataset_names)
|
|
263
267
|
|
|
264
|
-
imputed_column_names =
|
|
268
|
+
imputed_column_names = _get_na_column_names(df=df)
|
|
265
269
|
|
|
266
270
|
get_convergence_diagnostic(kernel=kernel, imputed_dataset_names=imputed_dataset_names, column_names=imputed_column_names, root_dir=save_metrics_path)
|
|
267
271
|
|
|
@@ -275,5 +279,206 @@ def _skip_targets(df: pd.DataFrame, target_cols: list[str]):
|
|
|
275
279
|
return df_feats, df_targets
|
|
276
280
|
|
|
277
281
|
|
|
282
|
+
# modern implementation
|
|
283
|
+
class MiceImputer:
|
|
284
|
+
"""
|
|
285
|
+
A modern MICE imputation pipeline that uses a FeatureSchema
|
|
286
|
+
to correctly discretize categorical features after imputation.
|
|
287
|
+
"""
|
|
288
|
+
def __init__(self,
|
|
289
|
+
schema: FeatureSchema,
|
|
290
|
+
iterations: int=20,
|
|
291
|
+
resulting_datasets: int = 1,
|
|
292
|
+
random_state: int = 101):
|
|
293
|
+
|
|
294
|
+
self.schema = schema
|
|
295
|
+
self.random_state = random_state
|
|
296
|
+
self.iterations = iterations
|
|
297
|
+
self.resulting_datasets = resulting_datasets
|
|
298
|
+
|
|
299
|
+
# --- Store schema info ---
|
|
300
|
+
|
|
301
|
+
# 1. Categorical info
|
|
302
|
+
if not self.schema.categorical_index_map:
|
|
303
|
+
_LOGGER.warning("FeatureSchema has no 'categorical_index_map'. No discretization will be applied.")
|
|
304
|
+
self.cat_info = {}
|
|
305
|
+
else:
|
|
306
|
+
self.cat_info = self.schema.categorical_index_map
|
|
307
|
+
|
|
308
|
+
# 2. Ordered feature names (critical for index mapping)
|
|
309
|
+
self.ordered_features = list(self.schema.feature_names)
|
|
310
|
+
|
|
311
|
+
# 3. Names of categorical features
|
|
312
|
+
self.categorical_features = list(self.schema.categorical_feature_names)
|
|
313
|
+
|
|
314
|
+
_LOGGER.info(f"MiceImputer initialized. Found {len(self.cat_info)} categorical features to discretize.")
|
|
315
|
+
|
|
316
|
+
def _post_process(self, imputed_df: pd.DataFrame) -> pd.DataFrame:
|
|
317
|
+
"""
|
|
318
|
+
Applies schema-based discretization to a completed dataframe.
|
|
319
|
+
|
|
320
|
+
This method works around the behavior of `discretize_categorical_values`
|
|
321
|
+
(which returns a full int32 array) by:
|
|
322
|
+
1. Calling it on the full, ordered feature array.
|
|
323
|
+
2. Extracting *only* the valid discretized categorical columns.
|
|
324
|
+
3. Updating the original float dataframe with these integer values.
|
|
325
|
+
"""
|
|
326
|
+
# If no categorical features are defined, return the df as-is.
|
|
327
|
+
if not self.cat_info:
|
|
328
|
+
return imputed_df
|
|
329
|
+
|
|
330
|
+
try:
|
|
331
|
+
# 1. Ensure DataFrame columns match the schema order
|
|
332
|
+
# This is critical for the index-based categorical_info
|
|
333
|
+
df_ordered: pd.DataFrame = imputed_df[self.ordered_features] # type: ignore
|
|
334
|
+
|
|
335
|
+
# 2. Convert to NumPy array
|
|
336
|
+
array_ordered = df_ordered.to_numpy()
|
|
337
|
+
|
|
338
|
+
# 3. Apply discretization utility (which returns a full int32 array)
|
|
339
|
+
# This array has *correct* categorical values but *truncated* continuous values.
|
|
340
|
+
discretized_array_int32 = discretize_categorical_values(
|
|
341
|
+
array_ordered,
|
|
342
|
+
self.cat_info,
|
|
343
|
+
start_at_zero=True # Assuming 0-based indexing
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# 4. Create a new DF from the int32 array, keeping the categorical columns.
|
|
347
|
+
df_discretized_cats = pd.DataFrame(
|
|
348
|
+
discretized_array_int32,
|
|
349
|
+
columns=self.ordered_features,
|
|
350
|
+
index=df_ordered.index # <-- Critical: align index
|
|
351
|
+
)[self.categorical_features] # <-- Select only cat features
|
|
352
|
+
|
|
353
|
+
# 5. "Rejoin": Start with a fresh copy of the *original* imputed DF (which has correct continuous floats).
|
|
354
|
+
final_df = df_ordered.copy()
|
|
355
|
+
|
|
356
|
+
# 6. Use .update() to "paste" the integer categorical values
|
|
357
|
+
# over the old float categorical values. Continuous floats are unaffected.
|
|
358
|
+
final_df.update(df_discretized_cats)
|
|
359
|
+
|
|
360
|
+
return final_df
|
|
361
|
+
|
|
362
|
+
except Exception as e:
|
|
363
|
+
_LOGGER.error(f"Failed during post-processing discretization:\n\tInput DF shape: {imputed_df.shape}\n\tSchema features: {len(self.ordered_features)}\n\tCategorical info keys: {list(self.cat_info.keys())}\n{e}")
|
|
364
|
+
raise
|
|
365
|
+
|
|
366
|
+
def _run_mice(self,
|
|
367
|
+
df: pd.DataFrame,
|
|
368
|
+
df_name: str) -> tuple[mf.ImputationKernel, list[pd.DataFrame], list[str]]:
|
|
369
|
+
"""
|
|
370
|
+
Runs the MICE kernel and applies schema-based post-processing.
|
|
371
|
+
|
|
372
|
+
Parameters:
|
|
373
|
+
df (pd.DataFrame): The input dataframe *with NaNs*. Should only contain feature columns.
|
|
374
|
+
df_name (str): The base name for the dataset.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
tuple[mf.ImputationKernel, list[pd.DataFrame], list[str]]:
|
|
378
|
+
- The trained MICE kernel
|
|
379
|
+
- A list of imputed and processed DataFrames
|
|
380
|
+
- A list of names for the new DataFrames
|
|
381
|
+
"""
|
|
382
|
+
# Ensure input df only contains features from the schema and is in the correct order.
|
|
383
|
+
try:
|
|
384
|
+
df_feats = df[self.ordered_features]
|
|
385
|
+
except KeyError as e:
|
|
386
|
+
_LOGGER.error(f"Input DataFrame is missing required schema columns: {e}")
|
|
387
|
+
raise
|
|
388
|
+
|
|
389
|
+
# 1. Initialize kernel
|
|
390
|
+
kernel = mf.ImputationKernel(
|
|
391
|
+
data=df_feats,
|
|
392
|
+
num_datasets=self.resulting_datasets,
|
|
393
|
+
random_state=self.random_state
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
_LOGGER.info("➡️ Schema-based MICE imputation running...")
|
|
397
|
+
|
|
398
|
+
# 2. Perform MICE
|
|
399
|
+
kernel.mice(self.iterations)
|
|
400
|
+
|
|
401
|
+
# 3. Retrieve, process, and collect datasets
|
|
402
|
+
imputed_datasets = []
|
|
403
|
+
for i in range(self.resulting_datasets):
|
|
404
|
+
# complete_data returns a pd.DataFrame
|
|
405
|
+
completed_df = kernel.complete_data(dataset=i)
|
|
406
|
+
|
|
407
|
+
# Apply our new discretization and ordering
|
|
408
|
+
processed_df = self._post_process(completed_df)
|
|
409
|
+
imputed_datasets.append(processed_df)
|
|
410
|
+
|
|
411
|
+
if not imputed_datasets:
|
|
412
|
+
_LOGGER.error("No imputed datasets were generated.")
|
|
413
|
+
raise ValueError()
|
|
414
|
+
|
|
415
|
+
# 4. Generate names
|
|
416
|
+
if self.resulting_datasets == 1:
|
|
417
|
+
imputed_dataset_names = [f"{df_name}_MICE"]
|
|
418
|
+
else:
|
|
419
|
+
imputed_dataset_names = [f"{df_name}_MICE_{i+1}" for i in range(self.resulting_datasets)]
|
|
420
|
+
|
|
421
|
+
# 5. Validate indexes
|
|
422
|
+
for imputed_df, subname in zip(imputed_datasets, imputed_dataset_names):
|
|
423
|
+
assert imputed_df.shape[0] == df.shape[0], f"❌ Row count mismatch in dataset {subname}"
|
|
424
|
+
assert all(imputed_df.index == df.index), f"❌ Index mismatch in dataset {subname}"
|
|
425
|
+
|
|
426
|
+
_LOGGER.info("Schema-based MICE imputation complete.")
|
|
427
|
+
|
|
428
|
+
return kernel, imputed_datasets, imputed_dataset_names
|
|
429
|
+
|
|
430
|
+
def run_pipeline(self,
|
|
431
|
+
df_path_or_dir: Union[str,Path],
|
|
432
|
+
save_datasets_dir: Union[str,Path],
|
|
433
|
+
save_metrics_dir: Union[str,Path],
|
|
434
|
+
):
|
|
435
|
+
"""
|
|
436
|
+
Runs the complete MICE imputation pipeline.
|
|
437
|
+
|
|
438
|
+
This method automates the entire workflow:
|
|
439
|
+
1. Loads data from a CSV file path or a directory with CSV files.
|
|
440
|
+
2. Separates features and targets based on the `FeatureSchema`.
|
|
441
|
+
3. Runs the MICE algorithm on the feature set.
|
|
442
|
+
4. Applies schema-based post-processing to discretize categorical features.
|
|
443
|
+
5. Saves the final, processed, and imputed dataset(s) (re-joined with targets) to `save_datasets_dir`.
|
|
444
|
+
6. Generates and saves convergence and distribution plots for all imputed columns to `save_metrics_dir`.
|
|
445
|
+
|
|
446
|
+
Parameters
|
|
447
|
+
----------
|
|
448
|
+
df_path_or_dir :[str,Path]
|
|
449
|
+
Path to a single CSV file or a directory containing multiple CSV files to impute.
|
|
450
|
+
save_datasets_dir : [str,Path]
|
|
451
|
+
Directory where the final imputed and processed dataset(s) will be saved as CSVs.
|
|
452
|
+
save_metrics_dir : [str,Path]
|
|
453
|
+
Directory where convergence and distribution plots will be saved.
|
|
454
|
+
"""
|
|
455
|
+
# Check paths
|
|
456
|
+
save_datasets_path = make_fullpath(save_datasets_dir, make=True)
|
|
457
|
+
save_metrics_path = make_fullpath(save_metrics_dir, make=True)
|
|
458
|
+
|
|
459
|
+
input_path = make_fullpath(df_path_or_dir)
|
|
460
|
+
if input_path.is_file():
|
|
461
|
+
all_file_paths = [input_path]
|
|
462
|
+
else:
|
|
463
|
+
all_file_paths = list(list_csv_paths(input_path).values())
|
|
464
|
+
|
|
465
|
+
for df_path in all_file_paths:
|
|
466
|
+
|
|
467
|
+
df, df_name = load_dataframe(df_path=df_path, kind="pandas")
|
|
468
|
+
|
|
469
|
+
df_features: pd.DataFrame = df[self.schema.feature_names] # type: ignore
|
|
470
|
+
df_targets = df.drop(columns=self.schema.feature_names)
|
|
471
|
+
|
|
472
|
+
imputed_column_names = _get_na_column_names(df=df_features)
|
|
473
|
+
|
|
474
|
+
kernel, imputed_datasets, imputed_dataset_names = self._run_mice(df=df_features, df_name=df_name)
|
|
475
|
+
|
|
476
|
+
save_imputed_datasets(save_dir=save_datasets_path, imputed_datasets=imputed_datasets, df_targets=df_targets, imputed_dataset_names=imputed_dataset_names)
|
|
477
|
+
|
|
478
|
+
get_convergence_diagnostic(kernel=kernel, imputed_dataset_names=imputed_dataset_names, column_names=imputed_column_names, root_dir=save_metrics_path)
|
|
479
|
+
|
|
480
|
+
get_imputed_distributions(kernel=kernel, df_name=df_name, root_dir=save_metrics_path, column_names=imputed_column_names)
|
|
481
|
+
|
|
482
|
+
|
|
278
483
|
def info():
|
|
279
484
|
_script_info(__all__)
|
ml_tools/ML_callbacks.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import torch
|
|
3
3
|
from tqdm.auto import tqdm
|
|
4
|
+
from typing import Union, Literal, Optional
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
4
7
|
from .path_manager import make_fullpath, sanitize_filename
|
|
5
|
-
from .keys import PyTorchLogKeys
|
|
8
|
+
from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
|
|
6
9
|
from ._logger import _LOGGER
|
|
7
|
-
from typing import Optional
|
|
8
10
|
from ._script_info import _script_info
|
|
9
|
-
from typing import Union, Literal
|
|
10
|
-
from pathlib import Path
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
__all__ = [
|
|
@@ -113,18 +113,19 @@ class TqdmProgressBar(Callback):
|
|
|
113
113
|
class EarlyStopping(Callback):
|
|
114
114
|
"""
|
|
115
115
|
Stop training when a monitored metric has stopped improving.
|
|
116
|
-
|
|
117
|
-
Args:
|
|
118
|
-
monitor (str): Quantity to be monitored. Defaults to 'val_loss'.
|
|
119
|
-
min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
|
120
|
-
patience (int): Number of epochs with no improvement after which training will be stopped.
|
|
121
|
-
mode (str): One of {'auto', 'min', 'max'}. In 'min' mode, training will stop when the quantity
|
|
122
|
-
monitored has stopped decreasing; in 'max' mode it will stop when the quantity
|
|
123
|
-
monitored has stopped increasing; in 'auto' mode, the direction is automatically
|
|
124
|
-
inferred from the name of the monitored quantity.
|
|
125
|
-
verbose (int): Verbosity mode.
|
|
126
116
|
"""
|
|
127
117
|
def __init__(self, monitor: str=PyTorchLogKeys.VAL_LOSS, min_delta: float=0.0, patience: int=5, mode: Literal['auto', 'min', 'max']='auto', verbose: int=1):
|
|
118
|
+
"""
|
|
119
|
+
Args:
|
|
120
|
+
monitor (str): Quantity to be monitored. Defaults to 'val_loss'.
|
|
121
|
+
min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
|
122
|
+
patience (int): Number of epochs with no improvement after which training will be stopped.
|
|
123
|
+
mode (str): One of {'auto', 'min', 'max'}. In 'min' mode, training will stop when the quantity
|
|
124
|
+
monitored has stopped decreasing; in 'max' mode it will stop when the quantity
|
|
125
|
+
monitored has stopped increasing; in 'auto' mode, the direction is automatically
|
|
126
|
+
inferred from the name of the monitored quantity.
|
|
127
|
+
verbose (int): Verbosity mode.
|
|
128
|
+
"""
|
|
128
129
|
super().__init__()
|
|
129
130
|
self.monitor = monitor
|
|
130
131
|
self.patience = patience
|
|
@@ -188,22 +189,23 @@ class EarlyStopping(Callback):
|
|
|
188
189
|
|
|
189
190
|
class ModelCheckpoint(Callback):
|
|
190
191
|
"""
|
|
191
|
-
Saves the model to a directory with automated filename generation and rotation.
|
|
192
|
-
|
|
193
|
-
- If `save_best_only` is True, it saves the single best model, deleting the
|
|
194
|
-
previous best.
|
|
195
|
-
- If `save_best_only` is False, it keeps the 3 most recent checkpoints,
|
|
196
|
-
deleting the oldest ones automatically.
|
|
197
|
-
|
|
198
|
-
Args:
|
|
199
|
-
save_dir (str): Directory where checkpoint files will be saved.
|
|
200
|
-
monitor (str): Metric to monitor for `save_best_only=True`.
|
|
201
|
-
save_best_only (bool): If true, save only the best model.
|
|
202
|
-
mode (str): One of {'auto', 'min', 'max'}.
|
|
203
|
-
verbose (int): Verbosity mode.
|
|
192
|
+
Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
|
|
204
193
|
"""
|
|
205
194
|
def __init__(self, save_dir: Union[str,Path], checkpoint_name: Optional[str]=None, monitor: str = PyTorchLogKeys.VAL_LOSS,
|
|
206
195
|
save_best_only: bool = True, mode: Literal['auto', 'min', 'max']= 'auto', verbose: int = 0):
|
|
196
|
+
"""
|
|
197
|
+
- If `save_best_only` is True, it saves the single best model, deleting the previous best.
|
|
198
|
+
- If `save_best_only` is False, it keeps the 3 most recent checkpoints, deleting the oldest ones automatically.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
save_dir (str): Directory where checkpoint files will be saved.
|
|
202
|
+
checkpoint_name (str| None): If None, the filename will include the epoch and score.
|
|
203
|
+
monitor (str): Metric to monitor.
|
|
204
|
+
save_best_only (bool): If true, save only the best model.
|
|
205
|
+
mode (str): One of {'auto', 'min', 'max'}.
|
|
206
|
+
verbose (int): Verbosity mode.
|
|
207
|
+
"""
|
|
208
|
+
|
|
207
209
|
super().__init__()
|
|
208
210
|
self.save_dir = make_fullpath(save_dir, make=True, enforce="directory")
|
|
209
211
|
if not self.save_dir.is_dir():
|
|
@@ -268,15 +270,29 @@ class ModelCheckpoint(Callback):
|
|
|
268
270
|
if self.verbose > 0:
|
|
269
271
|
_LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current:.4f}, saving model to {new_filepath}")
|
|
270
272
|
|
|
273
|
+
# Update best score *before* saving
|
|
274
|
+
self.best = current
|
|
275
|
+
|
|
276
|
+
# Create a comprehensive checkpoint dictionary
|
|
277
|
+
checkpoint_data = {
|
|
278
|
+
PyTorchCheckpointKeys.EPOCH: epoch,
|
|
279
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
|
|
280
|
+
PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
|
|
281
|
+
PyTorchCheckpointKeys.BEST_SCORE: self.best,
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
# Check for scheduler
|
|
285
|
+
if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
|
|
286
|
+
checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
|
|
287
|
+
|
|
271
288
|
# Save the new best model
|
|
272
|
-
torch.save(
|
|
289
|
+
torch.save(checkpoint_data, new_filepath)
|
|
273
290
|
|
|
274
291
|
# Delete the old best model file
|
|
275
292
|
if self.last_best_filepath and self.last_best_filepath.exists():
|
|
276
293
|
self.last_best_filepath.unlink()
|
|
277
294
|
|
|
278
295
|
# Update state
|
|
279
|
-
self.best = current
|
|
280
296
|
self.last_best_filepath = new_filepath
|
|
281
297
|
|
|
282
298
|
def _save_rolling_checkpoints(self, epoch, logs):
|
|
@@ -290,7 +306,19 @@ class ModelCheckpoint(Callback):
|
|
|
290
306
|
|
|
291
307
|
if self.verbose > 0:
|
|
292
308
|
_LOGGER.info(f'Epoch {epoch}: saving model to {filepath}')
|
|
293
|
-
|
|
309
|
+
|
|
310
|
+
# Create a comprehensive checkpoint dictionary
|
|
311
|
+
checkpoint_data = {
|
|
312
|
+
PyTorchCheckpointKeys.EPOCH: epoch,
|
|
313
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
|
|
314
|
+
PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
|
|
315
|
+
PyTorchCheckpointKeys.BEST_SCORE: self.best, # Save the current best score
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
|
|
319
|
+
checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
|
|
320
|
+
|
|
321
|
+
torch.save(checkpoint_data, filepath)
|
|
294
322
|
|
|
295
323
|
self.saved_checkpoints.append(filepath)
|
|
296
324
|
|
|
@@ -306,21 +334,26 @@ class ModelCheckpoint(Callback):
|
|
|
306
334
|
class LRScheduler(Callback):
|
|
307
335
|
"""
|
|
308
336
|
Callback to manage a PyTorch learning rate scheduler.
|
|
309
|
-
|
|
310
|
-
This callback automatically calls the scheduler's `step()` method at the
|
|
311
|
-
end of each epoch. It also logs a message when the learning rate changes.
|
|
312
|
-
|
|
313
|
-
Args:
|
|
314
|
-
scheduler: An initialized PyTorch learning rate scheduler.
|
|
315
|
-
monitor (str, optional): The metric to monitor for schedulers that
|
|
316
|
-
require it, like `ReduceLROnPlateau`.
|
|
317
|
-
Should match a key in the logs (e.g., 'val_loss').
|
|
318
337
|
"""
|
|
319
|
-
def __init__(self, scheduler, monitor: Optional[str] =
|
|
338
|
+
def __init__(self, scheduler, monitor: Optional[str] = PyTorchLogKeys.VAL_LOSS):
|
|
339
|
+
"""
|
|
340
|
+
This callback automatically calls the scheduler's `step()` method at the
|
|
341
|
+
end of each epoch. It also logs a message when the learning rate changes.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
scheduler: An initialized PyTorch learning rate scheduler.
|
|
345
|
+
monitor (str): The metric to monitor for schedulers that require it, like `ReduceLROnPlateau`. Should match a key in the logs (e.g., 'val_loss').
|
|
346
|
+
"""
|
|
320
347
|
super().__init__()
|
|
321
348
|
self.scheduler = scheduler
|
|
322
349
|
self.monitor = monitor
|
|
323
350
|
self.previous_lr = None
|
|
351
|
+
|
|
352
|
+
def set_trainer(self, trainer):
|
|
353
|
+
"""This is called by the Trainer to associate itself with the callback."""
|
|
354
|
+
super().set_trainer(trainer)
|
|
355
|
+
# Register the scheduler with the trainer so it can be added to the checkpoint
|
|
356
|
+
self.trainer.scheduler = self.scheduler # type: ignore
|
|
324
357
|
|
|
325
358
|
def on_train_begin(self, logs=None):
|
|
326
359
|
"""Store the initial learning rate."""
|