dragon-ml-toolbox 1.3.2__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 1.3.2
3
+ Version: 1.4.0
4
4
  Summary: A collection of tools for data science and machine learning projects
5
5
  Author-email: Karl Loza <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -27,6 +27,7 @@ Requires-Dist: ipython
27
27
  Requires-Dist: ipykernel
28
28
  Requires-Dist: notebook
29
29
  Requires-Dist: jupyterlab
30
+ Requires-Dist: ipywidgets
30
31
  Requires-Dist: joblib
31
32
  Requires-Dist: xgboost
32
33
  Requires-Dist: lightgbm<=4.5.0
@@ -1,18 +1,19 @@
1
- dragon_ml_toolbox-1.3.2.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
- dragon_ml_toolbox-1.3.2.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=e1Hg5ZtaBpDV7ZvxhLe1ac28l7nMjvi1MSE5YvB1s-o,1472
3
- ml_tools/MICE_imputation.py,sha256=71Kdi5rhPePIT5rJKIyRCM7ORPSjeujQCzKcLIwXs90,9428
1
+ dragon_ml_toolbox-1.4.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
+ dragon_ml_toolbox-1.4.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=e1Hg5ZtaBpDV7ZvxhLe1ac28l7nMjvi1MSE5YvB1s-o,1472
3
+ ml_tools/MICE_imputation.py,sha256=4kqZiesk8vyh4MBLnNE9grflG4fDusqzuYBElsbk4LY,9484
4
+ ml_tools/VIF_factor.py,sha256=rHSAxQcXLrG8dIjCXBAvETsSkCBfYus9NqimOnm2Bvk,9559
4
5
  ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- ml_tools/data_exploration.py,sha256=laTNbN5_xlhqWiKfF-cJ9yMZ8zAM2a-AryqgiIQBBLg,26649
6
+ ml_tools/data_exploration.py,sha256=qtkGumckC2PmTpj3brVFi072ewX0OI6dwUF4Or7Yikg,21341
6
7
  ml_tools/datasetmaster.py,sha256=VUneKshnmjOGbtqVVGTFcIMRKF3s6ZDYrosIYKDjD80,28956
7
- ml_tools/ensemble_learning.py,sha256=5UmlXI3Orm5zL0P07Ub_Y0gwjruH-REHY-cFWQpJWb0,29085
8
+ ml_tools/ensemble_learning.py,sha256=wK6mtOE4v9AWlxkcWhJj5XZjREChxb46kE0i2IxS-OE,28372
8
9
  ml_tools/handle_excel.py,sha256=IR0VQc3hYdmjwC31E5YxDnRcWig4jSIx7Y_7to-KZz4,11969
9
10
  ml_tools/logger.py,sha256=XwSpCUzw2Le24fJHyljBxNLgw63SwjZ0pMjTJqf0ylI,4622
10
11
  ml_tools/particle_swarm_optimization.py,sha256=jpkje4OETC9fyISxxUTx4XGrImSU6gDEcwz46ZDs2bQ,19250
11
12
  ml_tools/pytorch_models.py,sha256=Oykw02sOZLCjvSadQd64UGesBN7kq0x1EGXHusvYiQI,9908
12
13
  ml_tools/trainer.py,sha256=Zd7AaHeoNd8dEas2JChWoHaCUpWUVRDUMybuHaKJ0XY,16740
13
- ml_tools/utilities.py,sha256=mG_--EFplfI9H7OhrWI8VkdNJtTbs4Wbz32xvcFWps8,5518
14
+ ml_tools/utilities.py,sha256=gr1cyRUfZcRo9fjWpCaQkrvWY0-xJnDJdrE8JEsOi8o,6309
14
15
  ml_tools/vision_helpers.py,sha256=lBAW6dzAK-HOswAt1fU_tfP9hkNLY5D8c_I_7hhEXno,7528
15
- dragon_ml_toolbox-1.3.2.dist-info/METADATA,sha256=NgNKZD1v97kBBdE96OJELolvlAXviJ-DgJvZAjjy5Ik,2309
16
- dragon_ml_toolbox-1.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
17
- dragon_ml_toolbox-1.3.2.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
18
- dragon_ml_toolbox-1.3.2.dist-info/RECORD,,
16
+ dragon_ml_toolbox-1.4.0.dist-info/METADATA,sha256=V7Y96iAbgX6Xl6RWzEt4nGfKMZe4cuLs0BrFQghXxX8,2335
17
+ dragon_ml_toolbox-1.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
18
+ dragon_ml_toolbox-1.4.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
19
+ dragon_ml_toolbox-1.4.0.dist-info/RECORD,,
@@ -120,7 +120,7 @@ def get_imputed_distributions(kernel: mf.ImputationKernel, df_name: str, root_di
120
120
  '''
121
121
  # Check path
122
122
  os.makedirs(root_dir, exist_ok=True)
123
- local_save_dir = os.path.join(root_dir, f"Distribution_Metrics_{df_name}")
123
+ local_save_dir = os.path.join(root_dir, f"Distribution_Metrics_{df_name}_imputed")
124
124
  if not os.path.isdir(local_save_dir):
125
125
  os.makedirs(local_save_dir)
126
126
 
@@ -169,8 +169,12 @@ def get_imputed_distributions(kernel: mf.ImputationKernel, df_name: str, root_di
169
169
  # Adjust layout and save
170
170
  # fig.tight_layout()
171
171
  # fig.subplots_adjust(bottom=0.2, left=0.2) # Optional, depending on overflow
172
+
173
+ # sanitize savename
174
+ feature_save_name = sanitize_filename(filename)
175
+
172
176
  fig.savefig(
173
- os.path.join(local_save_dir, filename + ".svg"),
177
+ os.path.join(local_save_dir, feature_save_name + ".svg"),
174
178
  format='svg',
175
179
  bbox_inches='tight',
176
180
  pad_inches=0.1
@@ -185,8 +189,7 @@ def get_imputed_distributions(kernel: mf.ImputationKernel, df_name: str, root_di
185
189
  else:
186
190
  for feature in column_names:
187
191
  fig = kernel.plot_imputed_distributions(variables=[feature])
188
- feature_save_name = sanitize_filename(feature)
189
- _process_figure(fig, feature_save_name)
192
+ _process_figure(fig, feature)
190
193
 
191
194
  print("\tImputed distributions saved successfully.")
192
195
 
@@ -207,7 +210,7 @@ def run_mice_pipeline(df_path_or_dir: str, save_datasets_dir: str, save_metrics_
207
210
  if os.path.isfile(df_path_or_dir):
208
211
  all_file_paths = [df_path_or_dir]
209
212
  elif os.path.isdir(df_path_or_dir):
210
- all_file_paths, _ = list_csv_paths(df_path_or_dir)
213
+ all_file_paths = list_csv_paths(df_path_or_dir).values()
211
214
  else:
212
215
  raise ValueError(f"Invalid path or directory: {df_path_or_dir}")
213
216
 
ml_tools/VIF_factor.py ADDED
@@ -0,0 +1,209 @@
1
+
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from typing import Optional
6
+ from statsmodels.stats.outliers_influence import variance_inflation_factor
7
+ from statsmodels.tools.tools import add_constant
8
+ import warnings
9
+ import os
10
+ from .utilities import sanitize_filename, yield_dataframes_from_dir, save_dataframe
11
+
12
+
13
+ def compute_vif(
14
+ df: pd.DataFrame,
15
+ target_columns: Optional[list[str]] = None,
16
+ ignore_columns: Optional[list[str]] = None,
17
+ max_features_to_plot: int = 20,
18
+ save_dir: Optional[str] = None,
19
+ filename: Optional[str] = None,
20
+ fontsize: int = 14,
21
+ show_plot: bool = True,
22
+ ) -> pd.DataFrame:
23
+ """
24
+ Computes Variance Inflation Factors (VIF) for numeric columns in a DataFrame. Optionally, generates a bar plot of VIF values.
25
+
26
+ Args:
27
+ df (pd.DataFrame): The input DataFrame.
28
+ target_columns (list[str] | None): Optional list of columns to include. Defaults to all numeric columns.
29
+ ignore_columns (list[str] | None): Optional list of columns to exclude from the VIF computation. Skipped if `target_columns` is provided.
30
+ max_features_to_plot (int): Adjust the number of features shown in the plot.
31
+ save_dir (str | None): Directory to save the plot as SVG. If None, the plot is not saved.
32
+ filename (str | None): Optional filename for saving the plot. Defaults to "VIF_plot.svg".
33
+ fontsize (int): Base fontsize to scale title and labels on the plot.
34
+ show_plot (bool): Display plot.
35
+
36
+ Returns:
37
+ pd.DataFrame: DataFrame with features and their corresponding VIF values.
38
+
39
+ NOTE:
40
+ **Variance Inflation Factor (VIF)** quantifies the degree of multicollinearity among features in a dataset.
41
+ A VIF value indicates how much the variance of a regression coefficient is inflated due to linear dependence with other features.
42
+ A VIF of 1 suggests no correlation, values between 1 and 5 indicate moderate correlation, and values greater than 10 typically signal high multicollinearity, which may distort model interpretation and degrade performance.
43
+ """
44
+ ground_truth_cols = df.columns.to_list()
45
+ if target_columns is None:
46
+ sanitized_columns = df.select_dtypes(include='number').columns.tolist()
47
+ missing_features = set(ground_truth_cols) - set(sanitized_columns)
48
+ if missing_features:
49
+ print(f"⚠️ These columns are not Numeric:\n{missing_features}")
50
+ else:
51
+ sanitized_columns = list()
52
+ for feature in target_columns:
53
+ if feature not in ground_truth_cols:
54
+ print(f"⚠️ The provided column '{feature}' is not in the DataFrame.")
55
+ else:
56
+ sanitized_columns.append(feature)
57
+
58
+ if ignore_columns is not None and target_columns is None:
59
+ missing_ignore = set(ignore_columns) - set(ground_truth_cols)
60
+ if missing_ignore:
61
+ print(f"⚠️ Warning: The following 'columns to ignore' are not in the Dataframe:\n{missing_ignore}")
62
+ sanitized_columns = [f for f in sanitized_columns if f not in ignore_columns]
63
+
64
+ X = df[sanitized_columns].copy()
65
+ X = add_constant(X, has_constant='add')
66
+
67
+ vif_data = pd.DataFrame()
68
+ vif_data["feature"] = X.columns # type: ignore
69
+
70
+ with warnings.catch_warnings():
71
+ warnings.simplefilter("ignore", category=RuntimeWarning)
72
+
73
+ vif_data["VIF"] = [
74
+ variance_inflation_factor(X.values, i) for i in range(X.shape[1]) # type: ignore
75
+ ]
76
+
77
+ # Replace infinite values (perfect multicollinearity)
78
+ vif_data["VIF"] = vif_data["VIF"].replace([np.inf, -np.inf], 999.0)
79
+
80
+ # Drop the constant column
81
+ vif_data = vif_data[vif_data["feature"] != "const"]
82
+
83
+ # Add color coding
84
+ def vif_color(v: float) -> str:
85
+ if v >= 10:
86
+ return "red"
87
+ elif v >= 5:
88
+ return "gold"
89
+ else:
90
+ return "green"
91
+
92
+ vif_data["color"] = vif_data["VIF"].apply(vif_color)
93
+
94
+ # Sort by VIF descending
95
+ vif_data = vif_data.sort_values(by="VIF", ascending=False).reset_index(drop=True)
96
+
97
+ # Filter for plotting
98
+ plot_data = vif_data.head(max_features_to_plot)
99
+
100
+ if save_dir or show_plot:
101
+ if not plot_data.empty:
102
+ plt.figure(figsize=(10, 6))
103
+ plt.barh(
104
+ plot_data["feature"],
105
+ plot_data["VIF"],
106
+ color=plot_data["color"],
107
+ edgecolor='black'
108
+ )
109
+ plt.title("Variance Inflation Factor (VIF) per Feature", fontsize=fontsize+1)
110
+ plt.xlabel("VIF value", fontsize=fontsize)
111
+ plt.xticks(fontsize=fontsize)
112
+ plt.yticks(fontsize=fontsize)
113
+ plt.axvline(x=5, color='gold', linestyle='--', label='VIF = 5')
114
+ plt.axvline(x=10, color='red', linestyle='--', label='VIF = 10')
115
+ plt.xlim(0, 12)
116
+ plt.legend(loc='lower right', fontsize=fontsize-1)
117
+ plt.gca().invert_yaxis()
118
+ plt.grid(axis='x', linestyle='--', alpha=0.5)
119
+ plt.tight_layout()
120
+
121
+ if save_dir:
122
+ os.makedirs(save_dir, exist_ok=True)
123
+ if filename is None:
124
+ filename = "VIF_plot.svg"
125
+ else:
126
+ filename = sanitize_filename(filename)
127
+ if not filename.endswith(".svg"):
128
+ filename += ".svg"
129
+ save_path = os.path.join(save_dir, "VIF_" + filename)
130
+ plt.savefig(save_path, format='svg', bbox_inches='tight')
131
+ print(f"\tSaved VIF plot: '{filename}'")
132
+
133
+ if show_plot:
134
+ plt.show()
135
+ plt.close()
136
+
137
+ return vif_data.drop(columns="color")
138
+
139
+
140
+ def drop_vif_based(df: pd.DataFrame, vif_df: pd.DataFrame, threshold: float = 10.0) -> pd.DataFrame:
141
+ """
142
+ Drops columns from the original DataFrame based on their VIF values exceeding a given threshold.
143
+
144
+ Args:
145
+ df (pd.DataFrame): Original DataFrame containing the columns to test.
146
+ vif_df (pd.DataFrame): DataFrame with 'feature' and 'VIF' columns as returned by `compute_vif()`.
147
+ threshold (float): VIF threshold above which columns will be dropped.
148
+
149
+ Returns:
150
+ pd.DataFrame: A new DataFrame with high-VIF columns removed.
151
+ """
152
+ # Ensure expected structure
153
+ if 'feature' not in vif_df.columns or 'VIF' not in vif_df.columns:
154
+ raise ValueError("`vif_df` must contain 'feature' and 'VIF' columns.")
155
+
156
+ # Identify features to drop
157
+ to_drop = vif_df[vif_df["VIF"] > threshold]["feature"].tolist()
158
+ print(f"\tDropping {len(to_drop)} column(s) with VIF > {threshold}: {to_drop}")
159
+
160
+ result_df = df.drop(columns=to_drop)
161
+
162
+ if result_df.empty:
163
+ print(f"\t⚠️ Warning: All columns were dropped.")
164
+
165
+ return result_df
166
+
167
+
168
+ def compute_vif_multi(input_directory: str,
169
+ output_plot_directory: str,
170
+ output_dataset_directory: Optional[str] = None,
171
+ target_columns: Optional[list[str]] = None,
172
+ ignore_columns: Optional[list[str]] = None,
173
+ max_features_to_plot: int = 20,
174
+ fontsize: int = 14):
175
+ """
176
+ Computes Variance Inflation Factors (VIF) for numeric columns in a directory with CSV files (loaded as pandas DataFrames).
177
+ Generates a bar plot of VIF values. Optionally drops columns with VIF >= 10 and saves as a new CSV file.
178
+
179
+ Args:
180
+ input_directory (str): Target directory with CSV files able to be loaded as DataFrame.
181
+ output_plot_directory (str): Save plots to this directory.
182
+ output_dataset_directory (str | None): If provided, saves new CSV files to this directory.
183
+ target_columns (list[str] | None): Optional list of columns to include. Defaults to all numeric columns.
184
+ ignore_columns (list[str] | None): Optional list of columns to exclude from the VIF computation. Skipped if `target_columns` is provided.
185
+ max_features_to_plot (int): Adjust the number of features shown in the plot.
186
+ fontsize (int): Base fontsize to scale title and labels on hte plot.
187
+
188
+ NOTE:
189
+ **Variance Inflation Factor (VIF)** quantifies the degree of multicollinearity among features in a dataset.
190
+ A VIF value indicates how much the variance of a regression coefficient is inflated due to linear dependence with other features.
191
+ A VIF of 1 suggests no correlation, values between 1 and 5 indicate moderate correlation, and values greater than 10 typically signal high multicollinearity, which may distort model interpretation and degrade performance.
192
+ """
193
+ if output_dataset_directory is not None:
194
+ os.makedirs(output_dataset_directory, exist_ok=True)
195
+
196
+ for df, df_name in yield_dataframes_from_dir(datasets_dir=input_directory):
197
+ vif_dataframe = compute_vif(df=df,
198
+ target_columns=target_columns,
199
+ ignore_columns=ignore_columns,
200
+ max_features_to_plot=max_features_to_plot,
201
+ fontsize=fontsize,
202
+ save_dir=output_plot_directory,
203
+ filename=df_name,
204
+ show_plot=False)
205
+
206
+ if output_dataset_directory is not None:
207
+ new_filename = 'VIF_' + df_name
208
+ result_df = drop_vif_based(df=df, vif_df=vif_dataframe)
209
+ save_dataframe(df=result_df, save_dir=output_dataset_directory, filename=new_filename)
@@ -2,12 +2,10 @@ import pandas as pd
2
2
  import numpy as np
3
3
  import matplotlib.pyplot as plt
4
4
  import seaborn as sns
5
- from statsmodels.stats.outliers_influence import variance_inflation_factor
6
- from statsmodels.tools.tools import add_constant
7
5
  from IPython import get_ipython
8
6
  from IPython.display import clear_output
9
7
  import time
10
- from typing import Union, Literal, Dict, Tuple, Optional
8
+ from typing import Union, Literal, Dict, Tuple
11
9
  import os
12
10
  import sys
13
11
  import textwrap
@@ -26,10 +24,7 @@ __all__ = ["summarize_dataframe",
26
24
  "plot_value_distributions",
27
25
  "clip_outliers_single",
28
26
  "clip_outliers_multi",
29
- "merge_dataframes",
30
- "save_dataframe",
31
- "compute_vif",
32
- "drop_vif_based"]
27
+ "merge_dataframes"]
33
28
 
34
29
 
35
30
  def summarize_dataframe(df: pd.DataFrame, round_digits: int = 2):
@@ -278,7 +273,7 @@ def check_value_distributions(df: pd.DataFrame, view_frequencies: bool=True, bin
278
273
  Notes:
279
274
  - Binning is adaptive: if quantile binning results in ≤ 2 unique bins, raw values are used instead.
280
275
  """
281
- # cherrypick columns
276
+ # cherry-pick columns
282
277
  if skip_cols_with_key is not None:
283
278
  columns = [col for col in df.columns if skip_cols_with_key not in col]
284
279
  else:
@@ -351,7 +346,7 @@ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int
351
346
  dict_to_plot_std = dict()
352
347
  dict_to_plot_freq = dict()
353
348
 
354
- # cherrypick columns
349
+ # cherry-pick columns
355
350
  if skip_cols_with_key is not None:
356
351
  columns = [col for col in df.columns if skip_cols_with_key not in col]
357
352
  else:
@@ -399,7 +394,7 @@ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int
399
394
  labels = data.keys()
400
395
 
401
396
  plt.figure(figsize=(10, 6))
402
- colors = plt.cm.tab20.colors if len(data) <= 20 else plt.cm.viridis(np.linspace(0, 1, len(data)))
397
+ colors = plt.cm.tab20.colors if len(data) <= 20 else plt.cm.viridis(np.linspace(0, 1, len(data))) # type: ignore
403
398
 
404
399
  plt.bar(labels, data.values(), color=colors[:len(data)], alpha=0.85)
405
400
  plt.xlabel("Values", fontsize=base_fontsize)
@@ -574,141 +569,6 @@ def merge_dataframes(
574
569
  return merged_df
575
570
 
576
571
 
577
- def save_dataframe(df: pd.DataFrame, save_dir: str, filename: str) -> None:
578
- """
579
- Save a pandas DataFrame to a CSV file.
580
-
581
- Parameters:
582
- df: pandas.DataFrame to save
583
- save_dir: str, directory where the CSV file will be saved.
584
- filename: str, CSV filename, extension will be added if missing.
585
- """
586
- os.makedirs(save_dir, exist_ok=True)
587
-
588
- filename = sanitize_filename(filename)
589
-
590
- if not filename.endswith('.csv'):
591
- filename += '.csv'
592
-
593
- output_path = os.path.join(save_dir, filename)
594
-
595
- df.to_csv(output_path, index=False, encoding='utf-8')
596
- print(f"Saved file: '{filename}'")
597
-
598
-
599
- def compute_vif(
600
- df: pd.DataFrame,
601
- features: Optional[list[str]] = None,
602
- ignore_cols: Optional[list[str]] = None,
603
- plot: bool = True,
604
- save_dir: Union[str, None] = None
605
- ) -> pd.DataFrame:
606
- """
607
- Computes Variance Inflation Factors (VIF) for numeric features, optionally plots and saves the results.
608
-
609
- There cannot be empty values in the dataset.
610
-
611
- Args:
612
- df (pd.DataFrame): The input DataFrame.
613
- features (list[str] | None): Optional list of column names to evaluate. Defaults to all numeric columns.
614
- ignore_cols (list[str] | None): Optional list of column names to ignore.
615
- plot (bool): Whether to display a barplot of VIF values.
616
- save_dir (str | None): Directory to save the plot as SVG. If None, plot is not saved.
617
-
618
- Returns:
619
- pd.DataFrame: DataFrame with features and corresponding VIF values, sorted descending.
620
-
621
- NOTE:
622
- **Variance Inflation Factor (VIF)** quantifies the degree of multicollinearity among features in a dataset.
623
- A VIF value indicates how much the variance of a regression coefficient is inflated due to linear dependence with other features.
624
- A VIF of 1 suggests no correlation, values between 1 and 5 indicate moderate correlation, and values greater than 10 typically signal high multicollinearity, which may distort model interpretation and degrade performance.
625
-
626
- """
627
- if features is None:
628
- features = df.select_dtypes(include='number').columns.tolist()
629
-
630
- if ignore_cols is not None:
631
- missing = set(ignore_cols) - set(features)
632
- if missing:
633
- raise ValueError(f"The following 'columns to ignore' are not in the Dataframe:\n{missing}")
634
- features = [f for f in features if f not in ignore_cols]
635
-
636
- X = df[features].copy()
637
- X = add_constant(X, has_constant='add')
638
-
639
- vif_data = pd.DataFrame()
640
- vif_data["feature"] = X.columns
641
- vif_data["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
642
-
643
- # Drop the constant column
644
- vif_data = vif_data[vif_data["feature"] != "const"]
645
- vif_data = vif_data.sort_values(by="VIF", ascending=False).reset_index(drop=True) # type: ignore
646
-
647
- # Add color coding based on thresholds
648
- def vif_color(v: float) -> str:
649
- if v > 10:
650
- return "red"
651
- elif v > 5:
652
- return "gold"
653
- else:
654
- return "green"
655
-
656
- vif_data["color"] = vif_data["VIF"].apply(vif_color)
657
-
658
- # Plot
659
- if plot or save_dir:
660
- plt.figure(figsize=(10, 6))
661
- bars = plt.barh(
662
- vif_data["feature"],
663
- vif_data["VIF"],
664
- color=vif_data["color"],
665
- edgecolor='black'
666
- )
667
- plt.title("Variance Inflation Factor (VIF) per Feature")
668
- plt.xlabel("VIF")
669
- plt.axvline(x=5, color='gold', linestyle='--', label='VIF = 5')
670
- plt.axvline(x=10, color='red', linestyle='--', label='VIF = 10')
671
- plt.legend(loc='lower right')
672
- plt.gca().invert_yaxis()
673
- plt.grid(axis='x', linestyle='--', alpha=0.5)
674
- plt.tight_layout()
675
-
676
- if save_dir:
677
- os.makedirs(save_dir, exist_ok=True)
678
- save_path = os.path.join(save_dir, "VIF_plot.svg")
679
- plt.savefig(save_path, format='svg', bbox_inches='tight')
680
- print(f"Saved VIF plot to: {save_path}")
681
-
682
- if plot:
683
- plt.show()
684
- plt.close()
685
-
686
- return vif_data.drop(columns="color")
687
-
688
-
689
- def drop_vif_based(df: pd.DataFrame, vif_df: pd.DataFrame, threshold: float = 10.0) -> pd.DataFrame:
690
- """
691
- Drops features from the original DataFrame based on their VIF values exceeding a given threshold.
692
-
693
- Args:
694
- df (pd.DataFrame): Original DataFrame containing the features.
695
- vif_df (pd.DataFrame): DataFrame with 'feature' and 'VIF' columns as returned by `compute_vif()`.
696
- threshold (float): VIF threshold above which features will be dropped.
697
-
698
- Returns:
699
- pd.DataFrame: A new DataFrame with high-VIF features removed.
700
- """
701
- # Ensure expected structure
702
- if 'feature' not in vif_df.columns or 'VIF' not in vif_df.columns:
703
- raise ValueError("`vif_df` must contain 'feature' and 'VIF' columns.")
704
-
705
- # Identify features to drop
706
- to_drop = vif_df[vif_df["VIF"] > threshold]["feature"].tolist()
707
- print(f"Dropping {len(to_drop)} feature(s) with VIF > {threshold}: {to_drop}")
708
-
709
- return df.drop(columns=to_drop, errors="ignore")
710
-
711
-
712
572
  def _is_notebook():
713
573
  return get_ipython() is not None
714
574
 
@@ -139,8 +139,9 @@ def get_models(task: Literal["classification", "regression"], random_state: int=
139
139
 
140
140
  ###### 3. Process Dataset ######
141
141
  # function to split data into train and test
142
- def _split_data(features, target, test_size, random_state):
143
- X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=random_state, stratify=target)
142
+ def _split_data(features, target, test_size, random_state, task):
143
+ X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=random_state,
144
+ stratify=target if task=="classification" else None)
144
145
  return X_train, X_test, y_train, y_test
145
146
 
146
147
  # function to standardize the data
@@ -176,7 +177,7 @@ def _resample(X_train_scaled: np.ndarray, y_train: pd.Series,
176
177
  else:
177
178
  raise ValueError(f"Invalid resampling strategy: {strategy}")
178
179
 
179
- X_res, y_res = resample_algorithm.fit_resample(X_train_scaled, y_train)
180
+ X_res, y_res, *_ = resample_algorithm.fit_resample(X_train_scaled, y_train)
180
181
  return X_res, y_res
181
182
 
182
183
  # DATASET PIPELINE
@@ -199,7 +200,7 @@ def dataset_pipeline(df_features: pd.DataFrame, df_target: pd.Series, task: Lite
199
200
  print(f"\tUnique values for '{df_target.name}': {unique_values}")
200
201
 
201
202
  #Train test split
202
- X_train, X_test, y_train, y_test = _split_data(features=df_features, target=df_target, test_size=test_size, random_state=random_state)
203
+ X_train, X_test, y_train, y_test = _split_data(features=df_features, target=df_target, test_size=test_size, random_state=random_state, task=task)
203
204
 
204
205
  #DEBUG
205
206
  if debug:
@@ -343,8 +344,7 @@ def plot_roc_curve(
343
344
  color: str = "darkorange",
344
345
  figure_size: tuple = (10, 10),
345
346
  linewidth: int = 2,
346
- title_fontsize: int = 24,
347
- label_fontsize: int = 24,
347
+ base_fontsize: int = 24,
348
348
  input_features: Optional[np.ndarray] = None,
349
349
  ) -> plt.Figure: # type: ignore
350
350
  """
@@ -402,11 +402,11 @@ def plot_roc_curve(
402
402
  ax.plot(fpr, tpr, color=color, lw=linewidth, label=f"AUC = {auc_score:.2f}")
403
403
  ax.plot([0, 1], [0, 1], color="gray", linestyle="--", lw=1)
404
404
 
405
- ax.set_title(f"{model_name} - {target_name}", fontsize=title_fontsize)
406
- ax.set_xlabel("False Positive Rate", fontsize=label_fontsize)
407
- ax.set_ylabel("True Positive Rate", fontsize=label_fontsize)
408
- ax.tick_params(axis='both', labelsize=label_fontsize)
409
- ax.legend(loc="lower right", fontsize=label_fontsize)
405
+ ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
406
+ ax.set_xlabel("False Positive Rate", fontsize=base_fontsize)
407
+ ax.set_ylabel("True Positive Rate", fontsize=base_fontsize)
408
+ ax.tick_params(axis='both', labelsize=base_fontsize)
409
+ ax.legend(loc="lower right", fontsize=base_fontsize)
410
410
  ax.grid(True)
411
411
 
412
412
  # Save figure
@@ -416,6 +416,7 @@ def plot_roc_curve(
416
416
 
417
417
  return fig
418
418
 
419
+
419
420
  # function to evaluate the model and save metrics (Regression)
420
421
  def evaluate_model_regression(model, model_name: str,
421
422
  save_dir: str,
@@ -423,8 +424,7 @@ def evaluate_model_regression(model, model_name: str,
423
424
  target_id: str,
424
425
  figure_size: tuple = (12, 8),
425
426
  alpha_transparency: float = 0.5,
426
- title_fontsize: int = 24,
427
- normal_fontsize: int = 24):
427
+ base_fontsize: int = 24):
428
428
  # Generate predictions
429
429
  y_pred = model.predict(x_test_scaled)
430
430
 
@@ -448,9 +448,9 @@ def evaluate_model_regression(model, model_name: str,
448
448
  plt.figure(figsize=figure_size)
449
449
  plt.scatter(y_pred, residuals, alpha=alpha_transparency)
450
450
  plt.axhline(0, color='red', linestyle='--')
451
- plt.xlabel("Predicted Values", fontsize=normal_fontsize)
452
- plt.ylabel("Residuals", fontsize=normal_fontsize)
453
- plt.title(f"{model_name} - Residual Plot for {target_id}", fontsize=title_fontsize)
451
+ plt.xlabel("Predicted Values", fontsize=base_fontsize)
452
+ plt.ylabel("Residuals", fontsize=base_fontsize)
453
+ plt.title(f"{model_name} - Residual Plot for {target_id}", fontsize=base_fontsize)
454
454
  plt.grid(True)
455
455
  plt.tight_layout()
456
456
  plt.savefig(os.path.join(save_dir, f"Residual_Plot_{target_id}.svg"), bbox_inches='tight', format="svg")
@@ -462,9 +462,9 @@ def evaluate_model_regression(model, model_name: str,
462
462
  plt.plot([single_y_test.min(), single_y_test.max()],
463
463
  [single_y_test.min(), single_y_test.max()],
464
464
  'k--', lw=2)
465
- plt.xlabel('True Values', fontsize=normal_fontsize)
466
- plt.ylabel('Predictions', fontsize=normal_fontsize)
467
- plt.title(f"{model_name} - True vs Predicted for {target_id}", fontsize=title_fontsize)
465
+ plt.xlabel('True Values', fontsize=base_fontsize)
466
+ plt.ylabel('Predictions', fontsize=base_fontsize)
467
+ plt.title(f"{model_name} - True vs Predicted for {target_id}", fontsize=base_fontsize)
468
468
  plt.grid(True)
469
469
  plot_path = os.path.join(save_dir, f"Regression_Plot_{target_id}.svg")
470
470
  plt.savefig(plot_path, bbox_inches='tight', format="svg")
@@ -473,52 +473,53 @@ def evaluate_model_regression(model, model_name: str,
473
473
  return y_pred
474
474
 
475
475
  # Get SHAP values
476
- def get_shap_values(model, model_name: str,
477
- save_dir: str,
478
- features_to_explain: np.ndarray,
479
- feature_names: list[str],
480
- target_id: str,
481
- task: Literal["classification", "regression"],
482
- max_display_features: int=8,
483
- figsize: tuple=(14, 20),
484
- title_fontsize: int=38,
485
- label_fontsize: int=38,
486
- plot_type: Literal["bar", "dot"] = "dot"
487
- ):
476
+ def get_shap_values(
477
+ model,
478
+ model_name: str,
479
+ save_dir: str,
480
+ features_to_explain: np.ndarray,
481
+ feature_names: list[str],
482
+ target_id: str,
483
+ task: Literal["classification", "regression"],
484
+ max_display_features: int = 10,
485
+ figsize: tuple = (16, 20),
486
+ base_fontsize: int = 38,
487
+ ):
488
488
  """
489
489
  Universal SHAP explainer for regression and classification.
490
- - Use `X_train` (or a subsample of it) to see how the model explains the data it was trained on.
491
- - Use `X_test` (or a hold-out set) to see how the model explains unseen data.
492
- - Use the entire dataset to get the global view.
490
+ * Use `X_train` (or a subsample of it) to see how the model explains the data it was trained on.
491
+
492
+ * Use `X_test` (or a hold-out set) to see how the model explains unseen data.
493
+
494
+ * Use the entire dataset to get the global view.
493
495
 
494
496
  Parameters:
495
- - 'task': 'regression' or 'classification'
496
- - 'features_to_explain': Should match the model's training data format, including scaling.
497
- - 'save_dir': Directory to save visualizations
497
+ task: 'regression' or 'classification'
498
+ features_to_explain: Should match the model's training data format, including scaling.
499
+ save_dir: Directory to save visualizations
498
500
  """
499
- def _create_shap_plot(shap_values, features, feature_names,
500
- full_save_path: str, plot_type: str,
501
- title: str):
502
- """Helper function to create and save SHAP plots"""
503
- # Set style
504
- preferred_styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
505
- for style in preferred_styles:
501
+
502
+ def _apply_plot_style():
503
+ styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
504
+ for style in styles:
506
505
  if style in plt.style.available or style == 'default':
507
506
  plt.style.use(style)
508
507
  break
509
-
508
+
509
+ def _configure_rcparams():
510
+ plt.rc('font', size=base_fontsize)
511
+ plt.rc('axes', titlesize=base_fontsize)
512
+ plt.rc('axes', labelsize=base_fontsize)
513
+ plt.rc('xtick', labelsize=base_fontsize)
514
+ plt.rc('ytick', labelsize=base_fontsize + 2)
515
+ plt.rc('legend', fontsize=base_fontsize)
516
+ plt.rc('figure', titlesize=base_fontsize)
517
+
518
+ def _create_shap_plot(shap_values, features, save_path: str, plot_type: str, title: str):
519
+ _apply_plot_style()
520
+ _configure_rcparams()
510
521
  plt.figure(figsize=figsize)
511
-
512
- #set rc parameters for better readability
513
- plt.rc('font', size=label_fontsize)
514
- plt.rc('axes', titlesize=title_fontsize)
515
- plt.rc('axes', labelsize=label_fontsize)
516
- plt.rc('xtick', labelsize=label_fontsize)
517
- plt.rc('ytick', labelsize=label_fontsize)
518
- plt.rc('legend', fontsize=label_fontsize)
519
- plt.rc('figure', titlesize=title_fontsize)
520
-
521
- # Create the SHAP plot
522
+
522
523
  shap.summary_plot(
523
524
  shap_values=shap_values,
524
525
  features=features,
@@ -528,85 +529,75 @@ def get_shap_values(model, model_name: str,
528
529
  plot_size=figsize,
529
530
  max_display=max_display_features,
530
531
  alpha=0.7,
531
- color=plt.get_cmap('viridis') # type: ignore
532
+ # color='viridis'
532
533
  )
533
-
534
- # Add professional styling
534
+
535
535
  ax = plt.gca()
536
- ax.set_xlabel("SHAP Value Impact", fontsize=title_fontsize, weight='bold')
537
- ax.set_ylabel("Features", fontsize=title_fontsize, weight='bold')
538
- plt.title(title, fontsize=title_fontsize, pad=20, weight='bold')
539
-
540
- # Manually fix tick fonts
536
+ ax.set_xlabel("SHAP Value Impact", fontsize=base_fontsize + 2, weight='bold', labelpad=20)
537
+ plt.title(title, fontsize=base_fontsize + 2, pad=20, weight='bold')
538
+
541
539
  for tick in ax.get_xticklabels():
542
- tick.set_fontsize(label_fontsize)
543
- tick.set_rotation(45)
540
+ tick.set_fontsize(base_fontsize)
541
+ tick.set_rotation(30)
544
542
  for tick in ax.get_yticklabels():
545
- tick.set_fontsize(label_fontsize)
543
+ tick.set_fontsize(base_fontsize + 2)
546
544
 
547
- # Handle colorbar for dot plots
548
545
  if plot_type == "dot":
549
546
  cb = plt.gcf().axes[-1]
550
- # cb.set_ylabel("Feature Value", size=label_fontsize)
551
547
  cb.set_ylabel("", size=1)
552
- cb.tick_params(labelsize=label_fontsize - 2)
553
-
554
- # Save and clean up
555
- plt.savefig(
556
- full_save_path,
557
- bbox_inches='tight',
558
- facecolor='white',
559
- format="svg"
560
- )
548
+ cb.tick_params(labelsize=base_fontsize - 2)
549
+
550
+ plt.savefig(save_path, bbox_inches='tight', facecolor='white', format="svg")
561
551
  plt.close()
562
- rcdefaults() # Reset rc parameters to default
563
-
564
- # START
565
- explainer = shap.TreeExplainer(model)
566
- shap_values = explainer.shap_values(features_to_explain)
567
-
568
- # Handle different model types
569
- if task == 'classification':
570
- # Determine if multiclass
571
- try:
572
- is_multiclass = len(model.classes_) > 2
573
- class_names = model.classes_
574
- except AttributeError:
575
- is_multiclass = isinstance(shap_values, list) and len(shap_values) > 1
576
- class_names = list(range(len(shap_values))) if is_multiclass else [0, 1]
577
-
552
+ rcdefaults()
553
+
554
+ def _plot_for_classification(shap_values, class_names):
555
+ is_multiclass = isinstance(shap_values, list) and len(shap_values) > 1
556
+
578
557
  if is_multiclass:
579
- for class_idx, (class_shap, class_name) in enumerate(zip(shap_values, class_names)):
558
+ for class_shap, class_name in zip(shap_values, class_names):
559
+ for plot_type in ["bar", "dot"]:
560
+ _create_shap_plot(
561
+ shap_values=class_shap,
562
+ features=features_to_explain,
563
+ save_path=os.path.join(save_dir, f"SHAP_{target_id}_Class{class_name}_{plot_type}.svg"),
564
+ plot_type=plot_type,
565
+ title=f"{model_name} - {target_id} (Class {class_name})"
566
+ )
567
+ else:
568
+ values = shap_values[1] if isinstance(shap_values, list) else shap_values
569
+ for plot_type in ["bar", "dot"]:
580
570
  _create_shap_plot(
581
- shap_values=class_shap,
571
+ shap_values=values,
582
572
  features=features_to_explain,
583
- feature_names=feature_names,
584
- full_save_path=os.path.join(save_dir, f"SHAP_{target_id}_Class{class_name}.svg"),
573
+ save_path=os.path.join(save_dir, f"SHAP_{target_id}_{plot_type}.svg"),
585
574
  plot_type=plot_type,
586
- title=f"{model_name} - {target_id} (Class {class_name})"
575
+ title=f"{model_name} - {target_id}"
587
576
  )
588
- else:
589
- # Handle binary classification (single array case)
590
- plot_vals = shap_values[1] if isinstance(shap_values, list) else shap_values
577
+
578
+ def _plot_for_regression(shap_values):
579
+ for plot_type in ["bar", "dot"]:
591
580
  _create_shap_plot(
592
- shap_values=plot_vals,
581
+ shap_values=shap_values,
593
582
  features=features_to_explain,
594
- feature_names=feature_names,
595
- full_save_path=os.path.join(save_dir, f"SHAP_{target_id}.svg"),
583
+ save_path=os.path.join(save_dir, f"SHAP_{target_id}_{plot_type}.svg"),
596
584
  plot_type=plot_type,
597
585
  title=f"{model_name} - {target_id}"
598
586
  )
599
-
600
- else: # Regression
601
- _create_shap_plot(
602
- shap_values=shap_values,
603
- features=features_to_explain,
604
- feature_names=feature_names,
605
- full_save_path=os.path.join(save_dir, f"SHAP_{target_id}.svg"),
606
- plot_type=plot_type,
607
- title=f"{model_name} - {target_id}"
608
- )
609
-
587
+
588
+ explainer = shap.TreeExplainer(model)
589
+ shap_values = explainer.shap_values(features_to_explain)
590
+
591
+ if task == 'classification':
592
+ try:
593
+ class_names = model.classes_ if hasattr(model, 'classes_') else list(range(len(shap_values)))
594
+ except Exception:
595
+ class_names = list(range(len(shap_values)))
596
+ _plot_for_classification(shap_values, class_names)
597
+ else:
598
+ _plot_for_regression(shap_values)
599
+
600
+
610
601
  # TRAIN TEST PIPELINE
611
602
  def train_test_pipeline(model, model_name: str, dataset_id: str, task: Literal["classification", "regression"],
612
603
  train_features: np.ndarray, train_target: np.ndarray,
@@ -653,7 +644,7 @@ def train_test_pipeline(model, model_name: str, dataset_id: str, task: Literal["
653
644
  return trained_model, y_pred
654
645
 
655
646
  ###### 5. Execution ######
656
- def run_pipeline(datasets_dir: str, save_dir: str, target_columns: list[str], task: Literal["classification", "regression"]="regression",
647
+ def run_ensemble_pipeline(datasets_dir: str, save_dir: str, target_columns: list[str], task: Literal["classification", "regression"],
657
648
  resample_strategy: Literal[r"ADASYN", r'SMOTE', r'RANDOM', r'UNDERSAMPLE', None]=None, scaler: Literal["standard", "minmax", "maxabs"]="minmax", save_model: bool=False,
658
649
  test_size: float=0.2, debug:bool=False, L1_regularization: float=0.5, L2_regularization: float=0.5, learning_rate: float=0.005, random_state: int=101):
659
650
  #Check paths
@@ -672,15 +663,15 @@ def run_pipeline(datasets_dir: str, save_dir: str, target_columns: list[str], ta
672
663
  #Train models
673
664
  for model_name, model in models_dict.items():
674
665
  train_test_pipeline(model=model, model_name=model_name, dataset_id=dataframe_name, task=task,
675
- train_features=X_train, train_target=y_train,
666
+ train_features=X_train, train_target=y_train, # type: ignore
676
667
  test_features=X_test, test_target=y_test,
677
668
  feature_names=feature_names,target_id=target_name, scaler_object=scaler_object,
678
669
  debug=debug, save_dir=save_dir, save_model=save_model)
679
- print("\nTraining and evaluation complete.")
670
+ print("\n✅ Training and evaluation complete.")
680
671
 
681
672
 
682
673
  def _check_paths(datasets_dir: str, save_dir:str):
683
674
  if not os.path.isdir(save_dir):
684
675
  os.makedirs(save_dir)
685
676
  if not os.path.isdir(datasets_dir):
686
- raise IOError(f"Datasets directory '{datasets_dir}' not found.\nCheck path or run MICE script first.")
677
+ raise IOError(f"Datasets directory '{datasets_dir}' not found.")
ml_tools/utilities.py CHANGED
@@ -6,17 +6,15 @@ from pathlib import Path
6
6
  import re
7
7
 
8
8
 
9
- def list_csv_paths(directory: str) -> tuple[list[str], list[str]]:
9
+ def list_csv_paths(directory: str) -> dict[str, str]:
10
10
  """
11
- Lists all CSV files in a given directory and returns their paths with corresponding base names.
11
+ Lists all `.csv` files in the specified directory and returns a mapping: filenames (without extensions) to their absolute paths.
12
12
 
13
13
  Parameters:
14
14
  directory (str): Path to the directory containing `.csv` files.
15
15
 
16
16
  Returns:
17
- Tuple ([List[str], List[str]]):
18
- - List of absolute paths to `.csv` files.
19
- - List of corresponding base names (without extensions).
17
+ (dict[str, str]): Mapping {name, path}.
20
18
  """
21
19
  dir_path = Path(directory).expanduser().resolve()
22
20
 
@@ -26,11 +24,15 @@ def list_csv_paths(directory: str) -> tuple[list[str], list[str]]:
26
24
  csv_paths = list(dir_path.glob("*.csv"))
27
25
  if not csv_paths:
28
26
  raise IOError(f"No CSV files found in directory: {dir_path}")
27
+
28
+ # make a dictionary of paths and names
29
+ name_path_dict = {p.stem: str(p) for p in csv_paths}
30
+
31
+ print("🗂️ CSV files found:")
32
+ for name in name_path_dict.keys():
33
+ print(f"\t{name}")
29
34
 
30
- paths = [str(p) for p in csv_paths]
31
- names = [p.stem for p in csv_paths]
32
-
33
- return paths, names
35
+ return name_path_dict
34
36
 
35
37
 
36
38
  def load_dataframe(df_path: str) -> tuple[pd.DataFrame, str]:
@@ -49,7 +51,7 @@ def load_dataframe(df_path: str) -> tuple[pd.DataFrame, str]:
49
51
  df_name = path.stem
50
52
  if df.empty:
51
53
  raise ValueError(f"DataFrame '{df_name}' is empty.")
52
- print(f"Loaded dataset: '{df_name}' with shape: {df.shape}")
54
+ print(f"\n💿 Loaded dataset: '{df_name}' with shape: {df.shape}")
53
55
  return df, df_name
54
56
 
55
57
 
@@ -71,9 +73,8 @@ def yield_dataframes_from_dir(datasets_dir: str):
71
73
  - CSV files are read using UTF-8 encoding.
72
74
  - Output is streamed via a generator to support lazy loading of multiple datasets.
73
75
  """
74
- for df_path, df_name in list_csv_paths(datasets_dir):
75
- df = pd.read_csv(df_path)
76
- print(f"Loaded dataset: '{df_name}' with shape: {df.shape}")
76
+ for df_name, df_path in list_csv_paths(datasets_dir).items():
77
+ df, _ = load_dataframe(df_path)
77
78
  yield df, df_name
78
79
 
79
80
 
@@ -166,3 +167,28 @@ def sanitize_filename(filename: str) -> str:
166
167
 
167
168
  return sanitized
168
169
 
170
+
171
+ def save_dataframe(df: pd.DataFrame, save_dir: str, filename: str) -> None:
172
+ """
173
+ Save a pandas DataFrame to a CSV file.
174
+
175
+ Parameters:
176
+ df: pandas.DataFrame to save
177
+ save_dir: str, directory where the CSV file will be saved.
178
+ filename: str, CSV filename, extension will be added if missing.
179
+ """
180
+ if df.empty:
181
+ print(f"⚠️ Attempting to save an empty DataFrame: '{filename}'. Process Skipped.")
182
+ return
183
+
184
+ os.makedirs(save_dir, exist_ok=True)
185
+
186
+ filename = sanitize_filename(filename)
187
+
188
+ if not filename.endswith('.csv'):
189
+ filename += '.csv'
190
+
191
+ output_path = os.path.join(save_dir, filename)
192
+
193
+ df.to_csv(output_path, index=False, encoding='utf-8')
194
+ print(f"✅ Saved file: '{filename}'")