validmind 2.1.1__py3-none-any.whl → 2.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. validmind/__version__.py +1 -1
  2. validmind/ai.py +72 -49
  3. validmind/api_client.py +42 -16
  4. validmind/client.py +68 -25
  5. validmind/datasets/llm/rag/__init__.py +11 -0
  6. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_1.csv +30 -0
  7. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_2.csv +30 -0
  8. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_3.csv +53 -0
  9. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_4.csv +53 -0
  10. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_5.csv +53 -0
  11. validmind/datasets/llm/rag/rfp.py +41 -0
  12. validmind/errors.py +1 -1
  13. validmind/html_templates/__init__.py +0 -0
  14. validmind/html_templates/content_blocks.py +89 -14
  15. validmind/models/__init__.py +7 -4
  16. validmind/models/foundation.py +8 -34
  17. validmind/models/function.py +51 -0
  18. validmind/models/huggingface.py +16 -46
  19. validmind/models/metadata.py +42 -0
  20. validmind/models/pipeline.py +66 -0
  21. validmind/models/pytorch.py +8 -42
  22. validmind/models/r_model.py +33 -82
  23. validmind/models/sklearn.py +39 -38
  24. validmind/template.py +8 -26
  25. validmind/tests/__init__.py +43 -20
  26. validmind/tests/data_validation/ANOVAOneWayTable.py +1 -1
  27. validmind/tests/data_validation/ChiSquaredFeaturesTable.py +1 -1
  28. validmind/tests/data_validation/DescriptiveStatistics.py +2 -4
  29. validmind/tests/data_validation/Duplicates.py +1 -1
  30. validmind/tests/data_validation/IsolationForestOutliers.py +2 -2
  31. validmind/tests/data_validation/LaggedCorrelationHeatmap.py +1 -1
  32. validmind/tests/data_validation/TargetRateBarPlots.py +1 -1
  33. validmind/tests/data_validation/nlp/LanguageDetection.py +59 -0
  34. validmind/tests/data_validation/nlp/PolarityAndSubjectivity.py +48 -0
  35. validmind/tests/data_validation/nlp/Punctuations.py +11 -12
  36. validmind/tests/data_validation/nlp/Sentiment.py +57 -0
  37. validmind/tests/data_validation/nlp/Toxicity.py +45 -0
  38. validmind/tests/decorator.py +12 -7
  39. validmind/tests/model_validation/BertScore.py +100 -98
  40. validmind/tests/model_validation/BleuScore.py +93 -64
  41. validmind/tests/model_validation/ContextualRecall.py +74 -91
  42. validmind/tests/model_validation/MeteorScore.py +86 -74
  43. validmind/tests/model_validation/RegardScore.py +103 -121
  44. validmind/tests/model_validation/RougeScore.py +118 -0
  45. validmind/tests/model_validation/TokenDisparity.py +84 -121
  46. validmind/tests/model_validation/ToxicityScore.py +109 -123
  47. validmind/tests/model_validation/embeddings/CosineSimilarityComparison.py +96 -0
  48. validmind/tests/model_validation/embeddings/CosineSimilarityHeatmap.py +71 -0
  49. validmind/tests/model_validation/embeddings/EuclideanDistanceComparison.py +92 -0
  50. validmind/tests/model_validation/embeddings/EuclideanDistanceHeatmap.py +69 -0
  51. validmind/tests/model_validation/embeddings/PCAComponentsPairwisePlots.py +78 -0
  52. validmind/tests/model_validation/embeddings/StabilityAnalysis.py +35 -23
  53. validmind/tests/model_validation/embeddings/StabilityAnalysisKeyword.py +3 -0
  54. validmind/tests/model_validation/embeddings/StabilityAnalysisRandomNoise.py +7 -1
  55. validmind/tests/model_validation/embeddings/StabilityAnalysisSynonyms.py +3 -0
  56. validmind/tests/model_validation/embeddings/StabilityAnalysisTranslation.py +3 -0
  57. validmind/tests/model_validation/embeddings/TSNEComponentsPairwisePlots.py +99 -0
  58. validmind/tests/model_validation/ragas/AnswerCorrectness.py +131 -0
  59. validmind/tests/model_validation/ragas/AnswerRelevance.py +134 -0
  60. validmind/tests/model_validation/ragas/AnswerSimilarity.py +119 -0
  61. validmind/tests/model_validation/ragas/AspectCritique.py +167 -0
  62. validmind/tests/model_validation/ragas/ContextEntityRecall.py +133 -0
  63. validmind/tests/model_validation/ragas/ContextPrecision.py +123 -0
  64. validmind/tests/model_validation/ragas/ContextRecall.py +123 -0
  65. validmind/tests/model_validation/ragas/ContextRelevancy.py +114 -0
  66. validmind/tests/model_validation/ragas/Faithfulness.py +119 -0
  67. validmind/tests/model_validation/ragas/utils.py +66 -0
  68. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -7
  69. validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +8 -9
  70. validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +5 -10
  71. validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +3 -2
  72. validmind/tests/model_validation/sklearn/ROCCurve.py +2 -1
  73. validmind/tests/model_validation/sklearn/RegressionR2Square.py +1 -1
  74. validmind/tests/model_validation/sklearn/RobustnessDiagnosis.py +2 -3
  75. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +7 -11
  76. validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +3 -4
  77. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlot.py +1 -1
  78. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +1 -1
  79. validmind/tests/model_validation/statsmodels/RegressionModelInsampleComparison.py +1 -1
  80. validmind/tests/model_validation/statsmodels/RegressionModelOutsampleComparison.py +1 -1
  81. validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +1 -1
  82. validmind/tests/model_validation/statsmodels/RegressionModelsCoeffs.py +1 -1
  83. validmind/tests/model_validation/statsmodels/RegressionModelsPerformance.py +1 -1
  84. validmind/tests/model_validation/statsmodels/ScorecardHistogram.py +5 -6
  85. validmind/unit_metrics/__init__.py +26 -49
  86. validmind/unit_metrics/composite.py +13 -7
  87. validmind/unit_metrics/regression/sklearn/AdjustedRSquaredScore.py +1 -1
  88. validmind/utils.py +99 -6
  89. validmind/vm_models/__init__.py +1 -1
  90. validmind/vm_models/dataset/__init__.py +7 -0
  91. validmind/vm_models/dataset/dataset.py +560 -0
  92. validmind/vm_models/dataset/utils.py +146 -0
  93. validmind/vm_models/model.py +97 -72
  94. validmind/vm_models/test/metric.py +9 -24
  95. validmind/vm_models/test/result_wrapper.py +124 -28
  96. validmind/vm_models/test/threshold_test.py +10 -28
  97. validmind/vm_models/test_context.py +1 -1
  98. validmind/vm_models/test_suite/summary.py +3 -4
  99. {validmind-2.1.1.dist-info → validmind-2.2.4.dist-info}/METADATA +5 -3
  100. {validmind-2.1.1.dist-info → validmind-2.2.4.dist-info}/RECORD +103 -78
  101. validmind/models/catboost.py +0 -33
  102. validmind/models/statsmodels.py +0 -50
  103. validmind/models/xgboost.py +0 -30
  104. validmind/tests/model_validation/BertScoreAggregate.py +0 -90
  105. validmind/tests/model_validation/RegardHistogram.py +0 -148
  106. validmind/tests/model_validation/RougeMetrics.py +0 -147
  107. validmind/tests/model_validation/RougeMetricsAggregate.py +0 -133
  108. validmind/tests/model_validation/SelfCheckNLIScore.py +0 -112
  109. validmind/tests/model_validation/ToxicityHistogram.py +0 -136
  110. validmind/vm_models/dataset.py +0 -1303
  111. {validmind-2.1.1.dist-info → validmind-2.2.4.dist-info}/LICENSE +0 -0
  112. {validmind-2.1.1.dist-info → validmind-2.2.4.dist-info}/WHEEL +0 -0
  113. {validmind-2.1.1.dist-info → validmind-2.2.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,123 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import warnings
6
+
7
+ import plotly.express as px
8
+ from datasets import Dataset
9
+ from ragas import evaluate
10
+ from ragas.metrics import context_recall
11
+
12
+ from validmind import tags, tasks
13
+
14
+ from .utils import get_renamed_columns
15
+
16
+
17
+ @tags("ragas", "llm", "retrieval_performance")
18
+ @tasks("text_qa", "text_generation", "text_summarization", "text_classification")
19
+ def ContextRecall(
20
+ dataset,
21
+ question_column: str = "question",
22
+ contexts_column: str = "contexts",
23
+ ground_truth_column: str = "ground_truth",
24
+ ):
25
+ """
26
+ Context recall measures the extent to which the retrieved context aligns with the
27
+ annotated answer, treated as the ground truth. It is computed based on the `ground
28
+ truth` and the `retrieved context`, and the values range between 0 and 1, with higher
29
+ values indicating better performance.
30
+
31
+ To estimate context recall from the ground truth answer, each sentence in the ground
32
+ truth answer is analyzed to determine whether it can be attributed to the retrieved
33
+ context or not. In an ideal scenario, all sentences in the ground truth answer
34
+ should be attributable to the retrieved context.
35
+
36
+
37
+ The formula for calculating context recall is as follows:
38
+ $$
39
+ \\text{context recall} = {|\\text{GT sentences that can be attributed to context}| \\over |\\text{Number of sentences in GT}|}
40
+ $$
41
+
42
+ ### Configuring Columns
43
+
44
+ This metric requires the following columns in your dataset:
45
+ - `question` (str): The text query that was input into the model.
46
+ - `contexts` (List[str]): A list of text contexts which are retrieved and which
47
+ will be evaluated to make sure they contain all items in the ground truth.
48
+ - `ground_truth` (str): The ground truth text to compare with the retrieved contexts.
49
+
50
+ If the above data is not in the appropriate column, you can specify different column
51
+ names for these fields using the parameters `question_column`, `contexts_column`
52
+ and `ground_truth_column`.
53
+
54
+ For example, if your dataset has this data stored in different columns, you can
55
+ pass the following parameters:
56
+ ```python
57
+ {
58
+ "question_column": "question",
59
+ "contexts_column": "context_info"
60
+ "ground_truth_column": "my_ground_truth_col",
61
+ }
62
+ ```
63
+
64
+ If the data is stored as a dictionary in another column, specify the column and key
65
+ like this:
66
+ ```python
67
+ pred_col = dataset.prediction_column(model)
68
+ params = {
69
+ "contexts_column": f"{pred_col}.contexts",
70
+ "ground_truth_column": "my_ground_truth_col",
71
+ }
72
+ ```
73
+
74
+ For more complex situations, you can use a function to extract the data:
75
+ ```python
76
+ pred_col = dataset.prediction_column(model)
77
+ params = {
78
+ "contexts_column": lambda x: [x[pred_col]["context_message"]],
79
+ "ground_truth_column": "my_ground_truth_col",
80
+ }
81
+ ```
82
+ """
83
+ warnings.filterwarnings(
84
+ "ignore",
85
+ category=FutureWarning,
86
+ message="promote has been superseded by promote_options='default'.",
87
+ )
88
+
89
+ required_columns = {
90
+ "question": question_column,
91
+ "contexts": contexts_column,
92
+ "ground_truth": ground_truth_column,
93
+ }
94
+
95
+ df = get_renamed_columns(dataset.df, required_columns)
96
+
97
+ result_df = evaluate(
98
+ Dataset.from_pandas(df),
99
+ metrics=[context_recall],
100
+ ).to_pandas()
101
+
102
+ fig_histogram = px.histogram(x=result_df["context_recall"].to_list(), nbins=10)
103
+ fig_box = px.box(x=result_df["context_recall"].to_list())
104
+
105
+ return (
106
+ {
107
+ "Scores": result_df[
108
+ ["question", "contexts", "ground_truth", "context_recall"]
109
+ ],
110
+ "Aggregate Scores": [
111
+ {
112
+ "Mean Score": result_df["context_recall"].mean(),
113
+ "Median Score": result_df["context_recall"].median(),
114
+ "Max Score": result_df["context_recall"].max(),
115
+ "Min Score": result_df["context_recall"].min(),
116
+ "Standard Deviation": result_df["context_recall"].std(),
117
+ "Count": len(result_df),
118
+ }
119
+ ],
120
+ },
121
+ fig_histogram,
122
+ fig_box,
123
+ )
@@ -0,0 +1,114 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import warnings
6
+
7
+ import plotly.express as px
8
+ from datasets import Dataset
9
+ from ragas import evaluate
10
+ from ragas.metrics import context_relevancy
11
+
12
+ from validmind import tags, tasks
13
+
14
+ from .utils import get_renamed_columns
15
+
16
+
17
+ @tags("ragas", "llm", "retrieval_performance")
18
+ @tasks("text_qa", "text_generation", "text_summarization", "text_classification")
19
+ def ContextRelevancy(
20
+ dataset,
21
+ question_column: str = "question",
22
+ contexts_column: str = "contexts",
23
+ ):
24
+ """
25
+ Evaluates the context relevancy metric for entries in a dataset and visualizes the
26
+ results.
27
+
28
+ This metric gauges the relevancy of the retrieved context, calculated based on both
29
+ the `question` and `contexts`. The values fall within the range of (0, 1), with
30
+ higher values indicating better relevancy.
31
+
32
+ Ideally, the retrieved context should exclusively contain essential information to
33
+ address the provided query. To compute this, we initially estimate the value of by
34
+ identifying sentences within the retrieved context that are relevant for answering
35
+ the given question. The final score is determined by the following formula:
36
+
37
+ $$
38
+ \\text{context relevancy} = {|S| \\over |\\text{Total number of sentences in retrieved context}|}
39
+ $$
40
+
41
+ ### Configuring Columns
42
+
43
+ This metric requires the following columns in your dataset:
44
+ - `question` (str): The text query that was input into the model.
45
+ - `contexts` (List[str]): A list of text contexts which are retrieved and which
46
+ will be evaluated to make sure they are relevant to the question.
47
+
48
+ If the above data is not in the appropriate column, you can specify different column
49
+ names for these fields using the parameters `question_column` and `contexts_column`.
50
+
51
+ For example, if your dataset has this data stored in different columns, you can
52
+ pass the following parameters:
53
+ ```python
54
+ {
55
+ "question_column": "question",
56
+ "contexts_column": "context_info"
57
+ }
58
+ ```
59
+
60
+ If the data is stored as a dictionary in another column, specify the column and key
61
+ like this:
62
+ ```python
63
+ pred_col = dataset.prediction_column(model)
64
+ params = {
65
+ "contexts_column": f"{pred_col}.contexts",
66
+ }
67
+ ```
68
+
69
+ For more complex situations, you can use a function to extract the data:
70
+ ```python
71
+ pred_col = dataset.prediction_column(model)
72
+ params = {
73
+ "contexts_column": lambda x: [x[pred_col]["context_message"]],
74
+ }
75
+ ```
76
+ """
77
+ warnings.filterwarnings(
78
+ "ignore",
79
+ category=FutureWarning,
80
+ message="promote has been superseded by promote_options='default'.",
81
+ )
82
+
83
+ required_columns = {
84
+ "question": question_column,
85
+ "contexts": contexts_column,
86
+ }
87
+
88
+ df = get_renamed_columns(dataset.df, required_columns)
89
+
90
+ result_df = evaluate(
91
+ Dataset.from_pandas(df),
92
+ metrics=[context_relevancy],
93
+ ).to_pandas()
94
+
95
+ fig_histogram = px.histogram(x=result_df["context_relevancy"].to_list(), nbins=10)
96
+ fig_box = px.box(x=result_df["context_relevancy"].to_list())
97
+
98
+ return (
99
+ {
100
+ "Scores": result_df[["question", "contexts", "context_relevancy"]],
101
+ "Aggregate Scores": [
102
+ {
103
+ "Mean Score": result_df["context_relevancy"].mean(),
104
+ "Median Score": result_df["context_relevancy"].median(),
105
+ "Max Score": result_df["context_relevancy"].max(),
106
+ "Min Score": result_df["context_relevancy"].min(),
107
+ "Standard Deviation": result_df["context_relevancy"].std(),
108
+ "Count": len(result_df),
109
+ }
110
+ ],
111
+ },
112
+ fig_histogram,
113
+ fig_box,
114
+ )
@@ -0,0 +1,119 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import warnings
6
+
7
+ import plotly.express as px
8
+ from datasets import Dataset
9
+ from ragas import evaluate
10
+ from ragas.metrics import faithfulness
11
+
12
+ from validmind import tags, tasks
13
+
14
+ from .utils import get_renamed_columns
15
+
16
+
17
+ @tags("ragas", "llm", "rag_performance")
18
+ @tasks("text_qa", "text_generation", "text_summarization")
19
+ def Faithfulness(
20
+ dataset,
21
+ answer_column="answer",
22
+ contexts_column="contexts",
23
+ ):
24
+ """
25
+ Evaluates the faithfulness of the generated answers with respect to retrieved contexts.
26
+
27
+ This metric uses a judge LLM to measure the factual consistency of the generated answer
28
+ against the given context(s). It is calculated using the generated text `answer` from
29
+ the LLM and the retrieved `contexts` which come from some RAG process. The score is
30
+ a value between 0 and 1, where a higher score indicates that the generated answer is
31
+ more faithful to the given context(s).
32
+
33
+ The generated answer is regarded as faithful if all the claims that are made in the
34
+ answer can be inferred from the given context. To calculate this a set of claims from
35
+ the generated answer is first identified. Then each one of these claims are cross checked
36
+ with given context to determine if it can be inferred from given context or not. The
37
+ faithfulness score formula is as follows:
38
+
39
+ $$
40
+ \\text{Faithfulness score} = {|\\text{Number of claims in the generated answer that can be inferred from given context}| \\over |\\text{Total number of claims in the generated answer}|}
41
+ $$
42
+
43
+ ### Configuring Columns
44
+
45
+ This metric requires the following columns in your dataset:
46
+ - `contexts` (List[str]): A list of text contexts which are retrieved to generate
47
+ the answer.
48
+ - `answer` (str): The response generated by the model which will be evaluated for
49
+ faithfulness against the given contexts.
50
+
51
+ If the above data is not in the appropriate column, you can specify different column
52
+ names for these fields using the parameters `contexts_column` and `answer_column`.
53
+
54
+ For example, if your dataset has this data stored in different columns, you can
55
+ pass the following parameters:
56
+ ```python
57
+ {
58
+ "contexts_column": "context_info"
59
+ "answer_column": "my_answer_col",
60
+ }
61
+ ```
62
+
63
+ If the data is stored as a dictionary in another column, specify the column and key
64
+ like this:
65
+ ```python
66
+ pred_col = dataset.prediction_column(model)
67
+ params = {
68
+ "contexts_column": f"{pred_col}.contexts",
69
+ "answer_column": f"{pred_col}.answer",
70
+ }
71
+ ```
72
+
73
+ For more complex situations, you can use a function to extract the data:
74
+ ```python
75
+ pred_col = dataset.prediction_column(model)
76
+ params = {
77
+ "contexts_column": lambda row: [row[pred_col]["context_message"]],
78
+ "answer_column": lambda row: "\\n\\n".join(row[pred_col]["messages"]),
79
+ }
80
+ ```
81
+ """
82
+ warnings.filterwarnings(
83
+ "ignore",
84
+ category=FutureWarning,
85
+ message="promote has been superseded by promote_options='default'.",
86
+ )
87
+
88
+ required_columns = {
89
+ "answer": answer_column,
90
+ "contexts": contexts_column,
91
+ }
92
+
93
+ df = get_renamed_columns(dataset.df, required_columns)
94
+
95
+ result_df = evaluate(
96
+ Dataset.from_pandas(df),
97
+ metrics=[faithfulness],
98
+ ).to_pandas()
99
+
100
+ fig_histogram = px.histogram(x=result_df["faithfulness"].to_list(), nbins=10)
101
+ fig_box = px.box(x=result_df["faithfulness"].to_list())
102
+
103
+ return (
104
+ {
105
+ "Scores": result_df[["contexts", "answer", "faithfulness"]],
106
+ "Aggregate Scores": [
107
+ {
108
+ "Mean Score": result_df["faithfulness"].mean(),
109
+ "Median Score": result_df["faithfulness"].median(),
110
+ "Max Score": result_df["faithfulness"].max(),
111
+ "Min Score": result_df["faithfulness"].min(),
112
+ "Standard Deviation": result_df["faithfulness"].std(),
113
+ "Count": len(result_df),
114
+ }
115
+ ],
116
+ },
117
+ fig_histogram,
118
+ fig_box,
119
+ )
@@ -0,0 +1,66 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+
6
+ def _udf_get_sub_col(x, root_col, sub_col):
7
+ if not isinstance(x, dict):
8
+ raise TypeError(f"Expected a dictionary in column '{root_col}', got {type(x)}.")
9
+
10
+ if sub_col not in x:
11
+ raise KeyError(
12
+ f"Sub-column '{sub_col}' not found in dictionary in column '{root_col}'."
13
+ )
14
+
15
+ return x[sub_col]
16
+
17
+
18
+ def get_renamed_columns(df, column_map):
19
+ """Get a new df with columns renamed according to the column_map
20
+
21
+ Supports sub-column notation for getting values out of dictionaries that may be
22
+ stored in a column. Also supports
23
+
24
+ Args:
25
+ df (pd.DataFrame): The DataFrame to rename columns in.
26
+ column_map (dict): A dictionary mapping where the keys are the new column names
27
+ that ragas expects and the values are one of the following:
28
+ - The column name in the input dataframe
29
+ - A string in the format "root_col.sub_col" to get a sub-column from a dictionary
30
+ stored in a column.
31
+ - A function that takes the value of the column and returns the value to be
32
+ stored in the new column.
33
+
34
+ Returns:
35
+ pd.DataFrame: The DataFrame with columns renamed.
36
+ """
37
+ new_df = df.copy()
38
+
39
+ for new_name, source in column_map.items():
40
+ if callable(source):
41
+ try:
42
+ new_df[new_name] = new_df.apply(source, axis=1)
43
+ except Exception as e:
44
+ raise ValueError(
45
+ f"Failed to apply function to DataFrame. Error: {str(e)}"
46
+ )
47
+
48
+ elif "." in source:
49
+ root_col, sub_col = source.split(".")
50
+
51
+ if root_col in new_df.columns:
52
+ new_df[new_name] = new_df[root_col].apply(
53
+ lambda x: _udf_get_sub_col(x, root_col, sub_col)
54
+ )
55
+
56
+ else:
57
+ raise KeyError(f"Column '{root_col}' not found in DataFrame.")
58
+
59
+ else:
60
+ if source in new_df.columns:
61
+ new_df[new_name] = new_df[source]
62
+
63
+ else:
64
+ raise KeyError(f"Column '{source}' not found in DataFrame.")
65
+
66
+ return new_df
@@ -90,7 +90,7 @@ class OverfitDiagnosis(ThresholdTest):
90
90
  raise ValueError("features_columns must be provided in params")
91
91
 
92
92
  if self.params["features_columns"] is None:
93
- features_list = self.inputs.datasets[0].get_features_columns()
93
+ features_list = self.inputs.datasets[0].feature_columns
94
94
  else:
95
95
  features_list = self.params["features_columns"]
96
96
 
@@ -101,8 +101,7 @@ class OverfitDiagnosis(ThresholdTest):
101
101
 
102
102
  # Check if all elements from features_list are present in the feature columns
103
103
  all_present = all(
104
- elem in self.inputs.datasets[0].get_features_columns()
105
- for elem in features_list
104
+ elem in self.inputs.datasets[0].feature_columns for elem in features_list
106
105
  )
107
106
  if not all_present:
108
107
  raise ValueError(
@@ -134,10 +133,7 @@ class OverfitDiagnosis(ThresholdTest):
134
133
 
135
134
  for feature_column in features_list:
136
135
  bins = 10
137
- if (
138
- feature_column
139
- in self.inputs.datasets[0].get_categorical_features_columns()
140
- ):
136
+ if feature_column in self.inputs.datasets[0].feature_columns_categorical:
141
137
  bins = len(train_df[feature_column].unique())
142
138
  train_df["bin"] = pd.cut(train_df[feature_column], bins=bins)
143
139
 
@@ -71,15 +71,14 @@ class PermutationFeatureImportance(Metric):
71
71
  x = self.inputs.dataset.x_df()
72
72
  y = self.inputs.dataset.y_df()
73
73
 
74
- model_library = self.inputs.model.model_library()
75
- if (
76
- model_library == "statsmodels"
77
- or model_library == "pytorch"
78
- or model_library == "catboost"
79
- or model_library == "transformers"
80
- or model_library == "R"
81
- ):
82
- raise SkipTestError(f"Skipping PFI for {model_library} models")
74
+ if self.inputs.model.library in [
75
+ "statsmodels",
76
+ "pytorch",
77
+ "catboost",
78
+ "transformers",
79
+ "R",
80
+ ]:
81
+ raise SkipTestError(f"Skipping PFI for {self.inputs.model.library} models")
83
82
 
84
83
  pfi_values = permutation_importance(
85
84
  self.inputs.model.model,
@@ -92,9 +92,9 @@ class PopulationStabilityIndex(Metric):
92
92
  # The data looks like this: [{"initial": 2652, "percent_initial": 0.5525, "new": 830, "percent_new": 0.5188, "psi": 0.0021},...
93
93
  psi_table = [
94
94
  {
95
- "Bin": i
96
- if i < (len(metric_value) - 1)
97
- else "Total", # The last bin is the "Total" bin
95
+ "Bin": (
96
+ i if i < (len(metric_value) - 1) else "Total"
97
+ ), # The last bin is the "Total" bin
98
98
  "Count Initial": values["initial"],
99
99
  "Percent Initial (%)": values["percent_initial"] * 100,
100
100
  "Count New": values["new"],
@@ -180,13 +180,8 @@ class PopulationStabilityIndex(Metric):
180
180
  return psi_df.to_dict(orient="records")
181
181
 
182
182
  def run(self):
183
- model_library = self.inputs.model.model_library()
184
- if (
185
- model_library == "statsmodels"
186
- or model_library == "pytorch"
187
- or model_library == "catboost"
188
- ):
189
- logger.info(f"Skiping PSI for {model_library} models")
183
+ if self.inputs.model.library in ["statsmodels", "pytorch", "catboost"]:
184
+ logger.info(f"Skiping PSI for {self.inputs.model.library} models")
190
185
  return
191
186
 
192
187
  num_bins = self.params["num_bins"]
@@ -9,6 +9,7 @@ import plotly.graph_objects as go
9
9
  from sklearn.metrics import precision_recall_curve
10
10
 
11
11
  from validmind.errors import SkipTestError
12
+ from validmind.models import FoundationModel
12
13
  from validmind.vm_models import Figure, Metric
13
14
 
14
15
 
@@ -42,7 +43,7 @@ class PrecisionRecallCurve(Metric):
42
43
  different threshold levels.
43
44
 
44
45
  **Limitations**:
45
- * This metric is only applicable to binary classification models it raises errors for multiclass classification
46
+ * This metric is only applicable to binary classification models - it raises errors for multiclass classification
46
47
  models or Foundation models.
47
48
  * It may not fully represent the overall accuracy of the model if the cost of false positives and false negatives
48
49
  are extremely different, or if the dataset is heavily imbalanced.
@@ -62,7 +63,7 @@ class PrecisionRecallCurve(Metric):
62
63
  }
63
64
 
64
65
  def run(self):
65
- if self.inputs.model.model_library() == "FoundationModel":
66
+ if isinstance(self.inputs.model, FoundationModel):
66
67
  raise SkipTestError("Skipping PrecisionRecallCurve for Foundation models")
67
68
 
68
69
  y_true = self.inputs.dataset.y
@@ -9,6 +9,7 @@ import plotly.graph_objects as go
9
9
  from sklearn.metrics import roc_auc_score, roc_curve
10
10
 
11
11
  from validmind.errors import SkipTestError
12
+ from validmind.models import FoundationModel
12
13
  from validmind.vm_models import Figure, Metric
13
14
 
14
15
 
@@ -70,7 +71,7 @@ class ROCCurve(Metric):
70
71
  }
71
72
 
72
73
  def run(self):
73
- if self.inputs.model.model_library() == "FoundationModel":
74
+ if isinstance(self.inputs.model, FoundationModel):
74
75
  raise SkipTestError("Skipping ROCCurve for Foundation models")
75
76
 
76
77
  y_true = self.inputs.dataset.y
@@ -90,7 +90,7 @@ class RegressionR2Square(Metric):
90
90
  }
91
91
  )
92
92
 
93
- X_columns = self.inputs.datasets[0].get_features_columns()
93
+ X_columns = self.inputs.datasets[0].feature_columns
94
94
  adj_r2_train = adj_r2_score(
95
95
  y_train_true, y_train_pred, len(y_train_true), len(X_columns)
96
96
  )
@@ -109,12 +109,11 @@ class RobustnessDiagnosis(ThresholdTest):
109
109
 
110
110
  features_list = self.params["features_columns"]
111
111
  if features_list is None:
112
- features_list = self.inputs.datasets[0].get_numeric_features_columns()
112
+ features_list = self.inputs.datasets[0].feature_columns
113
113
 
114
114
  # Check if all elements from features_list are present in the numerical feature columns
115
115
  all_present = all(
116
- elem in self.inputs.datasets[0].get_numeric_features_columns()
117
- for elem in features_list
116
+ elem in self.inputs.datasets[0].feature_columns for elem in features_list
118
117
  )
119
118
  if not all_present:
120
119
  raise ValueError(
@@ -11,6 +11,7 @@ import shap
11
11
 
12
12
  from validmind.errors import UnsupportedModelForSHAPError
13
13
  from validmind.logging import get_logger
14
+ from validmind.models import CatBoostModel, SKlearnModel, StatsModelsModel
14
15
  from validmind.vm_models import Figure, Metric
15
16
 
16
17
  logger = get_logger(__name__)
@@ -131,20 +132,14 @@ class SHAPGlobalImportance(Metric):
131
132
  )
132
133
 
133
134
  def run(self):
134
- model_library = self.inputs.model.model_library()
135
- if model_library in [
136
- "statsmodels",
137
- "pytorch",
138
- "catboost",
139
- "transformers",
140
- "FoundationModel",
141
- "R",
142
- ]:
143
- logger.info(f"Skiping SHAP for {model_library} models")
135
+ if not isinstance(self.inputs.model, SKlearnModel) or isinstance(
136
+ self.inputs.model, (CatBoostModel, StatsModelsModel)
137
+ ):
138
+ logger.info(f"Skiping SHAP for {self.inputs.model.library} models")
144
139
  return
145
140
 
146
141
  trained_model = self.inputs.model.model
147
- model_class = self.inputs.model.model_class()
142
+ model_class = self.inputs.model.class_
148
143
 
149
144
  # the shap library generates a bunch of annoying warnings that we don't care about
150
145
  warnings.filterwarnings("ignore", category=UserWarning)
@@ -176,6 +171,7 @@ class SHAPGlobalImportance(Metric):
176
171
  ),
177
172
  )
178
173
  else:
174
+ model_class = "<ExternalModel>" if model_class is None else model_class
179
175
  raise UnsupportedModelForSHAPError(
180
176
  f"Model {model_class} not supported for SHAP importance."
181
177
  )
@@ -113,7 +113,7 @@ class WeakspotsDiagnosis(ThresholdTest):
113
113
  raise ValueError(f"Threshold for metric {metric} is missing")
114
114
 
115
115
  if self.params["features_columns"] is None:
116
- features_list = self.inputs.datasets[0].get_features_columns()
116
+ features_list = self.inputs.datasets[0].feature_columns
117
117
  else:
118
118
  features_list = self.params["features_columns"]
119
119
 
@@ -124,8 +124,7 @@ class WeakspotsDiagnosis(ThresholdTest):
124
124
 
125
125
  # Check if all elements from features_list are present in the feature columns
126
126
  all_present = all(
127
- elem in self.inputs.datasets[0].get_features_columns()
128
- for elem in features_list
127
+ elem in self.inputs.datasets[0].feature_columns for elem in features_list
129
128
  )
130
129
  if not all_present:
131
130
  raise ValueError(
@@ -150,7 +149,7 @@ class WeakspotsDiagnosis(ThresholdTest):
150
149
  results_headers.extend(self.default_metrics.keys())
151
150
  for feature in features_list:
152
151
  bins = 10
153
- if feature in self.inputs.datasets[0].get_categorical_features_columns():
152
+ if feature in self.inputs.datasets[0].feature_columns_categorical:
154
153
  bins = len(train_df[feature].unique())
155
154
  train_df["bin"] = pd.cut(train_df[feature], bins=bins)
156
155
 
@@ -89,7 +89,7 @@ class RegressionModelForecastPlot(Metric):
89
89
  figures = []
90
90
 
91
91
  for i, fitted_model in enumerate(model_list):
92
- feature_columns = datasets[0].get_features_columns()
92
+ feature_columns = datasets[0].feature_columns
93
93
 
94
94
  train_ds = datasets[0]
95
95
  test_ds = datasets[1]
@@ -98,7 +98,7 @@ class RegressionModelForecastPlotLevels(Metric):
98
98
  figures = []
99
99
 
100
100
  for i, fitted_model in enumerate(model_list):
101
- feature_columns = datasets[0].get_features_columns()
101
+ feature_columns = datasets[0].feature_columns
102
102
 
103
103
  train_ds = datasets[0]
104
104
  test_ds = datasets[1]