dragon-ml-toolbox 13.1.0__py3-none-any.whl → 14.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.1.dist-info}/METADATA +11 -2
- dragon_ml_toolbox-14.3.1.dist-info/RECORD +48 -0
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_datasetmaster.py +63 -205
- ml_tools/ML_evaluation.py +23 -15
- ml_tools/ML_evaluation_multi.py +5 -6
- ml_tools/ML_inference.py +0 -1
- ml_tools/ML_models.py +22 -6
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_trainer.py +463 -20
- ml_tools/ML_utilities.py +302 -4
- ml_tools/ML_vision_datasetmaster.py +1395 -0
- ml_tools/ML_vision_evaluation.py +260 -0
- ml_tools/ML_vision_inference.py +428 -0
- ml_tools/ML_vision_models.py +627 -0
- ml_tools/ML_vision_transformers.py +58 -0
- ml_tools/_ML_vision_recipe.py +88 -0
- ml_tools/__init__.py +1 -0
- ml_tools/_schema.py +79 -2
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +502 -93
- ml_tools/keys.py +42 -1
- ml_tools/math_utilities.py +1 -1
- ml_tools/serde.py +77 -15
- ml_tools/utilities.py +192 -3
- dragon_ml_toolbox-13.1.0.dist-info/RECORD +0 -41
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.1.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.1.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.1.dist-info}/top_level.txt +0 -0
ml_tools/data_exploration.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import pandas as pd
|
|
2
|
-
from pandas.api.types import is_numeric_dtype
|
|
2
|
+
from pandas.api.types import is_numeric_dtype, is_object_dtype
|
|
3
3
|
import numpy as np
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
5
|
import seaborn as sns
|
|
@@ -22,14 +22,16 @@ __all__ = [
|
|
|
22
22
|
"drop_columns_with_missing_data",
|
|
23
23
|
"drop_macro",
|
|
24
24
|
"clean_column_names",
|
|
25
|
+
"plot_value_distributions",
|
|
26
|
+
"plot_continuous_vs_target",
|
|
27
|
+
"plot_categorical_vs_target",
|
|
25
28
|
"encode_categorical_features",
|
|
26
29
|
"split_features_targets",
|
|
27
30
|
"split_continuous_binary",
|
|
28
|
-
"plot_correlation_heatmap",
|
|
29
|
-
"plot_value_distributions",
|
|
30
31
|
"clip_outliers_single",
|
|
31
32
|
"clip_outliers_multi",
|
|
32
33
|
"drop_outlier_samples",
|
|
34
|
+
"plot_correlation_heatmap",
|
|
33
35
|
"match_and_filter_columns_by_regex",
|
|
34
36
|
"standardize_percentages",
|
|
35
37
|
"reconstruct_one_hot",
|
|
@@ -342,6 +344,413 @@ def clean_column_names(df: pd.DataFrame, replacement_char: str = '-', replacemen
|
|
|
342
344
|
return new_df
|
|
343
345
|
|
|
344
346
|
|
|
347
|
+
def plot_value_distributions(
|
|
348
|
+
df: pd.DataFrame,
|
|
349
|
+
save_dir: Union[str, Path],
|
|
350
|
+
categorical_columns: Optional[List[str]] = None,
|
|
351
|
+
categorical_cardinality_threshold: int = 10,
|
|
352
|
+
max_categories: int = 50,
|
|
353
|
+
fill_na_with: str = "Missing"
|
|
354
|
+
):
|
|
355
|
+
"""
|
|
356
|
+
Plots and saves the value distributions for all columns in a DataFrame,
|
|
357
|
+
using the best plot type for each column (histogram or count plot).
|
|
358
|
+
|
|
359
|
+
Plots are saved as SVG files under two subdirectories in `save_dir`:
|
|
360
|
+
- "Distribution_Continuous" for continuous numeric features (histograms).
|
|
361
|
+
- "Distribution_Categorical" for categorical features (count plots).
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
df (pd.DataFrame): The input DataFrame to analyze.
|
|
365
|
+
save_dir (str | Path): Directory path to save the plots.
|
|
366
|
+
categorical_columns (List[str] | None): If provided, this list
|
|
367
|
+
of column names will be treated as categorical, and all other columns will be treated as continuous. This
|
|
368
|
+
overrides the `continuous_cardinality_threshold` logic.
|
|
369
|
+
categorical_cardinality_threshold (int): A numeric column will be treated
|
|
370
|
+
as 'categorical' if its number of unique values is less than or equal to this threshold. (Ignored if `categorical_columns` is set).
|
|
371
|
+
max_categories (int): The maximum number of unique categories a
|
|
372
|
+
categorical feature can have to be plotted. Features exceeding this limit will be skipped.
|
|
373
|
+
fill_na_with (str): A string to replace NaN values in categorical columns. This allows plotting 'missingness' as its
|
|
374
|
+
own category. Defaults to "Missing".
|
|
375
|
+
|
|
376
|
+
Notes:
|
|
377
|
+
- `seaborn.histplot` with KDE is used for continuous features.
|
|
378
|
+
- `seaborn.countplot` is used for categorical features.
|
|
379
|
+
"""
|
|
380
|
+
# 1. Setup save directories
|
|
381
|
+
base_save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
382
|
+
numeric_dir = base_save_path / "Distribution_Continuous"
|
|
383
|
+
categorical_dir = base_save_path / "Distribution_Categorical"
|
|
384
|
+
numeric_dir.mkdir(parents=True, exist_ok=True)
|
|
385
|
+
categorical_dir.mkdir(parents=True, exist_ok=True)
|
|
386
|
+
|
|
387
|
+
# 2. Filter columns to plot
|
|
388
|
+
columns_to_plot = df.columns.to_list()
|
|
389
|
+
|
|
390
|
+
# Setup for forced categorical logic
|
|
391
|
+
categorical_set = set(categorical_columns) if categorical_columns is not None else None
|
|
392
|
+
|
|
393
|
+
numeric_plots_saved = 0
|
|
394
|
+
categorical_plots_saved = 0
|
|
395
|
+
|
|
396
|
+
for col_name in columns_to_plot:
|
|
397
|
+
try:
|
|
398
|
+
is_numeric = is_numeric_dtype(df[col_name])
|
|
399
|
+
n_unique = df[col_name].nunique()
|
|
400
|
+
|
|
401
|
+
# --- 3. Determine Plot Type ---
|
|
402
|
+
is_continuous = False
|
|
403
|
+
if categorical_set is not None:
|
|
404
|
+
# Use the explicit list
|
|
405
|
+
if col_name not in categorical_set:
|
|
406
|
+
is_continuous = True
|
|
407
|
+
else:
|
|
408
|
+
# Use auto-detection
|
|
409
|
+
if is_numeric and n_unique > categorical_cardinality_threshold:
|
|
410
|
+
is_continuous = True
|
|
411
|
+
|
|
412
|
+
# --- Case 1: Continuous Numeric (Histogram) ---
|
|
413
|
+
if is_continuous:
|
|
414
|
+
plt.figure(figsize=(10, 6))
|
|
415
|
+
# Drop NaNs for histogram, as they can't be plotted on a numeric axis
|
|
416
|
+
sns.histplot(x=df[col_name].dropna(), kde=True, bins=30)
|
|
417
|
+
plt.title(f"Distribution of '{col_name}' (Continuous)")
|
|
418
|
+
plt.xlabel(col_name)
|
|
419
|
+
plt.ylabel("Count")
|
|
420
|
+
|
|
421
|
+
save_path = numeric_dir / f"{sanitize_filename(col_name)}.svg"
|
|
422
|
+
numeric_plots_saved += 1
|
|
423
|
+
|
|
424
|
+
# --- Case 2: Categorical or Low-Cardinality Numeric (Count Plot) ---
|
|
425
|
+
else:
|
|
426
|
+
# Check max categories
|
|
427
|
+
if n_unique > max_categories:
|
|
428
|
+
_LOGGER.warning(f"Skipping plot for '{col_name}': {n_unique} unique values > {max_categories} max_categories.")
|
|
429
|
+
continue
|
|
430
|
+
|
|
431
|
+
# Adaptive figure size
|
|
432
|
+
fig_width = max(10, n_unique * 0.5)
|
|
433
|
+
plt.figure(figsize=(fig_width, 7))
|
|
434
|
+
|
|
435
|
+
# Make a temporary copy for plotting to handle NaNs
|
|
436
|
+
temp_series = df[col_name].copy()
|
|
437
|
+
|
|
438
|
+
# Handle NaNs by replacing them with the specified string
|
|
439
|
+
if temp_series.isnull().any():
|
|
440
|
+
# Convert to object type first to allow string replacement
|
|
441
|
+
temp_series = temp_series.astype(object).fillna(fill_na_with)
|
|
442
|
+
|
|
443
|
+
# Convert all to string to be safe (handles low-card numeric)
|
|
444
|
+
temp_series = temp_series.astype(str)
|
|
445
|
+
|
|
446
|
+
# Get category order by frequency
|
|
447
|
+
order = temp_series.value_counts().index
|
|
448
|
+
sns.countplot(x=temp_series, order=order, palette="viridis")
|
|
449
|
+
|
|
450
|
+
plt.title(f"Distribution of '{col_name}' (Categorical)")
|
|
451
|
+
plt.xlabel(col_name)
|
|
452
|
+
plt.ylabel("Count")
|
|
453
|
+
|
|
454
|
+
# Smart tick rotation
|
|
455
|
+
max_label_len = 0
|
|
456
|
+
if n_unique > 0:
|
|
457
|
+
max_label_len = max(len(str(s)) for s in order)
|
|
458
|
+
|
|
459
|
+
# Rotate if labels are long OR there are many categories
|
|
460
|
+
if max_label_len > 10 or n_unique > 25:
|
|
461
|
+
plt.xticks(rotation=45, ha='right')
|
|
462
|
+
|
|
463
|
+
save_path = categorical_dir / f"{sanitize_filename(col_name)}.svg"
|
|
464
|
+
categorical_plots_saved += 1
|
|
465
|
+
|
|
466
|
+
# --- 4. Save Plot ---
|
|
467
|
+
plt.grid(True, linestyle='--', alpha=0.6, axis='y')
|
|
468
|
+
plt.tight_layout()
|
|
469
|
+
# Save as .svg
|
|
470
|
+
plt.savefig(save_path, format='svg', bbox_inches="tight")
|
|
471
|
+
plt.close()
|
|
472
|
+
|
|
473
|
+
except Exception as e:
|
|
474
|
+
_LOGGER.error(f"Failed to plot distribution for '{col_name}'. Error: {e}")
|
|
475
|
+
plt.close()
|
|
476
|
+
|
|
477
|
+
_LOGGER.info(f"Saved {numeric_plots_saved} continuous distribution plots to '{numeric_dir.name}'.")
|
|
478
|
+
_LOGGER.info(f"Saved {categorical_plots_saved} categorical distribution plots to '{categorical_dir.name}'.")
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def plot_continuous_vs_target(
|
|
482
|
+
df: pd.DataFrame,
|
|
483
|
+
targets: List[str],
|
|
484
|
+
save_dir: Union[str, Path],
|
|
485
|
+
features: Optional[List[str]] = None
|
|
486
|
+
):
|
|
487
|
+
"""
|
|
488
|
+
Plots each continuous feature against each target to visualize linear relationships.
|
|
489
|
+
|
|
490
|
+
This function is a common EDA step for regression tasks. It creates a
|
|
491
|
+
scatter plot for each feature-target pair, overlays a simple linear
|
|
492
|
+
regression line, and saves each plot as an individual .svg file.
|
|
493
|
+
|
|
494
|
+
Plots are saved in a structured way, with a subdirectory created for
|
|
495
|
+
each target variable.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
df (pd.DataFrame): The input DataFrame.
|
|
499
|
+
targets (List[str]): A list of target column names to plot (y-axis).
|
|
500
|
+
save_dir (str | Path): The base directory where plots will be saved. A subdirectory will be created here for each target.
|
|
501
|
+
features (List[str] | None): A list of feature column names to plot (x-axis). If None, all non-target columns in the
|
|
502
|
+
DataFrame will be used.
|
|
503
|
+
|
|
504
|
+
Notes:
|
|
505
|
+
- Only numeric features and numeric targets are processed. Non-numeric
|
|
506
|
+
columns in the lists will be skipped with a warning.
|
|
507
|
+
- Rows with NaN in either the feature or the target are dropped
|
|
508
|
+
pairwise for each plot.
|
|
509
|
+
"""
|
|
510
|
+
# 1. Validate the base save directory
|
|
511
|
+
base_save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
512
|
+
|
|
513
|
+
# 2. Validate helper
|
|
514
|
+
def _validate_numeric_cols(col_list: List[str], col_type: str) -> List[str]:
|
|
515
|
+
valid_cols = []
|
|
516
|
+
for col in col_list:
|
|
517
|
+
if col not in df.columns:
|
|
518
|
+
_LOGGER.warning(f"{col_type} column '{col}' not found. Skipping.")
|
|
519
|
+
elif not is_numeric_dtype(df[col]):
|
|
520
|
+
_LOGGER.warning(f"{col_type} column '{col}' is not numeric. Skipping.")
|
|
521
|
+
else:
|
|
522
|
+
valid_cols.append(col)
|
|
523
|
+
return valid_cols
|
|
524
|
+
|
|
525
|
+
# 3. Validate target columns FIRST
|
|
526
|
+
valid_targets = _validate_numeric_cols(targets, "Target")
|
|
527
|
+
if not valid_targets:
|
|
528
|
+
_LOGGER.error("No valid numeric target columns provided to plot.")
|
|
529
|
+
return
|
|
530
|
+
|
|
531
|
+
# 4. Determine and validate feature columns
|
|
532
|
+
if features is None:
|
|
533
|
+
_LOGGER.info("No 'features' list provided. Using all non-target columns as features.")
|
|
534
|
+
target_set = set(valid_targets)
|
|
535
|
+
# Get all columns that are not in the valid_targets set
|
|
536
|
+
features_to_validate = [col for col in df.columns if col not in target_set]
|
|
537
|
+
else:
|
|
538
|
+
features_to_validate = features
|
|
539
|
+
|
|
540
|
+
valid_features = _validate_numeric_cols(features_to_validate, "Feature")
|
|
541
|
+
|
|
542
|
+
if not valid_features:
|
|
543
|
+
_LOGGER.error("No valid numeric feature columns found to plot.")
|
|
544
|
+
return
|
|
545
|
+
|
|
546
|
+
# 5. Main plotting loop
|
|
547
|
+
total_plots_saved = 0
|
|
548
|
+
|
|
549
|
+
for target_name in valid_targets:
|
|
550
|
+
# Create a sanitized subdirectory for this target
|
|
551
|
+
safe_target_dir_name = sanitize_filename(f"{target_name}_vs_Continuous")
|
|
552
|
+
target_save_dir = base_save_path / safe_target_dir_name
|
|
553
|
+
target_save_dir.mkdir(parents=True, exist_ok=True)
|
|
554
|
+
|
|
555
|
+
_LOGGER.info(f"Generating plots for target: '{target_name}' -> Saving to '{target_save_dir.name}'")
|
|
556
|
+
|
|
557
|
+
for feature_name in valid_features:
|
|
558
|
+
|
|
559
|
+
# Drop NaNs pairwise for this specific plot
|
|
560
|
+
temp_df = df[[feature_name, target_name]].dropna()
|
|
561
|
+
|
|
562
|
+
if temp_df.empty:
|
|
563
|
+
_LOGGER.warning(f"No non-null data for '{feature_name}' vs '{target_name}'. Skipping plot.")
|
|
564
|
+
continue
|
|
565
|
+
|
|
566
|
+
x = temp_df[feature_name]
|
|
567
|
+
y = temp_df[target_name]
|
|
568
|
+
|
|
569
|
+
# 6. Perform linear fit
|
|
570
|
+
try:
|
|
571
|
+
# Use numpy's polyfit to get the slope (pf[0]) and intercept (pf[1])
|
|
572
|
+
pf = np.polyfit(x, y, 1)
|
|
573
|
+
# Create a polynomial function p(x)
|
|
574
|
+
p = np.poly1d(pf)
|
|
575
|
+
plot_regression_line = True
|
|
576
|
+
except (np.linalg.LinAlgError, ValueError):
|
|
577
|
+
_LOGGER.warning(f"Linear regression failed for '{feature_name}' vs '{target_name}'. Plotting scatter only.")
|
|
578
|
+
plot_regression_line = False
|
|
579
|
+
|
|
580
|
+
# 7. Create the plot
|
|
581
|
+
plt.figure(figsize=(10, 6))
|
|
582
|
+
ax = plt.gca()
|
|
583
|
+
|
|
584
|
+
# Plot the raw data points
|
|
585
|
+
ax.plot(x, y, 'o', alpha=0.5, label='Data points', markersize=5)
|
|
586
|
+
|
|
587
|
+
# Plot the regression line
|
|
588
|
+
if plot_regression_line:
|
|
589
|
+
ax.plot(x, p(x), "r--", label='Linear Fit') # type: ignore
|
|
590
|
+
|
|
591
|
+
ax.set_title(f'{feature_name} vs {target_name}')
|
|
592
|
+
ax.set_xlabel(feature_name)
|
|
593
|
+
ax.set_ylabel(target_name)
|
|
594
|
+
ax.legend()
|
|
595
|
+
plt.grid(True, linestyle='--', alpha=0.6)
|
|
596
|
+
plt.tight_layout()
|
|
597
|
+
|
|
598
|
+
# 8. Save the plot
|
|
599
|
+
safe_feature_name = sanitize_filename(feature_name)
|
|
600
|
+
plot_filename = f"{safe_feature_name}_vs_{safe_target_dir_name}.svg"
|
|
601
|
+
plot_path = target_save_dir / plot_filename
|
|
602
|
+
|
|
603
|
+
try:
|
|
604
|
+
plt.savefig(plot_path, bbox_inches="tight", format='svg')
|
|
605
|
+
total_plots_saved += 1
|
|
606
|
+
except Exception as e:
|
|
607
|
+
_LOGGER.error(f"Failed to save plot: {plot_path}. Error: {e}")
|
|
608
|
+
|
|
609
|
+
# Close the figure to free up memory
|
|
610
|
+
plt.close()
|
|
611
|
+
|
|
612
|
+
_LOGGER.info(f"Successfully saved {total_plots_saved} feature-vs-target plots to '{base_save_path}'.")
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
def plot_categorical_vs_target(
|
|
616
|
+
df: pd.DataFrame,
|
|
617
|
+
targets: List[str],
|
|
618
|
+
save_dir: Union[str, Path],
|
|
619
|
+
features: Optional[List[str]] = None,
|
|
620
|
+
plot_type: Literal["box", "violin"] = "box",
|
|
621
|
+
max_categories: int = 20,
|
|
622
|
+
fill_na_with: str = "Missing"
|
|
623
|
+
):
|
|
624
|
+
"""
|
|
625
|
+
Plots each categorical feature against each numeric target using box or violin plots.
|
|
626
|
+
|
|
627
|
+
This function is a core EDA step for regression tasks to understand the
|
|
628
|
+
relationship between a categorical independent variable and a continuous
|
|
629
|
+
dependent variable.
|
|
630
|
+
|
|
631
|
+
Args:
|
|
632
|
+
df (pd.DataFrame): The input DataFrame.
|
|
633
|
+
targets (List[str]): A list of numeric target column names (y-axis).
|
|
634
|
+
save_dir (str | Path): The base directory where plots will be saved. A subdirectory will be created here for each target.
|
|
635
|
+
features (List[str] | None): A list of categorical feature column names (x-axis). If None, all non-numeric (object) columns will be used.
|
|
636
|
+
plot_type (Literal["box", "violin"]): The type of plot to generate.
|
|
637
|
+
max_categories (int): The maximum number of unique categories a feature can have to be plotted. Features exceeding this limit will be skipped.
|
|
638
|
+
fill_na_with (str): A string to replace NaN values in categorical columns. This allows plotting 'missingness' as its own category. Defaults to "Missing".
|
|
639
|
+
|
|
640
|
+
Notes:
|
|
641
|
+
- Only numeric targets are processed.
|
|
642
|
+
- Features are automatically identified as categorical if they are 'object' dtype.
|
|
643
|
+
"""
|
|
644
|
+
# 1. Validate the base save directory and inputs
|
|
645
|
+
base_save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
646
|
+
|
|
647
|
+
if plot_type not in ["box", "violin"]:
|
|
648
|
+
_LOGGER.error(f"Invalid plot type '{plot_type}'")
|
|
649
|
+
raise ValueError()
|
|
650
|
+
|
|
651
|
+
# 2. Validate target columns (must be numeric)
|
|
652
|
+
valid_targets = []
|
|
653
|
+
for col in targets:
|
|
654
|
+
if col not in df.columns:
|
|
655
|
+
_LOGGER.warning(f"Target column '{col}' not found. Skipping.")
|
|
656
|
+
elif not is_numeric_dtype(df[col]):
|
|
657
|
+
_LOGGER.warning(f"Target column '{col}' is not numeric. Skipping.")
|
|
658
|
+
else:
|
|
659
|
+
valid_targets.append(col)
|
|
660
|
+
|
|
661
|
+
if not valid_targets:
|
|
662
|
+
_LOGGER.error("No valid numeric target columns provided to plot.")
|
|
663
|
+
return
|
|
664
|
+
|
|
665
|
+
# 3. Determine and validate feature columns
|
|
666
|
+
features_to_plot = []
|
|
667
|
+
if features is None:
|
|
668
|
+
_LOGGER.info("No 'features' list provided. Auto-detecting categorical features.")
|
|
669
|
+
for col in df.columns:
|
|
670
|
+
if col in valid_targets:
|
|
671
|
+
continue
|
|
672
|
+
|
|
673
|
+
# Auto-include object dtypes
|
|
674
|
+
if is_object_dtype(df[col]):
|
|
675
|
+
features_to_plot.append(col)
|
|
676
|
+
# Auto-include low-cardinality numeric features - REMOVED
|
|
677
|
+
# elif is_numeric_dtype(df[col]) and df[col].nunique() <= max_categories:
|
|
678
|
+
# _LOGGER.info(f"Treating low-cardinality numeric column '{col}' as categorical.")
|
|
679
|
+
# features_to_plot.append(col)
|
|
680
|
+
else:
|
|
681
|
+
# Validate user-provided list
|
|
682
|
+
for col in features:
|
|
683
|
+
if col not in df.columns:
|
|
684
|
+
_LOGGER.warning(f"Feature column '{col}' not found in DataFrame. Skipping.")
|
|
685
|
+
else:
|
|
686
|
+
features_to_plot.append(col)
|
|
687
|
+
|
|
688
|
+
if not features_to_plot:
|
|
689
|
+
_LOGGER.error("No valid categorical feature columns found to plot.")
|
|
690
|
+
return
|
|
691
|
+
|
|
692
|
+
# 4. Main plotting loop
|
|
693
|
+
total_plots_saved = 0
|
|
694
|
+
|
|
695
|
+
for target_name in valid_targets:
|
|
696
|
+
# Create a sanitized subdirectory for this target
|
|
697
|
+
safe_target_dir_name = sanitize_filename(f"{target_name}_vs_Categorical_{plot_type}")
|
|
698
|
+
target_save_dir = base_save_path / safe_target_dir_name
|
|
699
|
+
target_save_dir.mkdir(parents=True, exist_ok=True)
|
|
700
|
+
|
|
701
|
+
_LOGGER.info(f"Generating '{plot_type}' plots for target: '{target_name}' -> Saving to '{target_save_dir.name}'")
|
|
702
|
+
|
|
703
|
+
for feature_name in features_to_plot:
|
|
704
|
+
|
|
705
|
+
# Make a temporary copy for plotting to handle NaNs and dtypes
|
|
706
|
+
temp_df = df[[feature_name, target_name]].copy()
|
|
707
|
+
|
|
708
|
+
# Check cardinality
|
|
709
|
+
n_unique = temp_df[feature_name].nunique()
|
|
710
|
+
if n_unique > max_categories:
|
|
711
|
+
_LOGGER.warning(f"Skipping '{feature_name}': {n_unique} unique values > {max_categories} max_categories.")
|
|
712
|
+
continue
|
|
713
|
+
|
|
714
|
+
# Handle NaNs by replacing them with the specified string
|
|
715
|
+
if temp_df[feature_name].isnull().any():
|
|
716
|
+
# Convert to object type first to allow string replacement
|
|
717
|
+
temp_df[feature_name] = temp_df[feature_name].astype(object).fillna(fill_na_with)
|
|
718
|
+
|
|
719
|
+
# Convert feature to string to ensure correct plotting order
|
|
720
|
+
temp_df[feature_name] = temp_df[feature_name].astype(str)
|
|
721
|
+
|
|
722
|
+
# 5. Create the plot
|
|
723
|
+
# Increase figure width for categories
|
|
724
|
+
plt.figure(figsize=(max(10, n_unique * 1.2), 7))
|
|
725
|
+
|
|
726
|
+
if plot_type == "box":
|
|
727
|
+
sns.boxplot(x=feature_name, y=target_name, data=temp_df)
|
|
728
|
+
elif plot_type == "violin":
|
|
729
|
+
sns.violinplot(x=feature_name, y=target_name, data=temp_df)
|
|
730
|
+
|
|
731
|
+
plt.title(f'{target_name} vs {feature_name}')
|
|
732
|
+
plt.xlabel(feature_name)
|
|
733
|
+
plt.ylabel(target_name)
|
|
734
|
+
plt.xticks(rotation=45, ha='right')
|
|
735
|
+
plt.grid(True, linestyle='--', alpha=0.6, axis='y')
|
|
736
|
+
plt.tight_layout()
|
|
737
|
+
|
|
738
|
+
# 6. Save the plot
|
|
739
|
+
safe_feature_name = sanitize_filename(feature_name)
|
|
740
|
+
plot_filename = f"{safe_feature_name}_vs_{safe_target_dir_name}.svg"
|
|
741
|
+
plot_path = target_save_dir / plot_filename
|
|
742
|
+
|
|
743
|
+
try:
|
|
744
|
+
plt.savefig(plot_path, bbox_inches="tight", format='svg')
|
|
745
|
+
total_plots_saved += 1
|
|
746
|
+
except Exception as e:
|
|
747
|
+
_LOGGER.error(f"Failed to save plot: {plot_path}. Error: {e}")
|
|
748
|
+
|
|
749
|
+
plt.close()
|
|
750
|
+
|
|
751
|
+
_LOGGER.info(f"Successfully saved {total_plots_saved} categorical-vs-target plots to '{base_save_path}'.")
|
|
752
|
+
|
|
753
|
+
|
|
345
754
|
def encode_categorical_features(
|
|
346
755
|
df: pd.DataFrame,
|
|
347
756
|
columns_to_encode: List[str],
|
|
@@ -580,108 +989,108 @@ def plot_correlation_heatmap(df: pd.DataFrame,
|
|
|
580
989
|
plt.show()
|
|
581
990
|
plt.close()
|
|
582
991
|
|
|
583
|
-
|
|
584
|
-
def plot_value_distributions(df: pd.DataFrame, save_dir: Union[str, Path], bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
992
|
+
# OLD IMPLEMENTATION
|
|
993
|
+
# def plot_value_distributions(df: pd.DataFrame, save_dir: Union[str, Path], bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
|
|
994
|
+
# """
|
|
995
|
+
# Plots and saves the value distributions for all (or selected) columns in a DataFrame,
|
|
996
|
+
# with adaptive binning for numerical columns when appropriate.
|
|
997
|
+
|
|
998
|
+
# For each column both raw counts and relative frequencies are computed and plotted.
|
|
999
|
+
|
|
1000
|
+
# Plots are saved as PNG files under two subdirectories in `save_dir`:
|
|
1001
|
+
# - "Distribution_Counts" for absolute counts.
|
|
1002
|
+
# - "Distribution_Frequency" for relative frequencies.
|
|
1003
|
+
|
|
1004
|
+
# Args:
|
|
1005
|
+
# df (pd.DataFrame): The input DataFrame whose columns are to be analyzed.
|
|
1006
|
+
# save_dir (str | Path): Directory path where the plots will be saved. Will be created if it does not exist.
|
|
1007
|
+
# bin_threshold (int): Minimum number of unique values required to trigger binning
|
|
1008
|
+
# for numerical columns.
|
|
1009
|
+
# skip_cols_with_key (str | None): If provided, any column whose name contains this
|
|
1010
|
+
# substring will be excluded from analysis.
|
|
1011
|
+
|
|
1012
|
+
# Notes:
|
|
1013
|
+
# - Binning is adaptive: if quantile binning results in ≤ 2 unique bins, raw values are used instead.
|
|
1014
|
+
# - All non-alphanumeric characters in column names are sanitized for safe file naming.
|
|
1015
|
+
# - Colormap is automatically adapted based on the number of categories or bins.
|
|
1016
|
+
# """
|
|
1017
|
+
# save_path = make_fullpath(save_dir, make=True)
|
|
609
1018
|
|
|
610
|
-
|
|
611
|
-
|
|
1019
|
+
# dict_to_plot_std = dict()
|
|
1020
|
+
# dict_to_plot_freq = dict()
|
|
612
1021
|
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
1022
|
+
# # cherry-pick columns
|
|
1023
|
+
# if skip_cols_with_key is not None:
|
|
1024
|
+
# columns = [col for col in df.columns if skip_cols_with_key not in col]
|
|
1025
|
+
# else:
|
|
1026
|
+
# columns = df.columns.to_list()
|
|
618
1027
|
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
1028
|
+
# saved_plots = 0
|
|
1029
|
+
# for col in columns:
|
|
1030
|
+
# if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() > bin_threshold:
|
|
1031
|
+
# bins_number = 10
|
|
1032
|
+
# binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
|
|
1033
|
+
# while binned.nunique() <= 2:
|
|
1034
|
+
# bins_number -= 1
|
|
1035
|
+
# binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
|
|
1036
|
+
# if bins_number <= 2:
|
|
1037
|
+
# break
|
|
629
1038
|
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
1039
|
+
# if binned.nunique() <= 2:
|
|
1040
|
+
# view_std = df[col].value_counts(sort=False).sort_index()
|
|
1041
|
+
# else:
|
|
1042
|
+
# view_std = binned.value_counts(sort=False)
|
|
634
1043
|
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
1044
|
+
# else:
|
|
1045
|
+
# view_std = df[col].value_counts(sort=False).sort_index()
|
|
1046
|
+
|
|
1047
|
+
# # unlikely scenario where the series is empty
|
|
1048
|
+
# if view_std.sum() == 0:
|
|
1049
|
+
# view_freq = view_std
|
|
1050
|
+
# else:
|
|
1051
|
+
# view_freq = 100 * view_std / view_std.sum() # Percentage
|
|
1052
|
+
# # view_freq = df[col].value_counts(normalize=True, bins=10) # relative percentages
|
|
644
1053
|
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
1054
|
+
# dict_to_plot_std[col] = dict(view_std)
|
|
1055
|
+
# dict_to_plot_freq[col] = dict(view_freq)
|
|
1056
|
+
# saved_plots += 1
|
|
648
1057
|
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
1058
|
+
# # plot helper
|
|
1059
|
+
# def _plot_helper(dict_: dict, target_dir: Path, ylabel: Literal["Frequency", "Counts"], base_fontsize: int=12):
|
|
1060
|
+
# for col, data in dict_.items():
|
|
1061
|
+
# safe_col = sanitize_filename(col)
|
|
653
1062
|
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
1063
|
+
# if isinstance(list(data.keys())[0], pd.Interval):
|
|
1064
|
+
# labels = [str(interval) for interval in data.keys()]
|
|
1065
|
+
# else:
|
|
1066
|
+
# labels = data.keys()
|
|
658
1067
|
|
|
659
|
-
|
|
660
|
-
|
|
1068
|
+
# plt.figure(figsize=(10, 6))
|
|
1069
|
+
# colors = plt.cm.tab20.colors if len(data) <= 20 else plt.cm.viridis(np.linspace(0, 1, len(data))) # type: ignore
|
|
661
1070
|
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
1071
|
+
# plt.bar(labels, data.values(), color=colors[:len(data)], alpha=0.85)
|
|
1072
|
+
# plt.xlabel("Values", fontsize=base_fontsize)
|
|
1073
|
+
# plt.ylabel(ylabel, fontsize=base_fontsize)
|
|
1074
|
+
# plt.title(f"Value Distribution for '{col}'", fontsize=base_fontsize+2)
|
|
1075
|
+
# plt.xticks(rotation=45, ha='right', fontsize=base_fontsize-2)
|
|
1076
|
+
# plt.yticks(fontsize=base_fontsize-2)
|
|
1077
|
+
# plt.grid(axis='y', linestyle='--', alpha=0.6)
|
|
1078
|
+
# plt.gca().set_facecolor('#f9f9f9')
|
|
1079
|
+
# plt.tight_layout()
|
|
671
1080
|
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
1081
|
+
# plot_path = target_dir / f"{safe_col}.png"
|
|
1082
|
+
# plt.savefig(plot_path, dpi=300, bbox_inches="tight")
|
|
1083
|
+
# plt.close()
|
|
675
1084
|
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
1085
|
+
# # Save plots
|
|
1086
|
+
# freq_dir = save_path / "Distribution_Frequency"
|
|
1087
|
+
# std_dir = save_path / "Distribution_Counts"
|
|
1088
|
+
# freq_dir.mkdir(parents=True, exist_ok=True)
|
|
1089
|
+
# std_dir.mkdir(parents=True, exist_ok=True)
|
|
1090
|
+
# _plot_helper(dict_=dict_to_plot_std, target_dir=std_dir, ylabel="Counts")
|
|
1091
|
+
# _plot_helper(dict_=dict_to_plot_freq, target_dir=freq_dir, ylabel="Frequency")
|
|
683
1092
|
|
|
684
|
-
|
|
1093
|
+
# _LOGGER.info(f"Saved {saved_plots} value distribution plots.")
|
|
685
1094
|
|
|
686
1095
|
|
|
687
1096
|
def clip_outliers_single(
|
|
@@ -1262,7 +1671,7 @@ def finalize_feature_schema(
|
|
|
1262
1671
|
categorical_feature_names_list: List[str] = []
|
|
1263
1672
|
categorical_index_map_dict: Dict[int, int] = {}
|
|
1264
1673
|
|
|
1265
|
-
_LOGGER.info("Finalizing feature schema...")
|
|
1674
|
+
# _LOGGER.info("Finalizing feature schema...")
|
|
1266
1675
|
|
|
1267
1676
|
if categorical_mappings:
|
|
1268
1677
|
# --- Categorical features are present ---
|