dragon-ml-toolbox 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -0,0 +1,751 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from statsmodels.stats.outliers_influence import variance_inflation_factor
6
+ from statsmodels.tools.tools import add_constant
7
+ from IPython import get_ipython
8
+ from IPython.display import clear_output
9
+ import time
10
+ from typing import Union, Literal, Dict, Tuple, Optional
11
+ import os
12
+ import sys
13
+ import textwrap
14
+ from utilities import sanitize_filename
15
+
16
+
17
+ # Keep track of all available functions, show using `info()`
18
+ __all__ = ["load_dataframe",
19
+ "summarize_dataframe",
20
+ "drop_rows_with_missing_data",
21
+ "split_features_targets",
22
+ "show_null_columns",
23
+ "drop_columns_with_missing_data",
24
+ "split_continuous_binary",
25
+ "plot_correlation_heatmap",
26
+ "check_value_distributions",
27
+ "plot_value_distributions",
28
+ "clip_outliers_single",
29
+ "clip_outliers_multi",
30
+ "merge_dataframes",
31
+ "save_dataframe",
32
+ "compute_vif",
33
+ "drop_vif_based"]
34
+
35
+
36
+ def load_dataframe(df_path: str) -> pd.DataFrame:
37
+ """
38
+ Loads a DataFrame from a CSV file.
39
+
40
+ Args:
41
+ df_path (str): Path to the CSV file.
42
+
43
+ Returns:
44
+ pd.DataFrame: Loaded DataFrame.
45
+ """
46
+ df = pd.read_csv(df_path, encoding='utf-8')
47
+ print(f"DataFrame shape {df.shape}")
48
+ return df
49
+
50
+
51
+ def summarize_dataframe(df: pd.DataFrame, round_digits: int = 2):
52
+ """
53
+ Returns a summary DataFrame with data types, non-null counts, number of unique values,
54
+ missing value percentage, and basic statistics for each column.
55
+
56
+ Parameters:
57
+ df (pd.DataFrame): The input DataFrame.
58
+ round_digits (int): Decimal places to round numerical statistics.
59
+
60
+ Returns:
61
+ pd.DataFrame: Summary table.
62
+ """
63
+ summary = pd.DataFrame({
64
+ 'Data Type': df.dtypes,
65
+ 'Non-Null Count': df.notnull().sum(),
66
+ 'Unique Values': df.nunique(),
67
+ 'Missing %': (df.isnull().mean() * 100).round(round_digits)
68
+ })
69
+
70
+ # For numeric columns, add summary statistics
71
+ numeric_cols = df.select_dtypes(include='number').columns
72
+ if not numeric_cols.empty:
73
+ summary_numeric = df[numeric_cols].describe().T[
74
+ ['mean', 'std', 'min', '25%', '50%', '75%', 'max']
75
+ ].round(round_digits)
76
+ summary = summary.join(summary_numeric, how='left')
77
+
78
+ print(f"Shape: {df.shape}")
79
+ return summary
80
+
81
+
82
+ def show_null_columns(df: pd.DataFrame, round_digits: int = 2):
83
+ """
84
+ Displays a table of columns with missing values, showing both the count and
85
+ percentage of missing entries per column.
86
+
87
+ Parameters:
88
+ df (pd.DataFrame): The input DataFrame.
89
+ round_digits (int): Number of decimal places for the percentage.
90
+
91
+ Returns:
92
+ pd.DataFrame: A DataFrame summarizing missing values in each column.
93
+ """
94
+ null_counts = df.isnull().sum()
95
+ null_percent = df.isnull().mean() * 100
96
+
97
+ # Filter only columns with at least one null
98
+ mask = null_counts > 0
99
+ null_summary = pd.DataFrame({
100
+ 'Missing Count': null_counts[mask],
101
+ 'Missing %': null_percent[mask].round(round_digits)
102
+ })
103
+
104
+ # Sort by descending percentage of missing values
105
+ null_summary = null_summary.sort_values(by='Missing %', ascending=False)
106
+ # print(null_summary)
107
+ return null_summary
108
+
109
+
110
+ def drop_rows_with_missing_data(df: pd.DataFrame, threshold: float = 0.7) -> pd.DataFrame:
111
+ """
112
+ Drops rows with more than `threshold` fraction of missing values.
113
+
114
+ Parameters:
115
+ df (pd.DataFrame): The input DataFrame.
116
+ threshold (float): Fraction of missing values above which rows are dropped.
117
+
118
+ Returns:
119
+ pd.DataFrame: A new DataFrame without the dropped rows.
120
+ """
121
+ missing_fraction = df.isnull().mean(axis=1)
122
+ rows_to_drop = missing_fraction[missing_fraction > threshold].index
123
+
124
+ if len(rows_to_drop) > 0:
125
+ print(f"Dropping {len(rows_to_drop)} rows with more than {threshold*100:.0f}% missing data.")
126
+ else:
127
+ print(f"No rows have more than {threshold*100:.0f}% missing data.")
128
+
129
+ return df.drop(index=rows_to_drop)
130
+
131
+
132
+ def split_features_targets(df: pd.DataFrame, targets: list[str]):
133
+ """
134
+ Splits a DataFrame's columns into features and targets.
135
+
136
+ Args:
137
+ df (pd.DataFrame): Pandas DataFrame containing the dataset.
138
+ targets (list[str]): List of column names to be treated as target variables.
139
+
140
+ Returns:
141
+ tuple: A tuple containing:
142
+ - pd.DataFrame: Targets dataframe.
143
+ - pd.DataFrame: Features dataframe.
144
+
145
+ Prints:
146
+ - Shape of the original dataframe.
147
+ - Shape of the targets dataframe.
148
+ - Shape of the features dataframe.
149
+ """
150
+ df_targets = df[targets]
151
+ df_features = df.drop(columns=targets)
152
+ print(f"Original shape: {df.shape}\nTargets shape: {df_targets.shape}\nFeatures shape: {df_features.shape}")
153
+ return df_targets, df_features
154
+
155
+
156
+ def split_continuous_binary(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
157
+ """
158
+ Split DataFrame into two DataFrames: one with continuous columns, one with binary columns.
159
+ Normalize binary values like 0.0/1.0 to 0/1 if detected.
160
+
161
+ Parameters:
162
+ df (pd.DataFrame): Input DataFrame with only numeric columns.
163
+
164
+ Returns:
165
+ Tuple(pd.DataFrame, pd.DataFrame): (continuous_columns_df, binary_columns_df)
166
+
167
+ Raises:
168
+ TypeError: If any column is not numeric.
169
+ """
170
+ if not all(np.issubdtype(dtype, np.number) for dtype in df.dtypes):
171
+ raise TypeError("All columns must be numeric (int or float).")
172
+
173
+ binary_cols = []
174
+ continuous_cols = []
175
+
176
+ for col in df.columns:
177
+ series = df[col]
178
+ unique_values = set(series[~series.isna()].unique())
179
+
180
+ if unique_values.issubset({0, 1}):
181
+ binary_cols.append(col)
182
+ elif unique_values.issubset({0.0, 1.0}):
183
+ df[col] = df[col].apply(lambda x: 0 if x == 0.0 else (1 if x == 1.0 else x))
184
+ binary_cols.append(col)
185
+ else:
186
+ continuous_cols.append(col)
187
+
188
+ binary_cols.sort()
189
+
190
+ df_cont = df[continuous_cols]
191
+ df_bin = df[binary_cols]
192
+
193
+ print(f"Continuous columns shape: {df_cont.shape}")
194
+ print(f"Binary columns shape: {df_bin.shape}")
195
+
196
+ return df_cont, df_bin # type: ignore
197
+
198
+
199
+ def drop_columns_with_missing_data(df: pd.DataFrame, threshold: float = 0.7) -> pd.DataFrame:
200
+ """
201
+ Drops columns with more than `threshold` fraction of missing values.
202
+
203
+ Parameters:
204
+ df (pd.DataFrame): The input DataFrame.
205
+ threshold (float): Fraction of missing values above which columns are dropped.
206
+
207
+ Returns:
208
+ pd.DataFrame: A new DataFrame without the dropped columns.
209
+ """
210
+ missing_fraction = df.isnull().mean()
211
+ cols_to_drop = missing_fraction[missing_fraction > threshold].index
212
+
213
+ if len(cols_to_drop) > 0:
214
+ print(f"Dropping columns with more than {threshold*100:.0f}% missing data:")
215
+ print(list(cols_to_drop))
216
+ else:
217
+ print(f"No columns have more than {threshold*100:.0f}% missing data.")
218
+
219
+ return df.drop(columns=cols_to_drop)
220
+
221
+
222
+ def plot_correlation_heatmap(df: pd.DataFrame, save_dir: Union[str, None] = None, method: Literal["pearson", "kendall", "spearman"]="pearson", plot_title: str="Correlation Heatmap"):
223
+ """
224
+ Plots a heatmap of pairwise correlations between numeric features in a DataFrame.
225
+
226
+ Args:
227
+ df (pd.DataFrame): The input dataset.
228
+ save_dir (str | None): If provided, the heatmap will be saved to this directory as a svg file.
229
+ plot_title: To make different plots, or overwrite existing ones.
230
+ method (str): Correlation method to use. Must be one of:
231
+ - 'pearson' (default): measures linear correlation (assumes normally distributed data),
232
+ - 'kendall': rank correlation (non-parametric),
233
+ - 'spearman': monotonic relationship (non-parametric).
234
+
235
+ Notes:
236
+ - Only numeric columns are included.
237
+ - Annotations are disabled if there are more than 20 features.
238
+ - Missing values are handled via pairwise complete observations.
239
+ """
240
+ numeric_df = df.select_dtypes(include='number')
241
+ if numeric_df.empty:
242
+ print("No numeric columns found. Heatmap not generated.")
243
+ return
244
+
245
+ corr = numeric_df.corr(method=method)
246
+
247
+ # Create a mask for the upper triangle
248
+ mask = np.triu(np.ones_like(corr, dtype=bool))
249
+
250
+ # Plot setup
251
+ size = max(10, numeric_df.shape[1])
252
+ plt.figure(figsize=(size, size * 0.8))
253
+
254
+ annot_bool = numeric_df.shape[1] <= 20
255
+ sns.heatmap(
256
+ corr,
257
+ mask=mask,
258
+ annot=annot_bool,
259
+ cmap='coolwarm',
260
+ fmt=".2f",
261
+ cbar_kws={"shrink": 0.8}
262
+ )
263
+
264
+ # sanitize the plot title
265
+ plot_title = sanitize_filename(plot_title)
266
+
267
+ plt.title(plot_title)
268
+ plt.xticks(rotation=45, ha='right')
269
+ plt.yticks(rotation=0)
270
+
271
+ plt.tight_layout()
272
+
273
+ if save_dir:
274
+ os.makedirs(save_dir, exist_ok=True)
275
+ full_path = os.path.join(save_dir, plot_title + ".svg")
276
+ plt.savefig(full_path, bbox_inches="tight", format='svg')
277
+ print(f"Saved correlation heatmap to: {full_path}")
278
+
279
+ plt.show()
280
+ plt.close()
281
+
282
+
283
+ def check_value_distributions(df: pd.DataFrame, view_frequencies: bool=True, bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
284
+ """
285
+ Analyzes value counts for each column in a DataFrame, optionally plots distributions,
286
+ and saves them as .png files in the specified directory.
287
+
288
+ Args:
289
+ df (pd.DataFrame): The dataset to analyze.
290
+ view_frequencies (bool): Print relative frequencies instead of value counts.
291
+ bin_threshold (int): Threshold of unique values to start using bins.
292
+ skip_cols_with_key (str | None): Skip column names containing the key. If None, don't skip any column.
293
+
294
+ Notes:
295
+ - Binning is adaptive: if quantile binning results in ≤ 2 unique bins, raw values are used instead.
296
+ """
297
+ # cherrypick columns
298
+ if skip_cols_with_key is not None:
299
+ columns = [col for col in df.columns if skip_cols_with_key not in col]
300
+ else:
301
+ columns = df.columns.to_list()
302
+
303
+ for col in columns:
304
+ if _is_notebook():
305
+ clear_output(wait=False)
306
+ if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() > bin_threshold:
307
+ bins_number = 10
308
+ binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
309
+ while binned.nunique() <= 2:
310
+ bins_number -= 1
311
+ binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
312
+ if bins_number <= 2:
313
+ break
314
+
315
+ if binned.nunique() <= 2:
316
+ view_std = df[col].value_counts(ascending=False)
317
+ else:
318
+ view_std = binned.value_counts(sort=False)
319
+
320
+ else:
321
+ view_std = df[col].value_counts(ascending=False)
322
+
323
+ view_std.name = col
324
+
325
+ # unlikely scenario where the series is empty
326
+ if view_std.sum() == 0:
327
+ view_freq = view_std
328
+ else:
329
+ view_freq = view_std / view_std.sum()
330
+ # view_freq = df[col].value_counts(normalize=True, bins=10) # relative percentages
331
+ view_freq.name = col
332
+
333
+ # Print value counts
334
+ print(view_freq if view_frequencies else view_std)
335
+
336
+ time.sleep(1)
337
+ user_input_ = input("Press enter to continue")
338
+
339
+
340
+ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
341
+ """
342
+ Plots and saves the value distributions for all (or selected) columns in a DataFrame,
343
+ with adaptive binning for numerical columns when appropriate.
344
+
345
+ For each column both raw counts and relative frequencies are computed and plotted.
346
+
347
+ Plots are saved as PNG files under two subdirectories in `save_dir`:
348
+ - "Distribution_Counts" for absolute counts.
349
+ - "Distribution_Frequency" for relative frequencies.
350
+
351
+ Args:
352
+ df (pd.DataFrame): The input DataFrame whose columns are to be analyzed.
353
+ save_dir (str): Directory path where the plots will be saved. Will be created if it does not exist.
354
+ bin_threshold (int): Minimum number of unique values required to trigger binning
355
+ for numerical columns.
356
+ skip_cols_with_key (str | None): If provided, any column whose name contains this
357
+ substring will be excluded from analysis.
358
+
359
+ Notes:
360
+ - Binning is adaptive: if quantile binning results in ≤ 2 unique bins, raw values are used instead.
361
+ - All non-alphanumeric characters in column names are sanitized for safe file naming.
362
+ - Colormap is automatically adapted based on the number of categories or bins.
363
+ """
364
+ if save_dir is not None:
365
+ os.makedirs(save_dir, exist_ok=True)
366
+
367
+ dict_to_plot_std = dict()
368
+ dict_to_plot_freq = dict()
369
+
370
+ # cherrypick columns
371
+ if skip_cols_with_key is not None:
372
+ columns = [col for col in df.columns if skip_cols_with_key not in col]
373
+ else:
374
+ columns = df.columns.to_list()
375
+
376
+ saved_plots = 0
377
+ for col in columns:
378
+ if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() > bin_threshold:
379
+ bins_number = 10
380
+ binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
381
+ while binned.nunique() <= 2:
382
+ bins_number -= 1
383
+ binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
384
+ if bins_number <= 2:
385
+ break
386
+
387
+ if binned.nunique() <= 2:
388
+ view_std = df[col].value_counts(sort=False).sort_index()
389
+ else:
390
+ view_std = binned.value_counts(sort=False)
391
+
392
+ else:
393
+ view_std = df[col].value_counts(sort=False).sort_index()
394
+
395
+ # unlikely scenario where the series is empty
396
+ if view_std.sum() == 0:
397
+ view_freq = view_std
398
+ else:
399
+ view_freq = 100 * view_std / view_std.sum() # Percentage
400
+ # view_freq = df[col].value_counts(normalize=True, bins=10) # relative percentages
401
+
402
+ if save_dir:
403
+ dict_to_plot_std[col] = dict(view_std)
404
+ dict_to_plot_freq[col] = dict(view_freq)
405
+ saved_plots += 1
406
+
407
+ # plot helper
408
+ def _plot_helper(dict_: dict, target_dir: str, ylabel: Literal["Frequency", "Counts"], base_fontsize: int=12):
409
+ for col, data in dict_.items():
410
+ safe_col = sanitize_filename(col)
411
+
412
+ if isinstance(list(data.keys())[0], pd.Interval):
413
+ labels = [str(interval) for interval in data.keys()]
414
+ else:
415
+ labels = data.keys()
416
+
417
+ plt.figure(figsize=(10, 6))
418
+ colors = plt.cm.tab20.colors if len(data) <= 20 else plt.cm.viridis(np.linspace(0, 1, len(data)))
419
+
420
+ plt.bar(labels, data.values(), color=colors[:len(data)], alpha=0.85)
421
+ plt.xlabel("Values", fontsize=base_fontsize)
422
+ plt.ylabel(ylabel, fontsize=base_fontsize)
423
+ plt.title(f"Value Distribution for '{col}'", fontsize=base_fontsize+2)
424
+ plt.xticks(rotation=45, ha='right', fontsize=base_fontsize-2)
425
+ plt.yticks(fontsize=base_fontsize-2)
426
+ plt.grid(axis='y', linestyle='--', alpha=0.6)
427
+ plt.gca().set_facecolor('#f9f9f9')
428
+ plt.tight_layout()
429
+
430
+ plot_path = os.path.join(target_dir, f"{safe_col}.png")
431
+ plt.savefig(plot_path, dpi=300, bbox_inches="tight")
432
+ plt.close()
433
+
434
+ # Save plots
435
+ freq_dir = os.path.join(save_dir, "Distribution_Frequency")
436
+ std_dir = os.path.join(save_dir, "Distribution_Counts")
437
+ os.makedirs(freq_dir, exist_ok=True)
438
+ os.makedirs(std_dir, exist_ok=True)
439
+ _plot_helper(dict_=dict_to_plot_std, target_dir=std_dir, ylabel="Counts")
440
+ _plot_helper(dict_=dict_to_plot_freq, target_dir=freq_dir, ylabel="Frequency")
441
+
442
+ print(f"Saved {saved_plots} plot(s)")
443
+
444
+
445
+ def clip_outliers_single(
446
+ df: pd.DataFrame,
447
+ column: str,
448
+ min_val: float,
449
+ max_val: float
450
+ ) -> Union[pd.DataFrame, None]:
451
+ """
452
+ Clips values in the specified numeric column to the range [min_val, max_val],
453
+ and returns a new DataFrame where the original column is replaced by the clipped version.
454
+
455
+ Args:
456
+ df (pd.DataFrame): The input DataFrame.
457
+ column (str): The name of the column to clip.
458
+ min_val (float): Minimum allowable value; values below are clipped to this.
459
+ max_val (float): Maximum allowable value; values above are clipped to this.
460
+
461
+ Returns:
462
+ pd.DataFrame: A new DataFrame with the specified column clipped in place.
463
+
464
+ None: if a problem with the dataframe column occurred.
465
+ """
466
+ if column not in df.columns:
467
+ print(f"Column '{column}' not found in DataFrame.")
468
+ return None
469
+
470
+ if not pd.api.types.is_numeric_dtype(df[column]):
471
+ print(f"Column '{column}' must be numeric.")
472
+ return None
473
+
474
+ new_df = df.copy(deep=True)
475
+ new_df[column] = new_df[column].clip(lower=min_val, upper=max_val)
476
+
477
+ print(f"Column '{column}' clipped to range [{min_val}, {max_val}].")
478
+ return new_df
479
+
480
+
481
+ def clip_outliers_multi(
482
+ df: pd.DataFrame,
483
+ clip_dict: Dict[str, Tuple[Union[int, float], Union[int, float]]],
484
+ verbose: bool=False
485
+ ) -> pd.DataFrame:
486
+ """
487
+ Clips values in multiple specified numeric columns to given [min, max] ranges,
488
+ updating values (deep copy) and skipping invalid entries.
489
+
490
+ Args:
491
+ df (pd.DataFrame): The input DataFrame.
492
+ clip_dict (dict): A dictionary where keys are column names and values are (min_val, max_val) tuples.
493
+ verbose (bool): prints clipped range for each column.
494
+
495
+ Returns:
496
+ pd.DataFrame: A new DataFrame with specified columns clipped.
497
+
498
+ Notes:
499
+ - Invalid specifications (missing column, non-numeric type, wrong tuple length)
500
+ will be reported but skipped.
501
+ """
502
+ new_df = df.copy()
503
+ skipped_columns = []
504
+ clipped_columns = 0
505
+
506
+ for col, bounds in clip_dict.items():
507
+ try:
508
+ if col not in df.columns:
509
+ raise ValueError(f"Column '{col}' not found in DataFrame.")
510
+
511
+ if not pd.api.types.is_numeric_dtype(df[col]):
512
+ raise TypeError(f"Column '{col}' is not numeric.")
513
+
514
+ if not (isinstance(bounds, tuple) and len(bounds) == 2):
515
+ raise ValueError(f"Bounds for '{col}' must be a tuple of (min, max).")
516
+
517
+ min_val, max_val = bounds
518
+ new_df[col] = new_df[col].clip(lower=min_val, upper=max_val)
519
+ if verbose:
520
+ print(f"Clipped '{col}' to range [{min_val}, {max_val}].")
521
+ clipped_columns += 1
522
+
523
+ except Exception as e:
524
+ skipped_columns.append((col, str(e)))
525
+ continue
526
+
527
+ print(f"Clipped {clipped_columns} columns.")
528
+
529
+ if skipped_columns:
530
+ print("\n⚠️ Skipped columns:")
531
+ for col, msg in skipped_columns:
532
+ print(f" - {col}: {msg}")
533
+
534
+ return new_df
535
+
536
+
537
+ def merge_dataframes(
538
+ *dfs: pd.DataFrame,
539
+ reset_index: bool = False,
540
+ direction: Literal["horizontal", "vertical"] = "horizontal"
541
+ ) -> pd.DataFrame:
542
+ """
543
+ Merges multiple DataFrames either horizontally or vertically.
544
+
545
+ Parameters:
546
+ *dfs (pd.DataFrame): Variable number of DataFrames to merge.
547
+ reset_index (bool): Whether to reset index in the final merged DataFrame.
548
+ direction (["horizontal" | "vertical"]):
549
+ - "horizontal": Merge on index, adding columns.
550
+ - "vertical": Append rows; all DataFrames must have identical columns.
551
+
552
+ Returns:
553
+ pd.DataFrame: A single merged DataFrame.
554
+
555
+ Raises:
556
+ ValueError:
557
+ - If fewer than 2 DataFrames are provided.
558
+ - If indexes do not match for horizontal merge.
559
+ - If column names or order differ for vertical merge.
560
+ """
561
+ if len(dfs) < 2:
562
+ raise ValueError("At least 2 DataFrames must be provided.")
563
+
564
+ for i, df in enumerate(dfs, start=1):
565
+ print(f"DataFrame {i} shape: {df.shape}")
566
+
567
+
568
+ if direction == "horizontal":
569
+ reference_index = dfs[0].index
570
+ for i, df in enumerate(dfs, start=1):
571
+ if not df.index.equals(reference_index):
572
+ raise ValueError(f"Indexes do not match: Dataset 1 and Dataset {i}.")
573
+ merged_df = pd.concat(dfs, axis=1)
574
+
575
+ elif direction == "vertical":
576
+ reference_columns = dfs[0].columns
577
+ for i, df in enumerate(dfs, start=1):
578
+ if not df.columns.equals(reference_columns):
579
+ raise ValueError(f"Column names/order do not match: Dataset 1 and Dataset {i}.")
580
+ merged_df = pd.concat(dfs, axis=0)
581
+
582
+ else:
583
+ raise ValueError(f"Invalid merge direction: {direction}")
584
+
585
+ if reset_index:
586
+ merged_df = merged_df.reset_index(drop=True)
587
+
588
+ print(f"Merged DataFrame shape: {merged_df.shape}")
589
+
590
+ return merged_df
591
+
592
+
593
+ def save_dataframe(df: pd.DataFrame, save_dir: str, filename: str) -> None:
594
+ """
595
+ Save a pandas DataFrame to a CSV file.
596
+
597
+ Parameters:
598
+ df: pandas.DataFrame to save
599
+ save_dir: str, directory where the CSV file will be saved.
600
+ filename: str, CSV filename, extension will be added if missing.
601
+ """
602
+ os.makedirs(save_dir, exist_ok=True)
603
+
604
+ filename = sanitize_filename(filename)
605
+
606
+ if not filename.endswith('.csv'):
607
+ filename += '.csv'
608
+
609
+ output_path = os.path.join(save_dir, filename)
610
+
611
+ df.to_csv(output_path, index=False, encoding='utf-8')
612
+ print(f"Saved file: '{filename}'")
613
+
614
+
615
+ def compute_vif(
616
+ df: pd.DataFrame,
617
+ features: Optional[list[str]] = None,
618
+ ignore_cols: Optional[list[str]] = None,
619
+ plot: bool = True,
620
+ save_dir: Union[str, None] = None
621
+ ) -> pd.DataFrame:
622
+ """
623
+ Computes Variance Inflation Factors (VIF) for numeric features, optionally plots and saves the results.
624
+
625
+ There cannot be empty values in the dataset.
626
+
627
+ Args:
628
+ df (pd.DataFrame): The input DataFrame.
629
+ features (list[str] | None): Optional list of column names to evaluate. Defaults to all numeric columns.
630
+ ignore_cols (list[str] | None): Optional list of column names to ignore.
631
+ plot (bool): Whether to display a barplot of VIF values.
632
+ save_dir (str | None): Directory to save the plot as SVG. If None, plot is not saved.
633
+
634
+ Returns:
635
+ pd.DataFrame: DataFrame with features and corresponding VIF values, sorted descending.
636
+
637
+ NOTE:
638
+ **Variance Inflation Factor (VIF)** quantifies the degree of multicollinearity among features in a dataset.
639
+ A VIF value indicates how much the variance of a regression coefficient is inflated due to linear dependence with other features.
640
+ A VIF of 1 suggests no correlation, values between 1 and 5 indicate moderate correlation, and values greater than 10 typically signal high multicollinearity, which may distort model interpretation and degrade performance.
641
+
642
+ """
643
+ if features is None:
644
+ features = df.select_dtypes(include='number').columns.tolist()
645
+
646
+ if ignore_cols is not None:
647
+ missing = set(ignore_cols) - set(features)
648
+ if missing:
649
+ raise ValueError(f"The following 'columns to ignore' are not in the Dataframe:\n{missing}")
650
+ features = [f for f in features if f not in ignore_cols]
651
+
652
+ X = df[features].copy()
653
+ X = add_constant(X, has_constant='add')
654
+
655
+ vif_data = pd.DataFrame()
656
+ vif_data["feature"] = X.columns
657
+ vif_data["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
658
+
659
+ # Drop the constant column
660
+ vif_data = vif_data[vif_data["feature"] != "const"]
661
+ vif_data = vif_data.sort_values(by="VIF", ascending=False).reset_index(drop=True) # type: ignore
662
+
663
+ # Add color coding based on thresholds
664
+ def vif_color(v: float) -> str:
665
+ if v > 10:
666
+ return "red"
667
+ elif v > 5:
668
+ return "gold"
669
+ else:
670
+ return "green"
671
+
672
+ vif_data["color"] = vif_data["VIF"].apply(vif_color)
673
+
674
+ # Plot
675
+ if plot or save_dir:
676
+ plt.figure(figsize=(10, 6))
677
+ bars = plt.barh(
678
+ vif_data["feature"],
679
+ vif_data["VIF"],
680
+ color=vif_data["color"],
681
+ edgecolor='black'
682
+ )
683
+ plt.title("Variance Inflation Factor (VIF) per Feature")
684
+ plt.xlabel("VIF")
685
+ plt.axvline(x=5, color='gold', linestyle='--', label='VIF = 5')
686
+ plt.axvline(x=10, color='red', linestyle='--', label='VIF = 10')
687
+ plt.legend(loc='lower right')
688
+ plt.gca().invert_yaxis()
689
+ plt.grid(axis='x', linestyle='--', alpha=0.5)
690
+ plt.tight_layout()
691
+
692
+ if save_dir:
693
+ os.makedirs(save_dir, exist_ok=True)
694
+ save_path = os.path.join(save_dir, "VIF_plot.svg")
695
+ plt.savefig(save_path, format='svg', bbox_inches='tight')
696
+ print(f"Saved VIF plot to: {save_path}")
697
+
698
+ if plot:
699
+ plt.show()
700
+ plt.close()
701
+
702
+ return vif_data.drop(columns="color")
703
+
704
+
705
+ def drop_vif_based(df: pd.DataFrame, vif_df: pd.DataFrame, threshold: float = 10.0) -> pd.DataFrame:
706
+ """
707
+ Drops features from the original DataFrame based on their VIF values exceeding a given threshold.
708
+
709
+ Args:
710
+ df (pd.DataFrame): Original DataFrame containing the features.
711
+ vif_df (pd.DataFrame): DataFrame with 'feature' and 'VIF' columns as returned by `compute_vif()`.
712
+ threshold (float): VIF threshold above which features will be dropped.
713
+
714
+ Returns:
715
+ pd.DataFrame: A new DataFrame with high-VIF features removed.
716
+ """
717
+ # Ensure expected structure
718
+ if 'feature' not in vif_df.columns or 'VIF' not in vif_df.columns:
719
+ raise ValueError("`vif_df` must contain 'feature' and 'VIF' columns.")
720
+
721
+ # Identify features to drop
722
+ to_drop = vif_df[vif_df["VIF"] > threshold]["feature"].tolist()
723
+ print(f"Dropping {len(to_drop)} feature(s) with VIF > {threshold}: {to_drop}")
724
+
725
+ return df.drop(columns=to_drop, errors="ignore")
726
+
727
+
728
+ def _is_notebook():
729
+ return get_ipython() is not None
730
+
731
+
732
+ def info(full_info: bool=True):
733
+ """
734
+ List available functions and their descriptions.
735
+ """
736
+ print("Available functions for data exploration:")
737
+ if full_info:
738
+ module = sys.modules[__name__]
739
+ for name in __all__:
740
+ obj = getattr(module, name, None)
741
+ if callable(obj):
742
+ doc = obj.__doc__ or "No docstring provided."
743
+ formatted_doc = textwrap.indent(textwrap.dedent(doc.strip()), prefix=" ")
744
+ print(f"\n{name}:\n{formatted_doc}")
745
+ else:
746
+ for i, name in enumerate(__all__, start=1):
747
+ print(f"{i} - {name}")
748
+
749
+
750
+ if __name__ == "__main__":
751
+ info()