dragon-ml-toolbox 10.1.1__py3-none-any.whl → 14.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/METADATA +38 -63
- dragon_ml_toolbox-14.2.0.dist-info/RECORD +48 -0
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/licenses/LICENSE +1 -1
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +11 -0
- ml_tools/ETL_cleaning.py +175 -59
- ml_tools/ETL_engineering.py +506 -70
- ml_tools/GUI_tools.py +2 -1
- ml_tools/MICE_imputation.py +212 -7
- ml_tools/ML_callbacks.py +73 -40
- ml_tools/ML_datasetmaster.py +267 -284
- ml_tools/ML_evaluation.py +119 -58
- ml_tools/ML_evaluation_multi.py +107 -32
- ml_tools/ML_inference.py +15 -5
- ml_tools/ML_models.py +234 -170
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +321 -97
- ml_tools/ML_scaler.py +10 -5
- ml_tools/ML_trainer.py +585 -40
- ml_tools/ML_utilities.py +528 -0
- ml_tools/ML_vision_datasetmaster.py +1315 -0
- ml_tools/ML_vision_evaluation.py +260 -0
- ml_tools/ML_vision_inference.py +428 -0
- ml_tools/ML_vision_models.py +627 -0
- ml_tools/ML_vision_transformers.py +58 -0
- ml_tools/PSO_optimization.py +10 -7
- ml_tools/RNN_forecast.py +2 -0
- ml_tools/SQL.py +22 -9
- ml_tools/VIF_factor.py +4 -3
- ml_tools/_ML_vision_recipe.py +88 -0
- ml_tools/__init__.py +1 -0
- ml_tools/_logger.py +0 -2
- ml_tools/_schema.py +96 -0
- ml_tools/constants.py +79 -0
- ml_tools/custom_logger.py +164 -16
- ml_tools/data_exploration.py +1092 -109
- ml_tools/ensemble_evaluation.py +48 -1
- ml_tools/ensemble_inference.py +6 -7
- ml_tools/ensemble_learning.py +4 -3
- ml_tools/handle_excel.py +1 -0
- ml_tools/keys.py +80 -0
- ml_tools/math_utilities.py +259 -0
- ml_tools/optimization_tools.py +198 -24
- ml_tools/path_manager.py +144 -45
- ml_tools/serde.py +192 -0
- ml_tools/utilities.py +287 -227
- dragon_ml_toolbox-10.1.1.dist-info/RECORD +0 -36
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/top_level.txt +0 -0
ml_tools/data_exploration.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
import pandas as pd
|
|
2
|
-
from pandas.api.types import is_numeric_dtype
|
|
2
|
+
from pandas.api.types import is_numeric_dtype, is_object_dtype
|
|
3
3
|
import numpy as np
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
5
|
import seaborn as sns
|
|
6
|
-
from typing import Union, Literal, Dict, Tuple, List, Optional
|
|
6
|
+
from typing import Union, Literal, Dict, Tuple, List, Optional, Any
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
import re
|
|
9
9
|
|
|
10
10
|
from .path_manager import sanitize_filename, make_fullpath
|
|
11
11
|
from ._script_info import _script_info
|
|
12
12
|
from ._logger import _LOGGER
|
|
13
|
-
from .utilities import
|
|
14
|
-
|
|
13
|
+
from .utilities import save_dataframe_filename
|
|
14
|
+
from ._schema import FeatureSchema
|
|
15
15
|
|
|
16
16
|
# Keep track of all available tools, show using `info()`
|
|
17
17
|
__all__ = [
|
|
@@ -21,14 +21,22 @@ __all__ = [
|
|
|
21
21
|
"show_null_columns",
|
|
22
22
|
"drop_columns_with_missing_data",
|
|
23
23
|
"drop_macro",
|
|
24
|
+
"clean_column_names",
|
|
25
|
+
"plot_value_distributions",
|
|
26
|
+
"plot_continuous_vs_target",
|
|
27
|
+
"plot_categorical_vs_target",
|
|
28
|
+
"encode_categorical_features",
|
|
24
29
|
"split_features_targets",
|
|
25
30
|
"split_continuous_binary",
|
|
26
|
-
"plot_correlation_heatmap",
|
|
27
|
-
"plot_value_distributions",
|
|
28
31
|
"clip_outliers_single",
|
|
29
32
|
"clip_outliers_multi",
|
|
33
|
+
"drop_outlier_samples",
|
|
34
|
+
"plot_correlation_heatmap",
|
|
30
35
|
"match_and_filter_columns_by_regex",
|
|
31
|
-
"standardize_percentages"
|
|
36
|
+
"standardize_percentages",
|
|
37
|
+
"reconstruct_one_hot",
|
|
38
|
+
"reconstruct_binary",
|
|
39
|
+
"finalize_feature_schema"
|
|
32
40
|
]
|
|
33
41
|
|
|
34
42
|
|
|
@@ -263,7 +271,7 @@ def drop_macro(df: pd.DataFrame,
|
|
|
263
271
|
|
|
264
272
|
# Log initial state
|
|
265
273
|
missing_data = show_null_columns(df=df_clean)
|
|
266
|
-
|
|
274
|
+
save_dataframe_filename(df=missing_data.reset_index(drop=False),
|
|
267
275
|
save_dir=log_directory,
|
|
268
276
|
filename="Missing_Data_start")
|
|
269
277
|
|
|
@@ -292,7 +300,7 @@ def drop_macro(df: pd.DataFrame,
|
|
|
292
300
|
|
|
293
301
|
# log final state
|
|
294
302
|
missing_data = show_null_columns(df=df_clean)
|
|
295
|
-
|
|
303
|
+
save_dataframe_filename(df=missing_data.reset_index(drop=False),
|
|
296
304
|
save_dir=log_directory,
|
|
297
305
|
filename="Missing_Data_final")
|
|
298
306
|
|
|
@@ -300,6 +308,547 @@ def drop_macro(df: pd.DataFrame,
|
|
|
300
308
|
return df_clean
|
|
301
309
|
|
|
302
310
|
|
|
311
|
+
def clean_column_names(df: pd.DataFrame, replacement_char: str = '-', replacement_pattern: str = r'[\[\]{}<>,:"]', verbose: bool = True) -> pd.DataFrame:
|
|
312
|
+
"""
|
|
313
|
+
Cleans DataFrame column names by replacing special characters.
|
|
314
|
+
|
|
315
|
+
This function is useful for ensuring compatibility with libraries like LightGBM,
|
|
316
|
+
which do not support special JSON characters such as `[]{}<>,:"` in feature names.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
df (pd.DataFrame): The input DataFrame.
|
|
320
|
+
replacement_char (str): The character to use for replacing characters.
|
|
321
|
+
replacement_pattern (str): Regex pattern to use for the replacement logic.
|
|
322
|
+
verbose (bool): If True, prints the renamed columns.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
pd.DataFrame: A new DataFrame with cleaned column names.
|
|
326
|
+
"""
|
|
327
|
+
new_df = df.copy()
|
|
328
|
+
|
|
329
|
+
original_columns = new_df.columns
|
|
330
|
+
new_columns = original_columns.str.replace(replacement_pattern, replacement_char, regex=True)
|
|
331
|
+
|
|
332
|
+
# Create a map of changes for logging
|
|
333
|
+
rename_map = {old: new for old, new in zip(original_columns, new_columns) if old != new}
|
|
334
|
+
|
|
335
|
+
if verbose:
|
|
336
|
+
if rename_map:
|
|
337
|
+
_LOGGER.info(f"Cleaned {len(rename_map)} column name(s) containing special characters:")
|
|
338
|
+
for old, new in rename_map.items():
|
|
339
|
+
print(f" '{old}' -> '{new}'")
|
|
340
|
+
else:
|
|
341
|
+
_LOGGER.info("No column names required cleaning.")
|
|
342
|
+
|
|
343
|
+
new_df.columns = new_columns
|
|
344
|
+
return new_df
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def plot_value_distributions(
|
|
348
|
+
df: pd.DataFrame,
|
|
349
|
+
save_dir: Union[str, Path],
|
|
350
|
+
categorical_columns: Optional[List[str]] = None,
|
|
351
|
+
categorical_cardinality_threshold: int = 10,
|
|
352
|
+
max_categories: int = 50,
|
|
353
|
+
fill_na_with: str = "Missing"
|
|
354
|
+
):
|
|
355
|
+
"""
|
|
356
|
+
Plots and saves the value distributions for all columns in a DataFrame,
|
|
357
|
+
using the best plot type for each column (histogram or count plot).
|
|
358
|
+
|
|
359
|
+
Plots are saved as SVG files under two subdirectories in `save_dir`:
|
|
360
|
+
- "Distribution_Continuous" for continuous numeric features (histograms).
|
|
361
|
+
- "Distribution_Categorical" for categorical features (count plots).
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
df (pd.DataFrame): The input DataFrame to analyze.
|
|
365
|
+
save_dir (str | Path): Directory path to save the plots.
|
|
366
|
+
categorical_columns (List[str] | None): If provided, this list
|
|
367
|
+
of column names will be treated as categorical, and all other columns will be treated as continuous. This
|
|
368
|
+
overrides the `continuous_cardinality_threshold` logic.
|
|
369
|
+
categorical_cardinality_threshold (int): A numeric column will be treated
|
|
370
|
+
as 'categorical' if its number of unique values is less than or equal to this threshold. (Ignored if `categorical_columns` is set).
|
|
371
|
+
max_categories (int): The maximum number of unique categories a
|
|
372
|
+
categorical feature can have to be plotted. Features exceeding this limit will be skipped.
|
|
373
|
+
fill_na_with (str): A string to replace NaN values in categorical columns. This allows plotting 'missingness' as its
|
|
374
|
+
own category. Defaults to "Missing".
|
|
375
|
+
|
|
376
|
+
Notes:
|
|
377
|
+
- `seaborn.histplot` with KDE is used for continuous features.
|
|
378
|
+
- `seaborn.countplot` is used for categorical features.
|
|
379
|
+
"""
|
|
380
|
+
# 1. Setup save directories
|
|
381
|
+
base_save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
382
|
+
numeric_dir = base_save_path / "Distribution_Continuous"
|
|
383
|
+
categorical_dir = base_save_path / "Distribution_Categorical"
|
|
384
|
+
numeric_dir.mkdir(parents=True, exist_ok=True)
|
|
385
|
+
categorical_dir.mkdir(parents=True, exist_ok=True)
|
|
386
|
+
|
|
387
|
+
# 2. Filter columns to plot
|
|
388
|
+
columns_to_plot = df.columns.to_list()
|
|
389
|
+
|
|
390
|
+
# Setup for forced categorical logic
|
|
391
|
+
categorical_set = set(categorical_columns) if categorical_columns is not None else None
|
|
392
|
+
|
|
393
|
+
numeric_plots_saved = 0
|
|
394
|
+
categorical_plots_saved = 0
|
|
395
|
+
|
|
396
|
+
for col_name in columns_to_plot:
|
|
397
|
+
try:
|
|
398
|
+
is_numeric = is_numeric_dtype(df[col_name])
|
|
399
|
+
n_unique = df[col_name].nunique()
|
|
400
|
+
|
|
401
|
+
# --- 3. Determine Plot Type ---
|
|
402
|
+
is_continuous = False
|
|
403
|
+
if categorical_set is not None:
|
|
404
|
+
# Use the explicit list
|
|
405
|
+
if col_name not in categorical_set:
|
|
406
|
+
is_continuous = True
|
|
407
|
+
else:
|
|
408
|
+
# Use auto-detection
|
|
409
|
+
if is_numeric and n_unique > categorical_cardinality_threshold:
|
|
410
|
+
is_continuous = True
|
|
411
|
+
|
|
412
|
+
# --- Case 1: Continuous Numeric (Histogram) ---
|
|
413
|
+
if is_continuous:
|
|
414
|
+
plt.figure(figsize=(10, 6))
|
|
415
|
+
# Drop NaNs for histogram, as they can't be plotted on a numeric axis
|
|
416
|
+
sns.histplot(x=df[col_name].dropna(), kde=True, bins=30)
|
|
417
|
+
plt.title(f"Distribution of '{col_name}' (Continuous)")
|
|
418
|
+
plt.xlabel(col_name)
|
|
419
|
+
plt.ylabel("Count")
|
|
420
|
+
|
|
421
|
+
save_path = numeric_dir / f"{sanitize_filename(col_name)}.svg"
|
|
422
|
+
numeric_plots_saved += 1
|
|
423
|
+
|
|
424
|
+
# --- Case 2: Categorical or Low-Cardinality Numeric (Count Plot) ---
|
|
425
|
+
else:
|
|
426
|
+
# Check max categories
|
|
427
|
+
if n_unique > max_categories:
|
|
428
|
+
_LOGGER.warning(f"Skipping plot for '{col_name}': {n_unique} unique values > {max_categories} max_categories.")
|
|
429
|
+
continue
|
|
430
|
+
|
|
431
|
+
# Adaptive figure size
|
|
432
|
+
fig_width = max(10, n_unique * 0.5)
|
|
433
|
+
plt.figure(figsize=(fig_width, 7))
|
|
434
|
+
|
|
435
|
+
# Make a temporary copy for plotting to handle NaNs
|
|
436
|
+
temp_series = df[col_name].copy()
|
|
437
|
+
|
|
438
|
+
# Handle NaNs by replacing them with the specified string
|
|
439
|
+
if temp_series.isnull().any():
|
|
440
|
+
# Convert to object type first to allow string replacement
|
|
441
|
+
temp_series = temp_series.astype(object).fillna(fill_na_with)
|
|
442
|
+
|
|
443
|
+
# Convert all to string to be safe (handles low-card numeric)
|
|
444
|
+
temp_series = temp_series.astype(str)
|
|
445
|
+
|
|
446
|
+
# Get category order by frequency
|
|
447
|
+
order = temp_series.value_counts().index
|
|
448
|
+
sns.countplot(x=temp_series, order=order, palette="viridis")
|
|
449
|
+
|
|
450
|
+
plt.title(f"Distribution of '{col_name}' (Categorical)")
|
|
451
|
+
plt.xlabel(col_name)
|
|
452
|
+
plt.ylabel("Count")
|
|
453
|
+
|
|
454
|
+
# Smart tick rotation
|
|
455
|
+
max_label_len = 0
|
|
456
|
+
if n_unique > 0:
|
|
457
|
+
max_label_len = max(len(str(s)) for s in order)
|
|
458
|
+
|
|
459
|
+
# Rotate if labels are long OR there are many categories
|
|
460
|
+
if max_label_len > 10 or n_unique > 25:
|
|
461
|
+
plt.xticks(rotation=45, ha='right')
|
|
462
|
+
|
|
463
|
+
save_path = categorical_dir / f"{sanitize_filename(col_name)}.svg"
|
|
464
|
+
categorical_plots_saved += 1
|
|
465
|
+
|
|
466
|
+
# --- 4. Save Plot ---
|
|
467
|
+
plt.grid(True, linestyle='--', alpha=0.6, axis='y')
|
|
468
|
+
plt.tight_layout()
|
|
469
|
+
# Save as .svg
|
|
470
|
+
plt.savefig(save_path, format='svg', bbox_inches="tight")
|
|
471
|
+
plt.close()
|
|
472
|
+
|
|
473
|
+
except Exception as e:
|
|
474
|
+
_LOGGER.error(f"Failed to plot distribution for '{col_name}'. Error: {e}")
|
|
475
|
+
plt.close()
|
|
476
|
+
|
|
477
|
+
_LOGGER.info(f"Saved {numeric_plots_saved} continuous distribution plots to '{numeric_dir.name}'.")
|
|
478
|
+
_LOGGER.info(f"Saved {categorical_plots_saved} categorical distribution plots to '{categorical_dir.name}'.")
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def plot_continuous_vs_target(
|
|
482
|
+
df: pd.DataFrame,
|
|
483
|
+
targets: List[str],
|
|
484
|
+
save_dir: Union[str, Path],
|
|
485
|
+
features: Optional[List[str]] = None
|
|
486
|
+
):
|
|
487
|
+
"""
|
|
488
|
+
Plots each continuous feature against each target to visualize linear relationships.
|
|
489
|
+
|
|
490
|
+
This function is a common EDA step for regression tasks. It creates a
|
|
491
|
+
scatter plot for each feature-target pair, overlays a simple linear
|
|
492
|
+
regression line, and saves each plot as an individual .svg file.
|
|
493
|
+
|
|
494
|
+
Plots are saved in a structured way, with a subdirectory created for
|
|
495
|
+
each target variable.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
df (pd.DataFrame): The input DataFrame.
|
|
499
|
+
targets (List[str]): A list of target column names to plot (y-axis).
|
|
500
|
+
save_dir (str | Path): The base directory where plots will be saved. A subdirectory will be created here for each target.
|
|
501
|
+
features (List[str] | None): A list of feature column names to plot (x-axis). If None, all non-target columns in the
|
|
502
|
+
DataFrame will be used.
|
|
503
|
+
|
|
504
|
+
Notes:
|
|
505
|
+
- Only numeric features and numeric targets are processed. Non-numeric
|
|
506
|
+
columns in the lists will be skipped with a warning.
|
|
507
|
+
- Rows with NaN in either the feature or the target are dropped
|
|
508
|
+
pairwise for each plot.
|
|
509
|
+
"""
|
|
510
|
+
# 1. Validate the base save directory
|
|
511
|
+
base_save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
512
|
+
|
|
513
|
+
# 2. Validate helper
|
|
514
|
+
def _validate_numeric_cols(col_list: List[str], col_type: str) -> List[str]:
|
|
515
|
+
valid_cols = []
|
|
516
|
+
for col in col_list:
|
|
517
|
+
if col not in df.columns:
|
|
518
|
+
_LOGGER.warning(f"{col_type} column '{col}' not found. Skipping.")
|
|
519
|
+
elif not is_numeric_dtype(df[col]):
|
|
520
|
+
_LOGGER.warning(f"{col_type} column '{col}' is not numeric. Skipping.")
|
|
521
|
+
else:
|
|
522
|
+
valid_cols.append(col)
|
|
523
|
+
return valid_cols
|
|
524
|
+
|
|
525
|
+
# 3. Validate target columns FIRST
|
|
526
|
+
valid_targets = _validate_numeric_cols(targets, "Target")
|
|
527
|
+
if not valid_targets:
|
|
528
|
+
_LOGGER.error("No valid numeric target columns provided to plot.")
|
|
529
|
+
return
|
|
530
|
+
|
|
531
|
+
# 4. Determine and validate feature columns
|
|
532
|
+
if features is None:
|
|
533
|
+
_LOGGER.info("No 'features' list provided. Using all non-target columns as features.")
|
|
534
|
+
target_set = set(valid_targets)
|
|
535
|
+
# Get all columns that are not in the valid_targets set
|
|
536
|
+
features_to_validate = [col for col in df.columns if col not in target_set]
|
|
537
|
+
else:
|
|
538
|
+
features_to_validate = features
|
|
539
|
+
|
|
540
|
+
valid_features = _validate_numeric_cols(features_to_validate, "Feature")
|
|
541
|
+
|
|
542
|
+
if not valid_features:
|
|
543
|
+
_LOGGER.error("No valid numeric feature columns found to plot.")
|
|
544
|
+
return
|
|
545
|
+
|
|
546
|
+
# 5. Main plotting loop
|
|
547
|
+
total_plots_saved = 0
|
|
548
|
+
|
|
549
|
+
for target_name in valid_targets:
|
|
550
|
+
# Create a sanitized subdirectory for this target
|
|
551
|
+
safe_target_dir_name = sanitize_filename(f"{target_name}_vs_Continuous")
|
|
552
|
+
target_save_dir = base_save_path / safe_target_dir_name
|
|
553
|
+
target_save_dir.mkdir(parents=True, exist_ok=True)
|
|
554
|
+
|
|
555
|
+
_LOGGER.info(f"Generating plots for target: '{target_name}' -> Saving to '{target_save_dir.name}'")
|
|
556
|
+
|
|
557
|
+
for feature_name in valid_features:
|
|
558
|
+
|
|
559
|
+
# Drop NaNs pairwise for this specific plot
|
|
560
|
+
temp_df = df[[feature_name, target_name]].dropna()
|
|
561
|
+
|
|
562
|
+
if temp_df.empty:
|
|
563
|
+
_LOGGER.warning(f"No non-null data for '{feature_name}' vs '{target_name}'. Skipping plot.")
|
|
564
|
+
continue
|
|
565
|
+
|
|
566
|
+
x = temp_df[feature_name]
|
|
567
|
+
y = temp_df[target_name]
|
|
568
|
+
|
|
569
|
+
# 6. Perform linear fit
|
|
570
|
+
try:
|
|
571
|
+
# Use numpy's polyfit to get the slope (pf[0]) and intercept (pf[1])
|
|
572
|
+
pf = np.polyfit(x, y, 1)
|
|
573
|
+
# Create a polynomial function p(x)
|
|
574
|
+
p = np.poly1d(pf)
|
|
575
|
+
plot_regression_line = True
|
|
576
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
577
|
+
_LOGGER.warning(f"Linear regression failed for '{feature_name}' vs '{target_name}'. Plotting scatter only.")
|
|
578
|
+
plot_regression_line = False
|
|
579
|
+
|
|
580
|
+
# 7. Create the plot
|
|
581
|
+
plt.figure(figsize=(10, 6))
|
|
582
|
+
ax = plt.gca()
|
|
583
|
+
|
|
584
|
+
# Plot the raw data points
|
|
585
|
+
ax.plot(x, y, 'o', alpha=0.5, label='Data points', markersize=5)
|
|
586
|
+
|
|
587
|
+
# Plot the regression line
|
|
588
|
+
if plot_regression_line:
|
|
589
|
+
ax.plot(x, p(x), "r--", label='Linear Fit') # type: ignore
|
|
590
|
+
|
|
591
|
+
ax.set_title(f'{feature_name} vs {target_name}')
|
|
592
|
+
ax.set_xlabel(feature_name)
|
|
593
|
+
ax.set_ylabel(target_name)
|
|
594
|
+
ax.legend()
|
|
595
|
+
plt.grid(True, linestyle='--', alpha=0.6)
|
|
596
|
+
plt.tight_layout()
|
|
597
|
+
|
|
598
|
+
# 8. Save the plot
|
|
599
|
+
safe_feature_name = sanitize_filename(feature_name)
|
|
600
|
+
plot_filename = f"{safe_feature_name}_vs_{safe_target_dir_name}.svg"
|
|
601
|
+
plot_path = target_save_dir / plot_filename
|
|
602
|
+
|
|
603
|
+
try:
|
|
604
|
+
plt.savefig(plot_path, bbox_inches="tight", format='svg')
|
|
605
|
+
total_plots_saved += 1
|
|
606
|
+
except Exception as e:
|
|
607
|
+
_LOGGER.error(f"Failed to save plot: {plot_path}. Error: {e}")
|
|
608
|
+
|
|
609
|
+
# Close the figure to free up memory
|
|
610
|
+
plt.close()
|
|
611
|
+
|
|
612
|
+
_LOGGER.info(f"Successfully saved {total_plots_saved} feature-vs-target plots to '{base_save_path}'.")
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
def plot_categorical_vs_target(
|
|
616
|
+
df: pd.DataFrame,
|
|
617
|
+
targets: List[str],
|
|
618
|
+
save_dir: Union[str, Path],
|
|
619
|
+
features: Optional[List[str]] = None,
|
|
620
|
+
plot_type: Literal["box", "violin"] = "box",
|
|
621
|
+
max_categories: int = 20,
|
|
622
|
+
fill_na_with: str = "Missing"
|
|
623
|
+
):
|
|
624
|
+
"""
|
|
625
|
+
Plots each categorical feature against each numeric target using box or violin plots.
|
|
626
|
+
|
|
627
|
+
This function is a core EDA step for regression tasks to understand the
|
|
628
|
+
relationship between a categorical independent variable and a continuous
|
|
629
|
+
dependent variable.
|
|
630
|
+
|
|
631
|
+
Args:
|
|
632
|
+
df (pd.DataFrame): The input DataFrame.
|
|
633
|
+
targets (List[str]): A list of numeric target column names (y-axis).
|
|
634
|
+
save_dir (str | Path): The base directory where plots will be saved. A subdirectory will be created here for each target.
|
|
635
|
+
features (List[str] | None): A list of categorical feature column names (x-axis). If None, all non-numeric (object) columns will be used.
|
|
636
|
+
plot_type (Literal["box", "violin"]): The type of plot to generate.
|
|
637
|
+
max_categories (int): The maximum number of unique categories a feature can have to be plotted. Features exceeding this limit will be skipped.
|
|
638
|
+
fill_na_with (str): A string to replace NaN values in categorical columns. This allows plotting 'missingness' as its own category. Defaults to "Missing".
|
|
639
|
+
|
|
640
|
+
Notes:
|
|
641
|
+
- Only numeric targets are processed.
|
|
642
|
+
- Features are automatically identified as categorical if they are 'object' dtype.
|
|
643
|
+
"""
|
|
644
|
+
# 1. Validate the base save directory and inputs
|
|
645
|
+
base_save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
646
|
+
|
|
647
|
+
if plot_type not in ["box", "violin"]:
|
|
648
|
+
_LOGGER.error(f"Invalid plot type '{plot_type}'")
|
|
649
|
+
raise ValueError()
|
|
650
|
+
|
|
651
|
+
# 2. Validate target columns (must be numeric)
|
|
652
|
+
valid_targets = []
|
|
653
|
+
for col in targets:
|
|
654
|
+
if col not in df.columns:
|
|
655
|
+
_LOGGER.warning(f"Target column '{col}' not found. Skipping.")
|
|
656
|
+
elif not is_numeric_dtype(df[col]):
|
|
657
|
+
_LOGGER.warning(f"Target column '{col}' is not numeric. Skipping.")
|
|
658
|
+
else:
|
|
659
|
+
valid_targets.append(col)
|
|
660
|
+
|
|
661
|
+
if not valid_targets:
|
|
662
|
+
_LOGGER.error("No valid numeric target columns provided to plot.")
|
|
663
|
+
return
|
|
664
|
+
|
|
665
|
+
# 3. Determine and validate feature columns
|
|
666
|
+
features_to_plot = []
|
|
667
|
+
if features is None:
|
|
668
|
+
_LOGGER.info("No 'features' list provided. Auto-detecting categorical features.")
|
|
669
|
+
for col in df.columns:
|
|
670
|
+
if col in valid_targets:
|
|
671
|
+
continue
|
|
672
|
+
|
|
673
|
+
# Auto-include object dtypes
|
|
674
|
+
if is_object_dtype(df[col]):
|
|
675
|
+
features_to_plot.append(col)
|
|
676
|
+
# Auto-include low-cardinality numeric features - REMOVED
|
|
677
|
+
# elif is_numeric_dtype(df[col]) and df[col].nunique() <= max_categories:
|
|
678
|
+
# _LOGGER.info(f"Treating low-cardinality numeric column '{col}' as categorical.")
|
|
679
|
+
# features_to_plot.append(col)
|
|
680
|
+
else:
|
|
681
|
+
# Validate user-provided list
|
|
682
|
+
for col in features:
|
|
683
|
+
if col not in df.columns:
|
|
684
|
+
_LOGGER.warning(f"Feature column '{col}' not found in DataFrame. Skipping.")
|
|
685
|
+
else:
|
|
686
|
+
features_to_plot.append(col)
|
|
687
|
+
|
|
688
|
+
if not features_to_plot:
|
|
689
|
+
_LOGGER.error("No valid categorical feature columns found to plot.")
|
|
690
|
+
return
|
|
691
|
+
|
|
692
|
+
# 4. Main plotting loop
|
|
693
|
+
total_plots_saved = 0
|
|
694
|
+
|
|
695
|
+
for target_name in valid_targets:
|
|
696
|
+
# Create a sanitized subdirectory for this target
|
|
697
|
+
safe_target_dir_name = sanitize_filename(f"{target_name}_vs_Categorical_{plot_type}")
|
|
698
|
+
target_save_dir = base_save_path / safe_target_dir_name
|
|
699
|
+
target_save_dir.mkdir(parents=True, exist_ok=True)
|
|
700
|
+
|
|
701
|
+
_LOGGER.info(f"Generating '{plot_type}' plots for target: '{target_name}' -> Saving to '{target_save_dir.name}'")
|
|
702
|
+
|
|
703
|
+
for feature_name in features_to_plot:
|
|
704
|
+
|
|
705
|
+
# Make a temporary copy for plotting to handle NaNs and dtypes
|
|
706
|
+
temp_df = df[[feature_name, target_name]].copy()
|
|
707
|
+
|
|
708
|
+
# Check cardinality
|
|
709
|
+
n_unique = temp_df[feature_name].nunique()
|
|
710
|
+
if n_unique > max_categories:
|
|
711
|
+
_LOGGER.warning(f"Skipping '{feature_name}': {n_unique} unique values > {max_categories} max_categories.")
|
|
712
|
+
continue
|
|
713
|
+
|
|
714
|
+
# Handle NaNs by replacing them with the specified string
|
|
715
|
+
if temp_df[feature_name].isnull().any():
|
|
716
|
+
# Convert to object type first to allow string replacement
|
|
717
|
+
temp_df[feature_name] = temp_df[feature_name].astype(object).fillna(fill_na_with)
|
|
718
|
+
|
|
719
|
+
# Convert feature to string to ensure correct plotting order
|
|
720
|
+
temp_df[feature_name] = temp_df[feature_name].astype(str)
|
|
721
|
+
|
|
722
|
+
# 5. Create the plot
|
|
723
|
+
# Increase figure width for categories
|
|
724
|
+
plt.figure(figsize=(max(10, n_unique * 1.2), 7))
|
|
725
|
+
|
|
726
|
+
if plot_type == "box":
|
|
727
|
+
sns.boxplot(x=feature_name, y=target_name, data=temp_df)
|
|
728
|
+
elif plot_type == "violin":
|
|
729
|
+
sns.violinplot(x=feature_name, y=target_name, data=temp_df)
|
|
730
|
+
|
|
731
|
+
plt.title(f'{target_name} vs {feature_name}')
|
|
732
|
+
plt.xlabel(feature_name)
|
|
733
|
+
plt.ylabel(target_name)
|
|
734
|
+
plt.xticks(rotation=45, ha='right')
|
|
735
|
+
plt.grid(True, linestyle='--', alpha=0.6, axis='y')
|
|
736
|
+
plt.tight_layout()
|
|
737
|
+
|
|
738
|
+
# 6. Save the plot
|
|
739
|
+
safe_feature_name = sanitize_filename(feature_name)
|
|
740
|
+
plot_filename = f"{safe_feature_name}_vs_{safe_target_dir_name}.svg"
|
|
741
|
+
plot_path = target_save_dir / plot_filename
|
|
742
|
+
|
|
743
|
+
try:
|
|
744
|
+
plt.savefig(plot_path, bbox_inches="tight", format='svg')
|
|
745
|
+
total_plots_saved += 1
|
|
746
|
+
except Exception as e:
|
|
747
|
+
_LOGGER.error(f"Failed to save plot: {plot_path}. Error: {e}")
|
|
748
|
+
|
|
749
|
+
plt.close()
|
|
750
|
+
|
|
751
|
+
_LOGGER.info(f"Successfully saved {total_plots_saved} categorical-vs-target plots to '{base_save_path}'.")
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
def encode_categorical_features(
|
|
755
|
+
df: pd.DataFrame,
|
|
756
|
+
columns_to_encode: List[str],
|
|
757
|
+
encode_nulls: bool,
|
|
758
|
+
null_label: str = "Other",
|
|
759
|
+
split_resulting_dataset: bool = True,
|
|
760
|
+
verbose: bool = True
|
|
761
|
+
) -> Tuple[Dict[str, Dict[str, int]], pd.DataFrame, Optional[pd.DataFrame]]:
|
|
762
|
+
"""
|
|
763
|
+
Finds unique values in specified categorical columns, encodes them into integers,
|
|
764
|
+
and returns a dictionary containing the mappings for each column.
|
|
765
|
+
|
|
766
|
+
This function automates the label encoding process and generates a simple,
|
|
767
|
+
human-readable dictionary of the mappings.
|
|
768
|
+
|
|
769
|
+
Args:
|
|
770
|
+
df (pd.DataFrame): The input DataFrame.
|
|
771
|
+
columns_to_encode (List[str]): A list of column names to be encoded.
|
|
772
|
+
encode_nulls (bool):
|
|
773
|
+
- If True, encodes Null values as a distinct category 'null_label' with a value of 0. Other categories start from 1.
|
|
774
|
+
- If False, Nulls are ignored and categories start from 0.
|
|
775
|
+
|
|
776
|
+
null_label (str): Category to encode Nulls to if `encode_nulls` is True. If a name collision with `null_label` occurs, the fallback key will be "__NULL__".
|
|
777
|
+
split_resulting_dataset (bool):
|
|
778
|
+
- If True, returns two separate DataFrames, one with non-categorical columns and one with the encoded columns.
|
|
779
|
+
- If False, returns a single DataFrame with all columns.
|
|
780
|
+
verbose (bool): If True, prints encoding progress.
|
|
781
|
+
|
|
782
|
+
Returns:
|
|
783
|
+
Tuple:
|
|
784
|
+
|
|
785
|
+
- Dict[str, Dict[str, int]]: A dictionary where each key is a column name and the value is its category-to-integer mapping.
|
|
786
|
+
|
|
787
|
+
- pd.DataFrame: The original dataframe with or without encoded columns (see `split_resulting_dataset`).
|
|
788
|
+
|
|
789
|
+
- pd.DataFrame | None: If `split_resulting_dataset` is True, the encoded columns as a new dataframe.
|
|
790
|
+
|
|
791
|
+
## **Note:**
|
|
792
|
+
Use `encode_nulls=False` when encoding binary values with missing entries or a malformed encoding will be returned silently.
|
|
793
|
+
"""
|
|
794
|
+
df_encoded = df.copy()
|
|
795
|
+
|
|
796
|
+
# Validate columns
|
|
797
|
+
valid_columns = [col for col in columns_to_encode if col in df_encoded.columns]
|
|
798
|
+
missing_columns = set(columns_to_encode) - set(valid_columns)
|
|
799
|
+
if missing_columns:
|
|
800
|
+
_LOGGER.warning(f"Columns not found and will be skipped: {list(missing_columns)}")
|
|
801
|
+
|
|
802
|
+
mappings: Dict[str, Dict[str, int]] = {}
|
|
803
|
+
|
|
804
|
+
_LOGGER.info(f"Encoding {len(valid_columns)} categorical column(s).")
|
|
805
|
+
for col_name in valid_columns:
|
|
806
|
+
has_nulls = df_encoded[col_name].isnull().any()
|
|
807
|
+
|
|
808
|
+
if encode_nulls and has_nulls:
|
|
809
|
+
# Handle nulls: "Other" -> 0, other categories -> 1, 2, 3...
|
|
810
|
+
categories = sorted([str(cat) for cat in df_encoded[col_name].dropna().unique()])
|
|
811
|
+
# Start mapping from 1 for non-null values
|
|
812
|
+
mapping = {category: i + 1 for i, category in enumerate(categories)}
|
|
813
|
+
|
|
814
|
+
# Apply mapping and fill remaining NaNs with 0
|
|
815
|
+
mapped_series = df_encoded[col_name].astype(str).map(mapping)
|
|
816
|
+
df_encoded[col_name] = mapped_series.fillna(0).astype(int)
|
|
817
|
+
|
|
818
|
+
# --- Validate nulls category---
|
|
819
|
+
# Ensure the key for 0 doesn't collide with a real category.
|
|
820
|
+
if null_label in mapping.keys():
|
|
821
|
+
# COLLISION! null_label is a real category
|
|
822
|
+
original_label = null_label
|
|
823
|
+
null_label = "__NULL__" # fallback
|
|
824
|
+
_LOGGER.warning(f"Column '{col_name}': '{original_label}' is a real category. Mapping nulls (0) to '{null_label}' instead.")
|
|
825
|
+
|
|
826
|
+
# Create the complete user-facing map including "Other"
|
|
827
|
+
user_mapping = {**mapping, null_label: 0}
|
|
828
|
+
mappings[col_name] = user_mapping
|
|
829
|
+
else:
|
|
830
|
+
# ignore nulls
|
|
831
|
+
categories = sorted([str(cat) for cat in df_encoded[col_name].dropna().unique()])
|
|
832
|
+
|
|
833
|
+
mapping = {category: i for i, category in enumerate(categories)}
|
|
834
|
+
|
|
835
|
+
df_encoded[col_name] = df_encoded[col_name].astype(str).map(mapping)
|
|
836
|
+
|
|
837
|
+
mappings[col_name] = mapping
|
|
838
|
+
|
|
839
|
+
if verbose:
|
|
840
|
+
cardinality = len(mappings[col_name])
|
|
841
|
+
print(f" - Encoded '{col_name}' with {cardinality} unique values.")
|
|
842
|
+
|
|
843
|
+
# Handle the dataset splitting logic
|
|
844
|
+
if split_resulting_dataset:
|
|
845
|
+
df_categorical = df_encoded[valid_columns].to_frame() # type: ignore
|
|
846
|
+
df_non_categorical = df.drop(columns=valid_columns)
|
|
847
|
+
return mappings, df_non_categorical, df_categorical
|
|
848
|
+
else:
|
|
849
|
+
return mappings, df_encoded, None
|
|
850
|
+
|
|
851
|
+
|
|
303
852
|
def split_features_targets(df: pd.DataFrame, targets: list[str]):
|
|
304
853
|
"""
|
|
305
854
|
Splits a DataFrame's columns into features and targets.
|
|
@@ -369,9 +918,9 @@ def split_continuous_binary(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFram
|
|
|
369
918
|
return df_cont, df_bin # type: ignore
|
|
370
919
|
|
|
371
920
|
|
|
372
|
-
def plot_correlation_heatmap(df: pd.DataFrame,
|
|
921
|
+
def plot_correlation_heatmap(df: pd.DataFrame,
|
|
922
|
+
plot_title: str,
|
|
373
923
|
save_dir: Union[str, Path, None] = None,
|
|
374
|
-
plot_title: str="Correlation Heatmap",
|
|
375
924
|
method: Literal["pearson", "kendall", "spearman"]="pearson"):
|
|
376
925
|
"""
|
|
377
926
|
Plots a heatmap of pairwise correlations between numeric features in a DataFrame.
|
|
@@ -379,7 +928,7 @@ def plot_correlation_heatmap(df: pd.DataFrame,
|
|
|
379
928
|
Args:
|
|
380
929
|
df (pd.DataFrame): The input dataset.
|
|
381
930
|
save_dir (str | Path | None): If provided, the heatmap will be saved to this directory as a svg file.
|
|
382
|
-
plot_title:
|
|
931
|
+
plot_title: The suffix "`method` Correlation Heatmap" will be automatically appended.
|
|
383
932
|
method (str): Correlation method to use. Must be one of:
|
|
384
933
|
- 'pearson' (default): measures linear correlation (assumes normally distributed data),
|
|
385
934
|
- 'kendall': rank correlation (non-parametric),
|
|
@@ -394,6 +943,9 @@ def plot_correlation_heatmap(df: pd.DataFrame,
|
|
|
394
943
|
if numeric_df.empty:
|
|
395
944
|
_LOGGER.warning("No numeric columns found. Heatmap not generated.")
|
|
396
945
|
return
|
|
946
|
+
if method not in ["pearson", "kendall", "spearman"]:
|
|
947
|
+
_LOGGER.error(f"'method' must be pearson, kendall, or spearman.")
|
|
948
|
+
raise ValueError()
|
|
397
949
|
|
|
398
950
|
corr = numeric_df.corr(method=method)
|
|
399
951
|
|
|
@@ -414,7 +966,10 @@ def plot_correlation_heatmap(df: pd.DataFrame,
|
|
|
414
966
|
cbar_kws={"shrink": 0.8}
|
|
415
967
|
)
|
|
416
968
|
|
|
417
|
-
|
|
969
|
+
# add suffix to title
|
|
970
|
+
full_plot_title = f"{plot_title} - {method.title()} Correlation Heatmap"
|
|
971
|
+
|
|
972
|
+
plt.title(full_plot_title)
|
|
418
973
|
plt.xticks(rotation=45, ha='right')
|
|
419
974
|
plt.yticks(rotation=0)
|
|
420
975
|
|
|
@@ -423,119 +978,119 @@ def plot_correlation_heatmap(df: pd.DataFrame,
|
|
|
423
978
|
if save_dir:
|
|
424
979
|
save_path = make_fullpath(save_dir, make=True)
|
|
425
980
|
# sanitize the plot title to save the file
|
|
426
|
-
|
|
427
|
-
|
|
981
|
+
sanitized_plot_title = sanitize_filename(plot_title)
|
|
982
|
+
plot_filename = sanitized_plot_title + ".svg"
|
|
428
983
|
|
|
429
|
-
full_path = save_path /
|
|
984
|
+
full_path = save_path / plot_filename
|
|
430
985
|
|
|
431
986
|
plt.savefig(full_path, bbox_inches="tight", format='svg')
|
|
432
|
-
_LOGGER.info(f"Saved correlation heatmap: '{
|
|
987
|
+
_LOGGER.info(f"Saved correlation heatmap: '{plot_filename}'")
|
|
433
988
|
|
|
434
989
|
plt.show()
|
|
435
990
|
plt.close()
|
|
436
991
|
|
|
437
|
-
|
|
438
|
-
def plot_value_distributions(df: pd.DataFrame, save_dir: Union[str, Path], bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
992
|
+
# OLD IMPLEMENTATION
|
|
993
|
+
# def plot_value_distributions(df: pd.DataFrame, save_dir: Union[str, Path], bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
|
|
994
|
+
# """
|
|
995
|
+
# Plots and saves the value distributions for all (or selected) columns in a DataFrame,
|
|
996
|
+
# with adaptive binning for numerical columns when appropriate.
|
|
997
|
+
|
|
998
|
+
# For each column both raw counts and relative frequencies are computed and plotted.
|
|
999
|
+
|
|
1000
|
+
# Plots are saved as PNG files under two subdirectories in `save_dir`:
|
|
1001
|
+
# - "Distribution_Counts" for absolute counts.
|
|
1002
|
+
# - "Distribution_Frequency" for relative frequencies.
|
|
1003
|
+
|
|
1004
|
+
# Args:
|
|
1005
|
+
# df (pd.DataFrame): The input DataFrame whose columns are to be analyzed.
|
|
1006
|
+
# save_dir (str | Path): Directory path where the plots will be saved. Will be created if it does not exist.
|
|
1007
|
+
# bin_threshold (int): Minimum number of unique values required to trigger binning
|
|
1008
|
+
# for numerical columns.
|
|
1009
|
+
# skip_cols_with_key (str | None): If provided, any column whose name contains this
|
|
1010
|
+
# substring will be excluded from analysis.
|
|
1011
|
+
|
|
1012
|
+
# Notes:
|
|
1013
|
+
# - Binning is adaptive: if quantile binning results in ≤ 2 unique bins, raw values are used instead.
|
|
1014
|
+
# - All non-alphanumeric characters in column names are sanitized for safe file naming.
|
|
1015
|
+
# - Colormap is automatically adapted based on the number of categories or bins.
|
|
1016
|
+
# """
|
|
1017
|
+
# save_path = make_fullpath(save_dir, make=True)
|
|
463
1018
|
|
|
464
|
-
|
|
465
|
-
|
|
1019
|
+
# dict_to_plot_std = dict()
|
|
1020
|
+
# dict_to_plot_freq = dict()
|
|
466
1021
|
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
1022
|
+
# # cherry-pick columns
|
|
1023
|
+
# if skip_cols_with_key is not None:
|
|
1024
|
+
# columns = [col for col in df.columns if skip_cols_with_key not in col]
|
|
1025
|
+
# else:
|
|
1026
|
+
# columns = df.columns.to_list()
|
|
472
1027
|
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
if binned.nunique() <= 2:
|
|
485
|
-
view_std = df[col].value_counts(sort=False).sort_index()
|
|
486
|
-
else:
|
|
487
|
-
view_std = binned.value_counts(sort=False)
|
|
1028
|
+
# saved_plots = 0
|
|
1029
|
+
# for col in columns:
|
|
1030
|
+
# if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() > bin_threshold:
|
|
1031
|
+
# bins_number = 10
|
|
1032
|
+
# binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
|
|
1033
|
+
# while binned.nunique() <= 2:
|
|
1034
|
+
# bins_number -= 1
|
|
1035
|
+
# binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
|
|
1036
|
+
# if bins_number <= 2:
|
|
1037
|
+
# break
|
|
488
1038
|
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
1039
|
+
# if binned.nunique() <= 2:
|
|
1040
|
+
# view_std = df[col].value_counts(sort=False).sort_index()
|
|
1041
|
+
# else:
|
|
1042
|
+
# view_std = binned.value_counts(sort=False)
|
|
1043
|
+
|
|
1044
|
+
# else:
|
|
1045
|
+
# view_std = df[col].value_counts(sort=False).sort_index()
|
|
1046
|
+
|
|
1047
|
+
# # unlikely scenario where the series is empty
|
|
1048
|
+
# if view_std.sum() == 0:
|
|
1049
|
+
# view_freq = view_std
|
|
1050
|
+
# else:
|
|
1051
|
+
# view_freq = 100 * view_std / view_std.sum() # Percentage
|
|
1052
|
+
# # view_freq = df[col].value_counts(normalize=True, bins=10) # relative percentages
|
|
498
1053
|
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
1054
|
+
# dict_to_plot_std[col] = dict(view_std)
|
|
1055
|
+
# dict_to_plot_freq[col] = dict(view_freq)
|
|
1056
|
+
# saved_plots += 1
|
|
502
1057
|
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
1058
|
+
# # plot helper
|
|
1059
|
+
# def _plot_helper(dict_: dict, target_dir: Path, ylabel: Literal["Frequency", "Counts"], base_fontsize: int=12):
|
|
1060
|
+
# for col, data in dict_.items():
|
|
1061
|
+
# safe_col = sanitize_filename(col)
|
|
507
1062
|
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
1063
|
+
# if isinstance(list(data.keys())[0], pd.Interval):
|
|
1064
|
+
# labels = [str(interval) for interval in data.keys()]
|
|
1065
|
+
# else:
|
|
1066
|
+
# labels = data.keys()
|
|
512
1067
|
|
|
513
|
-
|
|
514
|
-
|
|
1068
|
+
# plt.figure(figsize=(10, 6))
|
|
1069
|
+
# colors = plt.cm.tab20.colors if len(data) <= 20 else plt.cm.viridis(np.linspace(0, 1, len(data))) # type: ignore
|
|
515
1070
|
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
1071
|
+
# plt.bar(labels, data.values(), color=colors[:len(data)], alpha=0.85)
|
|
1072
|
+
# plt.xlabel("Values", fontsize=base_fontsize)
|
|
1073
|
+
# plt.ylabel(ylabel, fontsize=base_fontsize)
|
|
1074
|
+
# plt.title(f"Value Distribution for '{col}'", fontsize=base_fontsize+2)
|
|
1075
|
+
# plt.xticks(rotation=45, ha='right', fontsize=base_fontsize-2)
|
|
1076
|
+
# plt.yticks(fontsize=base_fontsize-2)
|
|
1077
|
+
# plt.grid(axis='y', linestyle='--', alpha=0.6)
|
|
1078
|
+
# plt.gca().set_facecolor('#f9f9f9')
|
|
1079
|
+
# plt.tight_layout()
|
|
525
1080
|
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
1081
|
+
# plot_path = target_dir / f"{safe_col}.png"
|
|
1082
|
+
# plt.savefig(plot_path, dpi=300, bbox_inches="tight")
|
|
1083
|
+
# plt.close()
|
|
529
1084
|
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
1085
|
+
# # Save plots
|
|
1086
|
+
# freq_dir = save_path / "Distribution_Frequency"
|
|
1087
|
+
# std_dir = save_path / "Distribution_Counts"
|
|
1088
|
+
# freq_dir.mkdir(parents=True, exist_ok=True)
|
|
1089
|
+
# std_dir.mkdir(parents=True, exist_ok=True)
|
|
1090
|
+
# _plot_helper(dict_=dict_to_plot_std, target_dir=std_dir, ylabel="Counts")
|
|
1091
|
+
# _plot_helper(dict_=dict_to_plot_freq, target_dir=freq_dir, ylabel="Frequency")
|
|
537
1092
|
|
|
538
|
-
|
|
1093
|
+
# _LOGGER.info(f"Saved {saved_plots} value distribution plots.")
|
|
539
1094
|
|
|
540
1095
|
|
|
541
1096
|
def clip_outliers_single(
|
|
@@ -628,7 +1183,99 @@ def clip_outliers_multi(
|
|
|
628
1183
|
if skipped_columns:
|
|
629
1184
|
_LOGGER.warning("Skipped columns:")
|
|
630
1185
|
for col, msg in skipped_columns:
|
|
631
|
-
print(f" - {col}
|
|
1186
|
+
print(f" - {col}")
|
|
1187
|
+
|
|
1188
|
+
return new_df
|
|
1189
|
+
|
|
1190
|
+
|
|
1191
|
+
def drop_outlier_samples(
|
|
1192
|
+
df: pd.DataFrame,
|
|
1193
|
+
bounds_dict: Dict[str, Tuple[Union[int, float], Union[int, float]]],
|
|
1194
|
+
drop_on_nulls: bool = False,
|
|
1195
|
+
verbose: bool = True
|
|
1196
|
+
) -> pd.DataFrame:
|
|
1197
|
+
"""
|
|
1198
|
+
Drops entire rows where values in specified numeric columns fall outside
|
|
1199
|
+
a given [min, max] range.
|
|
1200
|
+
|
|
1201
|
+
This function processes a copy of the DataFrame, ensuring the original is
|
|
1202
|
+
not modified. It skips columns with invalid specifications.
|
|
1203
|
+
|
|
1204
|
+
Args:
|
|
1205
|
+
df (pd.DataFrame): The input DataFrame.
|
|
1206
|
+
bounds_dict (dict): A dictionary where keys are column names and values
|
|
1207
|
+
are (min_val, max_val) tuples defining the valid range.
|
|
1208
|
+
drop_on_nulls (bool): If True, rows with NaN/None in a checked column
|
|
1209
|
+
will also be dropped. If False, NaN/None are ignored.
|
|
1210
|
+
verbose (bool): If True, prints the number of rows dropped for each column.
|
|
1211
|
+
|
|
1212
|
+
Returns:
|
|
1213
|
+
pd.DataFrame: A new DataFrame with the outlier rows removed.
|
|
1214
|
+
|
|
1215
|
+
Notes:
|
|
1216
|
+
- Invalid specifications (e.g., missing column, non-numeric type,
|
|
1217
|
+
incorrectly formatted bounds) will be reported and skipped.
|
|
1218
|
+
"""
|
|
1219
|
+
new_df = df.copy()
|
|
1220
|
+
skipped_columns: List[Tuple[str, str]] = []
|
|
1221
|
+
initial_rows = len(new_df)
|
|
1222
|
+
|
|
1223
|
+
for col, bounds in bounds_dict.items():
|
|
1224
|
+
try:
|
|
1225
|
+
# --- Validation Checks ---
|
|
1226
|
+
if col not in df.columns:
|
|
1227
|
+
_LOGGER.error(f"Column '{col}' not found in DataFrame.")
|
|
1228
|
+
raise ValueError()
|
|
1229
|
+
|
|
1230
|
+
if not pd.api.types.is_numeric_dtype(df[col]):
|
|
1231
|
+
_LOGGER.error(f"Column '{col}' is not of a numeric data type.")
|
|
1232
|
+
raise TypeError()
|
|
1233
|
+
|
|
1234
|
+
if not (isinstance(bounds, tuple) and len(bounds) == 2):
|
|
1235
|
+
_LOGGER.error(f"Bounds for '{col}' must be a tuple of (min, max).")
|
|
1236
|
+
raise ValueError()
|
|
1237
|
+
|
|
1238
|
+
# --- Filtering Logic ---
|
|
1239
|
+
min_val, max_val = bounds
|
|
1240
|
+
rows_before_drop = len(new_df)
|
|
1241
|
+
|
|
1242
|
+
# Create the base mask for values within the specified range
|
|
1243
|
+
# .between() is inclusive and evaluates to False for NaN
|
|
1244
|
+
mask_in_bounds = new_df[col].between(min_val, max_val)
|
|
1245
|
+
|
|
1246
|
+
if drop_on_nulls:
|
|
1247
|
+
# Keep only rows that are within bounds.
|
|
1248
|
+
# Since mask_in_bounds is False for NaN, nulls are dropped.
|
|
1249
|
+
final_mask = mask_in_bounds
|
|
1250
|
+
else:
|
|
1251
|
+
# Keep rows that are within bounds OR are null.
|
|
1252
|
+
mask_is_null = new_df[col].isnull()
|
|
1253
|
+
final_mask = mask_in_bounds | mask_is_null
|
|
1254
|
+
|
|
1255
|
+
# Apply the final mask
|
|
1256
|
+
new_df = new_df[final_mask]
|
|
1257
|
+
|
|
1258
|
+
rows_after_drop = len(new_df)
|
|
1259
|
+
|
|
1260
|
+
if verbose:
|
|
1261
|
+
dropped_count = rows_before_drop - rows_after_drop
|
|
1262
|
+
if dropped_count > 0:
|
|
1263
|
+
print(
|
|
1264
|
+
f" - Column '{col}': Dropped {dropped_count} rows with values outside range [{min_val}, {max_val}]."
|
|
1265
|
+
)
|
|
1266
|
+
|
|
1267
|
+
except (ValueError, TypeError) as e:
|
|
1268
|
+
skipped_columns.append((col, str(e)))
|
|
1269
|
+
continue
|
|
1270
|
+
|
|
1271
|
+
total_dropped = initial_rows - len(new_df)
|
|
1272
|
+
_LOGGER.info(f"Finished processing. Total rows dropped: {total_dropped}.")
|
|
1273
|
+
|
|
1274
|
+
if skipped_columns:
|
|
1275
|
+
_LOGGER.warning("Skipped the following columns due to errors:")
|
|
1276
|
+
for col, msg in skipped_columns:
|
|
1277
|
+
# Only print the column name for cleaner output as the error was already logged
|
|
1278
|
+
print(f" - {col}")
|
|
632
1279
|
|
|
633
1280
|
return new_df
|
|
634
1281
|
|
|
@@ -667,7 +1314,8 @@ def standardize_percentages(
|
|
|
667
1314
|
df: pd.DataFrame,
|
|
668
1315
|
columns: list[str],
|
|
669
1316
|
treat_one_as_proportion: bool = True,
|
|
670
|
-
round_digits: int = 2
|
|
1317
|
+
round_digits: int = 2,
|
|
1318
|
+
verbose: bool=True
|
|
671
1319
|
) -> pd.DataFrame:
|
|
672
1320
|
"""
|
|
673
1321
|
Standardizes numeric columns containing mixed-format percentages.
|
|
@@ -708,6 +1356,8 @@ def standardize_percentages(
|
|
|
708
1356
|
|
|
709
1357
|
# Otherwise, the value is assumed to be a correctly formatted percentage
|
|
710
1358
|
return x
|
|
1359
|
+
|
|
1360
|
+
fixed_columns: list[str] = list()
|
|
711
1361
|
|
|
712
1362
|
for col in columns:
|
|
713
1363
|
# --- Robustness Checks ---
|
|
@@ -725,10 +1375,343 @@ def standardize_percentages(
|
|
|
725
1375
|
|
|
726
1376
|
# Round the result
|
|
727
1377
|
df_copy[col] = df_copy[col].round(round_digits)
|
|
1378
|
+
|
|
1379
|
+
fixed_columns.append(col)
|
|
1380
|
+
|
|
1381
|
+
if verbose:
|
|
1382
|
+
_LOGGER.info(f"Columns standardized:")
|
|
1383
|
+
for fixed_col in fixed_columns:
|
|
1384
|
+
print(f" '{fixed_col}'")
|
|
728
1385
|
|
|
729
1386
|
return df_copy
|
|
730
1387
|
|
|
731
1388
|
|
|
1389
|
+
def reconstruct_one_hot(
|
|
1390
|
+
df: pd.DataFrame,
|
|
1391
|
+
features_to_reconstruct: List[Union[str, Tuple[str, Optional[str]]]],
|
|
1392
|
+
separator: str = '_',
|
|
1393
|
+
baseline_category_name: Optional[str] = "Other",
|
|
1394
|
+
drop_original: bool = True,
|
|
1395
|
+
verbose: bool = True
|
|
1396
|
+
) -> pd.DataFrame:
|
|
1397
|
+
"""
|
|
1398
|
+
Reconstructs original categorical columns from a one-hot encoded DataFrame.
|
|
1399
|
+
|
|
1400
|
+
This function identifies groups of one-hot encoded columns based on a common
|
|
1401
|
+
prefix (base feature name) and a separator. It then collapses each group
|
|
1402
|
+
into a single column containing the categorical value.
|
|
1403
|
+
|
|
1404
|
+
Args:
|
|
1405
|
+
df (pd.DataFrame):
|
|
1406
|
+
The input DataFrame with one-hot encoded columns.
|
|
1407
|
+
features_to_reconstruct (List[str | Tuple[str, str | None]]):
|
|
1408
|
+
A list defining the features to reconstruct. This list can contain:
|
|
1409
|
+
|
|
1410
|
+
- A string: (e.g., "Color")
|
|
1411
|
+
This reconstructs the feature 'Color' and assumes all-zero rows represent the baseline category ("Other" by default).
|
|
1412
|
+
- A tuple: (e.g., ("Pet", "Dog"))
|
|
1413
|
+
This reconstructs 'Pet' and maps all-zero rows to the baseline category "Dog".
|
|
1414
|
+
- A tuple with None: (e.g., ("Size", None))
|
|
1415
|
+
This reconstructs 'Size' and maps all-zero rows to the NaN value.
|
|
1416
|
+
Example:
|
|
1417
|
+
[
|
|
1418
|
+
"Mood", # All-zeros -> "Other"
|
|
1419
|
+
("Color", "Red"), # All-zeros -> "Red"
|
|
1420
|
+
("Size", None) # All-zeros -> NaN
|
|
1421
|
+
]
|
|
1422
|
+
separator (str):
|
|
1423
|
+
The character separating the base name from the categorical value in
|
|
1424
|
+
the column names (e.g., '_' in 'B_a').
|
|
1425
|
+
baseline_category_name (str | None):
|
|
1426
|
+
The baseline category name to use by default if it is not explicitly provided.
|
|
1427
|
+
drop_original (bool):
|
|
1428
|
+
If True, the original one-hot encoded columns will be dropped from
|
|
1429
|
+
the returned DataFrame.
|
|
1430
|
+
|
|
1431
|
+
Returns:
|
|
1432
|
+
pd.DataFrame:
|
|
1433
|
+
A new DataFrame with the specified one-hot encoded features
|
|
1434
|
+
reconstructed into single categorical columns.
|
|
1435
|
+
|
|
1436
|
+
<br>
|
|
1437
|
+
|
|
1438
|
+
## Note:
|
|
1439
|
+
|
|
1440
|
+
This function is designed to be robust, but users should be aware of two key edge cases:
|
|
1441
|
+
|
|
1442
|
+
1. **Ambiguous Base Feature Prefixes**: If `base_feature_names` list contains names where one is a prefix of another (e.g., `['feat', 'feat_ext']`), the order is critical. The function will match columns greedily. To avoid incorrect grouping, always list the **most specific base names first** (e.g., `['feat_ext', 'feat']`).
|
|
1443
|
+
|
|
1444
|
+
2. **Malformed One-Hot Data**: If a row contains multiple `1`s within the same feature group (e.g., both `B_a` and `B_c` are `1`), the function will not raise an error. It uses `.idxmax()`, which returns the first column that contains the maximum value. This means it will silently select the first category it encounters and ignore the others, potentially masking an upstream data issue.
|
|
1445
|
+
"""
|
|
1446
|
+
if not isinstance(df, pd.DataFrame):
|
|
1447
|
+
_LOGGER.error("Input must be a pandas DataFrame.")
|
|
1448
|
+
raise TypeError()
|
|
1449
|
+
|
|
1450
|
+
if not (baseline_category_name is None or isinstance(baseline_category_name, str)):
|
|
1451
|
+
_LOGGER.error("The baseline_category must be None or a string.")
|
|
1452
|
+
raise TypeError()
|
|
1453
|
+
|
|
1454
|
+
new_df = df.copy()
|
|
1455
|
+
all_ohe_cols_to_drop = []
|
|
1456
|
+
reconstructed_count = 0
|
|
1457
|
+
|
|
1458
|
+
# --- 1. Parse and validate the reconstruction config ---
|
|
1459
|
+
# This normalizes the input into a clean {base_name: baseline_val} dict
|
|
1460
|
+
reconstruction_config: Dict[str, Optional[str]] = {}
|
|
1461
|
+
try:
|
|
1462
|
+
for item in features_to_reconstruct:
|
|
1463
|
+
if isinstance(item, str):
|
|
1464
|
+
# Case 1: "Color"
|
|
1465
|
+
base_name = item
|
|
1466
|
+
baseline_val = baseline_category_name
|
|
1467
|
+
elif isinstance(item, tuple) and len(item) == 2:
|
|
1468
|
+
# Case 2: ("Pet", "dog") or ("Size", None)
|
|
1469
|
+
base_name, baseline_val = item
|
|
1470
|
+
if not (isinstance(base_name, str) and (isinstance(baseline_val, str) or baseline_val is None)):
|
|
1471
|
+
_LOGGER.error(f"Invalid tuple format for '{item}'. Must be (str, str|None).")
|
|
1472
|
+
raise ValueError()
|
|
1473
|
+
else:
|
|
1474
|
+
_LOGGER.error(f"Invalid item '{item}'. Must be str or (str, str|None) tuple.")
|
|
1475
|
+
raise ValueError()
|
|
1476
|
+
|
|
1477
|
+
if base_name in reconstruction_config and verbose:
|
|
1478
|
+
_LOGGER.warning(f"Duplicate entry for '{base_name}' found. Using the last provided configuration.")
|
|
1479
|
+
|
|
1480
|
+
reconstruction_config[base_name] = baseline_val
|
|
1481
|
+
|
|
1482
|
+
except Exception as e:
|
|
1483
|
+
_LOGGER.error(f"Failed to parse 'features_to_reconstruct' argument: {e}")
|
|
1484
|
+
raise ValueError("Invalid configuration for 'features_to_reconstruct'.") from e
|
|
1485
|
+
|
|
1486
|
+
_LOGGER.info(f"Attempting to reconstruct {len(reconstruction_config)} one-hot encoded feature(s).")
|
|
1487
|
+
|
|
1488
|
+
# Main logic
|
|
1489
|
+
for base_name, baseline_category in reconstruction_config.items():
|
|
1490
|
+
# Regex to find all columns belonging to this base feature.
|
|
1491
|
+
pattern = f"^{re.escape(base_name)}{re.escape(separator)}"
|
|
1492
|
+
|
|
1493
|
+
# Find matching columns
|
|
1494
|
+
ohe_cols = [col for col in df.columns if re.match(pattern, col)]
|
|
1495
|
+
|
|
1496
|
+
if not ohe_cols:
|
|
1497
|
+
_LOGGER.warning(f"No one-hot encoded columns found for base feature '{base_name}'. Skipping.")
|
|
1498
|
+
continue
|
|
1499
|
+
|
|
1500
|
+
# For each row, find the column name with the maximum value (which is 1)
|
|
1501
|
+
reconstructed_series = new_df[ohe_cols].idxmax(axis=1) # type: ignore
|
|
1502
|
+
|
|
1503
|
+
# Extract the categorical value (the suffix) from the column name
|
|
1504
|
+
# Use n=1 in split to handle cases where the category itself might contain the separator
|
|
1505
|
+
new_column_values = reconstructed_series.str.split(separator, n=1).str[1]
|
|
1506
|
+
|
|
1507
|
+
# Handle rows where all OHE columns were 0 (e.g., original value was NaN or a dropped baseline).
|
|
1508
|
+
all_zero_mask = new_df[ohe_cols].sum(axis=1) == 0 # type: ignore
|
|
1509
|
+
|
|
1510
|
+
if baseline_category is not None:
|
|
1511
|
+
# A baseline category was provided
|
|
1512
|
+
new_column_values.loc[all_zero_mask] = baseline_category
|
|
1513
|
+
else:
|
|
1514
|
+
# No baseline provided: assign NaN
|
|
1515
|
+
new_column_values.loc[all_zero_mask] = np.nan # type: ignore
|
|
1516
|
+
|
|
1517
|
+
if verbose:
|
|
1518
|
+
print(f" - Mapped 'all-zero' rows for '{base_name}' to baseline: '{baseline_category}'.")
|
|
1519
|
+
|
|
1520
|
+
# Assign the new reconstructed column to the DataFrame
|
|
1521
|
+
new_df[base_name] = new_column_values
|
|
1522
|
+
|
|
1523
|
+
all_ohe_cols_to_drop.extend(ohe_cols)
|
|
1524
|
+
reconstructed_count += 1
|
|
1525
|
+
if verbose:
|
|
1526
|
+
print(f" - Reconstructed '{base_name}' from {len(ohe_cols)} columns.")
|
|
1527
|
+
|
|
1528
|
+
# Cleanup
|
|
1529
|
+
if drop_original and all_ohe_cols_to_drop:
|
|
1530
|
+
# Drop the original OHE columns, ensuring no duplicates in the drop list
|
|
1531
|
+
unique_cols_to_drop = list(set(all_ohe_cols_to_drop))
|
|
1532
|
+
new_df.drop(columns=unique_cols_to_drop, inplace=True)
|
|
1533
|
+
_LOGGER.info(f"Dropped {len(unique_cols_to_drop)} original one-hot encoded columns.")
|
|
1534
|
+
|
|
1535
|
+
_LOGGER.info(f"Successfully reconstructed {reconstructed_count} feature(s).")
|
|
1536
|
+
|
|
1537
|
+
return new_df
|
|
1538
|
+
|
|
1539
|
+
|
|
1540
|
+
def reconstruct_binary(
|
|
1541
|
+
df: pd.DataFrame,
|
|
1542
|
+
reconstruction_map: Dict[str, Tuple[str, Any, Any]],
|
|
1543
|
+
drop_original: bool = True,
|
|
1544
|
+
verbose: bool = True
|
|
1545
|
+
) -> pd.DataFrame:
|
|
1546
|
+
"""
|
|
1547
|
+
Reconstructs new categorical columns from existing binary (0/1) columns.
|
|
1548
|
+
|
|
1549
|
+
Used to reverse a binary encoding by mapping 0 and 1 back to
|
|
1550
|
+
descriptive categorical labels.
|
|
1551
|
+
|
|
1552
|
+
Args:
|
|
1553
|
+
df (pd.DataFrame):
|
|
1554
|
+
The input DataFrame.
|
|
1555
|
+
reconstruction_map (Dict[str, Tuple[str, Any, Any]]):
|
|
1556
|
+
A dictionary defining the reconstructions.
|
|
1557
|
+
Format:
|
|
1558
|
+
{ "new_col_name": ("source_col_name", "label_for_0", "label_for_1") }
|
|
1559
|
+
Example:
|
|
1560
|
+
{
|
|
1561
|
+
"Sex": ("Sex_male", "Female", "Male"),
|
|
1562
|
+
"Smoker": ("Is_Smoker", "No", "Yes")
|
|
1563
|
+
}
|
|
1564
|
+
drop_original (bool):
|
|
1565
|
+
If True, the original binary source columns (e.g., "Sex_male")
|
|
1566
|
+
will be dropped from the returned DataFrame.
|
|
1567
|
+
verbose (bool):
|
|
1568
|
+
If True, prints the details of each reconstruction.
|
|
1569
|
+
|
|
1570
|
+
Returns:
|
|
1571
|
+
pd.DataFrame:
|
|
1572
|
+
A new DataFrame with the reconstructed categorical columns.
|
|
1573
|
+
|
|
1574
|
+
Raises:
|
|
1575
|
+
TypeError: If `df` is not a pandas DataFrame.
|
|
1576
|
+
ValueError: If `reconstruction_map` is not a dictionary or a
|
|
1577
|
+
configuration is invalid (e.g., column name collision).
|
|
1578
|
+
|
|
1579
|
+
Notes:
|
|
1580
|
+
- The function operates on a copy of the DataFrame.
|
|
1581
|
+
- Rows with `NaN` in the source column will have `NaN` in the
|
|
1582
|
+
new column.
|
|
1583
|
+
- Values in the source column other than 0 or 1 (e.g., 2) will
|
|
1584
|
+
result in `NaN` in the new column.
|
|
1585
|
+
"""
|
|
1586
|
+
if not isinstance(df, pd.DataFrame):
|
|
1587
|
+
_LOGGER.error("Input must be a pandas DataFrame.")
|
|
1588
|
+
raise TypeError()
|
|
1589
|
+
|
|
1590
|
+
if not isinstance(reconstruction_map, dict):
|
|
1591
|
+
_LOGGER.error("`reconstruction_map` must be a dictionary with the required format.")
|
|
1592
|
+
raise ValueError()
|
|
1593
|
+
|
|
1594
|
+
new_df = df.copy()
|
|
1595
|
+
source_cols_to_drop: List[str] = []
|
|
1596
|
+
reconstructed_count = 0
|
|
1597
|
+
|
|
1598
|
+
_LOGGER.info(f"Attempting to reconstruct {len(reconstruction_map)} binary feature(s).")
|
|
1599
|
+
|
|
1600
|
+
for new_col_name, config in reconstruction_map.items():
|
|
1601
|
+
|
|
1602
|
+
# --- 1. Validation ---
|
|
1603
|
+
if not (isinstance(config, tuple) and len(config) == 3):
|
|
1604
|
+
_LOGGER.error(f"Config for '{new_col_name}' is invalid. Must be a 3-item tuple. Skipping.")
|
|
1605
|
+
raise ValueError()
|
|
1606
|
+
|
|
1607
|
+
source_col, label_for_0, label_for_1 = config
|
|
1608
|
+
|
|
1609
|
+
if source_col not in new_df.columns:
|
|
1610
|
+
_LOGGER.error(f"Source column '{source_col}' for new column '{new_col_name}' not found. Skipping.")
|
|
1611
|
+
raise ValueError()
|
|
1612
|
+
|
|
1613
|
+
if new_col_name in new_df.columns and verbose:
|
|
1614
|
+
_LOGGER.warning(f"New column '{new_col_name}' already exists and will be overwritten.")
|
|
1615
|
+
|
|
1616
|
+
if new_col_name == source_col:
|
|
1617
|
+
_LOGGER.error(f"New column name '{new_col_name}' cannot be the same as source column '{source_col}'.")
|
|
1618
|
+
raise ValueError()
|
|
1619
|
+
|
|
1620
|
+
# --- 2. Reconstruction ---
|
|
1621
|
+
# .map() handles 0, 1, preserves NaNs, and converts any other value to NaN.
|
|
1622
|
+
mapping_dict = {0: label_for_0, 1: label_for_1}
|
|
1623
|
+
new_df[new_col_name] = new_df[source_col].map(mapping_dict)
|
|
1624
|
+
|
|
1625
|
+
# --- 3. Logging/Tracking ---
|
|
1626
|
+
source_cols_to_drop.append(source_col)
|
|
1627
|
+
reconstructed_count += 1
|
|
1628
|
+
if verbose:
|
|
1629
|
+
print(f" - Reconstructed '{new_col_name}' from '{source_col}' (0='{label_for_0}', 1='{label_for_1}').")
|
|
1630
|
+
|
|
1631
|
+
# --- 4. Cleanup ---
|
|
1632
|
+
if drop_original and source_cols_to_drop:
|
|
1633
|
+
# Use set() to avoid duplicates if the same source col was used
|
|
1634
|
+
unique_cols_to_drop = list(set(source_cols_to_drop))
|
|
1635
|
+
new_df.drop(columns=unique_cols_to_drop, inplace=True)
|
|
1636
|
+
_LOGGER.info(f"Dropped {len(unique_cols_to_drop)} original binary source column(s).")
|
|
1637
|
+
|
|
1638
|
+
_LOGGER.info(f"Successfully reconstructed {reconstructed_count} feature(s).")
|
|
1639
|
+
|
|
1640
|
+
return new_df
|
|
1641
|
+
|
|
1642
|
+
|
|
1643
|
+
def finalize_feature_schema(
|
|
1644
|
+
df_features: pd.DataFrame,
|
|
1645
|
+
categorical_mappings: Optional[Dict[str, Dict[str, int]]]
|
|
1646
|
+
) -> FeatureSchema:
|
|
1647
|
+
"""
|
|
1648
|
+
Analyzes the final features DataFrame to create a definitive schema.
|
|
1649
|
+
|
|
1650
|
+
This function is the "single source of truth" for column order
|
|
1651
|
+
and type (categorical vs. continuous) for the entire ML pipeline.
|
|
1652
|
+
|
|
1653
|
+
It should be called at the end of the feature engineering process.
|
|
1654
|
+
|
|
1655
|
+
Args:
|
|
1656
|
+
df_features (pd.DataFrame):
|
|
1657
|
+
The final, processed DataFrame containing *only* feature columns
|
|
1658
|
+
in the exact order they will be fed to the model.
|
|
1659
|
+
categorical_mappings (Dict[str, Dict[str, int]] | None):
|
|
1660
|
+
The mappings dictionary generated by
|
|
1661
|
+
`encode_categorical_features`. Can be None if no
|
|
1662
|
+
categorical features exist.
|
|
1663
|
+
|
|
1664
|
+
Returns:
|
|
1665
|
+
FeatureSchema: A NamedTuple containing all necessary metadata for the pipeline.
|
|
1666
|
+
"""
|
|
1667
|
+
feature_names: List[str] = df_features.columns.to_list()
|
|
1668
|
+
|
|
1669
|
+
# Intermediate lists for building
|
|
1670
|
+
continuous_feature_names_list: List[str] = []
|
|
1671
|
+
categorical_feature_names_list: List[str] = []
|
|
1672
|
+
categorical_index_map_dict: Dict[int, int] = {}
|
|
1673
|
+
|
|
1674
|
+
# _LOGGER.info("Finalizing feature schema...")
|
|
1675
|
+
|
|
1676
|
+
if categorical_mappings:
|
|
1677
|
+
# --- Categorical features are present ---
|
|
1678
|
+
categorical_names_set = set(categorical_mappings.keys())
|
|
1679
|
+
|
|
1680
|
+
for index, name in enumerate(feature_names):
|
|
1681
|
+
if name in categorical_names_set:
|
|
1682
|
+
# This is a categorical feature
|
|
1683
|
+
cardinality = len(categorical_mappings[name])
|
|
1684
|
+
categorical_index_map_dict[index] = cardinality
|
|
1685
|
+
categorical_feature_names_list.append(name)
|
|
1686
|
+
else:
|
|
1687
|
+
# This is a continuous feature
|
|
1688
|
+
continuous_feature_names_list.append(name)
|
|
1689
|
+
|
|
1690
|
+
# Use the populated dict, or None if it's empty
|
|
1691
|
+
final_index_map = categorical_index_map_dict if categorical_index_map_dict else None
|
|
1692
|
+
|
|
1693
|
+
else:
|
|
1694
|
+
# --- No categorical features ---
|
|
1695
|
+
_LOGGER.info("No categorical mappings provided. Treating all features as continuous.")
|
|
1696
|
+
continuous_feature_names_list = list(feature_names)
|
|
1697
|
+
# categorical_feature_names_list remains empty
|
|
1698
|
+
# categorical_index_map_dict remains empty
|
|
1699
|
+
final_index_map = None # Explicitly set to None to match Optional type
|
|
1700
|
+
|
|
1701
|
+
_LOGGER.info(f"Schema created: {len(continuous_feature_names_list)} continuous, {len(categorical_feature_names_list)} categorical.")
|
|
1702
|
+
|
|
1703
|
+
# Create the final immutable instance
|
|
1704
|
+
schema_instance = FeatureSchema(
|
|
1705
|
+
feature_names=tuple(feature_names),
|
|
1706
|
+
continuous_feature_names=tuple(continuous_feature_names_list),
|
|
1707
|
+
categorical_feature_names=tuple(categorical_feature_names_list),
|
|
1708
|
+
categorical_index_map=final_index_map,
|
|
1709
|
+
categorical_mappings=categorical_mappings
|
|
1710
|
+
)
|
|
1711
|
+
|
|
1712
|
+
return schema_instance
|
|
1713
|
+
|
|
1714
|
+
|
|
732
1715
|
def _validate_columns(df: pd.DataFrame, columns: list[str]):
|
|
733
1716
|
valid_columns = [column for column in columns if column in df.columns]
|
|
734
1717
|
return valid_columns
|