dragon-ml-toolbox 2.4.0__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.0.0.dist-info}/METADATA +7 -4
- dragon_ml_toolbox-3.0.0.dist-info/RECORD +25 -0
- ml_tools/ETL_engineering.py +8 -7
- ml_tools/GUI_tools.py +24 -25
- ml_tools/MICE_imputation.py +8 -4
- ml_tools/ML_callbacks.py +341 -0
- ml_tools/ML_evaluation.py +255 -0
- ml_tools/ML_trainer.py +344 -0
- ml_tools/ML_tutorial.py +300 -0
- ml_tools/PSO_optimization.py +27 -20
- ml_tools/RNN_forecast.py +49 -0
- ml_tools/VIF_factor.py +6 -5
- ml_tools/datasetmaster.py +601 -527
- ml_tools/ensemble_learning.py +12 -9
- ml_tools/handle_excel.py +9 -10
- ml_tools/logger.py +45 -8
- ml_tools/utilities.py +18 -1
- dragon_ml_toolbox-2.4.0.dist-info/RECORD +0 -22
- ml_tools/trainer.py +0 -346
- ml_tools/vision_helpers.py +0 -231
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-2.4.0.dist-info → dragon_ml_toolbox-3.0.0.dist-info}/top_level.txt +0 -0
- /ml_tools/{pytorch_models.py → _pytorch_models.py} +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dragon-ml-toolbox
|
|
3
|
-
Version:
|
|
4
|
-
Summary: A collection of tools for data science and machine learning projects
|
|
3
|
+
Version: 3.0.0
|
|
4
|
+
Summary: A collection of tools for data science and machine learning projects.
|
|
5
5
|
Author-email: Karl Loza <luigiloza@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
7
7
|
Project-URL: Homepage, https://github.com/DrAg0n-BoRn/ML_tools
|
|
@@ -125,9 +125,12 @@ GUI_tools
|
|
|
125
125
|
handle_excel
|
|
126
126
|
logger
|
|
127
127
|
MICE_imputation
|
|
128
|
+
ML_callbacks
|
|
129
|
+
ML_evaluation
|
|
130
|
+
ML_trainer
|
|
131
|
+
ML_tutorial
|
|
128
132
|
PSO_optimization
|
|
129
|
-
|
|
133
|
+
RNN_forecast
|
|
130
134
|
utilities
|
|
131
135
|
VIF_factor
|
|
132
|
-
vision_helpers
|
|
133
136
|
```
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
dragon_ml_toolbox-3.0.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
|
|
2
|
+
dragon_ml_toolbox-3.0.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=6cfpIeQ6D4Mcs10nkogQrkVyq1T7i2qXjjNHFoUMOyE,1892
|
|
3
|
+
ml_tools/ETL_engineering.py,sha256=SRiloWhSpopS4ay8mzUu0H4e9-37Ox_jDHzODqsQ8pc,31642
|
|
4
|
+
ml_tools/GUI_tools.py,sha256=uFx6zIrQZzDPSTtOSHz8ptz-fxZiQz-lXHcrqwuYV_E,20385
|
|
5
|
+
ml_tools/MICE_imputation.py,sha256=ed-YeQkEAeHxTNkWIHs09T4YeYNF0aqAnrUTcdIEp9E,11372
|
|
6
|
+
ml_tools/ML_callbacks.py,sha256=gHZk-lyzAax6iEtG26zHuoobdAZCFJ6BmI6pWoXkOrw,13189
|
|
7
|
+
ml_tools/ML_evaluation.py,sha256=3xOqVXLJDhbioKZ922yxFnSuO4VDQ-HFzZyZZ1MskVM,10054
|
|
8
|
+
ml_tools/ML_trainer.py,sha256=zRs3crz_z4B285iJhmY7m4AFwnvvq4urOyl4zDuCLtA,14456
|
|
9
|
+
ml_tools/ML_tutorial.py,sha256=-9tJO9ISPxEjRINVaF_Bu7tiiJ2W3zznQ4gNlZeP1HQ,12238
|
|
10
|
+
ml_tools/PSO_optimization.py,sha256=RCvIFGyf28voo2mpbRKC6LfDzKslzY-aYoPwgv9F4Bg,25458
|
|
11
|
+
ml_tools/RNN_forecast.py,sha256=IZLcPs3by0Chei7ill_Grjxs7BBUnzau0Oavi3dWiyE,1886
|
|
12
|
+
ml_tools/VIF_factor.py,sha256=5GVAldH69Vkei3WRUZN1uPBMzGoOOeEOA-bgmZXbbUw,10301
|
|
13
|
+
ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
+
ml_tools/_particle_swarm_optimization.py,sha256=b_eNNkA89Y40hj76KauivT8KLScH1B9wF2IXptOqkOw,22220
|
|
15
|
+
ml_tools/_pytorch_models.py,sha256=bpWZsrSwCvHJQkR6UfoPpElsMv9AvmiNErNHC8NYB_I,10132
|
|
16
|
+
ml_tools/data_exploration.py,sha256=Fzbz_DKZ7F2e3-JbahLqKr3aP6lt9aCK9rNOHvR7nlA,23665
|
|
17
|
+
ml_tools/datasetmaster.py,sha256=N-uwfzWnl_qnoAqjbfS98I1pVNra5u6rhKLdWbFIReA,30122
|
|
18
|
+
ml_tools/ensemble_learning.py,sha256=PPtBBLgLvaYOdY-MlcjXuxWWXf3JQavLNEysFgzjc_s,37470
|
|
19
|
+
ml_tools/handle_excel.py,sha256=lwds7rDLlGSCWiWGI7xNg-Z7kxAepogp0lstSFa0590,12949
|
|
20
|
+
ml_tools/logger.py,sha256=jC4Q2OqmDm8ZO9VpuZqBSWdXryqaJvLscqVJ6caNMOk,6009
|
|
21
|
+
ml_tools/utilities.py,sha256=opNR-ACH6BnLkWAKcb19ef5tFxfx22TI6E2o0RYwiGA,21021
|
|
22
|
+
dragon_ml_toolbox-3.0.0.dist-info/METADATA,sha256=nmhUu0bwN4z1letePaDzGIQlmDUaBQ32esqGB-OasU4,3273
|
|
23
|
+
dragon_ml_toolbox-3.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
24
|
+
dragon_ml_toolbox-3.0.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
25
|
+
dragon_ml_toolbox-3.0.0.dist-info/RECORD,,
|
ml_tools/ETL_engineering.py
CHANGED
|
@@ -3,17 +3,18 @@ import re
|
|
|
3
3
|
from typing import Literal, Union, Optional, Any, Callable, List, Dict
|
|
4
4
|
from .utilities import _script_info
|
|
5
5
|
import pandas as pd
|
|
6
|
+
from .logger import _LOGGER
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
__all__ = [
|
|
9
10
|
"ColumnCleaner",
|
|
10
|
-
"DataFrameCleaner"
|
|
11
|
+
"DataFrameCleaner",
|
|
11
12
|
"TransformationRecipe",
|
|
12
13
|
"DataProcessor",
|
|
13
14
|
"KeywordDummifier",
|
|
14
15
|
"NumberExtractor",
|
|
15
16
|
"MultiNumberExtractor",
|
|
16
|
-
"RatioCalculator"
|
|
17
|
+
"RatioCalculator",
|
|
17
18
|
"CategoryMapper",
|
|
18
19
|
"RegexMapper",
|
|
19
20
|
"ValueBinner",
|
|
@@ -251,7 +252,7 @@ class DataProcessor:
|
|
|
251
252
|
raise TypeError(f"Invalid 'transform' action for '{input_col_name}': {transform_action}")
|
|
252
253
|
|
|
253
254
|
if not processed_columns:
|
|
254
|
-
|
|
255
|
+
_LOGGER.warning("The transformation resulted in an empty DataFrame.")
|
|
255
256
|
return pl.DataFrame()
|
|
256
257
|
|
|
257
258
|
return pl.DataFrame(processed_columns)
|
|
@@ -403,7 +404,7 @@ class NumberExtractor:
|
|
|
403
404
|
if not isinstance(round_digits, int):
|
|
404
405
|
raise TypeError("round_digits must be an integer.")
|
|
405
406
|
if dtype == "int":
|
|
406
|
-
|
|
407
|
+
_LOGGER.warning(f"'round_digits' is specified but dtype is 'int'. Rounding will be ignored.")
|
|
407
408
|
|
|
408
409
|
self.regex_pattern = regex_pattern
|
|
409
410
|
self.dtype = dtype
|
|
@@ -561,9 +562,9 @@ class RatioCalculator:
|
|
|
561
562
|
denominator = groups.struct.field("group_2").cast(pl.Float64, strict=False)
|
|
562
563
|
|
|
563
564
|
# Safely perform division, returning null if denominator is 0
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
).
|
|
565
|
+
final_expr = pl.when(denominator != 0).then(numerator / denominator).otherwise(None)
|
|
566
|
+
|
|
567
|
+
return pl.select(final_expr).to_series()
|
|
567
568
|
|
|
568
569
|
|
|
569
570
|
class CategoryMapper:
|
ml_tools/GUI_tools.py
CHANGED
|
@@ -7,6 +7,7 @@ from functools import wraps
|
|
|
7
7
|
from typing import Any, Dict, Tuple, List
|
|
8
8
|
from .utilities import _script_info
|
|
9
9
|
import numpy as np
|
|
10
|
+
from .logger import _LOGGER
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
__all__ = [
|
|
@@ -46,7 +47,7 @@ class PathManager:
|
|
|
46
47
|
if self._is_bundled:
|
|
47
48
|
# In a Briefcase bundle, resource_path gives an absolute path
|
|
48
49
|
# to the resource directory.
|
|
49
|
-
self.package_root = self._resource_path_func(self.package_name, "")
|
|
50
|
+
self.package_root = self._resource_path_func(self.package_name, "") # type: ignore
|
|
50
51
|
else:
|
|
51
52
|
# In development mode, the package root is the directory
|
|
52
53
|
# containing the anchor file.
|
|
@@ -56,7 +57,7 @@ class PathManager:
|
|
|
56
57
|
"""Checks if the app is running in a bundled environment."""
|
|
57
58
|
try:
|
|
58
59
|
# This is the function Briefcase provides in a bundled app
|
|
59
|
-
from briefcase.platforms.base import resource_path
|
|
60
|
+
from briefcase.platforms.base import resource_path # type: ignore
|
|
60
61
|
return True, resource_path
|
|
61
62
|
except ImportError:
|
|
62
63
|
return False, None
|
|
@@ -147,7 +148,7 @@ class ConfigManager:
|
|
|
147
148
|
"""
|
|
148
149
|
path = Path(file_path)
|
|
149
150
|
if path.exists() and not force_overwrite:
|
|
150
|
-
|
|
151
|
+
_LOGGER.warning(f"Configuration file already exists at {path}. Aborting.")
|
|
151
152
|
return
|
|
152
153
|
|
|
153
154
|
config = configparser.ConfigParser()
|
|
@@ -205,7 +206,7 @@ class ConfigManager:
|
|
|
205
206
|
|
|
206
207
|
with open(path, 'w') as configfile:
|
|
207
208
|
config.write(configfile)
|
|
208
|
-
|
|
209
|
+
_LOGGER.info(f"Successfully generated config template at: '{path}'")
|
|
209
210
|
|
|
210
211
|
|
|
211
212
|
# --- GUI Factory ---
|
|
@@ -219,8 +220,8 @@ class GUIFactory:
|
|
|
219
220
|
Initializes the factory with a configuration object.
|
|
220
221
|
"""
|
|
221
222
|
self.config = config
|
|
222
|
-
sg.theme(self.config.general.theme)
|
|
223
|
-
sg.set_options(font=(self.config.general.font_family, 12))
|
|
223
|
+
sg.theme(self.config.general.theme) # type: ignore
|
|
224
|
+
sg.set_options(font=(self.config.general.font_family, 12)) # type: ignore
|
|
224
225
|
|
|
225
226
|
# --- Atomic Element Generators ---
|
|
226
227
|
def make_button(self, text: str, key: str, **kwargs) -> sg.Button:
|
|
@@ -234,13 +235,13 @@ class GUIFactory:
|
|
|
234
235
|
(e.g., `tooltip='Click me'`, `disabled=True`).
|
|
235
236
|
"""
|
|
236
237
|
cfg = self.config
|
|
237
|
-
font = (cfg.fonts.font_family, cfg.fonts.button_size, cfg.fonts.button_style)
|
|
238
|
+
font = (cfg.fonts.font_family, cfg.fonts.button_size, cfg.fonts.button_style) # type: ignore
|
|
238
239
|
|
|
239
240
|
style_args = {
|
|
240
|
-
"size": cfg.layout.button_size,
|
|
241
|
+
"size": cfg.layout.button_size, # type: ignore
|
|
241
242
|
"font": font,
|
|
242
|
-
"button_color": (cfg.colors.button_text, cfg.colors.button_background),
|
|
243
|
-
"mouseover_colors": (cfg.colors.button_text, cfg.colors.button_background_hover),
|
|
243
|
+
"button_color": (cfg.colors.button_text, cfg.colors.button_background), # type: ignore
|
|
244
|
+
"mouseover_colors": (cfg.colors.button_text, cfg.colors.button_background_hover), # type: ignore
|
|
244
245
|
"border_width": 0,
|
|
245
246
|
**kwargs
|
|
246
247
|
}
|
|
@@ -257,7 +258,7 @@ class GUIFactory:
|
|
|
257
258
|
(e.g., `title_color='red'`, `relief=sg.RELIEF_SUNKEN`).
|
|
258
259
|
"""
|
|
259
260
|
cfg = self.config
|
|
260
|
-
font = (cfg.fonts.font_family, cfg.fonts.frame_size)
|
|
261
|
+
font = (cfg.fonts.font_family, cfg.fonts.frame_size) # type: ignore
|
|
261
262
|
|
|
262
263
|
style_args = {
|
|
263
264
|
"font": font,
|
|
@@ -289,7 +290,7 @@ class GUIFactory:
|
|
|
289
290
|
"""
|
|
290
291
|
cfg = self.config
|
|
291
292
|
bg_color = sg.theme_background_color()
|
|
292
|
-
label_font = (cfg.fonts.font_family, cfg.fonts.label_size, cfg.fonts.label_style)
|
|
293
|
+
label_font = (cfg.fonts.font_family, cfg.fonts.label_size, cfg.fonts.label_style) # type: ignore
|
|
293
294
|
|
|
294
295
|
columns = []
|
|
295
296
|
for name, (val_min, val_max) in data_dict.items():
|
|
@@ -298,21 +299,21 @@ class GUIFactory:
|
|
|
298
299
|
|
|
299
300
|
label = sg.Text(name, font=label_font, background_color=bg_color, key=f"_text_{name}")
|
|
300
301
|
|
|
301
|
-
input_style = {"size": cfg.layout.input_size_cont, "justification": "center"}
|
|
302
|
+
input_style = {"size": cfg.layout.input_size_cont, "justification": "center"} # type: ignore
|
|
302
303
|
if is_target:
|
|
303
|
-
input_style["text_color"] = cfg.colors.target_text
|
|
304
|
-
input_style["disabled_readonly_background_color"] = cfg.colors.target_background
|
|
304
|
+
input_style["text_color"] = cfg.colors.target_text # type: ignore
|
|
305
|
+
input_style["disabled_readonly_background_color"] = cfg.colors.target_background # type: ignore
|
|
305
306
|
|
|
306
307
|
element = sg.Input(default_text, key=key, disabled=is_target, **input_style)
|
|
307
308
|
|
|
308
309
|
if is_target:
|
|
309
310
|
layout = [[label], [element]]
|
|
310
311
|
else:
|
|
311
|
-
range_font = (cfg.fonts.font_family, cfg.fonts.range_size)
|
|
312
|
+
range_font = (cfg.fonts.font_family, cfg.fonts.range_size) # type: ignore
|
|
312
313
|
range_text = sg.Text(f"Range: {int(val_min)}-{int(val_max)}", font=range_font, background_color=bg_color)
|
|
313
314
|
layout = [[label], [element], [range_text]]
|
|
314
315
|
|
|
315
|
-
layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)])
|
|
316
|
+
layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)]) # type: ignore
|
|
316
317
|
columns.append(sg.Column(layout, background_color=bg_color))
|
|
317
318
|
|
|
318
319
|
if layout_mode == 'row':
|
|
@@ -340,17 +341,17 @@ class GUIFactory:
|
|
|
340
341
|
"""
|
|
341
342
|
cfg = self.config
|
|
342
343
|
bg_color = sg.theme_background_color()
|
|
343
|
-
label_font = (cfg.fonts.font_family, cfg.fonts.label_size, cfg.fonts.label_style)
|
|
344
|
+
label_font = (cfg.fonts.font_family, cfg.fonts.label_size, cfg.fonts.label_style) # type: ignore
|
|
344
345
|
|
|
345
346
|
columns = []
|
|
346
347
|
for name, values in data_dict.items():
|
|
347
348
|
label = sg.Text(name, font=label_font, background_color=bg_color, key=f"_text_{name}")
|
|
348
349
|
element = sg.Combo(
|
|
349
350
|
values, default_value=values[0], key=name,
|
|
350
|
-
size=cfg.layout.input_size_binary, readonly=True
|
|
351
|
+
size=cfg.layout.input_size_binary, readonly=True # type: ignore
|
|
351
352
|
)
|
|
352
353
|
layout = [[label], [element]]
|
|
353
|
-
layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)])
|
|
354
|
+
layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)]) # type: ignore
|
|
354
355
|
columns.append(sg.Column(layout, background_color=bg_color))
|
|
355
356
|
|
|
356
357
|
if layout_mode == 'row':
|
|
@@ -370,8 +371,8 @@ class GUIFactory:
|
|
|
370
371
|
**kwargs: Additional arguments to pass to the sg.Window constructor
|
|
371
372
|
(e.g., `location=(100, 100)`, `keep_on_top=True`).
|
|
372
373
|
"""
|
|
373
|
-
cfg = self.config.general
|
|
374
|
-
version = getattr(self.config.meta, 'version', None)
|
|
374
|
+
cfg = self.config.general # type: ignore
|
|
375
|
+
version = getattr(self.config.meta, 'version', None) # type: ignore
|
|
375
376
|
full_title = f"{title} v{version}" if version else title
|
|
376
377
|
|
|
377
378
|
window_args = {
|
|
@@ -406,9 +407,7 @@ def catch_exceptions(show_popup: bool = True):
|
|
|
406
407
|
sg.popup_error("An error occurred:", error_msg, title="Error")
|
|
407
408
|
else:
|
|
408
409
|
# Fallback for non-GUI contexts or if popup is disabled
|
|
409
|
-
|
|
410
|
-
print(error_msg)
|
|
411
|
-
print("-----------------------------")
|
|
410
|
+
_LOGGER.error(error_msg)
|
|
412
411
|
return wrapper
|
|
413
412
|
return decorator
|
|
414
413
|
|
ml_tools/MICE_imputation.py
CHANGED
|
@@ -6,6 +6,7 @@ import numpy as np
|
|
|
6
6
|
from .utilities import load_dataframe, list_csv_paths, sanitize_filename, _script_info, merge_dataframes, save_dataframe, threshold_binary_values, make_fullpath
|
|
7
7
|
from plotnine import ggplot, labs, theme, element_blank # type: ignore
|
|
8
8
|
from typing import Optional, Union
|
|
9
|
+
from .logger import _LOGGER
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
__all__ = [
|
|
@@ -40,7 +41,9 @@ def apply_mice(df: pd.DataFrame, df_name: str, binary_columns: Optional[list[str
|
|
|
40
41
|
if binary_columns is not None:
|
|
41
42
|
invalid_binary_columns = set(binary_columns) - set(df.columns)
|
|
42
43
|
if invalid_binary_columns:
|
|
43
|
-
|
|
44
|
+
_LOGGER.warning(f"⚠️ These 'binary columns' are not in the dataset:")
|
|
45
|
+
for invalid_binary_col in invalid_binary_columns:
|
|
46
|
+
print(f" - {invalid_binary_col}")
|
|
44
47
|
valid_binary_columns = [col for col in binary_columns if col not in invalid_binary_columns]
|
|
45
48
|
for imputed_df in imputed_datasets:
|
|
46
49
|
for binary_column_name in valid_binary_columns:
|
|
@@ -125,7 +128,7 @@ def get_convergence_diagnostic(kernel: mf.ImputationKernel, imputed_dataset_name
|
|
|
125
128
|
plt.savefig(save_path, bbox_inches='tight', format="svg")
|
|
126
129
|
plt.close()
|
|
127
130
|
|
|
128
|
-
|
|
131
|
+
_LOGGER.info(f"{dataset_file_dir} completed.")
|
|
129
132
|
|
|
130
133
|
|
|
131
134
|
# Imputed distributions
|
|
@@ -210,7 +213,7 @@ def get_imputed_distributions(kernel: mf.ImputationKernel, df_name: str, root_di
|
|
|
210
213
|
fig = kernel.plot_imputed_distributions(variables=[feature])
|
|
211
214
|
_process_figure(fig, feature)
|
|
212
215
|
|
|
213
|
-
|
|
216
|
+
_LOGGER.info(f"{local_dir_name} completed.")
|
|
214
217
|
|
|
215
218
|
|
|
216
219
|
def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str],
|
|
@@ -240,7 +243,8 @@ def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str]
|
|
|
240
243
|
all_file_paths = list(list_csv_paths(input_path).values())
|
|
241
244
|
|
|
242
245
|
for df_path in all_file_paths:
|
|
243
|
-
df
|
|
246
|
+
df: pd.DataFrame
|
|
247
|
+
df, df_name = load_dataframe(df_path=df_path, kind="pandas") # type: ignore
|
|
244
248
|
|
|
245
249
|
df, df_targets = _skip_targets(df, target_columns)
|
|
246
250
|
|
ml_tools/ML_callbacks.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from tqdm.auto import tqdm
|
|
4
|
+
from .utilities import make_fullpath, LogKeys
|
|
5
|
+
from .logger import _LOGGER
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Callback",
|
|
11
|
+
"History",
|
|
12
|
+
"TqdmProgressBar",
|
|
13
|
+
"EarlyStopping",
|
|
14
|
+
"ModelCheckpoint",
|
|
15
|
+
"LRScheduler"
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Callback:
|
|
20
|
+
"""
|
|
21
|
+
Abstract base class used to build new callbacks.
|
|
22
|
+
|
|
23
|
+
The methods of this class are automatically called by the Trainer at different
|
|
24
|
+
points during training. Subclasses can override these methods to implement
|
|
25
|
+
custom logic.
|
|
26
|
+
"""
|
|
27
|
+
def __init__(self):
|
|
28
|
+
self.trainer = None
|
|
29
|
+
|
|
30
|
+
def set_trainer(self, trainer):
|
|
31
|
+
"""This is called by the Trainer to associate itself with the callback."""
|
|
32
|
+
self.trainer = trainer
|
|
33
|
+
|
|
34
|
+
def on_train_begin(self, logs=None):
|
|
35
|
+
"""Called at the beginning of training."""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def on_train_end(self, logs=None):
|
|
39
|
+
"""Called at the end of training."""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
def on_epoch_begin(self, epoch, logs=None):
|
|
43
|
+
"""Called at the beginning of an epoch."""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
47
|
+
"""Called at the end of an epoch."""
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
def on_batch_begin(self, batch, logs=None):
|
|
51
|
+
"""Called at the beginning of a training batch."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def on_batch_end(self, batch, logs=None):
|
|
55
|
+
"""Called at the end of a training batch."""
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class History(Callback):
|
|
60
|
+
"""
|
|
61
|
+
Callback that records events into a `history` dictionary.
|
|
62
|
+
|
|
63
|
+
This callback is automatically applied to every MyTrainer model.
|
|
64
|
+
The `history` attribute is a dictionary mapping metric names (e.g., 'val_loss')
|
|
65
|
+
to a list of metric values.
|
|
66
|
+
"""
|
|
67
|
+
def on_train_begin(self, logs=None):
|
|
68
|
+
# Clear history at the beginning of training
|
|
69
|
+
self.trainer.history = {} # type: ignore
|
|
70
|
+
|
|
71
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
72
|
+
logs = logs or {}
|
|
73
|
+
for k, v in logs.items():
|
|
74
|
+
# Append new log values to the history dictionary
|
|
75
|
+
self.trainer.history.setdefault(k, []).append(v) # type: ignore
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class TqdmProgressBar(Callback):
|
|
79
|
+
"""Callback that provides a tqdm progress bar for training."""
|
|
80
|
+
def __init__(self):
|
|
81
|
+
self.epoch_bar = None
|
|
82
|
+
self.batch_bar = None
|
|
83
|
+
|
|
84
|
+
def on_train_begin(self, logs=None):
|
|
85
|
+
self.epochs = self.trainer.epochs # type: ignore
|
|
86
|
+
self.epoch_bar = tqdm(total=self.epochs, desc="Training Progress")
|
|
87
|
+
|
|
88
|
+
def on_epoch_begin(self, epoch, logs=None):
|
|
89
|
+
total_batches = len(self.trainer.train_loader) # type: ignore
|
|
90
|
+
self.batch_bar = tqdm(total=total_batches, desc=f"Epoch {epoch}/{self.epochs}", leave=False)
|
|
91
|
+
|
|
92
|
+
def on_batch_end(self, batch, logs=None):
|
|
93
|
+
self.batch_bar.update(1) # type: ignore
|
|
94
|
+
if logs:
|
|
95
|
+
self.batch_bar.set_postfix(loss=f"{logs.get(LogKeys.BATCH_LOSS, 0):.4f}") # type: ignore
|
|
96
|
+
|
|
97
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
98
|
+
self.batch_bar.close() # type: ignore
|
|
99
|
+
self.epoch_bar.update(1) # type: ignore
|
|
100
|
+
if logs:
|
|
101
|
+
train_loss_str = f"{logs.get(LogKeys.TRAIN_LOSS, 0):.4f}"
|
|
102
|
+
val_loss_str = f"{logs.get(LogKeys.VAL_LOSS, 0):.4f}"
|
|
103
|
+
self.epoch_bar.set_postfix_str(f"Train Loss: {train_loss_str}, Val Loss: {val_loss_str}") # type: ignore
|
|
104
|
+
|
|
105
|
+
def on_train_end(self, logs=None):
|
|
106
|
+
self.epoch_bar.close() # type: ignore
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class EarlyStopping(Callback):
|
|
110
|
+
"""
|
|
111
|
+
Stop training when a monitored metric has stopped improving.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
monitor (str): Quantity to be monitored. Defaults to 'val_loss'.
|
|
115
|
+
min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
|
116
|
+
patience (int): Number of epochs with no improvement after which training will be stopped.
|
|
117
|
+
mode (str): One of {'auto', 'min', 'max'}. In 'min' mode, training will stop when the quantity
|
|
118
|
+
monitored has stopped decreasing; in 'max' mode it will stop when the quantity
|
|
119
|
+
monitored has stopped increasing; in 'auto' mode, the direction is automatically
|
|
120
|
+
inferred from the name of the monitored quantity.
|
|
121
|
+
verbose (int): Verbosity mode.
|
|
122
|
+
"""
|
|
123
|
+
def __init__(self, monitor: str=LogKeys.VAL_LOSS, min_delta=0.0, patience=3, mode='auto', verbose=1):
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.monitor = monitor
|
|
126
|
+
self.patience = patience
|
|
127
|
+
self.min_delta = min_delta
|
|
128
|
+
self.wait = 0
|
|
129
|
+
self.stopped_epoch = 0
|
|
130
|
+
self.verbose = verbose
|
|
131
|
+
|
|
132
|
+
if mode not in ['auto', 'min', 'max']:
|
|
133
|
+
raise ValueError(f"EarlyStopping mode {mode} is unknown, choose one of ('auto', 'min', 'max')")
|
|
134
|
+
self.mode = mode
|
|
135
|
+
|
|
136
|
+
# Determine the comparison operator based on the mode
|
|
137
|
+
if self.mode == 'min':
|
|
138
|
+
self.monitor_op = np.less
|
|
139
|
+
elif self.mode == 'max':
|
|
140
|
+
self.monitor_op = np.greater
|
|
141
|
+
else: # auto mode
|
|
142
|
+
if 'acc' in self.monitor.lower():
|
|
143
|
+
self.monitor_op = np.greater
|
|
144
|
+
else: # Default to min mode for loss or other metrics
|
|
145
|
+
self.monitor_op = np.less
|
|
146
|
+
|
|
147
|
+
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
|
148
|
+
|
|
149
|
+
def on_train_begin(self, logs=None):
|
|
150
|
+
# Reset state at the beginning of training
|
|
151
|
+
self.wait = 0
|
|
152
|
+
self.stopped_epoch = 0
|
|
153
|
+
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
|
154
|
+
|
|
155
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
156
|
+
current = logs.get(self.monitor) # type: ignore
|
|
157
|
+
if current is None:
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
# Determine the comparison threshold based on the mode
|
|
161
|
+
if self.monitor_op == np.less:
|
|
162
|
+
# For 'min' mode, we need to be smaller than 'best' by at least 'min_delta'
|
|
163
|
+
# Correct check: current < self.best - self.min_delta
|
|
164
|
+
is_improvement = self.monitor_op(current, self.best - self.min_delta)
|
|
165
|
+
else:
|
|
166
|
+
# For 'max' mode, we need to be greater than 'best' by at least 'min_delta'
|
|
167
|
+
# Correct check: current > self.best + self.min_delta
|
|
168
|
+
is_improvement = self.monitor_op(current, self.best + self.min_delta)
|
|
169
|
+
|
|
170
|
+
if is_improvement:
|
|
171
|
+
if self.verbose > 1:
|
|
172
|
+
_LOGGER.info(f"EarlyStopping: {self.monitor} improved from {self.best:.4f} to {current:.4f}")
|
|
173
|
+
self.best = current
|
|
174
|
+
self.wait = 0
|
|
175
|
+
else:
|
|
176
|
+
self.wait += 1
|
|
177
|
+
if self.wait >= self.patience:
|
|
178
|
+
self.stopped_epoch = epoch
|
|
179
|
+
self.trainer.stop_training = True # type: ignore
|
|
180
|
+
if self.verbose > 0:
|
|
181
|
+
print("")
|
|
182
|
+
_LOGGER.info(f"Epoch {epoch+1}: early stopping after {self.wait} epochs with no improvement.")
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class ModelCheckpoint(Callback):
|
|
186
|
+
"""
|
|
187
|
+
Saves the model to a directory with automated filename generation and rotation. The filename includes the epoch and score.
|
|
188
|
+
|
|
189
|
+
- If `save_best_only` is True, it saves the single best model, deleting the
|
|
190
|
+
previous best.
|
|
191
|
+
- If `save_best_only` is False, it keeps the 3 most recent checkpoints,
|
|
192
|
+
deleting the oldest ones automatically.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
save_dir (str): Directory where checkpoint files will be saved.
|
|
196
|
+
monitor (str): Metric to monitor for `save_best_only=True`.
|
|
197
|
+
save_best_only (bool): If true, save only the best model.
|
|
198
|
+
mode (str): One of {'auto', 'min', 'max'}.
|
|
199
|
+
verbose (int): Verbosity mode.
|
|
200
|
+
"""
|
|
201
|
+
def __init__(self, save_dir: str, monitor: str = LogKeys.VAL_LOSS,
|
|
202
|
+
save_best_only: bool = False, mode: str = 'auto', verbose: int = 1):
|
|
203
|
+
super().__init__()
|
|
204
|
+
self.save_dir = make_fullpath(save_dir, make=True)
|
|
205
|
+
if not self.save_dir.is_dir():
|
|
206
|
+
_LOGGER.error(f"{save_dir} is not a valid directory.")
|
|
207
|
+
raise IOError()
|
|
208
|
+
|
|
209
|
+
self.monitor = monitor
|
|
210
|
+
self.save_best_only = save_best_only
|
|
211
|
+
self.verbose = verbose
|
|
212
|
+
|
|
213
|
+
# State variables to be managed during training
|
|
214
|
+
self.saved_checkpoints = []
|
|
215
|
+
self.last_best_filepath = None
|
|
216
|
+
|
|
217
|
+
if mode not in ['auto', 'min', 'max']:
|
|
218
|
+
raise ValueError(f"ModelCheckpoint mode {mode} is unknown.")
|
|
219
|
+
self.mode = mode
|
|
220
|
+
|
|
221
|
+
if self.mode == 'min':
|
|
222
|
+
self.monitor_op = np.less
|
|
223
|
+
elif self.mode == 'max':
|
|
224
|
+
self.monitor_op = np.greater
|
|
225
|
+
else:
|
|
226
|
+
self.monitor_op = np.less if 'loss' in self.monitor else np.greater
|
|
227
|
+
|
|
228
|
+
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
|
229
|
+
|
|
230
|
+
def on_train_begin(self, logs=None):
|
|
231
|
+
"""Reset state when training starts."""
|
|
232
|
+
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
|
233
|
+
self.saved_checkpoints = []
|
|
234
|
+
self.last_best_filepath = None
|
|
235
|
+
|
|
236
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
237
|
+
logs = logs or {}
|
|
238
|
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
239
|
+
|
|
240
|
+
if self.save_best_only:
|
|
241
|
+
self._save_best_model(epoch, logs)
|
|
242
|
+
else:
|
|
243
|
+
self._save_rolling_checkpoints(epoch, logs)
|
|
244
|
+
|
|
245
|
+
def _save_best_model(self, epoch, logs):
|
|
246
|
+
"""Saves a single best model and deletes the previous one."""
|
|
247
|
+
current = logs.get(self.monitor)
|
|
248
|
+
if current is None:
|
|
249
|
+
return
|
|
250
|
+
|
|
251
|
+
if self.monitor_op(current, self.best):
|
|
252
|
+
old_best_str = f"{self.best:.4f}" if self.best not in [np.Inf, -np.Inf] else "inf"
|
|
253
|
+
|
|
254
|
+
# Create a descriptive filename
|
|
255
|
+
filename = f"epoch_{epoch}-{self.monitor}_{current:.4f}.pth"
|
|
256
|
+
new_filepath = self.save_dir / filename
|
|
257
|
+
|
|
258
|
+
if self.verbose > 0:
|
|
259
|
+
print("")
|
|
260
|
+
_LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current:.4f}, saving model to {new_filepath}")
|
|
261
|
+
|
|
262
|
+
# Save the new best model
|
|
263
|
+
torch.save(self.trainer.model.state_dict(), new_filepath) # type: ignore
|
|
264
|
+
|
|
265
|
+
# Delete the old best model file
|
|
266
|
+
if self.last_best_filepath and self.last_best_filepath.exists():
|
|
267
|
+
self.last_best_filepath.unlink()
|
|
268
|
+
|
|
269
|
+
# Update state
|
|
270
|
+
self.best = current
|
|
271
|
+
self.last_best_filepath = new_filepath
|
|
272
|
+
|
|
273
|
+
def _save_rolling_checkpoints(self, epoch, logs):
|
|
274
|
+
"""Saves the latest model and keeps only the last 5."""
|
|
275
|
+
filename = f"epoch_{epoch}.pth"
|
|
276
|
+
filepath = self.save_dir / filename
|
|
277
|
+
|
|
278
|
+
if self.verbose > 0:
|
|
279
|
+
print("")
|
|
280
|
+
_LOGGER.info(f'Epoch {epoch}: saving model to {filepath}')
|
|
281
|
+
torch.save(self.trainer.model.state_dict(), filepath) # type: ignore
|
|
282
|
+
|
|
283
|
+
self.saved_checkpoints.append(filepath)
|
|
284
|
+
|
|
285
|
+
# If we have more than n checkpoints, remove the oldest one
|
|
286
|
+
if len(self.saved_checkpoints) > 3:
|
|
287
|
+
file_to_delete = self.saved_checkpoints.pop(0)
|
|
288
|
+
if file_to_delete.exists():
|
|
289
|
+
if self.verbose > 0:
|
|
290
|
+
_LOGGER.info(f" -> Deleting old checkpoint: {file_to_delete.name}")
|
|
291
|
+
file_to_delete.unlink()
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class LRScheduler(Callback):
|
|
295
|
+
"""
|
|
296
|
+
Callback to manage a PyTorch learning rate scheduler.
|
|
297
|
+
|
|
298
|
+
This callback automatically calls the scheduler's `step()` method at the
|
|
299
|
+
end of each epoch. It also logs a message when the learning rate changes.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
scheduler: An initialized PyTorch learning rate scheduler.
|
|
303
|
+
monitor (str, optional): The metric to monitor for schedulers that
|
|
304
|
+
require it, like `ReduceLROnPlateau`.
|
|
305
|
+
Should match a key in the logs (e.g., 'val_loss').
|
|
306
|
+
"""
|
|
307
|
+
def __init__(self, scheduler, monitor: Optional[str] = None):
|
|
308
|
+
super().__init__()
|
|
309
|
+
self.scheduler = scheduler
|
|
310
|
+
self.monitor = monitor
|
|
311
|
+
self.previous_lr = None
|
|
312
|
+
|
|
313
|
+
def on_train_begin(self, logs=None):
|
|
314
|
+
"""Store the initial learning rate."""
|
|
315
|
+
self.previous_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
316
|
+
|
|
317
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
318
|
+
"""Step the scheduler and log any change in learning rate."""
|
|
319
|
+
# For schedulers that need a metric (e.g., val_loss)
|
|
320
|
+
if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
321
|
+
if self.monitor is None:
|
|
322
|
+
raise ValueError("LRScheduler needs a `monitor` metric for ReduceLROnPlateau.")
|
|
323
|
+
|
|
324
|
+
metric_val = logs.get(self.monitor) # type: ignore
|
|
325
|
+
if metric_val is not None:
|
|
326
|
+
self.scheduler.step(metric_val)
|
|
327
|
+
else:
|
|
328
|
+
print("")
|
|
329
|
+
_LOGGER.warning(f"LRScheduler could not find metric '{self.monitor}' in logs.")
|
|
330
|
+
|
|
331
|
+
# For all other schedulers
|
|
332
|
+
else:
|
|
333
|
+
self.scheduler.step()
|
|
334
|
+
|
|
335
|
+
# Log the change if the LR was updated
|
|
336
|
+
current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
337
|
+
if current_lr != self.previous_lr:
|
|
338
|
+
print("")
|
|
339
|
+
_LOGGER.info(f"Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
|
|
340
|
+
self.previous_lr = current_lr
|
|
341
|
+
|