validmind 2.7.2__py3-none-any.whl → 2.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. validmind/__version__.py +1 -1
  2. validmind/ai/test_descriptions.py +20 -4
  3. validmind/ai/test_result_description/user.jinja +5 -0
  4. validmind/datasets/credit_risk/lending_club.py +444 -14
  5. validmind/tests/data_validation/MutualInformation.py +129 -0
  6. validmind/tests/data_validation/ScoreBandDefaultRates.py +139 -0
  7. validmind/tests/data_validation/TooManyZeroValues.py +6 -5
  8. validmind/tests/data_validation/UniqueRows.py +3 -1
  9. validmind/tests/decorator.py +18 -16
  10. validmind/tests/load.py +4 -1
  11. validmind/tests/model_validation/sklearn/CalibrationCurve.py +116 -0
  12. validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +261 -0
  13. validmind/tests/model_validation/sklearn/ConfusionMatrix.py +1 -0
  14. validmind/tests/model_validation/sklearn/HyperParametersTuning.py +144 -56
  15. validmind/tests/model_validation/sklearn/ModelParameters.py +74 -0
  16. validmind/tests/model_validation/sklearn/ScoreProbabilityAlignment.py +130 -0
  17. validmind/tests/model_validation/statsmodels/CumulativePredictionProbabilities.py +5 -6
  18. validmind/tests/model_validation/statsmodels/PredictionProbabilitiesHistogram.py +2 -3
  19. validmind/tests/run.py +43 -72
  20. validmind/utils.py +23 -7
  21. validmind/vm_models/result/result.py +18 -17
  22. {validmind-2.7.2.dist-info → validmind-2.7.5.dist-info}/METADATA +2 -2
  23. {validmind-2.7.2.dist-info → validmind-2.7.5.dist-info}/RECORD +26 -20
  24. {validmind-2.7.2.dist-info → validmind-2.7.5.dist-info}/WHEEL +1 -1
  25. {validmind-2.7.2.dist-info → validmind-2.7.5.dist-info}/LICENSE +0 -0
  26. {validmind-2.7.2.dist-info → validmind-2.7.5.dist-info}/entry_points.txt +0 -0
validmind/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "2.7.2"
1
+ __version__ = "2.7.5"
@@ -65,6 +65,23 @@ def prompt_to_message(role, prompt):
65
65
  return {"role": role, "content": content}
66
66
 
67
67
 
68
+ def _get_llm_global_context():
69
+
70
+ # Get the context from the environment variable
71
+ context = os.getenv("VALIDMIND_LLM_DESCRIPTIONS_CONTEXT", "")
72
+
73
+ # Check if context should be used (similar to descriptions enabled pattern)
74
+ context_enabled = os.getenv(
75
+ "VALIDMIND_LLM_DESCRIPTIONS_CONTEXT_ENABLED", "1"
76
+ ) not in [
77
+ "0",
78
+ "false",
79
+ ]
80
+
81
+ # Only use context if it's enabled and not empty
82
+ return context if context_enabled and context else None
83
+
84
+
68
85
  def generate_description(
69
86
  test_id: str,
70
87
  test_description: str,
@@ -79,15 +96,11 @@ def generate_description(
79
96
  "No tables, unit metric or figures provided - cannot generate description"
80
97
  )
81
98
 
82
- # # TODO: fix circular import
83
- # from validmind.ai.utils import get_client_and_model
84
-
85
99
  client, model = get_client_and_model()
86
100
 
87
101
  # get last part of test id
88
102
  test_name = title or test_id.split(".")[-1]
89
103
 
90
- # TODO: fully support metrics
91
104
  if metric is not None:
92
105
  tables = [] if not tables else tables
93
106
  tables.append(
@@ -108,12 +121,15 @@ def generate_description(
108
121
  else:
109
122
  summary = None
110
123
 
124
+ context = _get_llm_global_context()
125
+
111
126
  input_data = {
112
127
  "test_name": test_name,
113
128
  "test_description": test_description,
114
129
  "title": title,
115
130
  "summary": summary,
116
131
  "figures": [figure._get_b64_url() for figure in ([] if tables else figures)],
132
+ "context": context,
117
133
  }
118
134
  system, user = _load_prompt()
119
135
 
@@ -8,6 +8,11 @@
8
8
 
9
9
  Generate a description of the following result of the test using the instructions given in your system prompt.
10
10
 
11
+ {%- if context %}
12
+ **Context**:
13
+ {{ context }}
14
+ {%- endif %}
15
+
11
16
  {%- if summary %}
12
17
  **Test Result Tables** *(Raw Data)*:
13
18
  {{ summary }}
@@ -355,33 +355,76 @@ def _woebin(df):
355
355
  return bins_df
356
356
 
357
357
 
358
- def split(df, add_constant=False):
358
+ def split(df, validation_size=None, test_size=0.2, add_constant=False):
359
+ """
360
+ Split dataset into train, validation (optional), and test sets.
361
+
362
+ Args:
363
+ df: Input DataFrame
364
+ validation_split: If None, returns train/test split. If float, returns train/val/test split
365
+ test_size: Proportion of data for test set (default: 0.2)
366
+ add_constant: Whether to add constant column for statsmodels (default: False)
367
+
368
+ Returns:
369
+ If validation_size is None:
370
+ train_df, test_df
371
+ If validation_size is float:
372
+ train_df, validation_df, test_df
373
+ """
359
374
  df = df.copy()
360
375
 
361
- # Splitting the dataset into training and test sets
362
- train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
376
+ # First split off the test set
377
+ train_val_df, test_df = train_test_split(df, test_size=test_size, random_state=42)
363
378
 
364
379
  if add_constant:
365
- # Add a constant to the model for both training and testing datasets
366
- train_df = sm.add_constant(train_df)
367
380
  test_df = sm.add_constant(test_df)
368
381
 
369
- # Calculate and print details for the training dataset
370
- print("After splitting the dataset into training and test sets:")
371
- print(
372
- f"Training Dataset:\nRows: {train_df.shape[0]}\nColumns: {train_df.shape[1]}\nMissing values: {train_df.isnull().sum().sum()}\n"
382
+ if validation_size is None:
383
+ if add_constant:
384
+ train_val_df = sm.add_constant(train_val_df)
385
+
386
+ # Print details for two-way split
387
+ print("After splitting the dataset into training and test sets:")
388
+ print(
389
+ f"Training Dataset:\nRows: {train_val_df.shape[0]}\nColumns: {train_val_df.shape[1]}\n"
390
+ f"Missing values: {train_val_df.isnull().sum().sum()}\n"
391
+ )
392
+ print(
393
+ f"Test Dataset:\nRows: {test_df.shape[0]}\nColumns: {test_df.shape[1]}\n"
394
+ f"Missing values: {test_df.isnull().sum().sum()}\n"
395
+ )
396
+
397
+ return train_val_df, test_df
398
+
399
+ # Calculate validation size as proportion of remaining data
400
+ val_size = validation_size / (1 - test_size)
401
+ train_df, validation_df = train_test_split(
402
+ train_val_df, test_size=val_size, random_state=42
373
403
  )
374
404
 
375
- # Calculate and print details for the test dataset
405
+ if add_constant:
406
+ train_df = sm.add_constant(train_df)
407
+ validation_df = sm.add_constant(validation_df)
408
+
409
+ # Print details for three-way split
410
+ print("After splitting the dataset into training, validation, and test sets:")
376
411
  print(
377
- f"Test Dataset:\nRows: {test_df.shape[0]}\nColumns: {test_df.shape[1]}\nMissing values: {test_df.isnull().sum().sum()}\n"
412
+ f"Training Dataset:\nRows: {train_df.shape[0]}\nColumns: {train_df.shape[1]}\n"
413
+ f"Missing values: {train_df.isnull().sum().sum()}\n"
414
+ )
415
+ print(
416
+ f"Validation Dataset:\nRows: {validation_df.shape[0]}\nColumns: {validation_df.shape[1]}\n"
417
+ f"Missing values: {validation_df.isnull().sum().sum()}\n"
418
+ )
419
+ print(
420
+ f"Test Dataset:\nRows: {test_df.shape[0]}\nColumns: {test_df.shape[1]}\n"
421
+ f"Missing values: {test_df.isnull().sum().sum()}\n"
378
422
  )
379
423
 
380
- return train_df, test_df
424
+ return train_df, validation_df, test_df
381
425
 
382
426
 
383
427
  def compute_scores(probabilities):
384
-
385
428
  target_score = score_params["target_score"]
386
429
  target_odds = score_params["target_odds"]
387
430
  pdo = score_params["pdo"]
@@ -389,6 +432,393 @@ def compute_scores(probabilities):
389
432
  factor = pdo / np.log(2)
390
433
  offset = target_score - (factor * np.log(target_odds))
391
434
 
392
- scores = offset + factor * np.log(probabilities / (1 - probabilities))
435
+ # Add negative sign to reverse the relationship
436
+ scores = offset - factor * np.log(probabilities / (1 - probabilities))
393
437
 
394
438
  return scores
439
+
440
+
441
+ def get_demo_test_config(x_test=None, y_test=None):
442
+ """Get demo test configuration.
443
+
444
+ Args:
445
+ x_test: Test features DataFrame
446
+ y_test: Test target Series
447
+
448
+ Returns:
449
+ dict: Test configuration dictionary
450
+ """
451
+ default_config = {}
452
+
453
+ # RAW DATA TESTS
454
+ default_config["validmind.data_validation.DatasetDescription:raw_data"] = {
455
+ "inputs": {
456
+ "dataset": "raw_dataset",
457
+ }
458
+ }
459
+ default_config["validmind.data_validation.DescriptiveStatistics:raw_data"] = {
460
+ "inputs": {
461
+ "dataset": "raw_dataset",
462
+ }
463
+ }
464
+ default_config["validmind.data_validation.MissingValues:raw_data"] = {
465
+ "inputs": {
466
+ "dataset": "raw_dataset",
467
+ },
468
+ "params": {"min_threshold": 1},
469
+ }
470
+ default_config["validmind.data_validation.ClassImbalance:raw_data"] = {
471
+ "inputs": {
472
+ "dataset": "raw_dataset",
473
+ },
474
+ "params": {"min_percent_threshold": 10},
475
+ }
476
+ default_config["validmind.data_validation.Duplicates:raw_data"] = {
477
+ "inputs": {
478
+ "dataset": "raw_dataset",
479
+ },
480
+ "params": {"min_threshold": 1},
481
+ }
482
+ default_config["validmind.data_validation.HighCardinality:raw_data"] = {
483
+ "inputs": {
484
+ "dataset": "raw_dataset",
485
+ },
486
+ "params": {
487
+ "num_threshold": 100,
488
+ "percent_threshold": 0.1,
489
+ "threshold_type": "percent",
490
+ },
491
+ }
492
+ default_config["validmind.data_validation.Skewness:raw_data"] = {
493
+ "inputs": {
494
+ "dataset": "raw_dataset",
495
+ },
496
+ "params": {"max_threshold": 1},
497
+ }
498
+ default_config["validmind.data_validation.UniqueRows:raw_data"] = {
499
+ "inputs": {
500
+ "dataset": "raw_dataset",
501
+ },
502
+ "params": {"min_percent_threshold": 1},
503
+ }
504
+ default_config["validmind.data_validation.TooManyZeroValues:raw_data"] = {
505
+ "inputs": {
506
+ "dataset": "raw_dataset",
507
+ },
508
+ "params": {"max_percent_threshold": 0.03},
509
+ }
510
+ default_config["validmind.data_validation.IQROutliersTable:raw_data"] = {
511
+ "inputs": {
512
+ "dataset": "raw_dataset",
513
+ },
514
+ "params": {"threshold": 5},
515
+ }
516
+
517
+ # PREPROCESSED DATA TESTS
518
+ default_config[
519
+ "validmind.data_validation.DescriptiveStatistics:preprocessed_data"
520
+ ] = {
521
+ "inputs": {
522
+ "dataset": "preprocess_dataset",
523
+ }
524
+ }
525
+ default_config[
526
+ "validmind.data_validation.TabularDescriptionTables:preprocessed_data"
527
+ ] = {
528
+ "inputs": {
529
+ "dataset": "preprocess_dataset",
530
+ }
531
+ }
532
+ default_config["validmind.data_validation.MissingValues:preprocessed_data"] = {
533
+ "inputs": {
534
+ "dataset": "preprocess_dataset",
535
+ },
536
+ "params": {"min_threshold": 1},
537
+ }
538
+ default_config[
539
+ "validmind.data_validation.TabularNumericalHistograms:preprocessed_data"
540
+ ] = {
541
+ "inputs": {
542
+ "dataset": "preprocess_dataset",
543
+ }
544
+ }
545
+ default_config[
546
+ "validmind.data_validation.TabularCategoricalBarPlots:preprocessed_data"
547
+ ] = {
548
+ "inputs": {
549
+ "dataset": "preprocess_dataset",
550
+ }
551
+ }
552
+ default_config["validmind.data_validation.TargetRateBarPlots:preprocessed_data"] = {
553
+ "inputs": {
554
+ "dataset": "preprocess_dataset",
555
+ },
556
+ "params": {"default_column": "loan_status"},
557
+ }
558
+
559
+ # DEVELOPMENT DATA TESTS
560
+ default_config[
561
+ "validmind.data_validation.DescriptiveStatistics:development_data"
562
+ ] = {"input_grid": {"dataset": ["train_dataset", "test_dataset"]}}
563
+
564
+ default_config[
565
+ "validmind.data_validation.TabularDescriptionTables:development_data"
566
+ ] = {"input_grid": {"dataset": ["train_dataset", "test_dataset"]}}
567
+
568
+ default_config["validmind.data_validation.ClassImbalance:development_data"] = {
569
+ "input_grid": {"dataset": ["train_dataset", "test_dataset"]},
570
+ "params": {"min_percent_threshold": 10},
571
+ }
572
+
573
+ default_config["validmind.data_validation.UniqueRows:development_data"] = {
574
+ "input_grid": {"dataset": ["train_dataset", "test_dataset"]},
575
+ "params": {"min_percent_threshold": 1},
576
+ }
577
+
578
+ default_config[
579
+ "validmind.data_validation.TabularNumericalHistograms:development_data"
580
+ ] = {"input_grid": {"dataset": ["train_dataset", "test_dataset"]}}
581
+
582
+ # FEATURE SELECTION TESTS
583
+ default_config["validmind.data_validation.MutualInformation:development_data"] = {
584
+ "input_grid": {"dataset": ["train_dataset", "test_dataset"]},
585
+ "params": {"min_threshold": 0.01},
586
+ }
587
+
588
+ default_config[
589
+ "validmind.data_validation.PearsonCorrelationMatrix:development_data"
590
+ ] = {"input_grid": {"dataset": ["train_dataset", "test_dataset"]}}
591
+
592
+ default_config[
593
+ "validmind.data_validation.HighPearsonCorrelation:development_data"
594
+ ] = {
595
+ "input_grid": {"dataset": ["train_dataset", "test_dataset"]},
596
+ "params": {"max_threshold": 0.3, "top_n_correlations": 10},
597
+ }
598
+
599
+ default_config["validmind.data_validation.WOEBinTable"] = {
600
+ "input_grid": {"dataset": ["preprocess_dataset"]},
601
+ "params": {"breaks_adj": breaks_adj},
602
+ }
603
+
604
+ default_config["validmind.data_validation.WOEBinPlots"] = {
605
+ "input_grid": {"dataset": ["preprocess_dataset"]},
606
+ "params": {"breaks_adj": breaks_adj},
607
+ }
608
+
609
+ # MODEL TRAINING TESTS
610
+ default_config["validmind.data_validation.DatasetSplit"] = {
611
+ "inputs": {"datasets": ["train_dataset", "test_dataset"]}
612
+ }
613
+
614
+ default_config["validmind.model_validation.ModelMetadata"] = {
615
+ "input_grid": {"model": ["xgb_model", "rf_model"]}
616
+ }
617
+
618
+ default_config["validmind.model_validation.sklearn.ModelParameters"] = {
619
+ "input_grid": {"model": ["xgb_model", "rf_model"]}
620
+ }
621
+
622
+ # MODEL SELECTION TESTS
623
+ default_config["validmind.model_validation.statsmodels.GINITable"] = {
624
+ "input_grid": {
625
+ "dataset": ["train_dataset", "test_dataset"],
626
+ "model": ["xgb_model", "rf_model"],
627
+ }
628
+ }
629
+
630
+ default_config["validmind.model_validation.sklearn.ClassifierPerformance"] = {
631
+ "input_grid": {
632
+ "dataset": ["train_dataset", "test_dataset"],
633
+ "model": ["xgb_model", "rf_model"],
634
+ }
635
+ }
636
+
637
+ default_config[
638
+ "validmind.model_validation.sklearn.TrainingTestDegradation:XGBoost"
639
+ ] = {
640
+ "inputs": {"datasets": ["train_dataset", "test_dataset"], "model": "xgb_model"},
641
+ "params": {"max_threshold": 0.1},
642
+ }
643
+
644
+ default_config[
645
+ "validmind.model_validation.sklearn.TrainingTestDegradation:RandomForest"
646
+ ] = {
647
+ "inputs": {"datasets": ["train_dataset", "test_dataset"], "model": "rf_model"},
648
+ "params": {"max_threshold": 0.1},
649
+ }
650
+
651
+ default_config["validmind.model_validation.sklearn.HyperParametersTuning"] = {
652
+ "inputs": {"model": "xgb_model", "dataset": "train_dataset"},
653
+ "params": {
654
+ "param_grid": {"n_estimators": [50, 100]},
655
+ "scoring": ["roc_auc", "recall"],
656
+ "fit_params": {
657
+ "eval_set": [(x_test, y_test)],
658
+ "verbose": False,
659
+ },
660
+ "thresholds": [0.3, 0.5],
661
+ },
662
+ }
663
+
664
+ # MODEL PERFORMANCE - DISCRIMINATION TESTS
665
+ default_config["validmind.model_validation.sklearn.ROCCurve"] = {
666
+ "input_grid": {
667
+ "dataset": ["train_dataset", "test_dataset"],
668
+ "model": ["xgb_model"],
669
+ }
670
+ }
671
+
672
+ default_config["validmind.model_validation.sklearn.MinimumROCAUCScore"] = {
673
+ "input_grid": {
674
+ "dataset": ["train_dataset", "test_dataset"],
675
+ "model": ["xgb_model"],
676
+ },
677
+ "params": {"min_threshold": 0.5},
678
+ }
679
+
680
+ default_config[
681
+ "validmind.model_validation.statsmodels.PredictionProbabilitiesHistogram"
682
+ ] = {
683
+ "input_grid": {
684
+ "dataset": ["train_dataset", "test_dataset"],
685
+ "model": ["xgb_model"],
686
+ }
687
+ }
688
+
689
+ default_config[
690
+ "validmind.model_validation.statsmodels.CumulativePredictionProbabilities"
691
+ ] = {
692
+ "input_grid": {
693
+ "model": ["xgb_model"],
694
+ "dataset": ["train_dataset", "test_dataset"],
695
+ }
696
+ }
697
+
698
+ default_config["validmind.model_validation.sklearn.PopulationStabilityIndex"] = {
699
+ "inputs": {"datasets": ["train_dataset", "test_dataset"], "model": "xgb_model"},
700
+ "params": {"num_bins": 10, "mode": "fixed"},
701
+ }
702
+
703
+ # MODEL PERFORMANCE - ACCURACY TESTS
704
+ default_config["validmind.model_validation.sklearn.ConfusionMatrix"] = {
705
+ "input_grid": {
706
+ "dataset": ["train_dataset", "test_dataset"],
707
+ "model": ["xgb_model"],
708
+ }
709
+ }
710
+
711
+ default_config["validmind.model_validation.sklearn.MinimumAccuracy"] = {
712
+ "input_grid": {
713
+ "dataset": ["train_dataset", "test_dataset"],
714
+ "model": ["xgb_model"],
715
+ },
716
+ "params": {"min_threshold": 0.7},
717
+ }
718
+
719
+ default_config["validmind.model_validation.sklearn.MinimumF1Score"] = {
720
+ "input_grid": {
721
+ "dataset": ["train_dataset", "test_dataset"],
722
+ "model": ["xgb_model"],
723
+ },
724
+ "params": {"min_threshold": 0.5},
725
+ }
726
+
727
+ default_config["validmind.model_validation.sklearn.PrecisionRecallCurve"] = {
728
+ "input_grid": {
729
+ "dataset": ["train_dataset", "test_dataset"],
730
+ "model": ["xgb_model"],
731
+ }
732
+ }
733
+
734
+ default_config["validmind.model_validation.sklearn.CalibrationCurve"] = {
735
+ "input_grid": {
736
+ "dataset": ["train_dataset", "test_dataset"],
737
+ "model": ["xgb_model"],
738
+ }
739
+ }
740
+
741
+ default_config[
742
+ "validmind.model_validation.sklearn.ClassifierThresholdOptimization"
743
+ ] = {
744
+ "inputs": {"dataset": "train_dataset", "model": "xgb_model"},
745
+ "params": {
746
+ "target_recall": 0.8 # Find a threshold that achieves a recall of 80%
747
+ },
748
+ }
749
+
750
+ # MODEL PERFORMANCE - SCORING TESTS
751
+ default_config["validmind.model_validation.statsmodels.ScorecardHistogram"] = {
752
+ "input_grid": {"dataset": ["train_dataset", "test_dataset"]},
753
+ "params": {"score_column": "xgb_scores"},
754
+ }
755
+
756
+ default_config["validmind.data_validation.ScoreBandDefaultRates"] = {
757
+ "input_grid": {"dataset": ["train_dataset"], "model": ["xgb_model"]},
758
+ "params": {
759
+ "score_column": "xgb_scores",
760
+ "score_bands": [504, 537, 570], # Creates four score bands
761
+ },
762
+ }
763
+
764
+ default_config["validmind.model_validation.sklearn.ScoreProbabilityAlignment"] = {
765
+ "input_grid": {"dataset": ["train_dataset"], "model": ["xgb_model"]},
766
+ "params": {"score_column": "xgb_scores"},
767
+ }
768
+
769
+ # MODEL DIAGNOSIS TESTS
770
+ default_config["validmind.model_validation.sklearn.WeakspotsDiagnosis"] = {
771
+ "inputs": {
772
+ "datasets": ["train_dataset", "test_dataset"],
773
+ "model": "xgb_model",
774
+ },
775
+ }
776
+
777
+ default_config["validmind.model_validation.sklearn.OverfitDiagnosis"] = {
778
+ "inputs": {
779
+ "model": "xgb_model",
780
+ "datasets": ["train_dataset", "test_dataset"],
781
+ },
782
+ "params": {"cut_off_threshold": 0.04},
783
+ }
784
+
785
+ default_config["validmind.model_validation.sklearn.RobustnessDiagnosis"] = {
786
+ "inputs": {
787
+ "datasets": ["train_dataset", "test_dataset"],
788
+ "model": "xgb_model",
789
+ },
790
+ "params": {
791
+ "scaling_factor_std_dev_list": [0.1, 0.2, 0.3, 0.4, 0.5],
792
+ "performance_decay_threshold": 0.05,
793
+ },
794
+ }
795
+
796
+ # EXPLAINABILITY TESTS
797
+ default_config[
798
+ "validmind.model_validation.sklearn.PermutationFeatureImportance"
799
+ ] = {
800
+ "input_grid": {
801
+ "dataset": ["train_dataset", "test_dataset"],
802
+ "model": ["xgb_model"],
803
+ }
804
+ }
805
+
806
+ default_config["validmind.model_validation.FeaturesAUC"] = {
807
+ "input_grid": {
808
+ "model": ["xgb_model"],
809
+ "dataset": ["train_dataset", "test_dataset"],
810
+ },
811
+ }
812
+
813
+ default_config["validmind.model_validation.sklearn.SHAPGlobalImportance"] = {
814
+ "input_grid": {
815
+ "model": ["xgb_model"],
816
+ "dataset": ["train_dataset", "test_dataset"],
817
+ },
818
+ "params": {
819
+ "kernel_explainer_samples": 10,
820
+ "tree_or_linear_explainer_samples": 200,
821
+ },
822
+ }
823
+
824
+ return default_config
@@ -0,0 +1,129 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import plotly.graph_objects as go
6
+ from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
7
+ from validmind import tags, tasks
8
+ from validmind.vm_models import VMDataset
9
+ from validmind.vm_models.result import RawData
10
+
11
+
12
+ @tags("feature_selection", "data_analysis")
13
+ @tasks("classification", "regression")
14
+ def MutualInformation(
15
+ dataset: VMDataset, min_threshold: float = 0.01, task: str = "classification"
16
+ ):
17
+ """
18
+ Calculates mutual information scores between features and target variable to evaluate feature relevance.
19
+
20
+ ### Purpose
21
+
22
+ The Mutual Information test quantifies the predictive power of each feature by measuring its statistical
23
+ dependency with the target variable. This helps identify relevant features for model training and
24
+ detect potential redundant or irrelevant variables, supporting feature selection decisions and model
25
+ interpretability.
26
+
27
+ ### Test Mechanism
28
+
29
+ The test employs sklearn's mutual_info_classif/mutual_info_regression functions to compute mutual
30
+ information between each feature and the target. It produces a normalized score (0 to 1) for each
31
+ feature, where higher scores indicate stronger relationships. Results are presented in both tabular
32
+ format and visualized through a bar plot with a configurable threshold line.
33
+
34
+ ### Signs of High Risk
35
+
36
+ - Many features showing very low mutual information scores
37
+ - Key business features exhibiting unexpectedly low scores
38
+ - All features showing similar, low information content
39
+ - Large discrepancy between business importance and MI scores
40
+ - Highly skewed distribution of MI scores
41
+ - Critical features below the minimum threshold
42
+ - Unexpected zero or near-zero scores for known important features
43
+ - Inconsistent scores across different data samples
44
+
45
+ ### Strengths
46
+
47
+ - Captures non-linear relationships between features and target
48
+ - Scale-invariant measurement of feature relevance
49
+ - Works for both classification and regression tasks
50
+ - Provides interpretable scores (0 to 1 scale)
51
+ - Supports automated feature selection
52
+ - No assumptions about data distribution
53
+ - Handles numerical and categorical features
54
+ - Computationally efficient for most datasets
55
+
56
+ ### Limitations
57
+
58
+ - Requires sufficient data for reliable estimates
59
+ - May be computationally intensive for very large datasets
60
+ - Cannot detect redundant features (pairwise relationships)
61
+ - Sensitive to feature discretization for continuous variables
62
+ - Does not account for feature interactions
63
+ - May underestimate importance of rare but crucial events
64
+ - Cannot handle missing values directly
65
+ - May be affected by extreme class imbalance
66
+ """
67
+ if task not in ["classification", "regression"]:
68
+ raise ValueError("task must be either 'classification' or 'regression'")
69
+
70
+ X = dataset.x
71
+ y = dataset.y
72
+
73
+ # Select appropriate MI function based on task type
74
+ if task == "classification":
75
+ mi_scores = mutual_info_classif(X, y)
76
+ else:
77
+ mi_scores = mutual_info_regression(X, y)
78
+
79
+ # Create DataFrame for raw data
80
+ raw_data = RawData(
81
+ feature=dataset.feature_columns,
82
+ mutual_information_score=mi_scores.tolist(),
83
+ pass_fail=["Pass" if score >= min_threshold else "Fail" for score in mi_scores],
84
+ )
85
+
86
+ # Create Plotly figure
87
+ fig = go.Figure()
88
+
89
+ # Sort data for better visualization
90
+ sorted_indices = sorted(
91
+ range(len(mi_scores)), key=lambda k: mi_scores[k], reverse=True
92
+ )
93
+ sorted_features = [dataset.feature_columns[i] for i in sorted_indices]
94
+ sorted_scores = [mi_scores[i] for i in sorted_indices]
95
+
96
+ # Add bar plot
97
+ fig.add_trace(
98
+ go.Bar(
99
+ x=sorted_features,
100
+ y=sorted_scores,
101
+ marker_color=[
102
+ "blue" if score >= min_threshold else "red" for score in sorted_scores
103
+ ],
104
+ name="Mutual Information Score",
105
+ )
106
+ )
107
+
108
+ # Add threshold line
109
+ fig.add_hline(
110
+ y=min_threshold,
111
+ line_dash="dash",
112
+ line_color="gray",
113
+ annotation_text=f"Threshold ({min_threshold})",
114
+ annotation_position="right",
115
+ )
116
+
117
+ # Update layout
118
+ fig.update_layout(
119
+ title="Mutual Information Scores by Feature",
120
+ xaxis_title="Features",
121
+ yaxis_title="Mutual Information Score",
122
+ xaxis_tickangle=-45,
123
+ showlegend=False,
124
+ width=1000,
125
+ height=600,
126
+ template="plotly_white",
127
+ )
128
+
129
+ return raw_data, fig