dragon-ml-toolbox 1.4.8__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-1.4.8.dist-info → dragon_ml_toolbox-2.1.0.dist-info}/METADATA +24 -14
- dragon_ml_toolbox-2.1.0.dist-info/RECORD +20 -0
- {dragon_ml_toolbox-1.4.8.dist-info → dragon_ml_toolbox-2.1.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +5 -4
- ml_tools/MICE_imputation.py +27 -28
- ml_tools/PSO_optimization.py +490 -0
- ml_tools/VIF_factor.py +20 -17
- ml_tools/{particle_swarm_optimization.py → _particle_swarm_optimization.py} +5 -0
- ml_tools/data_exploration.py +58 -32
- ml_tools/ensemble_learning.py +40 -42
- ml_tools/handle_excel.py +98 -78
- ml_tools/logger.py +13 -11
- ml_tools/utilities.py +134 -46
- dragon_ml_toolbox-1.4.8.dist-info/RECORD +0 -19
- {dragon_ml_toolbox-1.4.8.dist-info → dragon_ml_toolbox-2.1.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-1.4.8.dist-info → dragon_ml_toolbox-2.1.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-1.4.8.dist-info → dragon_ml_toolbox-2.1.0.dist-info}/top_level.txt +0 -0
ml_tools/data_exploration.py
CHANGED
|
@@ -5,9 +5,9 @@ import seaborn as sns
|
|
|
5
5
|
from IPython import get_ipython
|
|
6
6
|
from IPython.display import clear_output
|
|
7
7
|
import time
|
|
8
|
-
from typing import Union, Literal, Dict, Tuple, List
|
|
9
|
-
import
|
|
10
|
-
from .utilities import sanitize_filename, _script_info
|
|
8
|
+
from typing import Union, Literal, Dict, Tuple, List, Optional
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from .utilities import sanitize_filename, _script_info, make_fullpath
|
|
11
11
|
import re
|
|
12
12
|
|
|
13
13
|
|
|
@@ -59,26 +59,48 @@ def summarize_dataframe(df: pd.DataFrame, round_digits: int = 2):
|
|
|
59
59
|
return summary
|
|
60
60
|
|
|
61
61
|
|
|
62
|
-
def drop_rows_with_missing_data(df: pd.DataFrame, threshold: float = 0.7) -> pd.DataFrame:
|
|
62
|
+
def drop_rows_with_missing_data(df: pd.DataFrame, targets: Optional[list[str]], threshold: float = 0.7) -> pd.DataFrame:
|
|
63
63
|
"""
|
|
64
|
-
Drops rows
|
|
64
|
+
Drops rows from the DataFrame using a two-stage strategy:
|
|
65
|
+
|
|
66
|
+
1. If `targets`, remove any row where all target columns are missing.
|
|
67
|
+
2. Among features, drop those with more than `threshold` fraction of missing values.
|
|
65
68
|
|
|
66
69
|
Parameters:
|
|
67
70
|
df (pd.DataFrame): The input DataFrame.
|
|
68
|
-
|
|
71
|
+
targets (list[str] | None): List of target column names.
|
|
72
|
+
threshold (float): Maximum allowed fraction of missing values in feature columns.
|
|
69
73
|
|
|
70
74
|
Returns:
|
|
71
|
-
pd.DataFrame: A
|
|
75
|
+
pd.DataFrame: A cleaned DataFrame with problematic rows removed.
|
|
72
76
|
"""
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
if
|
|
77
|
-
|
|
77
|
+
df_clean = df.copy()
|
|
78
|
+
|
|
79
|
+
# Stage 1: Drop rows with all target columns missing
|
|
80
|
+
if targets is not None:
|
|
81
|
+
target_na = df_clean[targets].isnull().all(axis=1)
|
|
82
|
+
if target_na.any():
|
|
83
|
+
print(f"🧹 Dropping {target_na.sum()} rows with all target columns missing.")
|
|
84
|
+
df_clean = df_clean[~target_na]
|
|
85
|
+
else:
|
|
86
|
+
print("✅ No rows with all targets missing.")
|
|
78
87
|
else:
|
|
79
|
-
|
|
88
|
+
targets = []
|
|
89
|
+
|
|
90
|
+
# Stage 2: Drop rows based on feature column missing values
|
|
91
|
+
feature_cols = [col for col in df_clean.columns if col not in targets]
|
|
92
|
+
if feature_cols:
|
|
93
|
+
feature_na_frac = df_clean[feature_cols].isnull().mean(axis=1)
|
|
94
|
+
rows_to_drop = feature_na_frac[feature_na_frac > threshold].index
|
|
95
|
+
if len(rows_to_drop) > 0:
|
|
96
|
+
print(f"📉 Dropping {len(rows_to_drop)} rows with more than {threshold*100:.0f}% missing feature data.")
|
|
97
|
+
df_clean = df_clean.drop(index=rows_to_drop)
|
|
98
|
+
else:
|
|
99
|
+
print(f"✅ No rows exceed the {threshold*100:.0f}% missing feature data threshold.")
|
|
100
|
+
else:
|
|
101
|
+
print("⚠️ No feature columns available to evaluate.")
|
|
80
102
|
|
|
81
|
-
return
|
|
103
|
+
return df_clean
|
|
82
104
|
|
|
83
105
|
|
|
84
106
|
def split_features_targets(df: pd.DataFrame, targets: list[str]):
|
|
@@ -205,13 +227,16 @@ def split_continuous_binary(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFram
|
|
|
205
227
|
return df_cont, df_bin # type: ignore
|
|
206
228
|
|
|
207
229
|
|
|
208
|
-
def plot_correlation_heatmap(df: pd.DataFrame,
|
|
230
|
+
def plot_correlation_heatmap(df: pd.DataFrame,
|
|
231
|
+
save_dir: Union[str, Path, None] = None,
|
|
232
|
+
plot_title: str="Correlation Heatmap",
|
|
233
|
+
method: Literal["pearson", "kendall", "spearman"]="pearson"):
|
|
209
234
|
"""
|
|
210
235
|
Plots a heatmap of pairwise correlations between numeric features in a DataFrame.
|
|
211
236
|
|
|
212
237
|
Args:
|
|
213
238
|
df (pd.DataFrame): The input dataset.
|
|
214
|
-
save_dir (str | None): If provided, the heatmap will be saved to this directory as a svg file.
|
|
239
|
+
save_dir (str | Path | None): If provided, the heatmap will be saved to this directory as a svg file.
|
|
215
240
|
plot_title: To make different plots, or overwrite existing ones.
|
|
216
241
|
method (str): Correlation method to use. Must be one of:
|
|
217
242
|
- 'pearson' (default): measures linear correlation (assumes normally distributed data),
|
|
@@ -254,10 +279,13 @@ def plot_correlation_heatmap(df: pd.DataFrame, save_dir: Union[str, None] = None
|
|
|
254
279
|
plt.tight_layout()
|
|
255
280
|
|
|
256
281
|
if save_dir:
|
|
282
|
+
save_path = make_fullpath(save_dir, make=True)
|
|
257
283
|
# sanitize the plot title to save the file
|
|
258
284
|
plot_title = sanitize_filename(plot_title)
|
|
259
|
-
|
|
260
|
-
|
|
285
|
+
plot_title = plot_title + ".svg"
|
|
286
|
+
|
|
287
|
+
full_path = save_path / plot_title
|
|
288
|
+
|
|
261
289
|
plt.savefig(full_path, bbox_inches="tight", format='svg')
|
|
262
290
|
print(f"Saved correlation heatmap: '{plot_title}.svg'")
|
|
263
291
|
|
|
@@ -322,7 +350,7 @@ def check_value_distributions(df: pd.DataFrame, view_frequencies: bool=True, bin
|
|
|
322
350
|
user_input_ = input("Press enter to continue")
|
|
323
351
|
|
|
324
352
|
|
|
325
|
-
def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
|
|
353
|
+
def plot_value_distributions(df: pd.DataFrame, save_dir: Union[str, Path], bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
|
|
326
354
|
"""
|
|
327
355
|
Plots and saves the value distributions for all (or selected) columns in a DataFrame,
|
|
328
356
|
with adaptive binning for numerical columns when appropriate.
|
|
@@ -335,7 +363,7 @@ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int
|
|
|
335
363
|
|
|
336
364
|
Args:
|
|
337
365
|
df (pd.DataFrame): The input DataFrame whose columns are to be analyzed.
|
|
338
|
-
save_dir (str): Directory path where the plots will be saved. Will be created if it does not exist.
|
|
366
|
+
save_dir (str | Path): Directory path where the plots will be saved. Will be created if it does not exist.
|
|
339
367
|
bin_threshold (int): Minimum number of unique values required to trigger binning
|
|
340
368
|
for numerical columns.
|
|
341
369
|
skip_cols_with_key (str | None): If provided, any column whose name contains this
|
|
@@ -346,8 +374,7 @@ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int
|
|
|
346
374
|
- All non-alphanumeric characters in column names are sanitized for safe file naming.
|
|
347
375
|
- Colormap is automatically adapted based on the number of categories or bins.
|
|
348
376
|
"""
|
|
349
|
-
|
|
350
|
-
os.makedirs(save_dir, exist_ok=True)
|
|
377
|
+
save_path = make_fullpath(save_dir, make=True)
|
|
351
378
|
|
|
352
379
|
dict_to_plot_std = dict()
|
|
353
380
|
dict_to_plot_freq = dict()
|
|
@@ -384,13 +411,12 @@ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int
|
|
|
384
411
|
view_freq = 100 * view_std / view_std.sum() # Percentage
|
|
385
412
|
# view_freq = df[col].value_counts(normalize=True, bins=10) # relative percentages
|
|
386
413
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
saved_plots += 1
|
|
414
|
+
dict_to_plot_std[col] = dict(view_std)
|
|
415
|
+
dict_to_plot_freq[col] = dict(view_freq)
|
|
416
|
+
saved_plots += 1
|
|
391
417
|
|
|
392
418
|
# plot helper
|
|
393
|
-
def _plot_helper(dict_: dict, target_dir:
|
|
419
|
+
def _plot_helper(dict_: dict, target_dir: Path, ylabel: Literal["Frequency", "Counts"], base_fontsize: int=12):
|
|
394
420
|
for col, data in dict_.items():
|
|
395
421
|
safe_col = sanitize_filename(col)
|
|
396
422
|
|
|
@@ -412,15 +438,15 @@ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int
|
|
|
412
438
|
plt.gca().set_facecolor('#f9f9f9')
|
|
413
439
|
plt.tight_layout()
|
|
414
440
|
|
|
415
|
-
plot_path =
|
|
441
|
+
plot_path = target_dir / f"{safe_col}.png"
|
|
416
442
|
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
|
|
417
443
|
plt.close()
|
|
418
444
|
|
|
419
445
|
# Save plots
|
|
420
|
-
freq_dir =
|
|
421
|
-
std_dir =
|
|
422
|
-
|
|
423
|
-
|
|
446
|
+
freq_dir = save_path / "Distribution_Frequency"
|
|
447
|
+
std_dir = save_path / "Distribution_Counts"
|
|
448
|
+
freq_dir.mkdir(parents=True, exist_ok=True)
|
|
449
|
+
std_dir.mkdir(parents=True, exist_ok=True)
|
|
424
450
|
_plot_helper(dict_=dict_to_plot_std, target_dir=std_dir, ylabel="Counts")
|
|
425
451
|
_plot_helper(dict_=dict_to_plot_freq, target_dir=freq_dir, ylabel="Frequency")
|
|
426
452
|
|
ml_tools/ensemble_learning.py
CHANGED
|
@@ -5,7 +5,7 @@ import matplotlib.pyplot as plt
|
|
|
5
5
|
from matplotlib.colors import Colormap
|
|
6
6
|
from matplotlib import rcdefaults
|
|
7
7
|
|
|
8
|
-
import
|
|
8
|
+
from pathlib import Path
|
|
9
9
|
from typing import Literal, Union, Optional, Iterator, Tuple
|
|
10
10
|
|
|
11
11
|
from imblearn.over_sampling import ADASYN, SMOTE, RandomOverSampler
|
|
@@ -19,7 +19,7 @@ from sklearn.model_selection import train_test_split
|
|
|
19
19
|
from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay, mean_absolute_error, mean_squared_error, r2_score, roc_curve, roc_auc_score
|
|
20
20
|
import shap
|
|
21
21
|
|
|
22
|
-
from .utilities import yield_dataframes_from_dir, sanitize_filename, _script_info, serialize_object
|
|
22
|
+
from .utilities import yield_dataframes_from_dir, sanitize_filename, _script_info, serialize_object, make_fullpath
|
|
23
23
|
|
|
24
24
|
import warnings # Ignore warnings
|
|
25
25
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
|
@@ -469,30 +469,31 @@ def _train_model(model, train_features, train_target):
|
|
|
469
469
|
return model
|
|
470
470
|
|
|
471
471
|
# handle local directories
|
|
472
|
-
def _local_directories(model_name: str, dataset_id: str, save_dir: str):
|
|
473
|
-
|
|
474
|
-
if not os.path.isdir(dataset_dir):
|
|
475
|
-
os.makedirs(dataset_dir)
|
|
472
|
+
def _local_directories(model_name: str, dataset_id: str, save_dir: Union[str,Path]):
|
|
473
|
+
save_path = make_fullpath(save_dir, make=True)
|
|
476
474
|
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
475
|
+
dataset_dir = save_path / dataset_id
|
|
476
|
+
dataset_dir.mkdir(parents=True, exist_ok=True)
|
|
477
|
+
|
|
478
|
+
model_dir = dataset_dir / model_name
|
|
479
|
+
model_dir.mkdir(parents=True, exist_ok=True)
|
|
480
480
|
|
|
481
481
|
return model_dir
|
|
482
482
|
|
|
483
483
|
# save model
|
|
484
|
-
def _save_model(trained_model, model_name: str, target_name:str, feature_names: list[str], save_directory: str):
|
|
484
|
+
def _save_model(trained_model, model_name: str, target_name:str, feature_names: list[str], save_directory: Union[str,Path]):
|
|
485
485
|
#Sanitize filenames to save
|
|
486
486
|
sanitized_target_name = sanitize_filename(target_name)
|
|
487
487
|
filename = f"{model_name}_{sanitized_target_name}"
|
|
488
488
|
to_save = {'model': trained_model, 'feature_names': feature_names, 'target_name':target_name}
|
|
489
|
+
|
|
489
490
|
serialize_object(obj=to_save, save_dir=save_directory, filename=filename, verbose=False, raise_on_error=True)
|
|
490
491
|
|
|
491
492
|
# function to evaluate the model and save metrics (Classification)
|
|
492
493
|
def evaluate_model_classification(
|
|
493
494
|
model,
|
|
494
495
|
model_name: str,
|
|
495
|
-
save_dir: str,
|
|
496
|
+
save_dir: Union[str,Path],
|
|
496
497
|
x_test_scaled: np.ndarray,
|
|
497
498
|
single_y_test: np.ndarray,
|
|
498
499
|
target_name: str,
|
|
@@ -524,7 +525,7 @@ def evaluate_model_classification(
|
|
|
524
525
|
Returns:
|
|
525
526
|
y_pred: Predicted class labels
|
|
526
527
|
"""
|
|
527
|
-
|
|
528
|
+
save_path = make_fullpath(save_dir, make=True)
|
|
528
529
|
|
|
529
530
|
y_pred = model.predict(x_test_scaled)
|
|
530
531
|
accuracy = accuracy_score(single_y_test, y_pred)
|
|
@@ -538,7 +539,7 @@ def evaluate_model_classification(
|
|
|
538
539
|
|
|
539
540
|
# Save text report
|
|
540
541
|
sanitized_target_name = sanitize_filename(target_name)
|
|
541
|
-
report_path =
|
|
542
|
+
report_path = save_path / f"Classification_Report_{sanitized_target_name}.txt"
|
|
542
543
|
with open(report_path, "w") as f:
|
|
543
544
|
f.write(f"{model_name} - {target_name}\t\tAccuracy: {accuracy:.2f}\n")
|
|
544
545
|
f.write("Classification Report:\n")
|
|
@@ -568,7 +569,7 @@ def evaluate_model_classification(
|
|
|
568
569
|
text.set_fontsize(base_fontsize+4)
|
|
569
570
|
|
|
570
571
|
fig.tight_layout()
|
|
571
|
-
fig_path =
|
|
572
|
+
fig_path = save_path / f"Confusion_Matrix_{sanitized_target_name}.svg"
|
|
572
573
|
fig.savefig(fig_path, format="svg", bbox_inches="tight")
|
|
573
574
|
plt.close(fig)
|
|
574
575
|
|
|
@@ -580,7 +581,7 @@ def plot_roc_curve(
|
|
|
580
581
|
probabilities_or_model: Union[np.ndarray, xgb.XGBClassifier, lgb.LGBMClassifier, object],
|
|
581
582
|
model_name: str,
|
|
582
583
|
target_name: str,
|
|
583
|
-
save_directory: str,
|
|
584
|
+
save_directory: Union[str,Path],
|
|
584
585
|
color: str = "darkorange",
|
|
585
586
|
figure_size: tuple = (10, 10),
|
|
586
587
|
linewidth: int = 2,
|
|
@@ -594,7 +595,7 @@ def plot_roc_curve(
|
|
|
594
595
|
true_labels: np.ndarray of shape (n_samples,), ground truth binary labels (0 or 1).
|
|
595
596
|
probabilities_or_model: either predicted probabilities (ndarray), or a trained model with attribute `.predict_proba()`.
|
|
596
597
|
target_name: str, Target name.
|
|
597
|
-
save_directory: str, path to directory where figure is saved.
|
|
598
|
+
save_directory: str or Path, path to directory where figure is saved.
|
|
598
599
|
color: color of the ROC curve. Accepts any valid Matplotlib color specification. Examples:
|
|
599
600
|
- Named colors: "darkorange", "blue", "red", "green", "black"
|
|
600
601
|
- Hex codes: "#1f77b4", "#ff7f0e"
|
|
@@ -650,17 +651,17 @@ def plot_roc_curve(
|
|
|
650
651
|
ax.grid(True)
|
|
651
652
|
|
|
652
653
|
# Save figure
|
|
653
|
-
|
|
654
|
+
save_path = make_fullpath(save_directory, make=True)
|
|
654
655
|
sanitized_target_name = sanitize_filename(target_name)
|
|
655
|
-
|
|
656
|
-
fig.savefig(
|
|
656
|
+
full_save_path = save_path / f"ROC_{sanitized_target_name}.svg"
|
|
657
|
+
fig.savefig(full_save_path, bbox_inches="tight", format="svg")
|
|
657
658
|
|
|
658
659
|
return fig
|
|
659
660
|
|
|
660
661
|
|
|
661
662
|
# function to evaluate the model and save metrics (Regression)
|
|
662
663
|
def evaluate_model_regression(model, model_name: str,
|
|
663
|
-
save_dir: str,
|
|
664
|
+
save_dir: Union[str,Path],
|
|
664
665
|
x_test_scaled: np.ndarray, single_y_test: np.ndarray,
|
|
665
666
|
target_name: str,
|
|
666
667
|
figure_size: tuple = (12, 8),
|
|
@@ -677,7 +678,8 @@ def evaluate_model_regression(model, model_name: str,
|
|
|
677
678
|
|
|
678
679
|
# Create formatted report
|
|
679
680
|
sanitized_target_name = sanitize_filename(target_name)
|
|
680
|
-
|
|
681
|
+
save_path = make_fullpath(save_dir, make=True)
|
|
682
|
+
report_path = save_path / f"Regression_Report_{sanitized_target_name}.txt"
|
|
681
683
|
with open(report_path, "w") as f:
|
|
682
684
|
f.write(f"{model_name} - Regression Performance for '{target_name}'\n\n")
|
|
683
685
|
f.write(f"Mean Absolute Error (MAE): {mae:.4f}\n")
|
|
@@ -695,7 +697,8 @@ def evaluate_model_regression(model, model_name: str,
|
|
|
695
697
|
plt.title(f"{model_name} - Residual Plot for {target_name}", fontsize=base_fontsize)
|
|
696
698
|
plt.grid(True)
|
|
697
699
|
plt.tight_layout()
|
|
698
|
-
|
|
700
|
+
residual_path = save_path / f"Residual_Plot_{sanitized_target_name}.svg"
|
|
701
|
+
plt.savefig(residual_path, bbox_inches='tight', format="svg")
|
|
699
702
|
plt.close()
|
|
700
703
|
|
|
701
704
|
# Create true vs predicted values plot
|
|
@@ -708,7 +711,7 @@ def evaluate_model_regression(model, model_name: str,
|
|
|
708
711
|
plt.ylabel('Predictions', fontsize=base_fontsize)
|
|
709
712
|
plt.title(f"{model_name} - True vs Predicted for {target_name}", fontsize=base_fontsize)
|
|
710
713
|
plt.grid(True)
|
|
711
|
-
plot_path =
|
|
714
|
+
plot_path = save_path / f"Regression_Plot_{sanitized_target_name}.svg"
|
|
712
715
|
plt.savefig(plot_path, bbox_inches='tight', format="svg")
|
|
713
716
|
plt.close()
|
|
714
717
|
|
|
@@ -719,7 +722,7 @@ def evaluate_model_regression(model, model_name: str,
|
|
|
719
722
|
def get_shap_values(
|
|
720
723
|
model,
|
|
721
724
|
model_name: str,
|
|
722
|
-
save_dir: str,
|
|
725
|
+
save_dir: Union[str, Path],
|
|
723
726
|
features_to_explain: np.ndarray,
|
|
724
727
|
feature_names: list[str],
|
|
725
728
|
target_name: str,
|
|
@@ -737,11 +740,12 @@ def get_shap_values(
|
|
|
737
740
|
* Use the entire dataset to get the global view.
|
|
738
741
|
|
|
739
742
|
Parameters:
|
|
740
|
-
task: 'regression' or 'classification'
|
|
743
|
+
task: 'regression' or 'classification'.
|
|
741
744
|
features_to_explain: Should match the model's training data format, including scaling.
|
|
742
|
-
save_dir: Directory to save visualizations
|
|
745
|
+
save_dir: Directory to save visualizations.
|
|
743
746
|
"""
|
|
744
747
|
sanitized_target_name = sanitize_filename(target_name)
|
|
748
|
+
global_save_path = make_fullpath(save_dir, make=True)
|
|
745
749
|
|
|
746
750
|
def _apply_plot_style():
|
|
747
751
|
styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
|
|
@@ -759,7 +763,7 @@ def get_shap_values(
|
|
|
759
763
|
plt.rc('legend', fontsize=base_fontsize)
|
|
760
764
|
plt.rc('figure', titlesize=base_fontsize)
|
|
761
765
|
|
|
762
|
-
def _create_shap_plot(shap_values, features, save_path:
|
|
766
|
+
def _create_shap_plot(shap_values, features, save_path: Path, plot_type: str, title: str):
|
|
763
767
|
_apply_plot_style()
|
|
764
768
|
_configure_rcparams()
|
|
765
769
|
plt.figure(figsize=figsize)
|
|
@@ -804,7 +808,7 @@ def get_shap_values(
|
|
|
804
808
|
_create_shap_plot(
|
|
805
809
|
shap_values=class_shap,
|
|
806
810
|
features=features_to_explain,
|
|
807
|
-
save_path=
|
|
811
|
+
save_path=global_save_path / f"SHAP_{sanitized_target_name}_Class{class_name}_{plot_type}.svg",
|
|
808
812
|
plot_type=plot_type,
|
|
809
813
|
title=f"{model_name} - {target_name} (Class {class_name})"
|
|
810
814
|
)
|
|
@@ -814,7 +818,7 @@ def get_shap_values(
|
|
|
814
818
|
_create_shap_plot(
|
|
815
819
|
shap_values=values,
|
|
816
820
|
features=features_to_explain,
|
|
817
|
-
save_path=
|
|
821
|
+
save_path=global_save_path / f"SHAP_{sanitized_target_name}_{plot_type}.svg",
|
|
818
822
|
plot_type=plot_type,
|
|
819
823
|
title=f"{model_name} - {target_name}"
|
|
820
824
|
)
|
|
@@ -824,7 +828,7 @@ def get_shap_values(
|
|
|
824
828
|
_create_shap_plot(
|
|
825
829
|
shap_values=shap_values,
|
|
826
830
|
features=features_to_explain,
|
|
827
|
-
save_path=
|
|
831
|
+
save_path=global_save_path / f"SHAP_{sanitized_target_name}_{plot_type}.svg",
|
|
828
832
|
plot_type=plot_type,
|
|
829
833
|
title=f"{model_name} - {target_name}"
|
|
830
834
|
)
|
|
@@ -848,7 +852,7 @@ def train_test_pipeline(model, model_name: str, dataset_id: str, task: TaskType,
|
|
|
848
852
|
train_features: np.ndarray, train_target: np.ndarray,
|
|
849
853
|
test_features: np.ndarray, test_target: np.ndarray,
|
|
850
854
|
feature_names: list[str], target_name: str,
|
|
851
|
-
save_dir: str,
|
|
855
|
+
save_dir: Union[str,Path],
|
|
852
856
|
debug: bool=False, save_model: bool=False):
|
|
853
857
|
'''
|
|
854
858
|
1. Train model.
|
|
@@ -889,7 +893,7 @@ def train_test_pipeline(model, model_name: str, dataset_id: str, task: TaskType,
|
|
|
889
893
|
return trained_model, y_pred
|
|
890
894
|
|
|
891
895
|
###### 5. Execution ######
|
|
892
|
-
def run_ensemble_pipeline(datasets_dir: str, save_dir: str, target_columns: list[str], model_object: Union[RegressionTreeModels, ClassificationTreeModels],
|
|
896
|
+
def run_ensemble_pipeline(datasets_dir: Union[str,Path], save_dir: Union[str,Path], target_columns: list[str], model_object: Union[RegressionTreeModels, ClassificationTreeModels],
|
|
893
897
|
handle_classification_imbalance: HandleImbalanceStrategy=None, save_model: bool=False,
|
|
894
898
|
test_size: float=0.2, debug:bool=False):
|
|
895
899
|
#Check models
|
|
@@ -907,10 +911,11 @@ def run_ensemble_pipeline(datasets_dir: str, save_dir: str, target_columns: list
|
|
|
907
911
|
raise TypeError(f"Unrecognized model {type(model_object)}")
|
|
908
912
|
|
|
909
913
|
#Check paths
|
|
910
|
-
|
|
914
|
+
datasets_path = make_fullpath(datasets_dir)
|
|
915
|
+
save_path = make_fullpath(save_dir, make=True)
|
|
911
916
|
|
|
912
917
|
#Yield imputed dataset
|
|
913
|
-
for dataframe, dataframe_name in yield_dataframes_from_dir(
|
|
918
|
+
for dataframe, dataframe_name in yield_dataframes_from_dir(datasets_path):
|
|
914
919
|
#Yield features dataframe and target dataframe
|
|
915
920
|
for df_features, df_target, feature_names, target_name in dataset_yielder(df=dataframe, target_cols=target_columns):
|
|
916
921
|
#Dataset pipeline
|
|
@@ -925,15 +930,8 @@ def run_ensemble_pipeline(datasets_dir: str, save_dir: str, target_columns: list
|
|
|
925
930
|
train_features=X_train, train_target=y_train, # type: ignore
|
|
926
931
|
test_features=X_test, test_target=y_test,
|
|
927
932
|
feature_names=feature_names,target_name=target_name,
|
|
928
|
-
debug=debug, save_dir=
|
|
933
|
+
debug=debug, save_dir=save_path, save_model=save_model)
|
|
929
934
|
print("\n✅ Training and evaluation complete.")
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
def _check_paths(datasets_dir: str, save_dir:str):
|
|
933
|
-
if not os.path.isdir(save_dir):
|
|
934
|
-
os.makedirs(save_dir)
|
|
935
|
-
if not os.path.isdir(datasets_dir):
|
|
936
|
-
raise IOError(f"Datasets directory '{datasets_dir}' not found.")
|
|
937
935
|
|
|
938
936
|
|
|
939
937
|
def info():
|