dragon-ml-toolbox 2.4.0__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 2.4.0
4
- Summary: A collection of tools for data science and machine learning projects
3
+ Version: 3.0.0
4
+ Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: Karl Loza <luigiloza@gmail.com>
6
6
  License-Expression: MIT
7
7
  Project-URL: Homepage, https://github.com/DrAg0n-BoRn/ML_tools
@@ -125,9 +125,12 @@ GUI_tools
125
125
  handle_excel
126
126
  logger
127
127
  MICE_imputation
128
+ ML_callbacks
129
+ ML_evaluation
130
+ ML_trainer
131
+ ML_tutorial
128
132
  PSO_optimization
129
- trainer
133
+ RNN_forecast
130
134
  utilities
131
135
  VIF_factor
132
- vision_helpers
133
136
  ```
@@ -0,0 +1,25 @@
1
+ dragon_ml_toolbox-3.0.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
+ dragon_ml_toolbox-3.0.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=6cfpIeQ6D4Mcs10nkogQrkVyq1T7i2qXjjNHFoUMOyE,1892
3
+ ml_tools/ETL_engineering.py,sha256=SRiloWhSpopS4ay8mzUu0H4e9-37Ox_jDHzODqsQ8pc,31642
4
+ ml_tools/GUI_tools.py,sha256=uFx6zIrQZzDPSTtOSHz8ptz-fxZiQz-lXHcrqwuYV_E,20385
5
+ ml_tools/MICE_imputation.py,sha256=ed-YeQkEAeHxTNkWIHs09T4YeYNF0aqAnrUTcdIEp9E,11372
6
+ ml_tools/ML_callbacks.py,sha256=gHZk-lyzAax6iEtG26zHuoobdAZCFJ6BmI6pWoXkOrw,13189
7
+ ml_tools/ML_evaluation.py,sha256=3xOqVXLJDhbioKZ922yxFnSuO4VDQ-HFzZyZZ1MskVM,10054
8
+ ml_tools/ML_trainer.py,sha256=zRs3crz_z4B285iJhmY7m4AFwnvvq4urOyl4zDuCLtA,14456
9
+ ml_tools/ML_tutorial.py,sha256=-9tJO9ISPxEjRINVaF_Bu7tiiJ2W3zznQ4gNlZeP1HQ,12238
10
+ ml_tools/PSO_optimization.py,sha256=RCvIFGyf28voo2mpbRKC6LfDzKslzY-aYoPwgv9F4Bg,25458
11
+ ml_tools/RNN_forecast.py,sha256=IZLcPs3by0Chei7ill_Grjxs7BBUnzau0Oavi3dWiyE,1886
12
+ ml_tools/VIF_factor.py,sha256=5GVAldH69Vkei3WRUZN1uPBMzGoOOeEOA-bgmZXbbUw,10301
13
+ ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ ml_tools/_particle_swarm_optimization.py,sha256=b_eNNkA89Y40hj76KauivT8KLScH1B9wF2IXptOqkOw,22220
15
+ ml_tools/_pytorch_models.py,sha256=bpWZsrSwCvHJQkR6UfoPpElsMv9AvmiNErNHC8NYB_I,10132
16
+ ml_tools/data_exploration.py,sha256=Fzbz_DKZ7F2e3-JbahLqKr3aP6lt9aCK9rNOHvR7nlA,23665
17
+ ml_tools/datasetmaster.py,sha256=N-uwfzWnl_qnoAqjbfS98I1pVNra5u6rhKLdWbFIReA,30122
18
+ ml_tools/ensemble_learning.py,sha256=PPtBBLgLvaYOdY-MlcjXuxWWXf3JQavLNEysFgzjc_s,37470
19
+ ml_tools/handle_excel.py,sha256=lwds7rDLlGSCWiWGI7xNg-Z7kxAepogp0lstSFa0590,12949
20
+ ml_tools/logger.py,sha256=jC4Q2OqmDm8ZO9VpuZqBSWdXryqaJvLscqVJ6caNMOk,6009
21
+ ml_tools/utilities.py,sha256=opNR-ACH6BnLkWAKcb19ef5tFxfx22TI6E2o0RYwiGA,21021
22
+ dragon_ml_toolbox-3.0.0.dist-info/METADATA,sha256=nmhUu0bwN4z1letePaDzGIQlmDUaBQ32esqGB-OasU4,3273
23
+ dragon_ml_toolbox-3.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
24
+ dragon_ml_toolbox-3.0.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
25
+ dragon_ml_toolbox-3.0.0.dist-info/RECORD,,
@@ -3,17 +3,18 @@ import re
3
3
  from typing import Literal, Union, Optional, Any, Callable, List, Dict
4
4
  from .utilities import _script_info
5
5
  import pandas as pd
6
+ from .logger import _LOGGER
6
7
 
7
8
 
8
9
  __all__ = [
9
10
  "ColumnCleaner",
10
- "DataFrameCleaner"
11
+ "DataFrameCleaner",
11
12
  "TransformationRecipe",
12
13
  "DataProcessor",
13
14
  "KeywordDummifier",
14
15
  "NumberExtractor",
15
16
  "MultiNumberExtractor",
16
- "RatioCalculator"
17
+ "RatioCalculator",
17
18
  "CategoryMapper",
18
19
  "RegexMapper",
19
20
  "ValueBinner",
@@ -251,7 +252,7 @@ class DataProcessor:
251
252
  raise TypeError(f"Invalid 'transform' action for '{input_col_name}': {transform_action}")
252
253
 
253
254
  if not processed_columns:
254
- print("Warning: The transformation resulted in an empty DataFrame.")
255
+ _LOGGER.warning("The transformation resulted in an empty DataFrame.")
255
256
  return pl.DataFrame()
256
257
 
257
258
  return pl.DataFrame(processed_columns)
@@ -403,7 +404,7 @@ class NumberExtractor:
403
404
  if not isinstance(round_digits, int):
404
405
  raise TypeError("round_digits must be an integer.")
405
406
  if dtype == "int":
406
- print(f"Warning: 'round_digits' is specified but dtype is 'int'. Rounding will be ignored.")
407
+ _LOGGER.warning(f"'round_digits' is specified but dtype is 'int'. Rounding will be ignored.")
407
408
 
408
409
  self.regex_pattern = regex_pattern
409
410
  self.dtype = dtype
@@ -561,9 +562,9 @@ class RatioCalculator:
561
562
  denominator = groups.struct.field("group_2").cast(pl.Float64, strict=False)
562
563
 
563
564
  # Safely perform division, returning null if denominator is 0
564
- return pl.when(denominator != 0).then(
565
- numerator / denominator
566
- ).otherwise(None)
565
+ final_expr = pl.when(denominator != 0).then(numerator / denominator).otherwise(None)
566
+
567
+ return pl.select(final_expr).to_series()
567
568
 
568
569
 
569
570
  class CategoryMapper:
ml_tools/GUI_tools.py CHANGED
@@ -7,6 +7,7 @@ from functools import wraps
7
7
  from typing import Any, Dict, Tuple, List
8
8
  from .utilities import _script_info
9
9
  import numpy as np
10
+ from .logger import _LOGGER
10
11
 
11
12
 
12
13
  __all__ = [
@@ -46,7 +47,7 @@ class PathManager:
46
47
  if self._is_bundled:
47
48
  # In a Briefcase bundle, resource_path gives an absolute path
48
49
  # to the resource directory.
49
- self.package_root = self._resource_path_func(self.package_name, "")
50
+ self.package_root = self._resource_path_func(self.package_name, "") # type: ignore
50
51
  else:
51
52
  # In development mode, the package root is the directory
52
53
  # containing the anchor file.
@@ -56,7 +57,7 @@ class PathManager:
56
57
  """Checks if the app is running in a bundled environment."""
57
58
  try:
58
59
  # This is the function Briefcase provides in a bundled app
59
- from briefcase.platforms.base import resource_path
60
+ from briefcase.platforms.base import resource_path # type: ignore
60
61
  return True, resource_path
61
62
  except ImportError:
62
63
  return False, None
@@ -147,7 +148,7 @@ class ConfigManager:
147
148
  """
148
149
  path = Path(file_path)
149
150
  if path.exists() and not force_overwrite:
150
- print(f"Configuration file already exists at {path}. Aborting.")
151
+ _LOGGER.warning(f"Configuration file already exists at {path}. Aborting.")
151
152
  return
152
153
 
153
154
  config = configparser.ConfigParser()
@@ -205,7 +206,7 @@ class ConfigManager:
205
206
 
206
207
  with open(path, 'w') as configfile:
207
208
  config.write(configfile)
208
- print(f"Successfully generated config template at: '{path}'")
209
+ _LOGGER.info(f"Successfully generated config template at: '{path}'")
209
210
 
210
211
 
211
212
  # --- GUI Factory ---
@@ -219,8 +220,8 @@ class GUIFactory:
219
220
  Initializes the factory with a configuration object.
220
221
  """
221
222
  self.config = config
222
- sg.theme(self.config.general.theme)
223
- sg.set_options(font=(self.config.general.font_family, 12))
223
+ sg.theme(self.config.general.theme) # type: ignore
224
+ sg.set_options(font=(self.config.general.font_family, 12)) # type: ignore
224
225
 
225
226
  # --- Atomic Element Generators ---
226
227
  def make_button(self, text: str, key: str, **kwargs) -> sg.Button:
@@ -234,13 +235,13 @@ class GUIFactory:
234
235
  (e.g., `tooltip='Click me'`, `disabled=True`).
235
236
  """
236
237
  cfg = self.config
237
- font = (cfg.fonts.font_family, cfg.fonts.button_size, cfg.fonts.button_style)
238
+ font = (cfg.fonts.font_family, cfg.fonts.button_size, cfg.fonts.button_style) # type: ignore
238
239
 
239
240
  style_args = {
240
- "size": cfg.layout.button_size,
241
+ "size": cfg.layout.button_size, # type: ignore
241
242
  "font": font,
242
- "button_color": (cfg.colors.button_text, cfg.colors.button_background),
243
- "mouseover_colors": (cfg.colors.button_text, cfg.colors.button_background_hover),
243
+ "button_color": (cfg.colors.button_text, cfg.colors.button_background), # type: ignore
244
+ "mouseover_colors": (cfg.colors.button_text, cfg.colors.button_background_hover), # type: ignore
244
245
  "border_width": 0,
245
246
  **kwargs
246
247
  }
@@ -257,7 +258,7 @@ class GUIFactory:
257
258
  (e.g., `title_color='red'`, `relief=sg.RELIEF_SUNKEN`).
258
259
  """
259
260
  cfg = self.config
260
- font = (cfg.fonts.font_family, cfg.fonts.frame_size)
261
+ font = (cfg.fonts.font_family, cfg.fonts.frame_size) # type: ignore
261
262
 
262
263
  style_args = {
263
264
  "font": font,
@@ -289,7 +290,7 @@ class GUIFactory:
289
290
  """
290
291
  cfg = self.config
291
292
  bg_color = sg.theme_background_color()
292
- label_font = (cfg.fonts.font_family, cfg.fonts.label_size, cfg.fonts.label_style)
293
+ label_font = (cfg.fonts.font_family, cfg.fonts.label_size, cfg.fonts.label_style) # type: ignore
293
294
 
294
295
  columns = []
295
296
  for name, (val_min, val_max) in data_dict.items():
@@ -298,21 +299,21 @@ class GUIFactory:
298
299
 
299
300
  label = sg.Text(name, font=label_font, background_color=bg_color, key=f"_text_{name}")
300
301
 
301
- input_style = {"size": cfg.layout.input_size_cont, "justification": "center"}
302
+ input_style = {"size": cfg.layout.input_size_cont, "justification": "center"} # type: ignore
302
303
  if is_target:
303
- input_style["text_color"] = cfg.colors.target_text
304
- input_style["disabled_readonly_background_color"] = cfg.colors.target_background
304
+ input_style["text_color"] = cfg.colors.target_text # type: ignore
305
+ input_style["disabled_readonly_background_color"] = cfg.colors.target_background # type: ignore
305
306
 
306
307
  element = sg.Input(default_text, key=key, disabled=is_target, **input_style)
307
308
 
308
309
  if is_target:
309
310
  layout = [[label], [element]]
310
311
  else:
311
- range_font = (cfg.fonts.font_family, cfg.fonts.range_size)
312
+ range_font = (cfg.fonts.font_family, cfg.fonts.range_size) # type: ignore
312
313
  range_text = sg.Text(f"Range: {int(val_min)}-{int(val_max)}", font=range_font, background_color=bg_color)
313
314
  layout = [[label], [element], [range_text]]
314
315
 
315
- layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)])
316
+ layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)]) # type: ignore
316
317
  columns.append(sg.Column(layout, background_color=bg_color))
317
318
 
318
319
  if layout_mode == 'row':
@@ -340,17 +341,17 @@ class GUIFactory:
340
341
  """
341
342
  cfg = self.config
342
343
  bg_color = sg.theme_background_color()
343
- label_font = (cfg.fonts.font_family, cfg.fonts.label_size, cfg.fonts.label_style)
344
+ label_font = (cfg.fonts.font_family, cfg.fonts.label_size, cfg.fonts.label_style) # type: ignore
344
345
 
345
346
  columns = []
346
347
  for name, values in data_dict.items():
347
348
  label = sg.Text(name, font=label_font, background_color=bg_color, key=f"_text_{name}")
348
349
  element = sg.Combo(
349
350
  values, default_value=values[0], key=name,
350
- size=cfg.layout.input_size_binary, readonly=True
351
+ size=cfg.layout.input_size_binary, readonly=True # type: ignore
351
352
  )
352
353
  layout = [[label], [element]]
353
- layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)])
354
+ layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)]) # type: ignore
354
355
  columns.append(sg.Column(layout, background_color=bg_color))
355
356
 
356
357
  if layout_mode == 'row':
@@ -370,8 +371,8 @@ class GUIFactory:
370
371
  **kwargs: Additional arguments to pass to the sg.Window constructor
371
372
  (e.g., `location=(100, 100)`, `keep_on_top=True`).
372
373
  """
373
- cfg = self.config.general
374
- version = getattr(self.config.meta, 'version', None)
374
+ cfg = self.config.general # type: ignore
375
+ version = getattr(self.config.meta, 'version', None) # type: ignore
375
376
  full_title = f"{title} v{version}" if version else title
376
377
 
377
378
  window_args = {
@@ -406,9 +407,7 @@ def catch_exceptions(show_popup: bool = True):
406
407
  sg.popup_error("An error occurred:", error_msg, title="Error")
407
408
  else:
408
409
  # Fallback for non-GUI contexts or if popup is disabled
409
- print("--- An exception occurred ---")
410
- print(error_msg)
411
- print("-----------------------------")
410
+ _LOGGER.error(error_msg)
412
411
  return wrapper
413
412
  return decorator
414
413
 
@@ -6,6 +6,7 @@ import numpy as np
6
6
  from .utilities import load_dataframe, list_csv_paths, sanitize_filename, _script_info, merge_dataframes, save_dataframe, threshold_binary_values, make_fullpath
7
7
  from plotnine import ggplot, labs, theme, element_blank # type: ignore
8
8
  from typing import Optional, Union
9
+ from .logger import _LOGGER
9
10
 
10
11
 
11
12
  __all__ = [
@@ -40,7 +41,9 @@ def apply_mice(df: pd.DataFrame, df_name: str, binary_columns: Optional[list[str
40
41
  if binary_columns is not None:
41
42
  invalid_binary_columns = set(binary_columns) - set(df.columns)
42
43
  if invalid_binary_columns:
43
- print(f"⚠️ These 'binary columns' are not in the dataset: {invalid_binary_columns}")
44
+ _LOGGER.warning(f"⚠️ These 'binary columns' are not in the dataset:")
45
+ for invalid_binary_col in invalid_binary_columns:
46
+ print(f" - {invalid_binary_col}")
44
47
  valid_binary_columns = [col for col in binary_columns if col not in invalid_binary_columns]
45
48
  for imputed_df in imputed_datasets:
46
49
  for binary_column_name in valid_binary_columns:
@@ -125,7 +128,7 @@ def get_convergence_diagnostic(kernel: mf.ImputationKernel, imputed_dataset_name
125
128
  plt.savefig(save_path, bbox_inches='tight', format="svg")
126
129
  plt.close()
127
130
 
128
- print(f"{dataset_file_dir} completed.")
131
+ _LOGGER.info(f"{dataset_file_dir} completed.")
129
132
 
130
133
 
131
134
  # Imputed distributions
@@ -210,7 +213,7 @@ def get_imputed_distributions(kernel: mf.ImputationKernel, df_name: str, root_di
210
213
  fig = kernel.plot_imputed_distributions(variables=[feature])
211
214
  _process_figure(fig, feature)
212
215
 
213
- print(f"{local_dir_name} completed.")
216
+ _LOGGER.info(f"{local_dir_name} completed.")
214
217
 
215
218
 
216
219
  def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str],
@@ -240,7 +243,8 @@ def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str]
240
243
  all_file_paths = list(list_csv_paths(input_path).values())
241
244
 
242
245
  for df_path in all_file_paths:
243
- df, df_name = load_dataframe(df_path=df_path)
246
+ df: pd.DataFrame
247
+ df, df_name = load_dataframe(df_path=df_path, kind="pandas") # type: ignore
244
248
 
245
249
  df, df_targets = _skip_targets(df, target_columns)
246
250
 
@@ -0,0 +1,341 @@
1
+ import numpy as np
2
+ import torch
3
+ from tqdm.auto import tqdm
4
+ from .utilities import make_fullpath, LogKeys
5
+ from .logger import _LOGGER
6
+ from typing import Optional
7
+
8
+
9
+ __all__ = [
10
+ "Callback",
11
+ "History",
12
+ "TqdmProgressBar",
13
+ "EarlyStopping",
14
+ "ModelCheckpoint",
15
+ "LRScheduler"
16
+ ]
17
+
18
+
19
+ class Callback:
20
+ """
21
+ Abstract base class used to build new callbacks.
22
+
23
+ The methods of this class are automatically called by the Trainer at different
24
+ points during training. Subclasses can override these methods to implement
25
+ custom logic.
26
+ """
27
+ def __init__(self):
28
+ self.trainer = None
29
+
30
+ def set_trainer(self, trainer):
31
+ """This is called by the Trainer to associate itself with the callback."""
32
+ self.trainer = trainer
33
+
34
+ def on_train_begin(self, logs=None):
35
+ """Called at the beginning of training."""
36
+ pass
37
+
38
+ def on_train_end(self, logs=None):
39
+ """Called at the end of training."""
40
+ pass
41
+
42
+ def on_epoch_begin(self, epoch, logs=None):
43
+ """Called at the beginning of an epoch."""
44
+ pass
45
+
46
+ def on_epoch_end(self, epoch, logs=None):
47
+ """Called at the end of an epoch."""
48
+ pass
49
+
50
+ def on_batch_begin(self, batch, logs=None):
51
+ """Called at the beginning of a training batch."""
52
+ pass
53
+
54
+ def on_batch_end(self, batch, logs=None):
55
+ """Called at the end of a training batch."""
56
+ pass
57
+
58
+
59
+ class History(Callback):
60
+ """
61
+ Callback that records events into a `history` dictionary.
62
+
63
+ This callback is automatically applied to every MyTrainer model.
64
+ The `history` attribute is a dictionary mapping metric names (e.g., 'val_loss')
65
+ to a list of metric values.
66
+ """
67
+ def on_train_begin(self, logs=None):
68
+ # Clear history at the beginning of training
69
+ self.trainer.history = {} # type: ignore
70
+
71
+ def on_epoch_end(self, epoch, logs=None):
72
+ logs = logs or {}
73
+ for k, v in logs.items():
74
+ # Append new log values to the history dictionary
75
+ self.trainer.history.setdefault(k, []).append(v) # type: ignore
76
+
77
+
78
+ class TqdmProgressBar(Callback):
79
+ """Callback that provides a tqdm progress bar for training."""
80
+ def __init__(self):
81
+ self.epoch_bar = None
82
+ self.batch_bar = None
83
+
84
+ def on_train_begin(self, logs=None):
85
+ self.epochs = self.trainer.epochs # type: ignore
86
+ self.epoch_bar = tqdm(total=self.epochs, desc="Training Progress")
87
+
88
+ def on_epoch_begin(self, epoch, logs=None):
89
+ total_batches = len(self.trainer.train_loader) # type: ignore
90
+ self.batch_bar = tqdm(total=total_batches, desc=f"Epoch {epoch}/{self.epochs}", leave=False)
91
+
92
+ def on_batch_end(self, batch, logs=None):
93
+ self.batch_bar.update(1) # type: ignore
94
+ if logs:
95
+ self.batch_bar.set_postfix(loss=f"{logs.get(LogKeys.BATCH_LOSS, 0):.4f}") # type: ignore
96
+
97
+ def on_epoch_end(self, epoch, logs=None):
98
+ self.batch_bar.close() # type: ignore
99
+ self.epoch_bar.update(1) # type: ignore
100
+ if logs:
101
+ train_loss_str = f"{logs.get(LogKeys.TRAIN_LOSS, 0):.4f}"
102
+ val_loss_str = f"{logs.get(LogKeys.VAL_LOSS, 0):.4f}"
103
+ self.epoch_bar.set_postfix_str(f"Train Loss: {train_loss_str}, Val Loss: {val_loss_str}") # type: ignore
104
+
105
+ def on_train_end(self, logs=None):
106
+ self.epoch_bar.close() # type: ignore
107
+
108
+
109
+ class EarlyStopping(Callback):
110
+ """
111
+ Stop training when a monitored metric has stopped improving.
112
+
113
+ Args:
114
+ monitor (str): Quantity to be monitored. Defaults to 'val_loss'.
115
+ min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
116
+ patience (int): Number of epochs with no improvement after which training will be stopped.
117
+ mode (str): One of {'auto', 'min', 'max'}. In 'min' mode, training will stop when the quantity
118
+ monitored has stopped decreasing; in 'max' mode it will stop when the quantity
119
+ monitored has stopped increasing; in 'auto' mode, the direction is automatically
120
+ inferred from the name of the monitored quantity.
121
+ verbose (int): Verbosity mode.
122
+ """
123
+ def __init__(self, monitor: str=LogKeys.VAL_LOSS, min_delta=0.0, patience=3, mode='auto', verbose=1):
124
+ super().__init__()
125
+ self.monitor = monitor
126
+ self.patience = patience
127
+ self.min_delta = min_delta
128
+ self.wait = 0
129
+ self.stopped_epoch = 0
130
+ self.verbose = verbose
131
+
132
+ if mode not in ['auto', 'min', 'max']:
133
+ raise ValueError(f"EarlyStopping mode {mode} is unknown, choose one of ('auto', 'min', 'max')")
134
+ self.mode = mode
135
+
136
+ # Determine the comparison operator based on the mode
137
+ if self.mode == 'min':
138
+ self.monitor_op = np.less
139
+ elif self.mode == 'max':
140
+ self.monitor_op = np.greater
141
+ else: # auto mode
142
+ if 'acc' in self.monitor.lower():
143
+ self.monitor_op = np.greater
144
+ else: # Default to min mode for loss or other metrics
145
+ self.monitor_op = np.less
146
+
147
+ self.best = np.Inf if self.monitor_op == np.less else -np.Inf
148
+
149
+ def on_train_begin(self, logs=None):
150
+ # Reset state at the beginning of training
151
+ self.wait = 0
152
+ self.stopped_epoch = 0
153
+ self.best = np.Inf if self.monitor_op == np.less else -np.Inf
154
+
155
+ def on_epoch_end(self, epoch, logs=None):
156
+ current = logs.get(self.monitor) # type: ignore
157
+ if current is None:
158
+ return
159
+
160
+ # Determine the comparison threshold based on the mode
161
+ if self.monitor_op == np.less:
162
+ # For 'min' mode, we need to be smaller than 'best' by at least 'min_delta'
163
+ # Correct check: current < self.best - self.min_delta
164
+ is_improvement = self.monitor_op(current, self.best - self.min_delta)
165
+ else:
166
+ # For 'max' mode, we need to be greater than 'best' by at least 'min_delta'
167
+ # Correct check: current > self.best + self.min_delta
168
+ is_improvement = self.monitor_op(current, self.best + self.min_delta)
169
+
170
+ if is_improvement:
171
+ if self.verbose > 1:
172
+ _LOGGER.info(f"EarlyStopping: {self.monitor} improved from {self.best:.4f} to {current:.4f}")
173
+ self.best = current
174
+ self.wait = 0
175
+ else:
176
+ self.wait += 1
177
+ if self.wait >= self.patience:
178
+ self.stopped_epoch = epoch
179
+ self.trainer.stop_training = True # type: ignore
180
+ if self.verbose > 0:
181
+ print("")
182
+ _LOGGER.info(f"Epoch {epoch+1}: early stopping after {self.wait} epochs with no improvement.")
183
+
184
+
185
+ class ModelCheckpoint(Callback):
186
+ """
187
+ Saves the model to a directory with automated filename generation and rotation. The filename includes the epoch and score.
188
+
189
+ - If `save_best_only` is True, it saves the single best model, deleting the
190
+ previous best.
191
+ - If `save_best_only` is False, it keeps the 3 most recent checkpoints,
192
+ deleting the oldest ones automatically.
193
+
194
+ Args:
195
+ save_dir (str): Directory where checkpoint files will be saved.
196
+ monitor (str): Metric to monitor for `save_best_only=True`.
197
+ save_best_only (bool): If true, save only the best model.
198
+ mode (str): One of {'auto', 'min', 'max'}.
199
+ verbose (int): Verbosity mode.
200
+ """
201
+ def __init__(self, save_dir: str, monitor: str = LogKeys.VAL_LOSS,
202
+ save_best_only: bool = False, mode: str = 'auto', verbose: int = 1):
203
+ super().__init__()
204
+ self.save_dir = make_fullpath(save_dir, make=True)
205
+ if not self.save_dir.is_dir():
206
+ _LOGGER.error(f"{save_dir} is not a valid directory.")
207
+ raise IOError()
208
+
209
+ self.monitor = monitor
210
+ self.save_best_only = save_best_only
211
+ self.verbose = verbose
212
+
213
+ # State variables to be managed during training
214
+ self.saved_checkpoints = []
215
+ self.last_best_filepath = None
216
+
217
+ if mode not in ['auto', 'min', 'max']:
218
+ raise ValueError(f"ModelCheckpoint mode {mode} is unknown.")
219
+ self.mode = mode
220
+
221
+ if self.mode == 'min':
222
+ self.monitor_op = np.less
223
+ elif self.mode == 'max':
224
+ self.monitor_op = np.greater
225
+ else:
226
+ self.monitor_op = np.less if 'loss' in self.monitor else np.greater
227
+
228
+ self.best = np.Inf if self.monitor_op == np.less else -np.Inf
229
+
230
+ def on_train_begin(self, logs=None):
231
+ """Reset state when training starts."""
232
+ self.best = np.Inf if self.monitor_op == np.less else -np.Inf
233
+ self.saved_checkpoints = []
234
+ self.last_best_filepath = None
235
+
236
+ def on_epoch_end(self, epoch, logs=None):
237
+ logs = logs or {}
238
+ self.save_dir.mkdir(parents=True, exist_ok=True)
239
+
240
+ if self.save_best_only:
241
+ self._save_best_model(epoch, logs)
242
+ else:
243
+ self._save_rolling_checkpoints(epoch, logs)
244
+
245
+ def _save_best_model(self, epoch, logs):
246
+ """Saves a single best model and deletes the previous one."""
247
+ current = logs.get(self.monitor)
248
+ if current is None:
249
+ return
250
+
251
+ if self.monitor_op(current, self.best):
252
+ old_best_str = f"{self.best:.4f}" if self.best not in [np.Inf, -np.Inf] else "inf"
253
+
254
+ # Create a descriptive filename
255
+ filename = f"epoch_{epoch}-{self.monitor}_{current:.4f}.pth"
256
+ new_filepath = self.save_dir / filename
257
+
258
+ if self.verbose > 0:
259
+ print("")
260
+ _LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current:.4f}, saving model to {new_filepath}")
261
+
262
+ # Save the new best model
263
+ torch.save(self.trainer.model.state_dict(), new_filepath) # type: ignore
264
+
265
+ # Delete the old best model file
266
+ if self.last_best_filepath and self.last_best_filepath.exists():
267
+ self.last_best_filepath.unlink()
268
+
269
+ # Update state
270
+ self.best = current
271
+ self.last_best_filepath = new_filepath
272
+
273
+ def _save_rolling_checkpoints(self, epoch, logs):
274
+ """Saves the latest model and keeps only the last 5."""
275
+ filename = f"epoch_{epoch}.pth"
276
+ filepath = self.save_dir / filename
277
+
278
+ if self.verbose > 0:
279
+ print("")
280
+ _LOGGER.info(f'Epoch {epoch}: saving model to {filepath}')
281
+ torch.save(self.trainer.model.state_dict(), filepath) # type: ignore
282
+
283
+ self.saved_checkpoints.append(filepath)
284
+
285
+ # If we have more than n checkpoints, remove the oldest one
286
+ if len(self.saved_checkpoints) > 3:
287
+ file_to_delete = self.saved_checkpoints.pop(0)
288
+ if file_to_delete.exists():
289
+ if self.verbose > 0:
290
+ _LOGGER.info(f" -> Deleting old checkpoint: {file_to_delete.name}")
291
+ file_to_delete.unlink()
292
+
293
+
294
+ class LRScheduler(Callback):
295
+ """
296
+ Callback to manage a PyTorch learning rate scheduler.
297
+
298
+ This callback automatically calls the scheduler's `step()` method at the
299
+ end of each epoch. It also logs a message when the learning rate changes.
300
+
301
+ Args:
302
+ scheduler: An initialized PyTorch learning rate scheduler.
303
+ monitor (str, optional): The metric to monitor for schedulers that
304
+ require it, like `ReduceLROnPlateau`.
305
+ Should match a key in the logs (e.g., 'val_loss').
306
+ """
307
+ def __init__(self, scheduler, monitor: Optional[str] = None):
308
+ super().__init__()
309
+ self.scheduler = scheduler
310
+ self.monitor = monitor
311
+ self.previous_lr = None
312
+
313
+ def on_train_begin(self, logs=None):
314
+ """Store the initial learning rate."""
315
+ self.previous_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
316
+
317
+ def on_epoch_end(self, epoch, logs=None):
318
+ """Step the scheduler and log any change in learning rate."""
319
+ # For schedulers that need a metric (e.g., val_loss)
320
+ if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
321
+ if self.monitor is None:
322
+ raise ValueError("LRScheduler needs a `monitor` metric for ReduceLROnPlateau.")
323
+
324
+ metric_val = logs.get(self.monitor) # type: ignore
325
+ if metric_val is not None:
326
+ self.scheduler.step(metric_val)
327
+ else:
328
+ print("")
329
+ _LOGGER.warning(f"LRScheduler could not find metric '{self.monitor}' in logs.")
330
+
331
+ # For all other schedulers
332
+ else:
333
+ self.scheduler.step()
334
+
335
+ # Log the change if the LR was updated
336
+ current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
337
+ if current_lr != self.previous_lr:
338
+ print("")
339
+ _LOGGER.info(f"Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
340
+ self.previous_lr = current_lr
341
+