dragon-ml-toolbox 10.8.0__py3-none-any.whl → 10.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-10.8.0.dist-info → dragon_ml_toolbox-10.9.0.dist-info}/METADATA +1 -1
- {dragon_ml_toolbox-10.8.0.dist-info → dragon_ml_toolbox-10.9.0.dist-info}/RECORD +11 -11
- ml_tools/ML_evaluation.py +6 -4
- ml_tools/SQL.py +4 -2
- ml_tools/ensemble_evaluation.py +48 -1
- ml_tools/keys.py +7 -0
- ml_tools/utilities.py +119 -20
- {dragon_ml_toolbox-10.8.0.dist-info → dragon_ml_toolbox-10.9.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-10.8.0.dist-info → dragon_ml_toolbox-10.9.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-10.8.0.dist-info → dragon_ml_toolbox-10.9.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-10.8.0.dist-info → dragon_ml_toolbox-10.9.0.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
dragon_ml_toolbox-10.
|
|
2
|
-
dragon_ml_toolbox-10.
|
|
1
|
+
dragon_ml_toolbox-10.9.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
|
|
2
|
+
dragon_ml_toolbox-10.9.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
|
|
3
3
|
ml_tools/ETL_cleaning.py,sha256=lSP5q6-ukGhJBPV8dlsqJvPXAzj4du_0J-SbtEd0Pjg,19292
|
|
4
4
|
ml_tools/ETL_engineering.py,sha256=a6KCWH6kRatZtjaFEF_o917ApPMK5_vRD-BjfCDAl-E,49400
|
|
5
5
|
ml_tools/GUI_tools.py,sha256=kEQWg-bog3pB5tI22gMGKWaCGHnz9TB2Lvvfhf5F2CI,45412
|
|
6
6
|
ml_tools/MICE_imputation.py,sha256=kVSythWfxJFR4-2mtcYCWQaQ1Oz5yyx_SJu5gjnS7H8,11670
|
|
7
7
|
ml_tools/ML_callbacks.py,sha256=JPvEw_cW5tYNJ2rMSgnNrKLuni_UrmuhDFaOw-u2SvA,13926
|
|
8
8
|
ml_tools/ML_datasetmaster.py,sha256=BMmdCVAZ-HSnnSPLzKla2TdZKvHkHj4t9A0V1Ba3i-I,30821
|
|
9
|
-
ml_tools/ML_evaluation.py,sha256=
|
|
9
|
+
ml_tools/ML_evaluation.py,sha256=q4_RsBjmidc_yDX-DQvpJW8RCHrOCJbgXKBORQdt-TM,16111
|
|
10
10
|
ml_tools/ML_evaluation_multi.py,sha256=2jTSNFCu8cz5C05EusnrDyffs59M2Fq3UXSHxo2TR1A,12515
|
|
11
11
|
ml_tools/ML_inference.py,sha256=SGDPiPxs_OYDKKRZziFMyaWcC8A37c70W9t-dMP5niI,23066
|
|
12
12
|
ml_tools/ML_models.py,sha256=FliuqGhxP7AWHCweTLlfssXFOjwvFhIYJsgj_w_-EI4,27901
|
|
@@ -15,22 +15,22 @@ ml_tools/ML_scaler.py,sha256=IrZsAr1xjvuLi8s5IKR-qbk2mS_awl3mn_xoXg5TJyA,7535
|
|
|
15
15
|
ml_tools/ML_trainer.py,sha256=xw1zMgYpdqwsTt604xe3GTQNvpg6z6Ze-avmitGBFeU,23539
|
|
16
16
|
ml_tools/PSO_optimization.py,sha256=q0VYpssQGbPum7xdnkDXlJQKhZMYZo8acHpKhajPK3c,22954
|
|
17
17
|
ml_tools/RNN_forecast.py,sha256=8rNZr-eWOBXMiDQV22e_tQTPM5LM2IFggEAa1FaoXaI,1965
|
|
18
|
-
ml_tools/SQL.py,sha256=
|
|
18
|
+
ml_tools/SQL.py,sha256=givoz6CGWRUdqnBem3VGZxzGdo3ZbX00kyHNjzI8kWE,10803
|
|
19
19
|
ml_tools/VIF_factor.py,sha256=MkMh_RIdsN2XUPzKNGRiEcmB17R_MmvGV4ezpL5zD2E,10403
|
|
20
20
|
ml_tools/__init__.py,sha256=q0y9faQ6e17XCQ7eUiCZ1FJ4Bg5EQqLjZ9f_l5REUUY,41
|
|
21
21
|
ml_tools/_logger.py,sha256=wcImAiXEZKPNcwM30qBh3t7HvoPURonJY0nrgMGF0sM,4719
|
|
22
22
|
ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
|
|
23
23
|
ml_tools/custom_logger.py,sha256=ry43hk54K6xKo8jRAgq1sFxUpOA9T0LIJ7sw0so2BW0,5880
|
|
24
24
|
ml_tools/data_exploration.py,sha256=4McT2BR9muK4JVVTKUfvRyThe0m_o2vpy9RJ1f_1FeY,28692
|
|
25
|
-
ml_tools/ensemble_evaluation.py,sha256=
|
|
25
|
+
ml_tools/ensemble_evaluation.py,sha256=FGHSe8LBI8_w8LjNeJWOcYQ1UK_mc6fVah8gmSvNVGg,26853
|
|
26
26
|
ml_tools/ensemble_inference.py,sha256=EFHnbjbu31fcVp88NBx8lWAVdu2Gpg9MY9huVZJHFfM,9350
|
|
27
27
|
ml_tools/ensemble_learning.py,sha256=3s0kH4i_naj0IVl_T4knst-Hwg4TScWjEdsXX5KAi7I,21929
|
|
28
28
|
ml_tools/handle_excel.py,sha256=He4UT15sCGhaG-JKfs7uYVAubxWjrqgJ6U7OhMR2fuE,14005
|
|
29
|
-
ml_tools/keys.py,sha256=
|
|
29
|
+
ml_tools/keys.py,sha256=FDpbS3Jb0pjrVvvp2_8nZi919mbob_-xwuy5OOtKM_A,1848
|
|
30
30
|
ml_tools/optimization_tools.py,sha256=P3I6lIpvZ8Xf2kX5FvvBKBmrK2pB6idBpkTzfUJxTeE,5073
|
|
31
31
|
ml_tools/path_manager.py,sha256=wLJlz3Y9_1-LB9em4B2VYDCVuTOX2eOc7D6hbbebjgM,14990
|
|
32
|
-
ml_tools/utilities.py,sha256=
|
|
33
|
-
dragon_ml_toolbox-10.
|
|
34
|
-
dragon_ml_toolbox-10.
|
|
35
|
-
dragon_ml_toolbox-10.
|
|
36
|
-
dragon_ml_toolbox-10.
|
|
32
|
+
ml_tools/utilities.py,sha256=30z0x1aDLyBGzF98_tgSaxwFafYwQS-GTFzXHopBSGc,29105
|
|
33
|
+
dragon_ml_toolbox-10.9.0.dist-info/METADATA,sha256=NK8z4StYOVR0ByF_l-vNjyrFgbb2qddBa6lOzlQsZrg,6968
|
|
34
|
+
dragon_ml_toolbox-10.9.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
35
|
+
dragon_ml_toolbox-10.9.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
36
|
+
dragon_ml_toolbox-10.9.0.dist-info/RECORD,,
|
ml_tools/ML_evaluation.py
CHANGED
|
@@ -22,6 +22,7 @@ from .path_manager import make_fullpath
|
|
|
22
22
|
from ._logger import _LOGGER
|
|
23
23
|
from typing import Union, Optional, List
|
|
24
24
|
from ._script_info import _script_info
|
|
25
|
+
from .keys import SHAPKeys
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
__all__ = [
|
|
@@ -333,7 +334,8 @@ def shap_summary_plot(model,
|
|
|
333
334
|
plt.close()
|
|
334
335
|
|
|
335
336
|
# Save Summary Data to CSV
|
|
336
|
-
|
|
337
|
+
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
338
|
+
summary_path = save_dir_path / shap_summary_filename
|
|
337
339
|
# Ensure the array is 1D before creating the DataFrame
|
|
338
340
|
mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
|
|
339
341
|
|
|
@@ -341,9 +343,9 @@ def shap_summary_plot(model,
|
|
|
341
343
|
feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
342
344
|
|
|
343
345
|
summary_df = pd.DataFrame({
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
}).sort_values(
|
|
346
|
+
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
347
|
+
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
348
|
+
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
347
349
|
|
|
348
350
|
summary_df.to_csv(summary_path, index=False)
|
|
349
351
|
|
ml_tools/SQL.py
CHANGED
|
@@ -4,7 +4,7 @@ from pathlib import Path
|
|
|
4
4
|
from typing import Union, Dict, Any, Optional, List, Literal
|
|
5
5
|
from ._logger import _LOGGER
|
|
6
6
|
from ._script_info import _script_info
|
|
7
|
-
from .path_manager import make_fullpath
|
|
7
|
+
from .path_manager import make_fullpath, sanitize_filename
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
@@ -94,11 +94,13 @@ class DatabaseManager:
|
|
|
94
94
|
if not self.cursor:
|
|
95
95
|
_LOGGER.error("Database connection is not open.")
|
|
96
96
|
raise sqlite3.Error()
|
|
97
|
+
|
|
98
|
+
sanitized_table_name = sanitize_filename(table_name)
|
|
97
99
|
|
|
98
100
|
columns_def = ", ".join([f'"{col_name}" {col_type}' for col_name, col_type in schema.items()])
|
|
99
101
|
exists_clause = "IF NOT EXISTS" if if_not_exists else ""
|
|
100
102
|
|
|
101
|
-
query = f"CREATE TABLE {exists_clause} {
|
|
103
|
+
query = f"CREATE TABLE {exists_clause} {sanitized_table_name} ({columns_def})"
|
|
102
104
|
|
|
103
105
|
_LOGGER.info(f"➡️ Executing: {query}")
|
|
104
106
|
self.cursor.execute(query)
|
ml_tools/ensemble_evaluation.py
CHANGED
|
@@ -25,6 +25,7 @@ from typing import Union, Optional, Literal
|
|
|
25
25
|
from .path_manager import sanitize_filename, make_fullpath
|
|
26
26
|
from ._script_info import _script_info
|
|
27
27
|
from ._logger import _LOGGER
|
|
28
|
+
from .keys import SHAPKeys
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
__all__ = [
|
|
@@ -472,7 +473,7 @@ def get_shap_values(
|
|
|
472
473
|
save_dir: Directory to save visualizations.
|
|
473
474
|
"""
|
|
474
475
|
sanitized_target_name = sanitize_filename(target_name)
|
|
475
|
-
global_save_path = make_fullpath(save_dir, make=True)
|
|
476
|
+
global_save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
476
477
|
|
|
477
478
|
def _apply_plot_style():
|
|
478
479
|
styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
|
|
@@ -539,6 +540,15 @@ def get_shap_values(
|
|
|
539
540
|
plot_type=plot_type,
|
|
540
541
|
title=f"{model_name} - {target_name} (Class {class_name})"
|
|
541
542
|
)
|
|
543
|
+
|
|
544
|
+
# Save the summary data for the current class
|
|
545
|
+
summary_save_path = global_save_path / f"SHAP_{sanitized_target_name}_{class_name}.csv"
|
|
546
|
+
_save_summary_csv(
|
|
547
|
+
shap_values_for_summary=class_shap,
|
|
548
|
+
feature_names=feature_names,
|
|
549
|
+
save_path=summary_save_path
|
|
550
|
+
)
|
|
551
|
+
|
|
542
552
|
else:
|
|
543
553
|
values = shap_values[1] if isinstance(shap_values, list) else shap_values
|
|
544
554
|
for plot_type in ["bar", "dot"]:
|
|
@@ -549,6 +559,15 @@ def get_shap_values(
|
|
|
549
559
|
plot_type=plot_type,
|
|
550
560
|
title=f"{model_name} - {target_name}"
|
|
551
561
|
)
|
|
562
|
+
|
|
563
|
+
# Save the summary data for the positive class
|
|
564
|
+
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
565
|
+
summary_save_path = global_save_path / shap_summary_filename
|
|
566
|
+
_save_summary_csv(
|
|
567
|
+
shap_values_for_summary=values,
|
|
568
|
+
feature_names=feature_names,
|
|
569
|
+
save_path=summary_save_path
|
|
570
|
+
)
|
|
552
571
|
|
|
553
572
|
def _plot_for_regression(shap_values):
|
|
554
573
|
for plot_type in ["bar", "dot"]:
|
|
@@ -559,6 +578,34 @@ def get_shap_values(
|
|
|
559
578
|
plot_type=plot_type,
|
|
560
579
|
title=f"{model_name} - {target_name}"
|
|
561
580
|
)
|
|
581
|
+
|
|
582
|
+
# Save the summary data to a CSV file
|
|
583
|
+
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
584
|
+
summary_save_path = global_save_path / shap_summary_filename
|
|
585
|
+
_save_summary_csv(
|
|
586
|
+
shap_values_for_summary=shap_values,
|
|
587
|
+
feature_names=feature_names,
|
|
588
|
+
save_path=summary_save_path
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
def _save_summary_csv(shap_values_for_summary: np.ndarray, feature_names: list[str], save_path: Path):
|
|
592
|
+
"""Calculates and saves the SHAP summary data to a CSV file."""
|
|
593
|
+
mean_abs_shap = np.abs(shap_values_for_summary).mean(axis=0)
|
|
594
|
+
|
|
595
|
+
# Create default feature names if none are provided
|
|
596
|
+
current_feature_names = feature_names
|
|
597
|
+
if current_feature_names is None:
|
|
598
|
+
current_feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
599
|
+
|
|
600
|
+
summary_df = pd.DataFrame({
|
|
601
|
+
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
602
|
+
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
603
|
+
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
604
|
+
|
|
605
|
+
summary_df.to_csv(save_path, index=False)
|
|
606
|
+
# print(f"📝 SHAP summary data saved as '{save_path.name}'")
|
|
607
|
+
|
|
608
|
+
|
|
562
609
|
#START_O
|
|
563
610
|
|
|
564
611
|
explainer = shap.TreeExplainer(model)
|
ml_tools/keys.py
CHANGED
|
@@ -61,6 +61,13 @@ class DatasetKeys:
|
|
|
61
61
|
SCALER_PREFIX = "scaler_"
|
|
62
62
|
|
|
63
63
|
|
|
64
|
+
class SHAPKeys:
|
|
65
|
+
"""Keys for SHAP functions"""
|
|
66
|
+
FEATURE_COLUMN = "feature"
|
|
67
|
+
SHAP_VALUE_COLUMN = "mean_abs_shap_value"
|
|
68
|
+
SAVENAME = "shap_summary"
|
|
69
|
+
|
|
70
|
+
|
|
64
71
|
class _OneHotOtherPlaceholder:
|
|
65
72
|
"""Used internally by GUI_tools."""
|
|
66
73
|
OTHER_GUI = "OTHER"
|
ml_tools/utilities.py
CHANGED
|
@@ -9,7 +9,7 @@ from joblib.externals.loky.process_executor import TerminatedWorkerError
|
|
|
9
9
|
from .path_manager import sanitize_filename, make_fullpath, list_csv_paths, list_files_by_extension, list_subdirectories
|
|
10
10
|
from ._script_info import _script_info
|
|
11
11
|
from ._logger import _LOGGER
|
|
12
|
-
from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys
|
|
12
|
+
from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys, SHAPKeys
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
# Keep track of available tools
|
|
@@ -26,7 +26,8 @@ __all__ = [
|
|
|
26
26
|
"distribute_dataset_by_target",
|
|
27
27
|
"train_dataset_orchestrator",
|
|
28
28
|
"train_dataset_yielder",
|
|
29
|
-
"find_model_artifacts"
|
|
29
|
+
"find_model_artifacts",
|
|
30
|
+
"select_features_by_shap"
|
|
30
31
|
]
|
|
31
32
|
|
|
32
33
|
|
|
@@ -34,6 +35,7 @@ __all__ = [
|
|
|
34
35
|
@overload
|
|
35
36
|
def load_dataframe(
|
|
36
37
|
df_path: Union[str, Path],
|
|
38
|
+
use_columns: Optional[list[str]] = None,
|
|
37
39
|
kind: Literal["pandas"] = "pandas",
|
|
38
40
|
all_strings: bool = False,
|
|
39
41
|
verbose: bool = True
|
|
@@ -44,7 +46,8 @@ def load_dataframe(
|
|
|
44
46
|
@overload
|
|
45
47
|
def load_dataframe(
|
|
46
48
|
df_path: Union[str, Path],
|
|
47
|
-
|
|
49
|
+
use_columns: Optional[list[str]] = None,
|
|
50
|
+
kind: Literal["polars"] = "polars",
|
|
48
51
|
all_strings: bool = False,
|
|
49
52
|
verbose: bool = True
|
|
50
53
|
) -> Tuple[pl.DataFrame, str]:
|
|
@@ -52,6 +55,7 @@ def load_dataframe(
|
|
|
52
55
|
|
|
53
56
|
def load_dataframe(
|
|
54
57
|
df_path: Union[str, Path],
|
|
58
|
+
use_columns: Optional[list[str]] = None,
|
|
55
59
|
kind: Literal["pandas", "polars"] = "pandas",
|
|
56
60
|
all_strings: bool = False,
|
|
57
61
|
verbose: bool = True
|
|
@@ -60,11 +64,13 @@ def load_dataframe(
|
|
|
60
64
|
Load a CSV file into a DataFrame and extract its base name.
|
|
61
65
|
|
|
62
66
|
Can load data as either a pandas or a polars DataFrame. Allows for loading all
|
|
63
|
-
columns as string types to prevent type inference errors.
|
|
67
|
+
columns or a subset of columns as string types to prevent type inference errors.
|
|
64
68
|
|
|
65
69
|
Args:
|
|
66
70
|
df_path (str, Path):
|
|
67
71
|
The path to the CSV file.
|
|
72
|
+
use_columns (list[str] | None):
|
|
73
|
+
If provided, only these columns will be loaded from the CSV.
|
|
68
74
|
kind ("pandas", "polars"):
|
|
69
75
|
The type of DataFrame to load. Defaults to "pandas".
|
|
70
76
|
all_strings (bool):
|
|
@@ -78,28 +84,43 @@ def load_dataframe(
|
|
|
78
84
|
|
|
79
85
|
Raises:
|
|
80
86
|
FileNotFoundError: If the file does not exist at the given path.
|
|
81
|
-
ValueError: If the DataFrame is empty
|
|
87
|
+
ValueError: If the DataFrame is empty, an invalid 'kind' is provided, or a column in 'use_columns' is not found in the file.
|
|
82
88
|
"""
|
|
83
89
|
path = make_fullpath(df_path)
|
|
84
90
|
|
|
85
91
|
df_name = path.stem
|
|
86
92
|
|
|
87
|
-
|
|
88
|
-
if
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
93
|
+
try:
|
|
94
|
+
if kind == "pandas":
|
|
95
|
+
pd_kwargs: dict[str,Any]
|
|
96
|
+
pd_kwargs = {'encoding': 'utf-8'}
|
|
97
|
+
if use_columns:
|
|
98
|
+
pd_kwargs['usecols'] = use_columns
|
|
99
|
+
if all_strings:
|
|
100
|
+
pd_kwargs['dtype'] = str
|
|
101
|
+
|
|
102
|
+
df = pd.read_csv(path, **pd_kwargs)
|
|
103
|
+
|
|
104
|
+
elif kind == "polars":
|
|
105
|
+
pl_kwargs: dict[str,Any]
|
|
106
|
+
pl_kwargs = {}
|
|
107
|
+
if use_columns:
|
|
108
|
+
pl_kwargs['columns'] = use_columns
|
|
109
|
+
|
|
110
|
+
if all_strings:
|
|
111
|
+
pl_kwargs['infer_schema'] = False
|
|
112
|
+
else:
|
|
113
|
+
pl_kwargs['infer_schema_length'] = 1000
|
|
114
|
+
|
|
115
|
+
df = pl.read_csv(path, **pl_kwargs)
|
|
116
|
+
|
|
96
117
|
else:
|
|
97
|
-
|
|
98
|
-
|
|
118
|
+
_LOGGER.error(f"Invalid kind '{kind}'. Must be one of 'pandas' or 'polars'.")
|
|
119
|
+
raise ValueError()
|
|
99
120
|
|
|
100
|
-
|
|
101
|
-
_LOGGER.error(f"
|
|
102
|
-
raise
|
|
121
|
+
except (ValueError, pl.exceptions.ColumnNotFoundError) as e:
|
|
122
|
+
_LOGGER.error(f"Failed to load '{df_name}'. A specified column may not exist in the file.")
|
|
123
|
+
raise e
|
|
103
124
|
|
|
104
125
|
# This check works for both pandas and polars DataFrames
|
|
105
126
|
if df.shape[0] == 0:
|
|
@@ -111,7 +132,6 @@ def load_dataframe(
|
|
|
111
132
|
|
|
112
133
|
return df, df_name # type: ignore
|
|
113
134
|
|
|
114
|
-
|
|
115
135
|
def yield_dataframes_from_dir(datasets_dir: Union[str,Path], verbose: bool=True):
|
|
116
136
|
"""
|
|
117
137
|
Iterates over all CSV files in a given directory, loading each into a Pandas DataFrame.
|
|
@@ -683,5 +703,84 @@ def find_model_artifacts(target_directory: Union[str,Path], load_scaler: bool, v
|
|
|
683
703
|
return all_artifacts
|
|
684
704
|
|
|
685
705
|
|
|
706
|
+
def select_features_by_shap(
|
|
707
|
+
root_directory: Union[str, Path],
|
|
708
|
+
shap_threshold: float = 1.0,
|
|
709
|
+
verbose: bool = True) -> list[str]:
|
|
710
|
+
"""
|
|
711
|
+
Scans subdirectories to find SHAP summary CSVs, then extracts feature
|
|
712
|
+
names whose mean absolute SHAP value meets a specified threshold.
|
|
713
|
+
|
|
714
|
+
This function is useful for automated feature selection based on feature
|
|
715
|
+
importance scores aggregated from multiple models.
|
|
716
|
+
|
|
717
|
+
Args:
|
|
718
|
+
root_directory (Union[str, Path]):
|
|
719
|
+
The path to the root directory that contains model subdirectories.
|
|
720
|
+
shap_threshold (float):
|
|
721
|
+
The minimum mean absolute SHAP value for a feature to be included
|
|
722
|
+
in the final list.
|
|
723
|
+
|
|
724
|
+
Returns:
|
|
725
|
+
list[str]:
|
|
726
|
+
A single, sorted list of unique feature names that meet the
|
|
727
|
+
threshold criteria across all found files.
|
|
728
|
+
"""
|
|
729
|
+
if verbose:
|
|
730
|
+
_LOGGER.info(f"Starting feature selection with SHAP threshold >= {shap_threshold}")
|
|
731
|
+
root_path = make_fullpath(root_directory, enforce="directory")
|
|
732
|
+
|
|
733
|
+
# --- Step 2: Directory and File Discovery ---
|
|
734
|
+
subdirectories = list_subdirectories(root_dir=root_path, verbose=False)
|
|
735
|
+
|
|
736
|
+
shap_filename = SHAPKeys.SAVENAME + ".csv"
|
|
737
|
+
|
|
738
|
+
valid_csv_paths = []
|
|
739
|
+
for dir_name, dir_path in subdirectories.items():
|
|
740
|
+
expected_path = dir_path / shap_filename
|
|
741
|
+
if expected_path.is_file():
|
|
742
|
+
valid_csv_paths.append(expected_path)
|
|
743
|
+
else:
|
|
744
|
+
_LOGGER.warning(f"No '{shap_filename}' found in subdirectory '{dir_name}'.")
|
|
745
|
+
|
|
746
|
+
if not valid_csv_paths:
|
|
747
|
+
_LOGGER.error(f"Process halted: No '{shap_filename}' files were found in any subdirectory.")
|
|
748
|
+
return []
|
|
749
|
+
|
|
750
|
+
if verbose:
|
|
751
|
+
_LOGGER.info(f"Found {len(valid_csv_paths)} SHAP summary files to process.")
|
|
752
|
+
|
|
753
|
+
# --- Step 3: Data Processing and Feature Extraction ---
|
|
754
|
+
master_feature_set = set()
|
|
755
|
+
for csv_path in valid_csv_paths:
|
|
756
|
+
try:
|
|
757
|
+
df, _ = load_dataframe(csv_path, kind="pandas", verbose=False)
|
|
758
|
+
|
|
759
|
+
# Validate required columns
|
|
760
|
+
required_cols = {SHAPKeys.FEATURE_COLUMN, SHAPKeys.SHAP_VALUE_COLUMN}
|
|
761
|
+
if not required_cols.issubset(df.columns):
|
|
762
|
+
_LOGGER.warning(f"Skipping '{csv_path}': missing required columns.")
|
|
763
|
+
continue
|
|
764
|
+
|
|
765
|
+
# Filter by threshold and extract features
|
|
766
|
+
filtered_df = df[df[SHAPKeys.SHAP_VALUE_COLUMN] >= shap_threshold]
|
|
767
|
+
features = filtered_df[SHAPKeys.FEATURE_COLUMN].tolist()
|
|
768
|
+
master_feature_set.update(features)
|
|
769
|
+
|
|
770
|
+
except (ValueError, pd.errors.EmptyDataError):
|
|
771
|
+
_LOGGER.warning(f"Skipping '{csv_path}' because it is empty or malformed.")
|
|
772
|
+
continue
|
|
773
|
+
except Exception as e:
|
|
774
|
+
_LOGGER.error(f"An unexpected error occurred while processing '{csv_path}': {e}")
|
|
775
|
+
continue
|
|
776
|
+
|
|
777
|
+
# --- Step 4: Finalize and Return ---
|
|
778
|
+
final_features = sorted(list(master_feature_set))
|
|
779
|
+
if verbose:
|
|
780
|
+
_LOGGER.info(f"Selected {len(final_features)} unique features across all files.")
|
|
781
|
+
|
|
782
|
+
return final_features
|
|
783
|
+
|
|
784
|
+
|
|
686
785
|
def info():
|
|
687
786
|
_script_info(__all__)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|