dragon-ml-toolbox 10.2.0__py3-none-any.whl → 14.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (48) hide show
  1. {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/METADATA +38 -63
  2. dragon_ml_toolbox-14.2.0.dist-info/RECORD +48 -0
  3. {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/licenses/LICENSE +1 -1
  4. {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +11 -0
  5. ml_tools/ETL_cleaning.py +72 -34
  6. ml_tools/ETL_engineering.py +506 -70
  7. ml_tools/GUI_tools.py +2 -1
  8. ml_tools/MICE_imputation.py +212 -7
  9. ml_tools/ML_callbacks.py +73 -40
  10. ml_tools/ML_datasetmaster.py +267 -284
  11. ml_tools/ML_evaluation.py +119 -58
  12. ml_tools/ML_evaluation_multi.py +107 -32
  13. ml_tools/ML_inference.py +15 -5
  14. ml_tools/ML_models.py +234 -170
  15. ml_tools/ML_models_advanced.py +323 -0
  16. ml_tools/ML_optimization.py +321 -97
  17. ml_tools/ML_scaler.py +10 -5
  18. ml_tools/ML_trainer.py +585 -40
  19. ml_tools/ML_utilities.py +528 -0
  20. ml_tools/ML_vision_datasetmaster.py +1315 -0
  21. ml_tools/ML_vision_evaluation.py +260 -0
  22. ml_tools/ML_vision_inference.py +428 -0
  23. ml_tools/ML_vision_models.py +627 -0
  24. ml_tools/ML_vision_transformers.py +58 -0
  25. ml_tools/PSO_optimization.py +10 -7
  26. ml_tools/RNN_forecast.py +2 -0
  27. ml_tools/SQL.py +22 -9
  28. ml_tools/VIF_factor.py +4 -3
  29. ml_tools/_ML_vision_recipe.py +88 -0
  30. ml_tools/__init__.py +1 -0
  31. ml_tools/_logger.py +0 -2
  32. ml_tools/_schema.py +96 -0
  33. ml_tools/constants.py +79 -0
  34. ml_tools/custom_logger.py +164 -16
  35. ml_tools/data_exploration.py +1092 -109
  36. ml_tools/ensemble_evaluation.py +48 -1
  37. ml_tools/ensemble_inference.py +6 -7
  38. ml_tools/ensemble_learning.py +4 -3
  39. ml_tools/handle_excel.py +1 -0
  40. ml_tools/keys.py +80 -0
  41. ml_tools/math_utilities.py +259 -0
  42. ml_tools/optimization_tools.py +198 -24
  43. ml_tools/path_manager.py +144 -45
  44. ml_tools/serde.py +192 -0
  45. ml_tools/utilities.py +287 -227
  46. dragon_ml_toolbox-10.2.0.dist-info/RECORD +0 -36
  47. {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/WHEEL +0 -0
  48. {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,17 @@
1
1
  import pandas as pd
2
- from pandas.api.types import is_numeric_dtype
2
+ from pandas.api.types import is_numeric_dtype, is_object_dtype
3
3
  import numpy as np
4
4
  import matplotlib.pyplot as plt
5
5
  import seaborn as sns
6
- from typing import Union, Literal, Dict, Tuple, List, Optional
6
+ from typing import Union, Literal, Dict, Tuple, List, Optional, Any
7
7
  from pathlib import Path
8
8
  import re
9
9
 
10
10
  from .path_manager import sanitize_filename, make_fullpath
11
11
  from ._script_info import _script_info
12
12
  from ._logger import _LOGGER
13
- from .utilities import save_dataframe
14
-
13
+ from .utilities import save_dataframe_filename
14
+ from ._schema import FeatureSchema
15
15
 
16
16
  # Keep track of all available tools, show using `info()`
17
17
  __all__ = [
@@ -21,14 +21,22 @@ __all__ = [
21
21
  "show_null_columns",
22
22
  "drop_columns_with_missing_data",
23
23
  "drop_macro",
24
+ "clean_column_names",
25
+ "plot_value_distributions",
26
+ "plot_continuous_vs_target",
27
+ "plot_categorical_vs_target",
28
+ "encode_categorical_features",
24
29
  "split_features_targets",
25
30
  "split_continuous_binary",
26
- "plot_correlation_heatmap",
27
- "plot_value_distributions",
28
31
  "clip_outliers_single",
29
32
  "clip_outliers_multi",
33
+ "drop_outlier_samples",
34
+ "plot_correlation_heatmap",
30
35
  "match_and_filter_columns_by_regex",
31
- "standardize_percentages"
36
+ "standardize_percentages",
37
+ "reconstruct_one_hot",
38
+ "reconstruct_binary",
39
+ "finalize_feature_schema"
32
40
  ]
33
41
 
34
42
 
@@ -263,7 +271,7 @@ def drop_macro(df: pd.DataFrame,
263
271
 
264
272
  # Log initial state
265
273
  missing_data = show_null_columns(df=df_clean)
266
- save_dataframe(df=missing_data.reset_index(drop=False),
274
+ save_dataframe_filename(df=missing_data.reset_index(drop=False),
267
275
  save_dir=log_directory,
268
276
  filename="Missing_Data_start")
269
277
 
@@ -292,7 +300,7 @@ def drop_macro(df: pd.DataFrame,
292
300
 
293
301
  # log final state
294
302
  missing_data = show_null_columns(df=df_clean)
295
- save_dataframe(df=missing_data.reset_index(drop=False),
303
+ save_dataframe_filename(df=missing_data.reset_index(drop=False),
296
304
  save_dir=log_directory,
297
305
  filename="Missing_Data_final")
298
306
 
@@ -300,6 +308,547 @@ def drop_macro(df: pd.DataFrame,
300
308
  return df_clean
301
309
 
302
310
 
311
+ def clean_column_names(df: pd.DataFrame, replacement_char: str = '-', replacement_pattern: str = r'[\[\]{}<>,:"]', verbose: bool = True) -> pd.DataFrame:
312
+ """
313
+ Cleans DataFrame column names by replacing special characters.
314
+
315
+ This function is useful for ensuring compatibility with libraries like LightGBM,
316
+ which do not support special JSON characters such as `[]{}<>,:"` in feature names.
317
+
318
+ Args:
319
+ df (pd.DataFrame): The input DataFrame.
320
+ replacement_char (str): The character to use for replacing characters.
321
+ replacement_pattern (str): Regex pattern to use for the replacement logic.
322
+ verbose (bool): If True, prints the renamed columns.
323
+
324
+ Returns:
325
+ pd.DataFrame: A new DataFrame with cleaned column names.
326
+ """
327
+ new_df = df.copy()
328
+
329
+ original_columns = new_df.columns
330
+ new_columns = original_columns.str.replace(replacement_pattern, replacement_char, regex=True)
331
+
332
+ # Create a map of changes for logging
333
+ rename_map = {old: new for old, new in zip(original_columns, new_columns) if old != new}
334
+
335
+ if verbose:
336
+ if rename_map:
337
+ _LOGGER.info(f"Cleaned {len(rename_map)} column name(s) containing special characters:")
338
+ for old, new in rename_map.items():
339
+ print(f" '{old}' -> '{new}'")
340
+ else:
341
+ _LOGGER.info("No column names required cleaning.")
342
+
343
+ new_df.columns = new_columns
344
+ return new_df
345
+
346
+
347
+ def plot_value_distributions(
348
+ df: pd.DataFrame,
349
+ save_dir: Union[str, Path],
350
+ categorical_columns: Optional[List[str]] = None,
351
+ categorical_cardinality_threshold: int = 10,
352
+ max_categories: int = 50,
353
+ fill_na_with: str = "Missing"
354
+ ):
355
+ """
356
+ Plots and saves the value distributions for all columns in a DataFrame,
357
+ using the best plot type for each column (histogram or count plot).
358
+
359
+ Plots are saved as SVG files under two subdirectories in `save_dir`:
360
+ - "Distribution_Continuous" for continuous numeric features (histograms).
361
+ - "Distribution_Categorical" for categorical features (count plots).
362
+
363
+ Args:
364
+ df (pd.DataFrame): The input DataFrame to analyze.
365
+ save_dir (str | Path): Directory path to save the plots.
366
+ categorical_columns (List[str] | None): If provided, this list
367
+ of column names will be treated as categorical, and all other columns will be treated as continuous. This
368
+ overrides the `continuous_cardinality_threshold` logic.
369
+ categorical_cardinality_threshold (int): A numeric column will be treated
370
+ as 'categorical' if its number of unique values is less than or equal to this threshold. (Ignored if `categorical_columns` is set).
371
+ max_categories (int): The maximum number of unique categories a
372
+ categorical feature can have to be plotted. Features exceeding this limit will be skipped.
373
+ fill_na_with (str): A string to replace NaN values in categorical columns. This allows plotting 'missingness' as its
374
+ own category. Defaults to "Missing".
375
+
376
+ Notes:
377
+ - `seaborn.histplot` with KDE is used for continuous features.
378
+ - `seaborn.countplot` is used for categorical features.
379
+ """
380
+ # 1. Setup save directories
381
+ base_save_path = make_fullpath(save_dir, make=True, enforce="directory")
382
+ numeric_dir = base_save_path / "Distribution_Continuous"
383
+ categorical_dir = base_save_path / "Distribution_Categorical"
384
+ numeric_dir.mkdir(parents=True, exist_ok=True)
385
+ categorical_dir.mkdir(parents=True, exist_ok=True)
386
+
387
+ # 2. Filter columns to plot
388
+ columns_to_plot = df.columns.to_list()
389
+
390
+ # Setup for forced categorical logic
391
+ categorical_set = set(categorical_columns) if categorical_columns is not None else None
392
+
393
+ numeric_plots_saved = 0
394
+ categorical_plots_saved = 0
395
+
396
+ for col_name in columns_to_plot:
397
+ try:
398
+ is_numeric = is_numeric_dtype(df[col_name])
399
+ n_unique = df[col_name].nunique()
400
+
401
+ # --- 3. Determine Plot Type ---
402
+ is_continuous = False
403
+ if categorical_set is not None:
404
+ # Use the explicit list
405
+ if col_name not in categorical_set:
406
+ is_continuous = True
407
+ else:
408
+ # Use auto-detection
409
+ if is_numeric and n_unique > categorical_cardinality_threshold:
410
+ is_continuous = True
411
+
412
+ # --- Case 1: Continuous Numeric (Histogram) ---
413
+ if is_continuous:
414
+ plt.figure(figsize=(10, 6))
415
+ # Drop NaNs for histogram, as they can't be plotted on a numeric axis
416
+ sns.histplot(x=df[col_name].dropna(), kde=True, bins=30)
417
+ plt.title(f"Distribution of '{col_name}' (Continuous)")
418
+ plt.xlabel(col_name)
419
+ plt.ylabel("Count")
420
+
421
+ save_path = numeric_dir / f"{sanitize_filename(col_name)}.svg"
422
+ numeric_plots_saved += 1
423
+
424
+ # --- Case 2: Categorical or Low-Cardinality Numeric (Count Plot) ---
425
+ else:
426
+ # Check max categories
427
+ if n_unique > max_categories:
428
+ _LOGGER.warning(f"Skipping plot for '{col_name}': {n_unique} unique values > {max_categories} max_categories.")
429
+ continue
430
+
431
+ # Adaptive figure size
432
+ fig_width = max(10, n_unique * 0.5)
433
+ plt.figure(figsize=(fig_width, 7))
434
+
435
+ # Make a temporary copy for plotting to handle NaNs
436
+ temp_series = df[col_name].copy()
437
+
438
+ # Handle NaNs by replacing them with the specified string
439
+ if temp_series.isnull().any():
440
+ # Convert to object type first to allow string replacement
441
+ temp_series = temp_series.astype(object).fillna(fill_na_with)
442
+
443
+ # Convert all to string to be safe (handles low-card numeric)
444
+ temp_series = temp_series.astype(str)
445
+
446
+ # Get category order by frequency
447
+ order = temp_series.value_counts().index
448
+ sns.countplot(x=temp_series, order=order, palette="viridis")
449
+
450
+ plt.title(f"Distribution of '{col_name}' (Categorical)")
451
+ plt.xlabel(col_name)
452
+ plt.ylabel("Count")
453
+
454
+ # Smart tick rotation
455
+ max_label_len = 0
456
+ if n_unique > 0:
457
+ max_label_len = max(len(str(s)) for s in order)
458
+
459
+ # Rotate if labels are long OR there are many categories
460
+ if max_label_len > 10 or n_unique > 25:
461
+ plt.xticks(rotation=45, ha='right')
462
+
463
+ save_path = categorical_dir / f"{sanitize_filename(col_name)}.svg"
464
+ categorical_plots_saved += 1
465
+
466
+ # --- 4. Save Plot ---
467
+ plt.grid(True, linestyle='--', alpha=0.6, axis='y')
468
+ plt.tight_layout()
469
+ # Save as .svg
470
+ plt.savefig(save_path, format='svg', bbox_inches="tight")
471
+ plt.close()
472
+
473
+ except Exception as e:
474
+ _LOGGER.error(f"Failed to plot distribution for '{col_name}'. Error: {e}")
475
+ plt.close()
476
+
477
+ _LOGGER.info(f"Saved {numeric_plots_saved} continuous distribution plots to '{numeric_dir.name}'.")
478
+ _LOGGER.info(f"Saved {categorical_plots_saved} categorical distribution plots to '{categorical_dir.name}'.")
479
+
480
+
481
+ def plot_continuous_vs_target(
482
+ df: pd.DataFrame,
483
+ targets: List[str],
484
+ save_dir: Union[str, Path],
485
+ features: Optional[List[str]] = None
486
+ ):
487
+ """
488
+ Plots each continuous feature against each target to visualize linear relationships.
489
+
490
+ This function is a common EDA step for regression tasks. It creates a
491
+ scatter plot for each feature-target pair, overlays a simple linear
492
+ regression line, and saves each plot as an individual .svg file.
493
+
494
+ Plots are saved in a structured way, with a subdirectory created for
495
+ each target variable.
496
+
497
+ Args:
498
+ df (pd.DataFrame): The input DataFrame.
499
+ targets (List[str]): A list of target column names to plot (y-axis).
500
+ save_dir (str | Path): The base directory where plots will be saved. A subdirectory will be created here for each target.
501
+ features (List[str] | None): A list of feature column names to plot (x-axis). If None, all non-target columns in the
502
+ DataFrame will be used.
503
+
504
+ Notes:
505
+ - Only numeric features and numeric targets are processed. Non-numeric
506
+ columns in the lists will be skipped with a warning.
507
+ - Rows with NaN in either the feature or the target are dropped
508
+ pairwise for each plot.
509
+ """
510
+ # 1. Validate the base save directory
511
+ base_save_path = make_fullpath(save_dir, make=True, enforce="directory")
512
+
513
+ # 2. Validate helper
514
+ def _validate_numeric_cols(col_list: List[str], col_type: str) -> List[str]:
515
+ valid_cols = []
516
+ for col in col_list:
517
+ if col not in df.columns:
518
+ _LOGGER.warning(f"{col_type} column '{col}' not found. Skipping.")
519
+ elif not is_numeric_dtype(df[col]):
520
+ _LOGGER.warning(f"{col_type} column '{col}' is not numeric. Skipping.")
521
+ else:
522
+ valid_cols.append(col)
523
+ return valid_cols
524
+
525
+ # 3. Validate target columns FIRST
526
+ valid_targets = _validate_numeric_cols(targets, "Target")
527
+ if not valid_targets:
528
+ _LOGGER.error("No valid numeric target columns provided to plot.")
529
+ return
530
+
531
+ # 4. Determine and validate feature columns
532
+ if features is None:
533
+ _LOGGER.info("No 'features' list provided. Using all non-target columns as features.")
534
+ target_set = set(valid_targets)
535
+ # Get all columns that are not in the valid_targets set
536
+ features_to_validate = [col for col in df.columns if col not in target_set]
537
+ else:
538
+ features_to_validate = features
539
+
540
+ valid_features = _validate_numeric_cols(features_to_validate, "Feature")
541
+
542
+ if not valid_features:
543
+ _LOGGER.error("No valid numeric feature columns found to plot.")
544
+ return
545
+
546
+ # 5. Main plotting loop
547
+ total_plots_saved = 0
548
+
549
+ for target_name in valid_targets:
550
+ # Create a sanitized subdirectory for this target
551
+ safe_target_dir_name = sanitize_filename(f"{target_name}_vs_Continuous")
552
+ target_save_dir = base_save_path / safe_target_dir_name
553
+ target_save_dir.mkdir(parents=True, exist_ok=True)
554
+
555
+ _LOGGER.info(f"Generating plots for target: '{target_name}' -> Saving to '{target_save_dir.name}'")
556
+
557
+ for feature_name in valid_features:
558
+
559
+ # Drop NaNs pairwise for this specific plot
560
+ temp_df = df[[feature_name, target_name]].dropna()
561
+
562
+ if temp_df.empty:
563
+ _LOGGER.warning(f"No non-null data for '{feature_name}' vs '{target_name}'. Skipping plot.")
564
+ continue
565
+
566
+ x = temp_df[feature_name]
567
+ y = temp_df[target_name]
568
+
569
+ # 6. Perform linear fit
570
+ try:
571
+ # Use numpy's polyfit to get the slope (pf[0]) and intercept (pf[1])
572
+ pf = np.polyfit(x, y, 1)
573
+ # Create a polynomial function p(x)
574
+ p = np.poly1d(pf)
575
+ plot_regression_line = True
576
+ except (np.linalg.LinAlgError, ValueError):
577
+ _LOGGER.warning(f"Linear regression failed for '{feature_name}' vs '{target_name}'. Plotting scatter only.")
578
+ plot_regression_line = False
579
+
580
+ # 7. Create the plot
581
+ plt.figure(figsize=(10, 6))
582
+ ax = plt.gca()
583
+
584
+ # Plot the raw data points
585
+ ax.plot(x, y, 'o', alpha=0.5, label='Data points', markersize=5)
586
+
587
+ # Plot the regression line
588
+ if plot_regression_line:
589
+ ax.plot(x, p(x), "r--", label='Linear Fit') # type: ignore
590
+
591
+ ax.set_title(f'{feature_name} vs {target_name}')
592
+ ax.set_xlabel(feature_name)
593
+ ax.set_ylabel(target_name)
594
+ ax.legend()
595
+ plt.grid(True, linestyle='--', alpha=0.6)
596
+ plt.tight_layout()
597
+
598
+ # 8. Save the plot
599
+ safe_feature_name = sanitize_filename(feature_name)
600
+ plot_filename = f"{safe_feature_name}_vs_{safe_target_dir_name}.svg"
601
+ plot_path = target_save_dir / plot_filename
602
+
603
+ try:
604
+ plt.savefig(plot_path, bbox_inches="tight", format='svg')
605
+ total_plots_saved += 1
606
+ except Exception as e:
607
+ _LOGGER.error(f"Failed to save plot: {plot_path}. Error: {e}")
608
+
609
+ # Close the figure to free up memory
610
+ plt.close()
611
+
612
+ _LOGGER.info(f"Successfully saved {total_plots_saved} feature-vs-target plots to '{base_save_path}'.")
613
+
614
+
615
+ def plot_categorical_vs_target(
616
+ df: pd.DataFrame,
617
+ targets: List[str],
618
+ save_dir: Union[str, Path],
619
+ features: Optional[List[str]] = None,
620
+ plot_type: Literal["box", "violin"] = "box",
621
+ max_categories: int = 20,
622
+ fill_na_with: str = "Missing"
623
+ ):
624
+ """
625
+ Plots each categorical feature against each numeric target using box or violin plots.
626
+
627
+ This function is a core EDA step for regression tasks to understand the
628
+ relationship between a categorical independent variable and a continuous
629
+ dependent variable.
630
+
631
+ Args:
632
+ df (pd.DataFrame): The input DataFrame.
633
+ targets (List[str]): A list of numeric target column names (y-axis).
634
+ save_dir (str | Path): The base directory where plots will be saved. A subdirectory will be created here for each target.
635
+ features (List[str] | None): A list of categorical feature column names (x-axis). If None, all non-numeric (object) columns will be used.
636
+ plot_type (Literal["box", "violin"]): The type of plot to generate.
637
+ max_categories (int): The maximum number of unique categories a feature can have to be plotted. Features exceeding this limit will be skipped.
638
+ fill_na_with (str): A string to replace NaN values in categorical columns. This allows plotting 'missingness' as its own category. Defaults to "Missing".
639
+
640
+ Notes:
641
+ - Only numeric targets are processed.
642
+ - Features are automatically identified as categorical if they are 'object' dtype.
643
+ """
644
+ # 1. Validate the base save directory and inputs
645
+ base_save_path = make_fullpath(save_dir, make=True, enforce="directory")
646
+
647
+ if plot_type not in ["box", "violin"]:
648
+ _LOGGER.error(f"Invalid plot type '{plot_type}'")
649
+ raise ValueError()
650
+
651
+ # 2. Validate target columns (must be numeric)
652
+ valid_targets = []
653
+ for col in targets:
654
+ if col not in df.columns:
655
+ _LOGGER.warning(f"Target column '{col}' not found. Skipping.")
656
+ elif not is_numeric_dtype(df[col]):
657
+ _LOGGER.warning(f"Target column '{col}' is not numeric. Skipping.")
658
+ else:
659
+ valid_targets.append(col)
660
+
661
+ if not valid_targets:
662
+ _LOGGER.error("No valid numeric target columns provided to plot.")
663
+ return
664
+
665
+ # 3. Determine and validate feature columns
666
+ features_to_plot = []
667
+ if features is None:
668
+ _LOGGER.info("No 'features' list provided. Auto-detecting categorical features.")
669
+ for col in df.columns:
670
+ if col in valid_targets:
671
+ continue
672
+
673
+ # Auto-include object dtypes
674
+ if is_object_dtype(df[col]):
675
+ features_to_plot.append(col)
676
+ # Auto-include low-cardinality numeric features - REMOVED
677
+ # elif is_numeric_dtype(df[col]) and df[col].nunique() <= max_categories:
678
+ # _LOGGER.info(f"Treating low-cardinality numeric column '{col}' as categorical.")
679
+ # features_to_plot.append(col)
680
+ else:
681
+ # Validate user-provided list
682
+ for col in features:
683
+ if col not in df.columns:
684
+ _LOGGER.warning(f"Feature column '{col}' not found in DataFrame. Skipping.")
685
+ else:
686
+ features_to_plot.append(col)
687
+
688
+ if not features_to_plot:
689
+ _LOGGER.error("No valid categorical feature columns found to plot.")
690
+ return
691
+
692
+ # 4. Main plotting loop
693
+ total_plots_saved = 0
694
+
695
+ for target_name in valid_targets:
696
+ # Create a sanitized subdirectory for this target
697
+ safe_target_dir_name = sanitize_filename(f"{target_name}_vs_Categorical_{plot_type}")
698
+ target_save_dir = base_save_path / safe_target_dir_name
699
+ target_save_dir.mkdir(parents=True, exist_ok=True)
700
+
701
+ _LOGGER.info(f"Generating '{plot_type}' plots for target: '{target_name}' -> Saving to '{target_save_dir.name}'")
702
+
703
+ for feature_name in features_to_plot:
704
+
705
+ # Make a temporary copy for plotting to handle NaNs and dtypes
706
+ temp_df = df[[feature_name, target_name]].copy()
707
+
708
+ # Check cardinality
709
+ n_unique = temp_df[feature_name].nunique()
710
+ if n_unique > max_categories:
711
+ _LOGGER.warning(f"Skipping '{feature_name}': {n_unique} unique values > {max_categories} max_categories.")
712
+ continue
713
+
714
+ # Handle NaNs by replacing them with the specified string
715
+ if temp_df[feature_name].isnull().any():
716
+ # Convert to object type first to allow string replacement
717
+ temp_df[feature_name] = temp_df[feature_name].astype(object).fillna(fill_na_with)
718
+
719
+ # Convert feature to string to ensure correct plotting order
720
+ temp_df[feature_name] = temp_df[feature_name].astype(str)
721
+
722
+ # 5. Create the plot
723
+ # Increase figure width for categories
724
+ plt.figure(figsize=(max(10, n_unique * 1.2), 7))
725
+
726
+ if plot_type == "box":
727
+ sns.boxplot(x=feature_name, y=target_name, data=temp_df)
728
+ elif plot_type == "violin":
729
+ sns.violinplot(x=feature_name, y=target_name, data=temp_df)
730
+
731
+ plt.title(f'{target_name} vs {feature_name}')
732
+ plt.xlabel(feature_name)
733
+ plt.ylabel(target_name)
734
+ plt.xticks(rotation=45, ha='right')
735
+ plt.grid(True, linestyle='--', alpha=0.6, axis='y')
736
+ plt.tight_layout()
737
+
738
+ # 6. Save the plot
739
+ safe_feature_name = sanitize_filename(feature_name)
740
+ plot_filename = f"{safe_feature_name}_vs_{safe_target_dir_name}.svg"
741
+ plot_path = target_save_dir / plot_filename
742
+
743
+ try:
744
+ plt.savefig(plot_path, bbox_inches="tight", format='svg')
745
+ total_plots_saved += 1
746
+ except Exception as e:
747
+ _LOGGER.error(f"Failed to save plot: {plot_path}. Error: {e}")
748
+
749
+ plt.close()
750
+
751
+ _LOGGER.info(f"Successfully saved {total_plots_saved} categorical-vs-target plots to '{base_save_path}'.")
752
+
753
+
754
+ def encode_categorical_features(
755
+ df: pd.DataFrame,
756
+ columns_to_encode: List[str],
757
+ encode_nulls: bool,
758
+ null_label: str = "Other",
759
+ split_resulting_dataset: bool = True,
760
+ verbose: bool = True
761
+ ) -> Tuple[Dict[str, Dict[str, int]], pd.DataFrame, Optional[pd.DataFrame]]:
762
+ """
763
+ Finds unique values in specified categorical columns, encodes them into integers,
764
+ and returns a dictionary containing the mappings for each column.
765
+
766
+ This function automates the label encoding process and generates a simple,
767
+ human-readable dictionary of the mappings.
768
+
769
+ Args:
770
+ df (pd.DataFrame): The input DataFrame.
771
+ columns_to_encode (List[str]): A list of column names to be encoded.
772
+ encode_nulls (bool):
773
+ - If True, encodes Null values as a distinct category 'null_label' with a value of 0. Other categories start from 1.
774
+ - If False, Nulls are ignored and categories start from 0.
775
+
776
+ null_label (str): Category to encode Nulls to if `encode_nulls` is True. If a name collision with `null_label` occurs, the fallback key will be "__NULL__".
777
+ split_resulting_dataset (bool):
778
+ - If True, returns two separate DataFrames, one with non-categorical columns and one with the encoded columns.
779
+ - If False, returns a single DataFrame with all columns.
780
+ verbose (bool): If True, prints encoding progress.
781
+
782
+ Returns:
783
+ Tuple:
784
+
785
+ - Dict[str, Dict[str, int]]: A dictionary where each key is a column name and the value is its category-to-integer mapping.
786
+
787
+ - pd.DataFrame: The original dataframe with or without encoded columns (see `split_resulting_dataset`).
788
+
789
+ - pd.DataFrame | None: If `split_resulting_dataset` is True, the encoded columns as a new dataframe.
790
+
791
+ ## **Note:**
792
+ Use `encode_nulls=False` when encoding binary values with missing entries or a malformed encoding will be returned silently.
793
+ """
794
+ df_encoded = df.copy()
795
+
796
+ # Validate columns
797
+ valid_columns = [col for col in columns_to_encode if col in df_encoded.columns]
798
+ missing_columns = set(columns_to_encode) - set(valid_columns)
799
+ if missing_columns:
800
+ _LOGGER.warning(f"Columns not found and will be skipped: {list(missing_columns)}")
801
+
802
+ mappings: Dict[str, Dict[str, int]] = {}
803
+
804
+ _LOGGER.info(f"Encoding {len(valid_columns)} categorical column(s).")
805
+ for col_name in valid_columns:
806
+ has_nulls = df_encoded[col_name].isnull().any()
807
+
808
+ if encode_nulls and has_nulls:
809
+ # Handle nulls: "Other" -> 0, other categories -> 1, 2, 3...
810
+ categories = sorted([str(cat) for cat in df_encoded[col_name].dropna().unique()])
811
+ # Start mapping from 1 for non-null values
812
+ mapping = {category: i + 1 for i, category in enumerate(categories)}
813
+
814
+ # Apply mapping and fill remaining NaNs with 0
815
+ mapped_series = df_encoded[col_name].astype(str).map(mapping)
816
+ df_encoded[col_name] = mapped_series.fillna(0).astype(int)
817
+
818
+ # --- Validate nulls category---
819
+ # Ensure the key for 0 doesn't collide with a real category.
820
+ if null_label in mapping.keys():
821
+ # COLLISION! null_label is a real category
822
+ original_label = null_label
823
+ null_label = "__NULL__" # fallback
824
+ _LOGGER.warning(f"Column '{col_name}': '{original_label}' is a real category. Mapping nulls (0) to '{null_label}' instead.")
825
+
826
+ # Create the complete user-facing map including "Other"
827
+ user_mapping = {**mapping, null_label: 0}
828
+ mappings[col_name] = user_mapping
829
+ else:
830
+ # ignore nulls
831
+ categories = sorted([str(cat) for cat in df_encoded[col_name].dropna().unique()])
832
+
833
+ mapping = {category: i for i, category in enumerate(categories)}
834
+
835
+ df_encoded[col_name] = df_encoded[col_name].astype(str).map(mapping)
836
+
837
+ mappings[col_name] = mapping
838
+
839
+ if verbose:
840
+ cardinality = len(mappings[col_name])
841
+ print(f" - Encoded '{col_name}' with {cardinality} unique values.")
842
+
843
+ # Handle the dataset splitting logic
844
+ if split_resulting_dataset:
845
+ df_categorical = df_encoded[valid_columns].to_frame() # type: ignore
846
+ df_non_categorical = df.drop(columns=valid_columns)
847
+ return mappings, df_non_categorical, df_categorical
848
+ else:
849
+ return mappings, df_encoded, None
850
+
851
+
303
852
  def split_features_targets(df: pd.DataFrame, targets: list[str]):
304
853
  """
305
854
  Splits a DataFrame's columns into features and targets.
@@ -369,9 +918,9 @@ def split_continuous_binary(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFram
369
918
  return df_cont, df_bin # type: ignore
370
919
 
371
920
 
372
- def plot_correlation_heatmap(df: pd.DataFrame,
921
+ def plot_correlation_heatmap(df: pd.DataFrame,
922
+ plot_title: str,
373
923
  save_dir: Union[str, Path, None] = None,
374
- plot_title: str="Correlation Heatmap",
375
924
  method: Literal["pearson", "kendall", "spearman"]="pearson"):
376
925
  """
377
926
  Plots a heatmap of pairwise correlations between numeric features in a DataFrame.
@@ -379,7 +928,7 @@ def plot_correlation_heatmap(df: pd.DataFrame,
379
928
  Args:
380
929
  df (pd.DataFrame): The input dataset.
381
930
  save_dir (str | Path | None): If provided, the heatmap will be saved to this directory as a svg file.
382
- plot_title: To make different plots, or overwrite existing ones.
931
+ plot_title: The suffix "`method` Correlation Heatmap" will be automatically appended.
383
932
  method (str): Correlation method to use. Must be one of:
384
933
  - 'pearson' (default): measures linear correlation (assumes normally distributed data),
385
934
  - 'kendall': rank correlation (non-parametric),
@@ -394,6 +943,9 @@ def plot_correlation_heatmap(df: pd.DataFrame,
394
943
  if numeric_df.empty:
395
944
  _LOGGER.warning("No numeric columns found. Heatmap not generated.")
396
945
  return
946
+ if method not in ["pearson", "kendall", "spearman"]:
947
+ _LOGGER.error(f"'method' must be pearson, kendall, or spearman.")
948
+ raise ValueError()
397
949
 
398
950
  corr = numeric_df.corr(method=method)
399
951
 
@@ -414,7 +966,10 @@ def plot_correlation_heatmap(df: pd.DataFrame,
414
966
  cbar_kws={"shrink": 0.8}
415
967
  )
416
968
 
417
- plt.title(plot_title)
969
+ # add suffix to title
970
+ full_plot_title = f"{plot_title} - {method.title()} Correlation Heatmap"
971
+
972
+ plt.title(full_plot_title)
418
973
  plt.xticks(rotation=45, ha='right')
419
974
  plt.yticks(rotation=0)
420
975
 
@@ -423,119 +978,119 @@ def plot_correlation_heatmap(df: pd.DataFrame,
423
978
  if save_dir:
424
979
  save_path = make_fullpath(save_dir, make=True)
425
980
  # sanitize the plot title to save the file
426
- plot_title = sanitize_filename(plot_title)
427
- plot_title = plot_title + ".svg"
981
+ sanitized_plot_title = sanitize_filename(plot_title)
982
+ plot_filename = sanitized_plot_title + ".svg"
428
983
 
429
- full_path = save_path / plot_title
984
+ full_path = save_path / plot_filename
430
985
 
431
986
  plt.savefig(full_path, bbox_inches="tight", format='svg')
432
- _LOGGER.info(f"Saved correlation heatmap: '{plot_title}'")
987
+ _LOGGER.info(f"Saved correlation heatmap: '{plot_filename}'")
433
988
 
434
989
  plt.show()
435
990
  plt.close()
436
991
 
437
-
438
- def plot_value_distributions(df: pd.DataFrame, save_dir: Union[str, Path], bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
439
- """
440
- Plots and saves the value distributions for all (or selected) columns in a DataFrame,
441
- with adaptive binning for numerical columns when appropriate.
442
-
443
- For each column both raw counts and relative frequencies are computed and plotted.
444
-
445
- Plots are saved as PNG files under two subdirectories in `save_dir`:
446
- - "Distribution_Counts" for absolute counts.
447
- - "Distribution_Frequency" for relative frequencies.
448
-
449
- Args:
450
- df (pd.DataFrame): The input DataFrame whose columns are to be analyzed.
451
- save_dir (str | Path): Directory path where the plots will be saved. Will be created if it does not exist.
452
- bin_threshold (int): Minimum number of unique values required to trigger binning
453
- for numerical columns.
454
- skip_cols_with_key (str | None): If provided, any column whose name contains this
455
- substring will be excluded from analysis.
456
-
457
- Notes:
458
- - Binning is adaptive: if quantile binning results in ≤ 2 unique bins, raw values are used instead.
459
- - All non-alphanumeric characters in column names are sanitized for safe file naming.
460
- - Colormap is automatically adapted based on the number of categories or bins.
461
- """
462
- save_path = make_fullpath(save_dir, make=True)
992
+ # OLD IMPLEMENTATION
993
+ # def plot_value_distributions(df: pd.DataFrame, save_dir: Union[str, Path], bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
994
+ # """
995
+ # Plots and saves the value distributions for all (or selected) columns in a DataFrame,
996
+ # with adaptive binning for numerical columns when appropriate.
997
+
998
+ # For each column both raw counts and relative frequencies are computed and plotted.
999
+
1000
+ # Plots are saved as PNG files under two subdirectories in `save_dir`:
1001
+ # - "Distribution_Counts" for absolute counts.
1002
+ # - "Distribution_Frequency" for relative frequencies.
1003
+
1004
+ # Args:
1005
+ # df (pd.DataFrame): The input DataFrame whose columns are to be analyzed.
1006
+ # save_dir (str | Path): Directory path where the plots will be saved. Will be created if it does not exist.
1007
+ # bin_threshold (int): Minimum number of unique values required to trigger binning
1008
+ # for numerical columns.
1009
+ # skip_cols_with_key (str | None): If provided, any column whose name contains this
1010
+ # substring will be excluded from analysis.
1011
+
1012
+ # Notes:
1013
+ # - Binning is adaptive: if quantile binning results in ≤ 2 unique bins, raw values are used instead.
1014
+ # - All non-alphanumeric characters in column names are sanitized for safe file naming.
1015
+ # - Colormap is automatically adapted based on the number of categories or bins.
1016
+ # """
1017
+ # save_path = make_fullpath(save_dir, make=True)
463
1018
 
464
- dict_to_plot_std = dict()
465
- dict_to_plot_freq = dict()
1019
+ # dict_to_plot_std = dict()
1020
+ # dict_to_plot_freq = dict()
466
1021
 
467
- # cherry-pick columns
468
- if skip_cols_with_key is not None:
469
- columns = [col for col in df.columns if skip_cols_with_key not in col]
470
- else:
471
- columns = df.columns.to_list()
1022
+ # # cherry-pick columns
1023
+ # if skip_cols_with_key is not None:
1024
+ # columns = [col for col in df.columns if skip_cols_with_key not in col]
1025
+ # else:
1026
+ # columns = df.columns.to_list()
472
1027
 
473
- saved_plots = 0
474
- for col in columns:
475
- if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() > bin_threshold:
476
- bins_number = 10
477
- binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
478
- while binned.nunique() <= 2:
479
- bins_number -= 1
480
- binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
481
- if bins_number <= 2:
482
- break
483
-
484
- if binned.nunique() <= 2:
485
- view_std = df[col].value_counts(sort=False).sort_index()
486
- else:
487
- view_std = binned.value_counts(sort=False)
1028
+ # saved_plots = 0
1029
+ # for col in columns:
1030
+ # if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() > bin_threshold:
1031
+ # bins_number = 10
1032
+ # binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
1033
+ # while binned.nunique() <= 2:
1034
+ # bins_number -= 1
1035
+ # binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
1036
+ # if bins_number <= 2:
1037
+ # break
488
1038
 
489
- else:
490
- view_std = df[col].value_counts(sort=False).sort_index()
491
-
492
- # unlikely scenario where the series is empty
493
- if view_std.sum() == 0:
494
- view_freq = view_std
495
- else:
496
- view_freq = 100 * view_std / view_std.sum() # Percentage
497
- # view_freq = df[col].value_counts(normalize=True, bins=10) # relative percentages
1039
+ # if binned.nunique() <= 2:
1040
+ # view_std = df[col].value_counts(sort=False).sort_index()
1041
+ # else:
1042
+ # view_std = binned.value_counts(sort=False)
1043
+
1044
+ # else:
1045
+ # view_std = df[col].value_counts(sort=False).sort_index()
1046
+
1047
+ # # unlikely scenario where the series is empty
1048
+ # if view_std.sum() == 0:
1049
+ # view_freq = view_std
1050
+ # else:
1051
+ # view_freq = 100 * view_std / view_std.sum() # Percentage
1052
+ # # view_freq = df[col].value_counts(normalize=True, bins=10) # relative percentages
498
1053
 
499
- dict_to_plot_std[col] = dict(view_std)
500
- dict_to_plot_freq[col] = dict(view_freq)
501
- saved_plots += 1
1054
+ # dict_to_plot_std[col] = dict(view_std)
1055
+ # dict_to_plot_freq[col] = dict(view_freq)
1056
+ # saved_plots += 1
502
1057
 
503
- # plot helper
504
- def _plot_helper(dict_: dict, target_dir: Path, ylabel: Literal["Frequency", "Counts"], base_fontsize: int=12):
505
- for col, data in dict_.items():
506
- safe_col = sanitize_filename(col)
1058
+ # # plot helper
1059
+ # def _plot_helper(dict_: dict, target_dir: Path, ylabel: Literal["Frequency", "Counts"], base_fontsize: int=12):
1060
+ # for col, data in dict_.items():
1061
+ # safe_col = sanitize_filename(col)
507
1062
 
508
- if isinstance(list(data.keys())[0], pd.Interval):
509
- labels = [str(interval) for interval in data.keys()]
510
- else:
511
- labels = data.keys()
1063
+ # if isinstance(list(data.keys())[0], pd.Interval):
1064
+ # labels = [str(interval) for interval in data.keys()]
1065
+ # else:
1066
+ # labels = data.keys()
512
1067
 
513
- plt.figure(figsize=(10, 6))
514
- colors = plt.cm.tab20.colors if len(data) <= 20 else plt.cm.viridis(np.linspace(0, 1, len(data))) # type: ignore
1068
+ # plt.figure(figsize=(10, 6))
1069
+ # colors = plt.cm.tab20.colors if len(data) <= 20 else plt.cm.viridis(np.linspace(0, 1, len(data))) # type: ignore
515
1070
 
516
- plt.bar(labels, data.values(), color=colors[:len(data)], alpha=0.85)
517
- plt.xlabel("Values", fontsize=base_fontsize)
518
- plt.ylabel(ylabel, fontsize=base_fontsize)
519
- plt.title(f"Value Distribution for '{col}'", fontsize=base_fontsize+2)
520
- plt.xticks(rotation=45, ha='right', fontsize=base_fontsize-2)
521
- plt.yticks(fontsize=base_fontsize-2)
522
- plt.grid(axis='y', linestyle='--', alpha=0.6)
523
- plt.gca().set_facecolor('#f9f9f9')
524
- plt.tight_layout()
1071
+ # plt.bar(labels, data.values(), color=colors[:len(data)], alpha=0.85)
1072
+ # plt.xlabel("Values", fontsize=base_fontsize)
1073
+ # plt.ylabel(ylabel, fontsize=base_fontsize)
1074
+ # plt.title(f"Value Distribution for '{col}'", fontsize=base_fontsize+2)
1075
+ # plt.xticks(rotation=45, ha='right', fontsize=base_fontsize-2)
1076
+ # plt.yticks(fontsize=base_fontsize-2)
1077
+ # plt.grid(axis='y', linestyle='--', alpha=0.6)
1078
+ # plt.gca().set_facecolor('#f9f9f9')
1079
+ # plt.tight_layout()
525
1080
 
526
- plot_path = target_dir / f"{safe_col}.png"
527
- plt.savefig(plot_path, dpi=300, bbox_inches="tight")
528
- plt.close()
1081
+ # plot_path = target_dir / f"{safe_col}.png"
1082
+ # plt.savefig(plot_path, dpi=300, bbox_inches="tight")
1083
+ # plt.close()
529
1084
 
530
- # Save plots
531
- freq_dir = save_path / "Distribution_Frequency"
532
- std_dir = save_path / "Distribution_Counts"
533
- freq_dir.mkdir(parents=True, exist_ok=True)
534
- std_dir.mkdir(parents=True, exist_ok=True)
535
- _plot_helper(dict_=dict_to_plot_std, target_dir=std_dir, ylabel="Counts")
536
- _plot_helper(dict_=dict_to_plot_freq, target_dir=freq_dir, ylabel="Frequency")
1085
+ # # Save plots
1086
+ # freq_dir = save_path / "Distribution_Frequency"
1087
+ # std_dir = save_path / "Distribution_Counts"
1088
+ # freq_dir.mkdir(parents=True, exist_ok=True)
1089
+ # std_dir.mkdir(parents=True, exist_ok=True)
1090
+ # _plot_helper(dict_=dict_to_plot_std, target_dir=std_dir, ylabel="Counts")
1091
+ # _plot_helper(dict_=dict_to_plot_freq, target_dir=freq_dir, ylabel="Frequency")
537
1092
 
538
- _LOGGER.info(f"Saved {saved_plots} value distribution plots.")
1093
+ # _LOGGER.info(f"Saved {saved_plots} value distribution plots.")
539
1094
 
540
1095
 
541
1096
  def clip_outliers_single(
@@ -628,7 +1183,99 @@ def clip_outliers_multi(
628
1183
  if skipped_columns:
629
1184
  _LOGGER.warning("Skipped columns:")
630
1185
  for col, msg in skipped_columns:
631
- print(f" - {col}: {msg}")
1186
+ print(f" - {col}")
1187
+
1188
+ return new_df
1189
+
1190
+
1191
+ def drop_outlier_samples(
1192
+ df: pd.DataFrame,
1193
+ bounds_dict: Dict[str, Tuple[Union[int, float], Union[int, float]]],
1194
+ drop_on_nulls: bool = False,
1195
+ verbose: bool = True
1196
+ ) -> pd.DataFrame:
1197
+ """
1198
+ Drops entire rows where values in specified numeric columns fall outside
1199
+ a given [min, max] range.
1200
+
1201
+ This function processes a copy of the DataFrame, ensuring the original is
1202
+ not modified. It skips columns with invalid specifications.
1203
+
1204
+ Args:
1205
+ df (pd.DataFrame): The input DataFrame.
1206
+ bounds_dict (dict): A dictionary where keys are column names and values
1207
+ are (min_val, max_val) tuples defining the valid range.
1208
+ drop_on_nulls (bool): If True, rows with NaN/None in a checked column
1209
+ will also be dropped. If False, NaN/None are ignored.
1210
+ verbose (bool): If True, prints the number of rows dropped for each column.
1211
+
1212
+ Returns:
1213
+ pd.DataFrame: A new DataFrame with the outlier rows removed.
1214
+
1215
+ Notes:
1216
+ - Invalid specifications (e.g., missing column, non-numeric type,
1217
+ incorrectly formatted bounds) will be reported and skipped.
1218
+ """
1219
+ new_df = df.copy()
1220
+ skipped_columns: List[Tuple[str, str]] = []
1221
+ initial_rows = len(new_df)
1222
+
1223
+ for col, bounds in bounds_dict.items():
1224
+ try:
1225
+ # --- Validation Checks ---
1226
+ if col not in df.columns:
1227
+ _LOGGER.error(f"Column '{col}' not found in DataFrame.")
1228
+ raise ValueError()
1229
+
1230
+ if not pd.api.types.is_numeric_dtype(df[col]):
1231
+ _LOGGER.error(f"Column '{col}' is not of a numeric data type.")
1232
+ raise TypeError()
1233
+
1234
+ if not (isinstance(bounds, tuple) and len(bounds) == 2):
1235
+ _LOGGER.error(f"Bounds for '{col}' must be a tuple of (min, max).")
1236
+ raise ValueError()
1237
+
1238
+ # --- Filtering Logic ---
1239
+ min_val, max_val = bounds
1240
+ rows_before_drop = len(new_df)
1241
+
1242
+ # Create the base mask for values within the specified range
1243
+ # .between() is inclusive and evaluates to False for NaN
1244
+ mask_in_bounds = new_df[col].between(min_val, max_val)
1245
+
1246
+ if drop_on_nulls:
1247
+ # Keep only rows that are within bounds.
1248
+ # Since mask_in_bounds is False for NaN, nulls are dropped.
1249
+ final_mask = mask_in_bounds
1250
+ else:
1251
+ # Keep rows that are within bounds OR are null.
1252
+ mask_is_null = new_df[col].isnull()
1253
+ final_mask = mask_in_bounds | mask_is_null
1254
+
1255
+ # Apply the final mask
1256
+ new_df = new_df[final_mask]
1257
+
1258
+ rows_after_drop = len(new_df)
1259
+
1260
+ if verbose:
1261
+ dropped_count = rows_before_drop - rows_after_drop
1262
+ if dropped_count > 0:
1263
+ print(
1264
+ f" - Column '{col}': Dropped {dropped_count} rows with values outside range [{min_val}, {max_val}]."
1265
+ )
1266
+
1267
+ except (ValueError, TypeError) as e:
1268
+ skipped_columns.append((col, str(e)))
1269
+ continue
1270
+
1271
+ total_dropped = initial_rows - len(new_df)
1272
+ _LOGGER.info(f"Finished processing. Total rows dropped: {total_dropped}.")
1273
+
1274
+ if skipped_columns:
1275
+ _LOGGER.warning("Skipped the following columns due to errors:")
1276
+ for col, msg in skipped_columns:
1277
+ # Only print the column name for cleaner output as the error was already logged
1278
+ print(f" - {col}")
632
1279
 
633
1280
  return new_df
634
1281
 
@@ -667,7 +1314,8 @@ def standardize_percentages(
667
1314
  df: pd.DataFrame,
668
1315
  columns: list[str],
669
1316
  treat_one_as_proportion: bool = True,
670
- round_digits: int = 2
1317
+ round_digits: int = 2,
1318
+ verbose: bool=True
671
1319
  ) -> pd.DataFrame:
672
1320
  """
673
1321
  Standardizes numeric columns containing mixed-format percentages.
@@ -708,6 +1356,8 @@ def standardize_percentages(
708
1356
 
709
1357
  # Otherwise, the value is assumed to be a correctly formatted percentage
710
1358
  return x
1359
+
1360
+ fixed_columns: list[str] = list()
711
1361
 
712
1362
  for col in columns:
713
1363
  # --- Robustness Checks ---
@@ -725,10 +1375,343 @@ def standardize_percentages(
725
1375
 
726
1376
  # Round the result
727
1377
  df_copy[col] = df_copy[col].round(round_digits)
1378
+
1379
+ fixed_columns.append(col)
1380
+
1381
+ if verbose:
1382
+ _LOGGER.info(f"Columns standardized:")
1383
+ for fixed_col in fixed_columns:
1384
+ print(f" '{fixed_col}'")
728
1385
 
729
1386
  return df_copy
730
1387
 
731
1388
 
1389
+ def reconstruct_one_hot(
1390
+ df: pd.DataFrame,
1391
+ features_to_reconstruct: List[Union[str, Tuple[str, Optional[str]]]],
1392
+ separator: str = '_',
1393
+ baseline_category_name: Optional[str] = "Other",
1394
+ drop_original: bool = True,
1395
+ verbose: bool = True
1396
+ ) -> pd.DataFrame:
1397
+ """
1398
+ Reconstructs original categorical columns from a one-hot encoded DataFrame.
1399
+
1400
+ This function identifies groups of one-hot encoded columns based on a common
1401
+ prefix (base feature name) and a separator. It then collapses each group
1402
+ into a single column containing the categorical value.
1403
+
1404
+ Args:
1405
+ df (pd.DataFrame):
1406
+ The input DataFrame with one-hot encoded columns.
1407
+ features_to_reconstruct (List[str | Tuple[str, str | None]]):
1408
+ A list defining the features to reconstruct. This list can contain:
1409
+
1410
+ - A string: (e.g., "Color")
1411
+ This reconstructs the feature 'Color' and assumes all-zero rows represent the baseline category ("Other" by default).
1412
+ - A tuple: (e.g., ("Pet", "Dog"))
1413
+ This reconstructs 'Pet' and maps all-zero rows to the baseline category "Dog".
1414
+ - A tuple with None: (e.g., ("Size", None))
1415
+ This reconstructs 'Size' and maps all-zero rows to the NaN value.
1416
+ Example:
1417
+ [
1418
+ "Mood", # All-zeros -> "Other"
1419
+ ("Color", "Red"), # All-zeros -> "Red"
1420
+ ("Size", None) # All-zeros -> NaN
1421
+ ]
1422
+ separator (str):
1423
+ The character separating the base name from the categorical value in
1424
+ the column names (e.g., '_' in 'B_a').
1425
+ baseline_category_name (str | None):
1426
+ The baseline category name to use by default if it is not explicitly provided.
1427
+ drop_original (bool):
1428
+ If True, the original one-hot encoded columns will be dropped from
1429
+ the returned DataFrame.
1430
+
1431
+ Returns:
1432
+ pd.DataFrame:
1433
+ A new DataFrame with the specified one-hot encoded features
1434
+ reconstructed into single categorical columns.
1435
+
1436
+ <br>
1437
+
1438
+ ## Note:
1439
+
1440
+ This function is designed to be robust, but users should be aware of two key edge cases:
1441
+
1442
+ 1. **Ambiguous Base Feature Prefixes**: If `base_feature_names` list contains names where one is a prefix of another (e.g., `['feat', 'feat_ext']`), the order is critical. The function will match columns greedily. To avoid incorrect grouping, always list the **most specific base names first** (e.g., `['feat_ext', 'feat']`).
1443
+
1444
+ 2. **Malformed One-Hot Data**: If a row contains multiple `1`s within the same feature group (e.g., both `B_a` and `B_c` are `1`), the function will not raise an error. It uses `.idxmax()`, which returns the first column that contains the maximum value. This means it will silently select the first category it encounters and ignore the others, potentially masking an upstream data issue.
1445
+ """
1446
+ if not isinstance(df, pd.DataFrame):
1447
+ _LOGGER.error("Input must be a pandas DataFrame.")
1448
+ raise TypeError()
1449
+
1450
+ if not (baseline_category_name is None or isinstance(baseline_category_name, str)):
1451
+ _LOGGER.error("The baseline_category must be None or a string.")
1452
+ raise TypeError()
1453
+
1454
+ new_df = df.copy()
1455
+ all_ohe_cols_to_drop = []
1456
+ reconstructed_count = 0
1457
+
1458
+ # --- 1. Parse and validate the reconstruction config ---
1459
+ # This normalizes the input into a clean {base_name: baseline_val} dict
1460
+ reconstruction_config: Dict[str, Optional[str]] = {}
1461
+ try:
1462
+ for item in features_to_reconstruct:
1463
+ if isinstance(item, str):
1464
+ # Case 1: "Color"
1465
+ base_name = item
1466
+ baseline_val = baseline_category_name
1467
+ elif isinstance(item, tuple) and len(item) == 2:
1468
+ # Case 2: ("Pet", "dog") or ("Size", None)
1469
+ base_name, baseline_val = item
1470
+ if not (isinstance(base_name, str) and (isinstance(baseline_val, str) or baseline_val is None)):
1471
+ _LOGGER.error(f"Invalid tuple format for '{item}'. Must be (str, str|None).")
1472
+ raise ValueError()
1473
+ else:
1474
+ _LOGGER.error(f"Invalid item '{item}'. Must be str or (str, str|None) tuple.")
1475
+ raise ValueError()
1476
+
1477
+ if base_name in reconstruction_config and verbose:
1478
+ _LOGGER.warning(f"Duplicate entry for '{base_name}' found. Using the last provided configuration.")
1479
+
1480
+ reconstruction_config[base_name] = baseline_val
1481
+
1482
+ except Exception as e:
1483
+ _LOGGER.error(f"Failed to parse 'features_to_reconstruct' argument: {e}")
1484
+ raise ValueError("Invalid configuration for 'features_to_reconstruct'.") from e
1485
+
1486
+ _LOGGER.info(f"Attempting to reconstruct {len(reconstruction_config)} one-hot encoded feature(s).")
1487
+
1488
+ # Main logic
1489
+ for base_name, baseline_category in reconstruction_config.items():
1490
+ # Regex to find all columns belonging to this base feature.
1491
+ pattern = f"^{re.escape(base_name)}{re.escape(separator)}"
1492
+
1493
+ # Find matching columns
1494
+ ohe_cols = [col for col in df.columns if re.match(pattern, col)]
1495
+
1496
+ if not ohe_cols:
1497
+ _LOGGER.warning(f"No one-hot encoded columns found for base feature '{base_name}'. Skipping.")
1498
+ continue
1499
+
1500
+ # For each row, find the column name with the maximum value (which is 1)
1501
+ reconstructed_series = new_df[ohe_cols].idxmax(axis=1) # type: ignore
1502
+
1503
+ # Extract the categorical value (the suffix) from the column name
1504
+ # Use n=1 in split to handle cases where the category itself might contain the separator
1505
+ new_column_values = reconstructed_series.str.split(separator, n=1).str[1]
1506
+
1507
+ # Handle rows where all OHE columns were 0 (e.g., original value was NaN or a dropped baseline).
1508
+ all_zero_mask = new_df[ohe_cols].sum(axis=1) == 0 # type: ignore
1509
+
1510
+ if baseline_category is not None:
1511
+ # A baseline category was provided
1512
+ new_column_values.loc[all_zero_mask] = baseline_category
1513
+ else:
1514
+ # No baseline provided: assign NaN
1515
+ new_column_values.loc[all_zero_mask] = np.nan # type: ignore
1516
+
1517
+ if verbose:
1518
+ print(f" - Mapped 'all-zero' rows for '{base_name}' to baseline: '{baseline_category}'.")
1519
+
1520
+ # Assign the new reconstructed column to the DataFrame
1521
+ new_df[base_name] = new_column_values
1522
+
1523
+ all_ohe_cols_to_drop.extend(ohe_cols)
1524
+ reconstructed_count += 1
1525
+ if verbose:
1526
+ print(f" - Reconstructed '{base_name}' from {len(ohe_cols)} columns.")
1527
+
1528
+ # Cleanup
1529
+ if drop_original and all_ohe_cols_to_drop:
1530
+ # Drop the original OHE columns, ensuring no duplicates in the drop list
1531
+ unique_cols_to_drop = list(set(all_ohe_cols_to_drop))
1532
+ new_df.drop(columns=unique_cols_to_drop, inplace=True)
1533
+ _LOGGER.info(f"Dropped {len(unique_cols_to_drop)} original one-hot encoded columns.")
1534
+
1535
+ _LOGGER.info(f"Successfully reconstructed {reconstructed_count} feature(s).")
1536
+
1537
+ return new_df
1538
+
1539
+
1540
+ def reconstruct_binary(
1541
+ df: pd.DataFrame,
1542
+ reconstruction_map: Dict[str, Tuple[str, Any, Any]],
1543
+ drop_original: bool = True,
1544
+ verbose: bool = True
1545
+ ) -> pd.DataFrame:
1546
+ """
1547
+ Reconstructs new categorical columns from existing binary (0/1) columns.
1548
+
1549
+ Used to reverse a binary encoding by mapping 0 and 1 back to
1550
+ descriptive categorical labels.
1551
+
1552
+ Args:
1553
+ df (pd.DataFrame):
1554
+ The input DataFrame.
1555
+ reconstruction_map (Dict[str, Tuple[str, Any, Any]]):
1556
+ A dictionary defining the reconstructions.
1557
+ Format:
1558
+ { "new_col_name": ("source_col_name", "label_for_0", "label_for_1") }
1559
+ Example:
1560
+ {
1561
+ "Sex": ("Sex_male", "Female", "Male"),
1562
+ "Smoker": ("Is_Smoker", "No", "Yes")
1563
+ }
1564
+ drop_original (bool):
1565
+ If True, the original binary source columns (e.g., "Sex_male")
1566
+ will be dropped from the returned DataFrame.
1567
+ verbose (bool):
1568
+ If True, prints the details of each reconstruction.
1569
+
1570
+ Returns:
1571
+ pd.DataFrame:
1572
+ A new DataFrame with the reconstructed categorical columns.
1573
+
1574
+ Raises:
1575
+ TypeError: If `df` is not a pandas DataFrame.
1576
+ ValueError: If `reconstruction_map` is not a dictionary or a
1577
+ configuration is invalid (e.g., column name collision).
1578
+
1579
+ Notes:
1580
+ - The function operates on a copy of the DataFrame.
1581
+ - Rows with `NaN` in the source column will have `NaN` in the
1582
+ new column.
1583
+ - Values in the source column other than 0 or 1 (e.g., 2) will
1584
+ result in `NaN` in the new column.
1585
+ """
1586
+ if not isinstance(df, pd.DataFrame):
1587
+ _LOGGER.error("Input must be a pandas DataFrame.")
1588
+ raise TypeError()
1589
+
1590
+ if not isinstance(reconstruction_map, dict):
1591
+ _LOGGER.error("`reconstruction_map` must be a dictionary with the required format.")
1592
+ raise ValueError()
1593
+
1594
+ new_df = df.copy()
1595
+ source_cols_to_drop: List[str] = []
1596
+ reconstructed_count = 0
1597
+
1598
+ _LOGGER.info(f"Attempting to reconstruct {len(reconstruction_map)} binary feature(s).")
1599
+
1600
+ for new_col_name, config in reconstruction_map.items():
1601
+
1602
+ # --- 1. Validation ---
1603
+ if not (isinstance(config, tuple) and len(config) == 3):
1604
+ _LOGGER.error(f"Config for '{new_col_name}' is invalid. Must be a 3-item tuple. Skipping.")
1605
+ raise ValueError()
1606
+
1607
+ source_col, label_for_0, label_for_1 = config
1608
+
1609
+ if source_col not in new_df.columns:
1610
+ _LOGGER.error(f"Source column '{source_col}' for new column '{new_col_name}' not found. Skipping.")
1611
+ raise ValueError()
1612
+
1613
+ if new_col_name in new_df.columns and verbose:
1614
+ _LOGGER.warning(f"New column '{new_col_name}' already exists and will be overwritten.")
1615
+
1616
+ if new_col_name == source_col:
1617
+ _LOGGER.error(f"New column name '{new_col_name}' cannot be the same as source column '{source_col}'.")
1618
+ raise ValueError()
1619
+
1620
+ # --- 2. Reconstruction ---
1621
+ # .map() handles 0, 1, preserves NaNs, and converts any other value to NaN.
1622
+ mapping_dict = {0: label_for_0, 1: label_for_1}
1623
+ new_df[new_col_name] = new_df[source_col].map(mapping_dict)
1624
+
1625
+ # --- 3. Logging/Tracking ---
1626
+ source_cols_to_drop.append(source_col)
1627
+ reconstructed_count += 1
1628
+ if verbose:
1629
+ print(f" - Reconstructed '{new_col_name}' from '{source_col}' (0='{label_for_0}', 1='{label_for_1}').")
1630
+
1631
+ # --- 4. Cleanup ---
1632
+ if drop_original and source_cols_to_drop:
1633
+ # Use set() to avoid duplicates if the same source col was used
1634
+ unique_cols_to_drop = list(set(source_cols_to_drop))
1635
+ new_df.drop(columns=unique_cols_to_drop, inplace=True)
1636
+ _LOGGER.info(f"Dropped {len(unique_cols_to_drop)} original binary source column(s).")
1637
+
1638
+ _LOGGER.info(f"Successfully reconstructed {reconstructed_count} feature(s).")
1639
+
1640
+ return new_df
1641
+
1642
+
1643
+ def finalize_feature_schema(
1644
+ df_features: pd.DataFrame,
1645
+ categorical_mappings: Optional[Dict[str, Dict[str, int]]]
1646
+ ) -> FeatureSchema:
1647
+ """
1648
+ Analyzes the final features DataFrame to create a definitive schema.
1649
+
1650
+ This function is the "single source of truth" for column order
1651
+ and type (categorical vs. continuous) for the entire ML pipeline.
1652
+
1653
+ It should be called at the end of the feature engineering process.
1654
+
1655
+ Args:
1656
+ df_features (pd.DataFrame):
1657
+ The final, processed DataFrame containing *only* feature columns
1658
+ in the exact order they will be fed to the model.
1659
+ categorical_mappings (Dict[str, Dict[str, int]] | None):
1660
+ The mappings dictionary generated by
1661
+ `encode_categorical_features`. Can be None if no
1662
+ categorical features exist.
1663
+
1664
+ Returns:
1665
+ FeatureSchema: A NamedTuple containing all necessary metadata for the pipeline.
1666
+ """
1667
+ feature_names: List[str] = df_features.columns.to_list()
1668
+
1669
+ # Intermediate lists for building
1670
+ continuous_feature_names_list: List[str] = []
1671
+ categorical_feature_names_list: List[str] = []
1672
+ categorical_index_map_dict: Dict[int, int] = {}
1673
+
1674
+ # _LOGGER.info("Finalizing feature schema...")
1675
+
1676
+ if categorical_mappings:
1677
+ # --- Categorical features are present ---
1678
+ categorical_names_set = set(categorical_mappings.keys())
1679
+
1680
+ for index, name in enumerate(feature_names):
1681
+ if name in categorical_names_set:
1682
+ # This is a categorical feature
1683
+ cardinality = len(categorical_mappings[name])
1684
+ categorical_index_map_dict[index] = cardinality
1685
+ categorical_feature_names_list.append(name)
1686
+ else:
1687
+ # This is a continuous feature
1688
+ continuous_feature_names_list.append(name)
1689
+
1690
+ # Use the populated dict, or None if it's empty
1691
+ final_index_map = categorical_index_map_dict if categorical_index_map_dict else None
1692
+
1693
+ else:
1694
+ # --- No categorical features ---
1695
+ _LOGGER.info("No categorical mappings provided. Treating all features as continuous.")
1696
+ continuous_feature_names_list = list(feature_names)
1697
+ # categorical_feature_names_list remains empty
1698
+ # categorical_index_map_dict remains empty
1699
+ final_index_map = None # Explicitly set to None to match Optional type
1700
+
1701
+ _LOGGER.info(f"Schema created: {len(continuous_feature_names_list)} continuous, {len(categorical_feature_names_list)} categorical.")
1702
+
1703
+ # Create the final immutable instance
1704
+ schema_instance = FeatureSchema(
1705
+ feature_names=tuple(feature_names),
1706
+ continuous_feature_names=tuple(continuous_feature_names_list),
1707
+ categorical_feature_names=tuple(categorical_feature_names_list),
1708
+ categorical_index_map=final_index_map,
1709
+ categorical_mappings=categorical_mappings
1710
+ )
1711
+
1712
+ return schema_instance
1713
+
1714
+
732
1715
  def _validate_columns(df: pd.DataFrame, columns: list[str]):
733
1716
  valid_columns = [column for column in columns if column in df.columns]
734
1717
  return valid_columns