dragon-ml-toolbox 1.3.2__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.0.dist-info}/METADATA +2 -1
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.0.dist-info}/RECORD +11 -10
- ml_tools/MICE_imputation.py +8 -5
- ml_tools/VIF_factor.py +209 -0
- ml_tools/data_exploration.py +5 -145
- ml_tools/ensemble_learning.py +111 -120
- ml_tools/utilities.py +39 -13
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dragon-ml-toolbox
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.0
|
|
4
4
|
Summary: A collection of tools for data science and machine learning projects
|
|
5
5
|
Author-email: Karl Loza <luigiloza@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -27,6 +27,7 @@ Requires-Dist: ipython
|
|
|
27
27
|
Requires-Dist: ipykernel
|
|
28
28
|
Requires-Dist: notebook
|
|
29
29
|
Requires-Dist: jupyterlab
|
|
30
|
+
Requires-Dist: ipywidgets
|
|
30
31
|
Requires-Dist: joblib
|
|
31
32
|
Requires-Dist: xgboost
|
|
32
33
|
Requires-Dist: lightgbm<=4.5.0
|
|
@@ -1,18 +1,19 @@
|
|
|
1
|
-
dragon_ml_toolbox-1.
|
|
2
|
-
dragon_ml_toolbox-1.
|
|
3
|
-
ml_tools/MICE_imputation.py,sha256=
|
|
1
|
+
dragon_ml_toolbox-1.4.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
|
|
2
|
+
dragon_ml_toolbox-1.4.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=e1Hg5ZtaBpDV7ZvxhLe1ac28l7nMjvi1MSE5YvB1s-o,1472
|
|
3
|
+
ml_tools/MICE_imputation.py,sha256=4kqZiesk8vyh4MBLnNE9grflG4fDusqzuYBElsbk4LY,9484
|
|
4
|
+
ml_tools/VIF_factor.py,sha256=rHSAxQcXLrG8dIjCXBAvETsSkCBfYus9NqimOnm2Bvk,9559
|
|
4
5
|
ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
-
ml_tools/data_exploration.py,sha256=
|
|
6
|
+
ml_tools/data_exploration.py,sha256=qtkGumckC2PmTpj3brVFi072ewX0OI6dwUF4Or7Yikg,21341
|
|
6
7
|
ml_tools/datasetmaster.py,sha256=VUneKshnmjOGbtqVVGTFcIMRKF3s6ZDYrosIYKDjD80,28956
|
|
7
|
-
ml_tools/ensemble_learning.py,sha256=
|
|
8
|
+
ml_tools/ensemble_learning.py,sha256=wK6mtOE4v9AWlxkcWhJj5XZjREChxb46kE0i2IxS-OE,28372
|
|
8
9
|
ml_tools/handle_excel.py,sha256=IR0VQc3hYdmjwC31E5YxDnRcWig4jSIx7Y_7to-KZz4,11969
|
|
9
10
|
ml_tools/logger.py,sha256=XwSpCUzw2Le24fJHyljBxNLgw63SwjZ0pMjTJqf0ylI,4622
|
|
10
11
|
ml_tools/particle_swarm_optimization.py,sha256=jpkje4OETC9fyISxxUTx4XGrImSU6gDEcwz46ZDs2bQ,19250
|
|
11
12
|
ml_tools/pytorch_models.py,sha256=Oykw02sOZLCjvSadQd64UGesBN7kq0x1EGXHusvYiQI,9908
|
|
12
13
|
ml_tools/trainer.py,sha256=Zd7AaHeoNd8dEas2JChWoHaCUpWUVRDUMybuHaKJ0XY,16740
|
|
13
|
-
ml_tools/utilities.py,sha256=
|
|
14
|
+
ml_tools/utilities.py,sha256=gr1cyRUfZcRo9fjWpCaQkrvWY0-xJnDJdrE8JEsOi8o,6309
|
|
14
15
|
ml_tools/vision_helpers.py,sha256=lBAW6dzAK-HOswAt1fU_tfP9hkNLY5D8c_I_7hhEXno,7528
|
|
15
|
-
dragon_ml_toolbox-1.
|
|
16
|
-
dragon_ml_toolbox-1.
|
|
17
|
-
dragon_ml_toolbox-1.
|
|
18
|
-
dragon_ml_toolbox-1.
|
|
16
|
+
dragon_ml_toolbox-1.4.0.dist-info/METADATA,sha256=V7Y96iAbgX6Xl6RWzEt4nGfKMZe4cuLs0BrFQghXxX8,2335
|
|
17
|
+
dragon_ml_toolbox-1.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
18
|
+
dragon_ml_toolbox-1.4.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
19
|
+
dragon_ml_toolbox-1.4.0.dist-info/RECORD,,
|
ml_tools/MICE_imputation.py
CHANGED
|
@@ -120,7 +120,7 @@ def get_imputed_distributions(kernel: mf.ImputationKernel, df_name: str, root_di
|
|
|
120
120
|
'''
|
|
121
121
|
# Check path
|
|
122
122
|
os.makedirs(root_dir, exist_ok=True)
|
|
123
|
-
local_save_dir = os.path.join(root_dir, f"Distribution_Metrics_{df_name}")
|
|
123
|
+
local_save_dir = os.path.join(root_dir, f"Distribution_Metrics_{df_name}_imputed")
|
|
124
124
|
if not os.path.isdir(local_save_dir):
|
|
125
125
|
os.makedirs(local_save_dir)
|
|
126
126
|
|
|
@@ -169,8 +169,12 @@ def get_imputed_distributions(kernel: mf.ImputationKernel, df_name: str, root_di
|
|
|
169
169
|
# Adjust layout and save
|
|
170
170
|
# fig.tight_layout()
|
|
171
171
|
# fig.subplots_adjust(bottom=0.2, left=0.2) # Optional, depending on overflow
|
|
172
|
+
|
|
173
|
+
# sanitize savename
|
|
174
|
+
feature_save_name = sanitize_filename(filename)
|
|
175
|
+
|
|
172
176
|
fig.savefig(
|
|
173
|
-
os.path.join(local_save_dir,
|
|
177
|
+
os.path.join(local_save_dir, feature_save_name + ".svg"),
|
|
174
178
|
format='svg',
|
|
175
179
|
bbox_inches='tight',
|
|
176
180
|
pad_inches=0.1
|
|
@@ -185,8 +189,7 @@ def get_imputed_distributions(kernel: mf.ImputationKernel, df_name: str, root_di
|
|
|
185
189
|
else:
|
|
186
190
|
for feature in column_names:
|
|
187
191
|
fig = kernel.plot_imputed_distributions(variables=[feature])
|
|
188
|
-
|
|
189
|
-
_process_figure(fig, feature_save_name)
|
|
192
|
+
_process_figure(fig, feature)
|
|
190
193
|
|
|
191
194
|
print("\tImputed distributions saved successfully.")
|
|
192
195
|
|
|
@@ -207,7 +210,7 @@ def run_mice_pipeline(df_path_or_dir: str, save_datasets_dir: str, save_metrics_
|
|
|
207
210
|
if os.path.isfile(df_path_or_dir):
|
|
208
211
|
all_file_paths = [df_path_or_dir]
|
|
209
212
|
elif os.path.isdir(df_path_or_dir):
|
|
210
|
-
all_file_paths
|
|
213
|
+
all_file_paths = list_csv_paths(df_path_or_dir).values()
|
|
211
214
|
else:
|
|
212
215
|
raise ValueError(f"Invalid path or directory: {df_path_or_dir}")
|
|
213
216
|
|
ml_tools/VIF_factor.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
from typing import Optional
|
|
6
|
+
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
|
7
|
+
from statsmodels.tools.tools import add_constant
|
|
8
|
+
import warnings
|
|
9
|
+
import os
|
|
10
|
+
from .utilities import sanitize_filename, yield_dataframes_from_dir, save_dataframe
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def compute_vif(
|
|
14
|
+
df: pd.DataFrame,
|
|
15
|
+
target_columns: Optional[list[str]] = None,
|
|
16
|
+
ignore_columns: Optional[list[str]] = None,
|
|
17
|
+
max_features_to_plot: int = 20,
|
|
18
|
+
save_dir: Optional[str] = None,
|
|
19
|
+
filename: Optional[str] = None,
|
|
20
|
+
fontsize: int = 14,
|
|
21
|
+
show_plot: bool = True,
|
|
22
|
+
) -> pd.DataFrame:
|
|
23
|
+
"""
|
|
24
|
+
Computes Variance Inflation Factors (VIF) for numeric columns in a DataFrame. Optionally, generates a bar plot of VIF values.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
df (pd.DataFrame): The input DataFrame.
|
|
28
|
+
target_columns (list[str] | None): Optional list of columns to include. Defaults to all numeric columns.
|
|
29
|
+
ignore_columns (list[str] | None): Optional list of columns to exclude from the VIF computation. Skipped if `target_columns` is provided.
|
|
30
|
+
max_features_to_plot (int): Adjust the number of features shown in the plot.
|
|
31
|
+
save_dir (str | None): Directory to save the plot as SVG. If None, the plot is not saved.
|
|
32
|
+
filename (str | None): Optional filename for saving the plot. Defaults to "VIF_plot.svg".
|
|
33
|
+
fontsize (int): Base fontsize to scale title and labels on the plot.
|
|
34
|
+
show_plot (bool): Display plot.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
pd.DataFrame: DataFrame with features and their corresponding VIF values.
|
|
38
|
+
|
|
39
|
+
NOTE:
|
|
40
|
+
**Variance Inflation Factor (VIF)** quantifies the degree of multicollinearity among features in a dataset.
|
|
41
|
+
A VIF value indicates how much the variance of a regression coefficient is inflated due to linear dependence with other features.
|
|
42
|
+
A VIF of 1 suggests no correlation, values between 1 and 5 indicate moderate correlation, and values greater than 10 typically signal high multicollinearity, which may distort model interpretation and degrade performance.
|
|
43
|
+
"""
|
|
44
|
+
ground_truth_cols = df.columns.to_list()
|
|
45
|
+
if target_columns is None:
|
|
46
|
+
sanitized_columns = df.select_dtypes(include='number').columns.tolist()
|
|
47
|
+
missing_features = set(ground_truth_cols) - set(sanitized_columns)
|
|
48
|
+
if missing_features:
|
|
49
|
+
print(f"⚠️ These columns are not Numeric:\n{missing_features}")
|
|
50
|
+
else:
|
|
51
|
+
sanitized_columns = list()
|
|
52
|
+
for feature in target_columns:
|
|
53
|
+
if feature not in ground_truth_cols:
|
|
54
|
+
print(f"⚠️ The provided column '{feature}' is not in the DataFrame.")
|
|
55
|
+
else:
|
|
56
|
+
sanitized_columns.append(feature)
|
|
57
|
+
|
|
58
|
+
if ignore_columns is not None and target_columns is None:
|
|
59
|
+
missing_ignore = set(ignore_columns) - set(ground_truth_cols)
|
|
60
|
+
if missing_ignore:
|
|
61
|
+
print(f"⚠️ Warning: The following 'columns to ignore' are not in the Dataframe:\n{missing_ignore}")
|
|
62
|
+
sanitized_columns = [f for f in sanitized_columns if f not in ignore_columns]
|
|
63
|
+
|
|
64
|
+
X = df[sanitized_columns].copy()
|
|
65
|
+
X = add_constant(X, has_constant='add')
|
|
66
|
+
|
|
67
|
+
vif_data = pd.DataFrame()
|
|
68
|
+
vif_data["feature"] = X.columns # type: ignore
|
|
69
|
+
|
|
70
|
+
with warnings.catch_warnings():
|
|
71
|
+
warnings.simplefilter("ignore", category=RuntimeWarning)
|
|
72
|
+
|
|
73
|
+
vif_data["VIF"] = [
|
|
74
|
+
variance_inflation_factor(X.values, i) for i in range(X.shape[1]) # type: ignore
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
# Replace infinite values (perfect multicollinearity)
|
|
78
|
+
vif_data["VIF"] = vif_data["VIF"].replace([np.inf, -np.inf], 999.0)
|
|
79
|
+
|
|
80
|
+
# Drop the constant column
|
|
81
|
+
vif_data = vif_data[vif_data["feature"] != "const"]
|
|
82
|
+
|
|
83
|
+
# Add color coding
|
|
84
|
+
def vif_color(v: float) -> str:
|
|
85
|
+
if v >= 10:
|
|
86
|
+
return "red"
|
|
87
|
+
elif v >= 5:
|
|
88
|
+
return "gold"
|
|
89
|
+
else:
|
|
90
|
+
return "green"
|
|
91
|
+
|
|
92
|
+
vif_data["color"] = vif_data["VIF"].apply(vif_color)
|
|
93
|
+
|
|
94
|
+
# Sort by VIF descending
|
|
95
|
+
vif_data = vif_data.sort_values(by="VIF", ascending=False).reset_index(drop=True)
|
|
96
|
+
|
|
97
|
+
# Filter for plotting
|
|
98
|
+
plot_data = vif_data.head(max_features_to_plot)
|
|
99
|
+
|
|
100
|
+
if save_dir or show_plot:
|
|
101
|
+
if not plot_data.empty:
|
|
102
|
+
plt.figure(figsize=(10, 6))
|
|
103
|
+
plt.barh(
|
|
104
|
+
plot_data["feature"],
|
|
105
|
+
plot_data["VIF"],
|
|
106
|
+
color=plot_data["color"],
|
|
107
|
+
edgecolor='black'
|
|
108
|
+
)
|
|
109
|
+
plt.title("Variance Inflation Factor (VIF) per Feature", fontsize=fontsize+1)
|
|
110
|
+
plt.xlabel("VIF value", fontsize=fontsize)
|
|
111
|
+
plt.xticks(fontsize=fontsize)
|
|
112
|
+
plt.yticks(fontsize=fontsize)
|
|
113
|
+
plt.axvline(x=5, color='gold', linestyle='--', label='VIF = 5')
|
|
114
|
+
plt.axvline(x=10, color='red', linestyle='--', label='VIF = 10')
|
|
115
|
+
plt.xlim(0, 12)
|
|
116
|
+
plt.legend(loc='lower right', fontsize=fontsize-1)
|
|
117
|
+
plt.gca().invert_yaxis()
|
|
118
|
+
plt.grid(axis='x', linestyle='--', alpha=0.5)
|
|
119
|
+
plt.tight_layout()
|
|
120
|
+
|
|
121
|
+
if save_dir:
|
|
122
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
123
|
+
if filename is None:
|
|
124
|
+
filename = "VIF_plot.svg"
|
|
125
|
+
else:
|
|
126
|
+
filename = sanitize_filename(filename)
|
|
127
|
+
if not filename.endswith(".svg"):
|
|
128
|
+
filename += ".svg"
|
|
129
|
+
save_path = os.path.join(save_dir, "VIF_" + filename)
|
|
130
|
+
plt.savefig(save_path, format='svg', bbox_inches='tight')
|
|
131
|
+
print(f"\tSaved VIF plot: '{filename}'")
|
|
132
|
+
|
|
133
|
+
if show_plot:
|
|
134
|
+
plt.show()
|
|
135
|
+
plt.close()
|
|
136
|
+
|
|
137
|
+
return vif_data.drop(columns="color")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def drop_vif_based(df: pd.DataFrame, vif_df: pd.DataFrame, threshold: float = 10.0) -> pd.DataFrame:
|
|
141
|
+
"""
|
|
142
|
+
Drops columns from the original DataFrame based on their VIF values exceeding a given threshold.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
df (pd.DataFrame): Original DataFrame containing the columns to test.
|
|
146
|
+
vif_df (pd.DataFrame): DataFrame with 'feature' and 'VIF' columns as returned by `compute_vif()`.
|
|
147
|
+
threshold (float): VIF threshold above which columns will be dropped.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
pd.DataFrame: A new DataFrame with high-VIF columns removed.
|
|
151
|
+
"""
|
|
152
|
+
# Ensure expected structure
|
|
153
|
+
if 'feature' not in vif_df.columns or 'VIF' not in vif_df.columns:
|
|
154
|
+
raise ValueError("`vif_df` must contain 'feature' and 'VIF' columns.")
|
|
155
|
+
|
|
156
|
+
# Identify features to drop
|
|
157
|
+
to_drop = vif_df[vif_df["VIF"] > threshold]["feature"].tolist()
|
|
158
|
+
print(f"\tDropping {len(to_drop)} column(s) with VIF > {threshold}: {to_drop}")
|
|
159
|
+
|
|
160
|
+
result_df = df.drop(columns=to_drop)
|
|
161
|
+
|
|
162
|
+
if result_df.empty:
|
|
163
|
+
print(f"\t⚠️ Warning: All columns were dropped.")
|
|
164
|
+
|
|
165
|
+
return result_df
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def compute_vif_multi(input_directory: str,
|
|
169
|
+
output_plot_directory: str,
|
|
170
|
+
output_dataset_directory: Optional[str] = None,
|
|
171
|
+
target_columns: Optional[list[str]] = None,
|
|
172
|
+
ignore_columns: Optional[list[str]] = None,
|
|
173
|
+
max_features_to_plot: int = 20,
|
|
174
|
+
fontsize: int = 14):
|
|
175
|
+
"""
|
|
176
|
+
Computes Variance Inflation Factors (VIF) for numeric columns in a directory with CSV files (loaded as pandas DataFrames).
|
|
177
|
+
Generates a bar plot of VIF values. Optionally drops columns with VIF >= 10 and saves as a new CSV file.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
input_directory (str): Target directory with CSV files able to be loaded as DataFrame.
|
|
181
|
+
output_plot_directory (str): Save plots to this directory.
|
|
182
|
+
output_dataset_directory (str | None): If provided, saves new CSV files to this directory.
|
|
183
|
+
target_columns (list[str] | None): Optional list of columns to include. Defaults to all numeric columns.
|
|
184
|
+
ignore_columns (list[str] | None): Optional list of columns to exclude from the VIF computation. Skipped if `target_columns` is provided.
|
|
185
|
+
max_features_to_plot (int): Adjust the number of features shown in the plot.
|
|
186
|
+
fontsize (int): Base fontsize to scale title and labels on hte plot.
|
|
187
|
+
|
|
188
|
+
NOTE:
|
|
189
|
+
**Variance Inflation Factor (VIF)** quantifies the degree of multicollinearity among features in a dataset.
|
|
190
|
+
A VIF value indicates how much the variance of a regression coefficient is inflated due to linear dependence with other features.
|
|
191
|
+
A VIF of 1 suggests no correlation, values between 1 and 5 indicate moderate correlation, and values greater than 10 typically signal high multicollinearity, which may distort model interpretation and degrade performance.
|
|
192
|
+
"""
|
|
193
|
+
if output_dataset_directory is not None:
|
|
194
|
+
os.makedirs(output_dataset_directory, exist_ok=True)
|
|
195
|
+
|
|
196
|
+
for df, df_name in yield_dataframes_from_dir(datasets_dir=input_directory):
|
|
197
|
+
vif_dataframe = compute_vif(df=df,
|
|
198
|
+
target_columns=target_columns,
|
|
199
|
+
ignore_columns=ignore_columns,
|
|
200
|
+
max_features_to_plot=max_features_to_plot,
|
|
201
|
+
fontsize=fontsize,
|
|
202
|
+
save_dir=output_plot_directory,
|
|
203
|
+
filename=df_name,
|
|
204
|
+
show_plot=False)
|
|
205
|
+
|
|
206
|
+
if output_dataset_directory is not None:
|
|
207
|
+
new_filename = 'VIF_' + df_name
|
|
208
|
+
result_df = drop_vif_based(df=df, vif_df=vif_dataframe)
|
|
209
|
+
save_dataframe(df=result_df, save_dir=output_dataset_directory, filename=new_filename)
|
ml_tools/data_exploration.py
CHANGED
|
@@ -2,12 +2,10 @@ import pandas as pd
|
|
|
2
2
|
import numpy as np
|
|
3
3
|
import matplotlib.pyplot as plt
|
|
4
4
|
import seaborn as sns
|
|
5
|
-
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
|
6
|
-
from statsmodels.tools.tools import add_constant
|
|
7
5
|
from IPython import get_ipython
|
|
8
6
|
from IPython.display import clear_output
|
|
9
7
|
import time
|
|
10
|
-
from typing import Union, Literal, Dict, Tuple
|
|
8
|
+
from typing import Union, Literal, Dict, Tuple
|
|
11
9
|
import os
|
|
12
10
|
import sys
|
|
13
11
|
import textwrap
|
|
@@ -26,10 +24,7 @@ __all__ = ["summarize_dataframe",
|
|
|
26
24
|
"plot_value_distributions",
|
|
27
25
|
"clip_outliers_single",
|
|
28
26
|
"clip_outliers_multi",
|
|
29
|
-
"merge_dataframes"
|
|
30
|
-
"save_dataframe",
|
|
31
|
-
"compute_vif",
|
|
32
|
-
"drop_vif_based"]
|
|
27
|
+
"merge_dataframes"]
|
|
33
28
|
|
|
34
29
|
|
|
35
30
|
def summarize_dataframe(df: pd.DataFrame, round_digits: int = 2):
|
|
@@ -278,7 +273,7 @@ def check_value_distributions(df: pd.DataFrame, view_frequencies: bool=True, bin
|
|
|
278
273
|
Notes:
|
|
279
274
|
- Binning is adaptive: if quantile binning results in ≤ 2 unique bins, raw values are used instead.
|
|
280
275
|
"""
|
|
281
|
-
#
|
|
276
|
+
# cherry-pick columns
|
|
282
277
|
if skip_cols_with_key is not None:
|
|
283
278
|
columns = [col for col in df.columns if skip_cols_with_key not in col]
|
|
284
279
|
else:
|
|
@@ -351,7 +346,7 @@ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int
|
|
|
351
346
|
dict_to_plot_std = dict()
|
|
352
347
|
dict_to_plot_freq = dict()
|
|
353
348
|
|
|
354
|
-
#
|
|
349
|
+
# cherry-pick columns
|
|
355
350
|
if skip_cols_with_key is not None:
|
|
356
351
|
columns = [col for col in df.columns if skip_cols_with_key not in col]
|
|
357
352
|
else:
|
|
@@ -399,7 +394,7 @@ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int
|
|
|
399
394
|
labels = data.keys()
|
|
400
395
|
|
|
401
396
|
plt.figure(figsize=(10, 6))
|
|
402
|
-
colors = plt.cm.tab20.colors if len(data) <= 20 else plt.cm.viridis(np.linspace(0, 1, len(data)))
|
|
397
|
+
colors = plt.cm.tab20.colors if len(data) <= 20 else plt.cm.viridis(np.linspace(0, 1, len(data))) # type: ignore
|
|
403
398
|
|
|
404
399
|
plt.bar(labels, data.values(), color=colors[:len(data)], alpha=0.85)
|
|
405
400
|
plt.xlabel("Values", fontsize=base_fontsize)
|
|
@@ -574,141 +569,6 @@ def merge_dataframes(
|
|
|
574
569
|
return merged_df
|
|
575
570
|
|
|
576
571
|
|
|
577
|
-
def save_dataframe(df: pd.DataFrame, save_dir: str, filename: str) -> None:
|
|
578
|
-
"""
|
|
579
|
-
Save a pandas DataFrame to a CSV file.
|
|
580
|
-
|
|
581
|
-
Parameters:
|
|
582
|
-
df: pandas.DataFrame to save
|
|
583
|
-
save_dir: str, directory where the CSV file will be saved.
|
|
584
|
-
filename: str, CSV filename, extension will be added if missing.
|
|
585
|
-
"""
|
|
586
|
-
os.makedirs(save_dir, exist_ok=True)
|
|
587
|
-
|
|
588
|
-
filename = sanitize_filename(filename)
|
|
589
|
-
|
|
590
|
-
if not filename.endswith('.csv'):
|
|
591
|
-
filename += '.csv'
|
|
592
|
-
|
|
593
|
-
output_path = os.path.join(save_dir, filename)
|
|
594
|
-
|
|
595
|
-
df.to_csv(output_path, index=False, encoding='utf-8')
|
|
596
|
-
print(f"Saved file: '{filename}'")
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
def compute_vif(
|
|
600
|
-
df: pd.DataFrame,
|
|
601
|
-
features: Optional[list[str]] = None,
|
|
602
|
-
ignore_cols: Optional[list[str]] = None,
|
|
603
|
-
plot: bool = True,
|
|
604
|
-
save_dir: Union[str, None] = None
|
|
605
|
-
) -> pd.DataFrame:
|
|
606
|
-
"""
|
|
607
|
-
Computes Variance Inflation Factors (VIF) for numeric features, optionally plots and saves the results.
|
|
608
|
-
|
|
609
|
-
There cannot be empty values in the dataset.
|
|
610
|
-
|
|
611
|
-
Args:
|
|
612
|
-
df (pd.DataFrame): The input DataFrame.
|
|
613
|
-
features (list[str] | None): Optional list of column names to evaluate. Defaults to all numeric columns.
|
|
614
|
-
ignore_cols (list[str] | None): Optional list of column names to ignore.
|
|
615
|
-
plot (bool): Whether to display a barplot of VIF values.
|
|
616
|
-
save_dir (str | None): Directory to save the plot as SVG. If None, plot is not saved.
|
|
617
|
-
|
|
618
|
-
Returns:
|
|
619
|
-
pd.DataFrame: DataFrame with features and corresponding VIF values, sorted descending.
|
|
620
|
-
|
|
621
|
-
NOTE:
|
|
622
|
-
**Variance Inflation Factor (VIF)** quantifies the degree of multicollinearity among features in a dataset.
|
|
623
|
-
A VIF value indicates how much the variance of a regression coefficient is inflated due to linear dependence with other features.
|
|
624
|
-
A VIF of 1 suggests no correlation, values between 1 and 5 indicate moderate correlation, and values greater than 10 typically signal high multicollinearity, which may distort model interpretation and degrade performance.
|
|
625
|
-
|
|
626
|
-
"""
|
|
627
|
-
if features is None:
|
|
628
|
-
features = df.select_dtypes(include='number').columns.tolist()
|
|
629
|
-
|
|
630
|
-
if ignore_cols is not None:
|
|
631
|
-
missing = set(ignore_cols) - set(features)
|
|
632
|
-
if missing:
|
|
633
|
-
raise ValueError(f"The following 'columns to ignore' are not in the Dataframe:\n{missing}")
|
|
634
|
-
features = [f for f in features if f not in ignore_cols]
|
|
635
|
-
|
|
636
|
-
X = df[features].copy()
|
|
637
|
-
X = add_constant(X, has_constant='add')
|
|
638
|
-
|
|
639
|
-
vif_data = pd.DataFrame()
|
|
640
|
-
vif_data["feature"] = X.columns
|
|
641
|
-
vif_data["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
|
|
642
|
-
|
|
643
|
-
# Drop the constant column
|
|
644
|
-
vif_data = vif_data[vif_data["feature"] != "const"]
|
|
645
|
-
vif_data = vif_data.sort_values(by="VIF", ascending=False).reset_index(drop=True) # type: ignore
|
|
646
|
-
|
|
647
|
-
# Add color coding based on thresholds
|
|
648
|
-
def vif_color(v: float) -> str:
|
|
649
|
-
if v > 10:
|
|
650
|
-
return "red"
|
|
651
|
-
elif v > 5:
|
|
652
|
-
return "gold"
|
|
653
|
-
else:
|
|
654
|
-
return "green"
|
|
655
|
-
|
|
656
|
-
vif_data["color"] = vif_data["VIF"].apply(vif_color)
|
|
657
|
-
|
|
658
|
-
# Plot
|
|
659
|
-
if plot or save_dir:
|
|
660
|
-
plt.figure(figsize=(10, 6))
|
|
661
|
-
bars = plt.barh(
|
|
662
|
-
vif_data["feature"],
|
|
663
|
-
vif_data["VIF"],
|
|
664
|
-
color=vif_data["color"],
|
|
665
|
-
edgecolor='black'
|
|
666
|
-
)
|
|
667
|
-
plt.title("Variance Inflation Factor (VIF) per Feature")
|
|
668
|
-
plt.xlabel("VIF")
|
|
669
|
-
plt.axvline(x=5, color='gold', linestyle='--', label='VIF = 5')
|
|
670
|
-
plt.axvline(x=10, color='red', linestyle='--', label='VIF = 10')
|
|
671
|
-
plt.legend(loc='lower right')
|
|
672
|
-
plt.gca().invert_yaxis()
|
|
673
|
-
plt.grid(axis='x', linestyle='--', alpha=0.5)
|
|
674
|
-
plt.tight_layout()
|
|
675
|
-
|
|
676
|
-
if save_dir:
|
|
677
|
-
os.makedirs(save_dir, exist_ok=True)
|
|
678
|
-
save_path = os.path.join(save_dir, "VIF_plot.svg")
|
|
679
|
-
plt.savefig(save_path, format='svg', bbox_inches='tight')
|
|
680
|
-
print(f"Saved VIF plot to: {save_path}")
|
|
681
|
-
|
|
682
|
-
if plot:
|
|
683
|
-
plt.show()
|
|
684
|
-
plt.close()
|
|
685
|
-
|
|
686
|
-
return vif_data.drop(columns="color")
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
def drop_vif_based(df: pd.DataFrame, vif_df: pd.DataFrame, threshold: float = 10.0) -> pd.DataFrame:
|
|
690
|
-
"""
|
|
691
|
-
Drops features from the original DataFrame based on their VIF values exceeding a given threshold.
|
|
692
|
-
|
|
693
|
-
Args:
|
|
694
|
-
df (pd.DataFrame): Original DataFrame containing the features.
|
|
695
|
-
vif_df (pd.DataFrame): DataFrame with 'feature' and 'VIF' columns as returned by `compute_vif()`.
|
|
696
|
-
threshold (float): VIF threshold above which features will be dropped.
|
|
697
|
-
|
|
698
|
-
Returns:
|
|
699
|
-
pd.DataFrame: A new DataFrame with high-VIF features removed.
|
|
700
|
-
"""
|
|
701
|
-
# Ensure expected structure
|
|
702
|
-
if 'feature' not in vif_df.columns or 'VIF' not in vif_df.columns:
|
|
703
|
-
raise ValueError("`vif_df` must contain 'feature' and 'VIF' columns.")
|
|
704
|
-
|
|
705
|
-
# Identify features to drop
|
|
706
|
-
to_drop = vif_df[vif_df["VIF"] > threshold]["feature"].tolist()
|
|
707
|
-
print(f"Dropping {len(to_drop)} feature(s) with VIF > {threshold}: {to_drop}")
|
|
708
|
-
|
|
709
|
-
return df.drop(columns=to_drop, errors="ignore")
|
|
710
|
-
|
|
711
|
-
|
|
712
572
|
def _is_notebook():
|
|
713
573
|
return get_ipython() is not None
|
|
714
574
|
|
ml_tools/ensemble_learning.py
CHANGED
|
@@ -139,8 +139,9 @@ def get_models(task: Literal["classification", "regression"], random_state: int=
|
|
|
139
139
|
|
|
140
140
|
###### 3. Process Dataset ######
|
|
141
141
|
# function to split data into train and test
|
|
142
|
-
def _split_data(features, target, test_size, random_state):
|
|
143
|
-
X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=random_state,
|
|
142
|
+
def _split_data(features, target, test_size, random_state, task):
|
|
143
|
+
X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=random_state,
|
|
144
|
+
stratify=target if task=="classification" else None)
|
|
144
145
|
return X_train, X_test, y_train, y_test
|
|
145
146
|
|
|
146
147
|
# function to standardize the data
|
|
@@ -176,7 +177,7 @@ def _resample(X_train_scaled: np.ndarray, y_train: pd.Series,
|
|
|
176
177
|
else:
|
|
177
178
|
raise ValueError(f"Invalid resampling strategy: {strategy}")
|
|
178
179
|
|
|
179
|
-
X_res, y_res = resample_algorithm.fit_resample(X_train_scaled, y_train)
|
|
180
|
+
X_res, y_res, *_ = resample_algorithm.fit_resample(X_train_scaled, y_train)
|
|
180
181
|
return X_res, y_res
|
|
181
182
|
|
|
182
183
|
# DATASET PIPELINE
|
|
@@ -199,7 +200,7 @@ def dataset_pipeline(df_features: pd.DataFrame, df_target: pd.Series, task: Lite
|
|
|
199
200
|
print(f"\tUnique values for '{df_target.name}': {unique_values}")
|
|
200
201
|
|
|
201
202
|
#Train test split
|
|
202
|
-
X_train, X_test, y_train, y_test = _split_data(features=df_features, target=df_target, test_size=test_size, random_state=random_state)
|
|
203
|
+
X_train, X_test, y_train, y_test = _split_data(features=df_features, target=df_target, test_size=test_size, random_state=random_state, task=task)
|
|
203
204
|
|
|
204
205
|
#DEBUG
|
|
205
206
|
if debug:
|
|
@@ -343,8 +344,7 @@ def plot_roc_curve(
|
|
|
343
344
|
color: str = "darkorange",
|
|
344
345
|
figure_size: tuple = (10, 10),
|
|
345
346
|
linewidth: int = 2,
|
|
346
|
-
|
|
347
|
-
label_fontsize: int = 24,
|
|
347
|
+
base_fontsize: int = 24,
|
|
348
348
|
input_features: Optional[np.ndarray] = None,
|
|
349
349
|
) -> plt.Figure: # type: ignore
|
|
350
350
|
"""
|
|
@@ -402,11 +402,11 @@ def plot_roc_curve(
|
|
|
402
402
|
ax.plot(fpr, tpr, color=color, lw=linewidth, label=f"AUC = {auc_score:.2f}")
|
|
403
403
|
ax.plot([0, 1], [0, 1], color="gray", linestyle="--", lw=1)
|
|
404
404
|
|
|
405
|
-
ax.set_title(f"{model_name} - {target_name}", fontsize=
|
|
406
|
-
ax.set_xlabel("False Positive Rate", fontsize=
|
|
407
|
-
ax.set_ylabel("True Positive Rate", fontsize=
|
|
408
|
-
ax.tick_params(axis='both', labelsize=
|
|
409
|
-
ax.legend(loc="lower right", fontsize=
|
|
405
|
+
ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
|
|
406
|
+
ax.set_xlabel("False Positive Rate", fontsize=base_fontsize)
|
|
407
|
+
ax.set_ylabel("True Positive Rate", fontsize=base_fontsize)
|
|
408
|
+
ax.tick_params(axis='both', labelsize=base_fontsize)
|
|
409
|
+
ax.legend(loc="lower right", fontsize=base_fontsize)
|
|
410
410
|
ax.grid(True)
|
|
411
411
|
|
|
412
412
|
# Save figure
|
|
@@ -416,6 +416,7 @@ def plot_roc_curve(
|
|
|
416
416
|
|
|
417
417
|
return fig
|
|
418
418
|
|
|
419
|
+
|
|
419
420
|
# function to evaluate the model and save metrics (Regression)
|
|
420
421
|
def evaluate_model_regression(model, model_name: str,
|
|
421
422
|
save_dir: str,
|
|
@@ -423,8 +424,7 @@ def evaluate_model_regression(model, model_name: str,
|
|
|
423
424
|
target_id: str,
|
|
424
425
|
figure_size: tuple = (12, 8),
|
|
425
426
|
alpha_transparency: float = 0.5,
|
|
426
|
-
|
|
427
|
-
normal_fontsize: int = 24):
|
|
427
|
+
base_fontsize: int = 24):
|
|
428
428
|
# Generate predictions
|
|
429
429
|
y_pred = model.predict(x_test_scaled)
|
|
430
430
|
|
|
@@ -448,9 +448,9 @@ def evaluate_model_regression(model, model_name: str,
|
|
|
448
448
|
plt.figure(figsize=figure_size)
|
|
449
449
|
plt.scatter(y_pred, residuals, alpha=alpha_transparency)
|
|
450
450
|
plt.axhline(0, color='red', linestyle='--')
|
|
451
|
-
plt.xlabel("Predicted Values", fontsize=
|
|
452
|
-
plt.ylabel("Residuals", fontsize=
|
|
453
|
-
plt.title(f"{model_name} - Residual Plot for {target_id}", fontsize=
|
|
451
|
+
plt.xlabel("Predicted Values", fontsize=base_fontsize)
|
|
452
|
+
plt.ylabel("Residuals", fontsize=base_fontsize)
|
|
453
|
+
plt.title(f"{model_name} - Residual Plot for {target_id}", fontsize=base_fontsize)
|
|
454
454
|
plt.grid(True)
|
|
455
455
|
plt.tight_layout()
|
|
456
456
|
plt.savefig(os.path.join(save_dir, f"Residual_Plot_{target_id}.svg"), bbox_inches='tight', format="svg")
|
|
@@ -462,9 +462,9 @@ def evaluate_model_regression(model, model_name: str,
|
|
|
462
462
|
plt.plot([single_y_test.min(), single_y_test.max()],
|
|
463
463
|
[single_y_test.min(), single_y_test.max()],
|
|
464
464
|
'k--', lw=2)
|
|
465
|
-
plt.xlabel('True Values', fontsize=
|
|
466
|
-
plt.ylabel('Predictions', fontsize=
|
|
467
|
-
plt.title(f"{model_name} - True vs Predicted for {target_id}", fontsize=
|
|
465
|
+
plt.xlabel('True Values', fontsize=base_fontsize)
|
|
466
|
+
plt.ylabel('Predictions', fontsize=base_fontsize)
|
|
467
|
+
plt.title(f"{model_name} - True vs Predicted for {target_id}", fontsize=base_fontsize)
|
|
468
468
|
plt.grid(True)
|
|
469
469
|
plot_path = os.path.join(save_dir, f"Regression_Plot_{target_id}.svg")
|
|
470
470
|
plt.savefig(plot_path, bbox_inches='tight', format="svg")
|
|
@@ -473,52 +473,53 @@ def evaluate_model_regression(model, model_name: str,
|
|
|
473
473
|
return y_pred
|
|
474
474
|
|
|
475
475
|
# Get SHAP values
|
|
476
|
-
def get_shap_values(
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
476
|
+
def get_shap_values(
|
|
477
|
+
model,
|
|
478
|
+
model_name: str,
|
|
479
|
+
save_dir: str,
|
|
480
|
+
features_to_explain: np.ndarray,
|
|
481
|
+
feature_names: list[str],
|
|
482
|
+
target_id: str,
|
|
483
|
+
task: Literal["classification", "regression"],
|
|
484
|
+
max_display_features: int = 10,
|
|
485
|
+
figsize: tuple = (16, 20),
|
|
486
|
+
base_fontsize: int = 38,
|
|
487
|
+
):
|
|
488
488
|
"""
|
|
489
489
|
Universal SHAP explainer for regression and classification.
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
490
|
+
* Use `X_train` (or a subsample of it) to see how the model explains the data it was trained on.
|
|
491
|
+
|
|
492
|
+
* Use `X_test` (or a hold-out set) to see how the model explains unseen data.
|
|
493
|
+
|
|
494
|
+
* Use the entire dataset to get the global view.
|
|
493
495
|
|
|
494
496
|
Parameters:
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
497
|
+
task: 'regression' or 'classification'
|
|
498
|
+
features_to_explain: Should match the model's training data format, including scaling.
|
|
499
|
+
save_dir: Directory to save visualizations
|
|
498
500
|
"""
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
# Set style
|
|
504
|
-
preferred_styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
|
|
505
|
-
for style in preferred_styles:
|
|
501
|
+
|
|
502
|
+
def _apply_plot_style():
|
|
503
|
+
styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
|
|
504
|
+
for style in styles:
|
|
506
505
|
if style in plt.style.available or style == 'default':
|
|
507
506
|
plt.style.use(style)
|
|
508
507
|
break
|
|
509
|
-
|
|
508
|
+
|
|
509
|
+
def _configure_rcparams():
|
|
510
|
+
plt.rc('font', size=base_fontsize)
|
|
511
|
+
plt.rc('axes', titlesize=base_fontsize)
|
|
512
|
+
plt.rc('axes', labelsize=base_fontsize)
|
|
513
|
+
plt.rc('xtick', labelsize=base_fontsize)
|
|
514
|
+
plt.rc('ytick', labelsize=base_fontsize + 2)
|
|
515
|
+
plt.rc('legend', fontsize=base_fontsize)
|
|
516
|
+
plt.rc('figure', titlesize=base_fontsize)
|
|
517
|
+
|
|
518
|
+
def _create_shap_plot(shap_values, features, save_path: str, plot_type: str, title: str):
|
|
519
|
+
_apply_plot_style()
|
|
520
|
+
_configure_rcparams()
|
|
510
521
|
plt.figure(figsize=figsize)
|
|
511
|
-
|
|
512
|
-
#set rc parameters for better readability
|
|
513
|
-
plt.rc('font', size=label_fontsize)
|
|
514
|
-
plt.rc('axes', titlesize=title_fontsize)
|
|
515
|
-
plt.rc('axes', labelsize=label_fontsize)
|
|
516
|
-
plt.rc('xtick', labelsize=label_fontsize)
|
|
517
|
-
plt.rc('ytick', labelsize=label_fontsize)
|
|
518
|
-
plt.rc('legend', fontsize=label_fontsize)
|
|
519
|
-
plt.rc('figure', titlesize=title_fontsize)
|
|
520
|
-
|
|
521
|
-
# Create the SHAP plot
|
|
522
|
+
|
|
522
523
|
shap.summary_plot(
|
|
523
524
|
shap_values=shap_values,
|
|
524
525
|
features=features,
|
|
@@ -528,85 +529,75 @@ def get_shap_values(model, model_name: str,
|
|
|
528
529
|
plot_size=figsize,
|
|
529
530
|
max_display=max_display_features,
|
|
530
531
|
alpha=0.7,
|
|
531
|
-
color=
|
|
532
|
+
# color='viridis'
|
|
532
533
|
)
|
|
533
|
-
|
|
534
|
-
# Add professional styling
|
|
534
|
+
|
|
535
535
|
ax = plt.gca()
|
|
536
|
-
ax.set_xlabel("SHAP Value Impact", fontsize=
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
# Manually fix tick fonts
|
|
536
|
+
ax.set_xlabel("SHAP Value Impact", fontsize=base_fontsize + 2, weight='bold', labelpad=20)
|
|
537
|
+
plt.title(title, fontsize=base_fontsize + 2, pad=20, weight='bold')
|
|
538
|
+
|
|
541
539
|
for tick in ax.get_xticklabels():
|
|
542
|
-
tick.set_fontsize(
|
|
543
|
-
tick.set_rotation(
|
|
540
|
+
tick.set_fontsize(base_fontsize)
|
|
541
|
+
tick.set_rotation(30)
|
|
544
542
|
for tick in ax.get_yticklabels():
|
|
545
|
-
tick.set_fontsize(
|
|
543
|
+
tick.set_fontsize(base_fontsize + 2)
|
|
546
544
|
|
|
547
|
-
# Handle colorbar for dot plots
|
|
548
545
|
if plot_type == "dot":
|
|
549
546
|
cb = plt.gcf().axes[-1]
|
|
550
|
-
# cb.set_ylabel("Feature Value", size=label_fontsize)
|
|
551
547
|
cb.set_ylabel("", size=1)
|
|
552
|
-
cb.tick_params(labelsize=
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
plt.savefig(
|
|
556
|
-
full_save_path,
|
|
557
|
-
bbox_inches='tight',
|
|
558
|
-
facecolor='white',
|
|
559
|
-
format="svg"
|
|
560
|
-
)
|
|
548
|
+
cb.tick_params(labelsize=base_fontsize - 2)
|
|
549
|
+
|
|
550
|
+
plt.savefig(save_path, bbox_inches='tight', facecolor='white', format="svg")
|
|
561
551
|
plt.close()
|
|
562
|
-
rcdefaults()
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
# Handle different model types
|
|
569
|
-
if task == 'classification':
|
|
570
|
-
# Determine if multiclass
|
|
571
|
-
try:
|
|
572
|
-
is_multiclass = len(model.classes_) > 2
|
|
573
|
-
class_names = model.classes_
|
|
574
|
-
except AttributeError:
|
|
575
|
-
is_multiclass = isinstance(shap_values, list) and len(shap_values) > 1
|
|
576
|
-
class_names = list(range(len(shap_values))) if is_multiclass else [0, 1]
|
|
577
|
-
|
|
552
|
+
rcdefaults()
|
|
553
|
+
|
|
554
|
+
def _plot_for_classification(shap_values, class_names):
|
|
555
|
+
is_multiclass = isinstance(shap_values, list) and len(shap_values) > 1
|
|
556
|
+
|
|
578
557
|
if is_multiclass:
|
|
579
|
-
for
|
|
558
|
+
for class_shap, class_name in zip(shap_values, class_names):
|
|
559
|
+
for plot_type in ["bar", "dot"]:
|
|
560
|
+
_create_shap_plot(
|
|
561
|
+
shap_values=class_shap,
|
|
562
|
+
features=features_to_explain,
|
|
563
|
+
save_path=os.path.join(save_dir, f"SHAP_{target_id}_Class{class_name}_{plot_type}.svg"),
|
|
564
|
+
plot_type=plot_type,
|
|
565
|
+
title=f"{model_name} - {target_id} (Class {class_name})"
|
|
566
|
+
)
|
|
567
|
+
else:
|
|
568
|
+
values = shap_values[1] if isinstance(shap_values, list) else shap_values
|
|
569
|
+
for plot_type in ["bar", "dot"]:
|
|
580
570
|
_create_shap_plot(
|
|
581
|
-
shap_values=
|
|
571
|
+
shap_values=values,
|
|
582
572
|
features=features_to_explain,
|
|
583
|
-
|
|
584
|
-
full_save_path=os.path.join(save_dir, f"SHAP_{target_id}_Class{class_name}.svg"),
|
|
573
|
+
save_path=os.path.join(save_dir, f"SHAP_{target_id}_{plot_type}.svg"),
|
|
585
574
|
plot_type=plot_type,
|
|
586
|
-
title=f"{model_name} - {target_id}
|
|
575
|
+
title=f"{model_name} - {target_id}"
|
|
587
576
|
)
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
577
|
+
|
|
578
|
+
def _plot_for_regression(shap_values):
|
|
579
|
+
for plot_type in ["bar", "dot"]:
|
|
591
580
|
_create_shap_plot(
|
|
592
|
-
shap_values=
|
|
581
|
+
shap_values=shap_values,
|
|
593
582
|
features=features_to_explain,
|
|
594
|
-
|
|
595
|
-
full_save_path=os.path.join(save_dir, f"SHAP_{target_id}.svg"),
|
|
583
|
+
save_path=os.path.join(save_dir, f"SHAP_{target_id}_{plot_type}.svg"),
|
|
596
584
|
plot_type=plot_type,
|
|
597
585
|
title=f"{model_name} - {target_id}"
|
|
598
586
|
)
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
)
|
|
609
|
-
|
|
587
|
+
|
|
588
|
+
explainer = shap.TreeExplainer(model)
|
|
589
|
+
shap_values = explainer.shap_values(features_to_explain)
|
|
590
|
+
|
|
591
|
+
if task == 'classification':
|
|
592
|
+
try:
|
|
593
|
+
class_names = model.classes_ if hasattr(model, 'classes_') else list(range(len(shap_values)))
|
|
594
|
+
except Exception:
|
|
595
|
+
class_names = list(range(len(shap_values)))
|
|
596
|
+
_plot_for_classification(shap_values, class_names)
|
|
597
|
+
else:
|
|
598
|
+
_plot_for_regression(shap_values)
|
|
599
|
+
|
|
600
|
+
|
|
610
601
|
# TRAIN TEST PIPELINE
|
|
611
602
|
def train_test_pipeline(model, model_name: str, dataset_id: str, task: Literal["classification", "regression"],
|
|
612
603
|
train_features: np.ndarray, train_target: np.ndarray,
|
|
@@ -653,7 +644,7 @@ def train_test_pipeline(model, model_name: str, dataset_id: str, task: Literal["
|
|
|
653
644
|
return trained_model, y_pred
|
|
654
645
|
|
|
655
646
|
###### 5. Execution ######
|
|
656
|
-
def
|
|
647
|
+
def run_ensemble_pipeline(datasets_dir: str, save_dir: str, target_columns: list[str], task: Literal["classification", "regression"],
|
|
657
648
|
resample_strategy: Literal[r"ADASYN", r'SMOTE', r'RANDOM', r'UNDERSAMPLE', None]=None, scaler: Literal["standard", "minmax", "maxabs"]="minmax", save_model: bool=False,
|
|
658
649
|
test_size: float=0.2, debug:bool=False, L1_regularization: float=0.5, L2_regularization: float=0.5, learning_rate: float=0.005, random_state: int=101):
|
|
659
650
|
#Check paths
|
|
@@ -672,15 +663,15 @@ def run_pipeline(datasets_dir: str, save_dir: str, target_columns: list[str], ta
|
|
|
672
663
|
#Train models
|
|
673
664
|
for model_name, model in models_dict.items():
|
|
674
665
|
train_test_pipeline(model=model, model_name=model_name, dataset_id=dataframe_name, task=task,
|
|
675
|
-
train_features=X_train, train_target=y_train,
|
|
666
|
+
train_features=X_train, train_target=y_train, # type: ignore
|
|
676
667
|
test_features=X_test, test_target=y_test,
|
|
677
668
|
feature_names=feature_names,target_id=target_name, scaler_object=scaler_object,
|
|
678
669
|
debug=debug, save_dir=save_dir, save_model=save_model)
|
|
679
|
-
print("\
|
|
670
|
+
print("\n✅ Training and evaluation complete.")
|
|
680
671
|
|
|
681
672
|
|
|
682
673
|
def _check_paths(datasets_dir: str, save_dir:str):
|
|
683
674
|
if not os.path.isdir(save_dir):
|
|
684
675
|
os.makedirs(save_dir)
|
|
685
676
|
if not os.path.isdir(datasets_dir):
|
|
686
|
-
raise IOError(f"Datasets directory '{datasets_dir}' not found
|
|
677
|
+
raise IOError(f"Datasets directory '{datasets_dir}' not found.")
|
ml_tools/utilities.py
CHANGED
|
@@ -6,17 +6,15 @@ from pathlib import Path
|
|
|
6
6
|
import re
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def list_csv_paths(directory: str) ->
|
|
9
|
+
def list_csv_paths(directory: str) -> dict[str, str]:
|
|
10
10
|
"""
|
|
11
|
-
Lists all
|
|
11
|
+
Lists all `.csv` files in the specified directory and returns a mapping: filenames (without extensions) to their absolute paths.
|
|
12
12
|
|
|
13
13
|
Parameters:
|
|
14
14
|
directory (str): Path to the directory containing `.csv` files.
|
|
15
15
|
|
|
16
16
|
Returns:
|
|
17
|
-
|
|
18
|
-
- List of absolute paths to `.csv` files.
|
|
19
|
-
- List of corresponding base names (without extensions).
|
|
17
|
+
(dict[str, str]): Mapping {name, path}.
|
|
20
18
|
"""
|
|
21
19
|
dir_path = Path(directory).expanduser().resolve()
|
|
22
20
|
|
|
@@ -26,11 +24,15 @@ def list_csv_paths(directory: str) -> tuple[list[str], list[str]]:
|
|
|
26
24
|
csv_paths = list(dir_path.glob("*.csv"))
|
|
27
25
|
if not csv_paths:
|
|
28
26
|
raise IOError(f"No CSV files found in directory: {dir_path}")
|
|
27
|
+
|
|
28
|
+
# make a dictionary of paths and names
|
|
29
|
+
name_path_dict = {p.stem: str(p) for p in csv_paths}
|
|
30
|
+
|
|
31
|
+
print("🗂️ CSV files found:")
|
|
32
|
+
for name in name_path_dict.keys():
|
|
33
|
+
print(f"\t{name}")
|
|
29
34
|
|
|
30
|
-
|
|
31
|
-
names = [p.stem for p in csv_paths]
|
|
32
|
-
|
|
33
|
-
return paths, names
|
|
35
|
+
return name_path_dict
|
|
34
36
|
|
|
35
37
|
|
|
36
38
|
def load_dataframe(df_path: str) -> tuple[pd.DataFrame, str]:
|
|
@@ -49,7 +51,7 @@ def load_dataframe(df_path: str) -> tuple[pd.DataFrame, str]:
|
|
|
49
51
|
df_name = path.stem
|
|
50
52
|
if df.empty:
|
|
51
53
|
raise ValueError(f"DataFrame '{df_name}' is empty.")
|
|
52
|
-
print(f"Loaded dataset: '{df_name}' with shape: {df.shape}")
|
|
54
|
+
print(f"\n💿 Loaded dataset: '{df_name}' with shape: {df.shape}")
|
|
53
55
|
return df, df_name
|
|
54
56
|
|
|
55
57
|
|
|
@@ -71,9 +73,8 @@ def yield_dataframes_from_dir(datasets_dir: str):
|
|
|
71
73
|
- CSV files are read using UTF-8 encoding.
|
|
72
74
|
- Output is streamed via a generator to support lazy loading of multiple datasets.
|
|
73
75
|
"""
|
|
74
|
-
for
|
|
75
|
-
df =
|
|
76
|
-
print(f"Loaded dataset: '{df_name}' with shape: {df.shape}")
|
|
76
|
+
for df_name, df_path in list_csv_paths(datasets_dir).items():
|
|
77
|
+
df, _ = load_dataframe(df_path)
|
|
77
78
|
yield df, df_name
|
|
78
79
|
|
|
79
80
|
|
|
@@ -166,3 +167,28 @@ def sanitize_filename(filename: str) -> str:
|
|
|
166
167
|
|
|
167
168
|
return sanitized
|
|
168
169
|
|
|
170
|
+
|
|
171
|
+
def save_dataframe(df: pd.DataFrame, save_dir: str, filename: str) -> None:
|
|
172
|
+
"""
|
|
173
|
+
Save a pandas DataFrame to a CSV file.
|
|
174
|
+
|
|
175
|
+
Parameters:
|
|
176
|
+
df: pandas.DataFrame to save
|
|
177
|
+
save_dir: str, directory where the CSV file will be saved.
|
|
178
|
+
filename: str, CSV filename, extension will be added if missing.
|
|
179
|
+
"""
|
|
180
|
+
if df.empty:
|
|
181
|
+
print(f"⚠️ Attempting to save an empty DataFrame: '{filename}'. Process Skipped.")
|
|
182
|
+
return
|
|
183
|
+
|
|
184
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
185
|
+
|
|
186
|
+
filename = sanitize_filename(filename)
|
|
187
|
+
|
|
188
|
+
if not filename.endswith('.csv'):
|
|
189
|
+
filename += '.csv'
|
|
190
|
+
|
|
191
|
+
output_path = os.path.join(save_dir, filename)
|
|
192
|
+
|
|
193
|
+
df.to_csv(output_path, index=False, encoding='utf-8')
|
|
194
|
+
print(f"✅ Saved file: '{filename}'")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|