validmind 2.6.10__py3-none-any.whl → 2.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. validmind/__init__.py +2 -0
  2. validmind/__version__.py +1 -1
  3. validmind/ai/test_descriptions.py +20 -4
  4. validmind/ai/test_result_description/user.jinja +5 -0
  5. validmind/datasets/credit_risk/lending_club.py +444 -14
  6. validmind/tests/data_validation/MutualInformation.py +129 -0
  7. validmind/tests/data_validation/ScoreBandDefaultRates.py +139 -0
  8. validmind/tests/data_validation/TooManyZeroValues.py +6 -5
  9. validmind/tests/data_validation/UniqueRows.py +3 -1
  10. validmind/tests/decorator.py +18 -16
  11. validmind/tests/model_validation/sklearn/CalibrationCurve.py +116 -0
  12. validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +261 -0
  13. validmind/tests/model_validation/sklearn/ConfusionMatrix.py +1 -0
  14. validmind/tests/model_validation/sklearn/HyperParametersTuning.py +144 -56
  15. validmind/tests/model_validation/sklearn/ModelParameters.py +74 -0
  16. validmind/tests/model_validation/sklearn/ROCCurve.py +26 -23
  17. validmind/tests/model_validation/sklearn/ScoreProbabilityAlignment.py +130 -0
  18. validmind/tests/model_validation/statsmodels/CumulativePredictionProbabilities.py +5 -6
  19. validmind/tests/model_validation/statsmodels/PredictionProbabilitiesHistogram.py +2 -3
  20. validmind/tests/output.py +10 -1
  21. validmind/tests/run.py +52 -54
  22. validmind/utils.py +34 -7
  23. validmind/vm_models/figure.py +15 -0
  24. validmind/vm_models/result/__init__.py +2 -2
  25. validmind/vm_models/result/result.py +136 -23
  26. {validmind-2.6.10.dist-info → validmind-2.7.4.dist-info}/METADATA +1 -1
  27. {validmind-2.6.10.dist-info → validmind-2.7.4.dist-info}/RECORD +30 -24
  28. {validmind-2.6.10.dist-info → validmind-2.7.4.dist-info}/LICENSE +0 -0
  29. {validmind-2.6.10.dist-info → validmind-2.7.4.dist-info}/WHEEL +0 -0
  30. {validmind-2.6.10.dist-info → validmind-2.7.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,261 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import plotly.graph_objects as go
8
+ from plotly.subplots import make_subplots
9
+ from sklearn.metrics import (
10
+ roc_curve,
11
+ precision_recall_curve,
12
+ confusion_matrix,
13
+ )
14
+ from validmind import tags, tasks
15
+ from validmind.vm_models import VMDataset, VMModel
16
+
17
+
18
+ def find_optimal_threshold(y_true, y_prob, method="youden", target_recall=None):
19
+ """
20
+ Find the optimal classification threshold using various methods.
21
+
22
+ Args:
23
+ y_true: True binary labels
24
+ y_prob: Predicted probabilities
25
+ method: Method to use for finding optimal threshold
26
+ target_recall: Required if method='target_recall'
27
+
28
+ Returns:
29
+ dict: Dictionary containing threshold and metrics
30
+ """
31
+ # Get ROC and PR curve points
32
+ fpr, tpr, thresholds_roc = roc_curve(y_true, y_prob)
33
+ precision, recall, thresholds_pr = precision_recall_curve(y_true, y_prob)
34
+
35
+ # Find optimal threshold based on method
36
+ if method == "naive":
37
+ optimal_threshold = 0.5
38
+ elif method == "youden":
39
+ j_scores = tpr - fpr
40
+ best_idx = np.argmax(j_scores)
41
+ optimal_threshold = thresholds_roc[best_idx]
42
+ elif method == "f1":
43
+ f1_scores = 2 * (precision * recall) / (precision + recall)
44
+ best_idx = np.argmax(f1_scores)
45
+ optimal_threshold = (
46
+ thresholds_pr[best_idx] if best_idx < len(thresholds_pr) else 1.0
47
+ )
48
+ elif method == "precision_recall":
49
+ diff = abs(precision - recall)
50
+ best_idx = np.argmin(diff)
51
+ optimal_threshold = (
52
+ thresholds_pr[best_idx] if best_idx < len(thresholds_pr) else 1.0
53
+ )
54
+ elif method == "target_recall":
55
+ if target_recall is None:
56
+ raise ValueError(
57
+ "target_recall must be specified when method='target_recall'"
58
+ )
59
+ idx = np.argmin(abs(recall - target_recall))
60
+ optimal_threshold = thresholds_pr[idx] if idx < len(thresholds_pr) else 1.0
61
+ else:
62
+ raise ValueError(f"Unknown method: {method}")
63
+
64
+ # Calculate predictions with optimal threshold
65
+ y_pred = (y_prob >= optimal_threshold).astype(int)
66
+
67
+ # Calculate confusion matrix
68
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
69
+
70
+ # Calculate metrics directly
71
+ metrics = {
72
+ "method": method,
73
+ "threshold": optimal_threshold,
74
+ "precision": tp / (tp + fp) if (tp + fp) > 0 else 0,
75
+ "recall": tp / (tp + fn) if (tp + fn) > 0 else 0,
76
+ "f1_score": 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0,
77
+ "accuracy": (tp + tn) / (tp + tn + fp + fn),
78
+ }
79
+
80
+ return metrics
81
+
82
+
83
+ @tags("model_validation", "threshold_optimization", "classification_metrics")
84
+ @tasks("classification")
85
+ def ClassifierThresholdOptimization(
86
+ dataset: VMDataset, model: VMModel, methods=None, target_recall=None
87
+ ):
88
+ """
89
+ Analyzes and visualizes different threshold optimization methods for binary classification models.
90
+
91
+ ### Purpose
92
+
93
+ The Classifier Threshold Optimization test identifies optimal decision thresholds using various
94
+ methods to balance different performance metrics. This helps adapt the model's decision boundary
95
+ to specific business requirements, such as minimizing false positives in fraud detection or
96
+ achieving target recall in medical diagnosis.
97
+
98
+ ### Test Mechanism
99
+
100
+ The test implements multiple threshold optimization methods:
101
+ 1. Youden's J statistic (maximizing sensitivity + specificity - 1)
102
+ 2. F1-score optimization (balancing precision and recall)
103
+ 3. Precision-Recall equality point
104
+ 4. Target recall achievement
105
+ 5. Naive (0.5) threshold
106
+ For each method, it computes ROC and PR curves, identifies optimal points, and provides
107
+ comprehensive performance metrics at each threshold.
108
+
109
+ ### Signs of High Risk
110
+
111
+ - Large discrepancies between different optimization methods
112
+ - Optimal thresholds far from the default 0.5
113
+ - Poor performance metrics across all thresholds
114
+ - Significant gap between achieved and target recall
115
+ - Unstable thresholds across different methods
116
+ - Extreme trade-offs between precision and recall
117
+ - Threshold optimization showing minimal impact
118
+ - Business metrics not improving with optimization
119
+
120
+ ### Strengths
121
+
122
+ - Multiple optimization strategies for different needs
123
+ - Visual and numerical results for comparison
124
+ - Support for business-driven optimization (target recall)
125
+ - Comprehensive performance metrics at each threshold
126
+ - Integration with ROC and PR curves
127
+ - Handles class imbalance through various metrics
128
+ - Enables informed threshold selection
129
+ - Supports cost-sensitive decision making
130
+
131
+ ### Limitations
132
+
133
+ - Assumes cost of false positives/negatives are known
134
+ - May need adjustment for highly imbalanced datasets
135
+ - Threshold might not be stable across different samples
136
+ - Cannot handle multi-class problems directly
137
+ - Optimization methods may conflict with business needs
138
+ - Requires sufficient validation data
139
+ - May not capture temporal changes in optimal threshold
140
+ - Single threshold may not be optimal for all subgroups
141
+
142
+ Args:
143
+ dataset: VMDataset containing features and target
144
+ model: VMModel containing predictions
145
+ methods: List of methods to compare (default: ['youden', 'f1', 'precision_recall'])
146
+ target_recall: Target recall value if using 'target_recall' method
147
+
148
+ Returns:
149
+ Dictionary containing:
150
+ - table: DataFrame comparing different threshold optimization methods
151
+ (using weighted averages for precision, recall, and f1)
152
+ - figure: Plotly figure showing ROC and PR curves with optimal thresholds
153
+ """
154
+ # Verify binary classification
155
+ unique_values = np.unique(dataset.y)
156
+ if len(unique_values) != 2:
157
+ raise ValueError("Target variable must be binary")
158
+
159
+ if methods is None:
160
+ methods = ["naive", "youden", "f1", "precision_recall"]
161
+ if target_recall is not None:
162
+ methods.append("target_recall")
163
+
164
+ y_true = dataset.y
165
+ y_prob = dataset.y_prob(model)
166
+
167
+ # Get curve points for plotting
168
+ fpr, tpr, thresholds_roc = roc_curve(y_true, y_prob)
169
+ precision, recall, thresholds_pr = precision_recall_curve(y_true, y_prob)
170
+
171
+ # Calculate optimal thresholds and metrics
172
+ results = []
173
+ optimal_points = {}
174
+
175
+ for method in methods:
176
+ metrics = find_optimal_threshold(y_true, y_prob, method, target_recall)
177
+ results.append(metrics)
178
+
179
+ # Store optimal points for plotting
180
+ if method == "youden":
181
+ idx = np.argmax(tpr - fpr)
182
+ optimal_points[method] = {
183
+ "x": fpr[idx],
184
+ "y": tpr[idx],
185
+ "threshold": thresholds_roc[idx],
186
+ }
187
+ elif method in ["f1", "precision_recall", "target_recall"]:
188
+ idx = np.argmin(abs(thresholds_pr - metrics["threshold"]))
189
+ optimal_points[method] = {
190
+ "x": recall[idx],
191
+ "y": precision[idx],
192
+ "threshold": metrics["threshold"],
193
+ }
194
+
195
+ # Create visualization
196
+ fig = make_subplots(
197
+ rows=1, cols=2, subplot_titles=("ROC Curve", "Precision-Recall Curve")
198
+ )
199
+
200
+ # Plot ROC curve
201
+ fig.add_trace(
202
+ go.Scatter(x=fpr, y=tpr, name="ROC Curve", line=dict(color="blue")),
203
+ row=1,
204
+ col=1,
205
+ )
206
+
207
+ # Plot PR curve
208
+ fig.add_trace(
209
+ go.Scatter(x=recall, y=precision, name="PR Curve", line=dict(color="green")),
210
+ row=1,
211
+ col=2,
212
+ )
213
+
214
+ # Add optimal points
215
+ colors = {
216
+ "youden": "red",
217
+ "f1": "orange",
218
+ "precision_recall": "purple",
219
+ "target_recall": "brown",
220
+ }
221
+
222
+ for method, points in optimal_points.items():
223
+ if method == "youden":
224
+ fig.add_trace(
225
+ go.Scatter(
226
+ x=[points["x"]],
227
+ y=[points["y"]],
228
+ name=f'{method} (t={points["threshold"]:.2f})',
229
+ mode="markers",
230
+ marker=dict(size=10, color=colors[method]),
231
+ ),
232
+ row=1,
233
+ col=1,
234
+ )
235
+ else:
236
+ fig.add_trace(
237
+ go.Scatter(
238
+ x=[points["x"]],
239
+ y=[points["y"]],
240
+ name=f'{method} (t={points["threshold"]:.2f})',
241
+ mode="markers",
242
+ marker=dict(size=10, color=colors[method]),
243
+ ),
244
+ row=1,
245
+ col=2,
246
+ )
247
+
248
+ # Update layout
249
+ fig.update_layout(
250
+ height=500, title_text="Threshold Optimization Analysis", showlegend=True
251
+ )
252
+
253
+ fig.update_xaxes(title_text="False Positive Rate", row=1, col=1)
254
+ fig.update_xaxes(title_text="Recall", row=1, col=2)
255
+ fig.update_yaxes(title_text="True Positive Rate", row=1, col=1)
256
+ fig.update_yaxes(title_text="Precision", row=1, col=2)
257
+
258
+ # Create results table and sort by threshold descending
259
+ table = pd.DataFrame(results).sort_values("threshold", ascending=False)
260
+
261
+ return fig, table
@@ -106,6 +106,7 @@ def ConfusionMatrix(dataset: VMDataset, model: VMModel):
106
106
  autosize=False,
107
107
  width=600,
108
108
  height=600,
109
+ title_text="Confusion Matrix",
109
110
  )
110
111
 
111
112
  fig.add_annotation(
@@ -2,73 +2,161 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
- from typing import Union
6
-
5
+ from typing import Union, Dict, List
7
6
  from sklearn.model_selection import GridSearchCV
7
+ from sklearn.metrics import make_scorer, recall_score
8
8
 
9
9
  from validmind import tags, tasks
10
- from validmind.errors import SkipTestError
11
10
  from validmind.vm_models import VMDataset, VMModel
12
11
 
13
12
 
14
13
  @tags("sklearn", "model_performance")
15
14
  @tasks("classification", "clustering")
15
+ def custom_recall(y_true, y_pred_proba, threshold=0.5):
16
+ y_pred = (y_pred_proba >= threshold).astype(int)
17
+ return recall_score(y_true, y_pred)
18
+
19
+
20
+ def _get_metrics(scoring):
21
+ """Convert scoring parameter to list of metrics."""
22
+ if scoring is None:
23
+ return ["accuracy"]
24
+ return (
25
+ scoring
26
+ if isinstance(scoring, list)
27
+ else list(scoring.keys()) if isinstance(scoring, dict) else [scoring]
28
+ )
29
+
30
+
31
+ def _get_thresholds(thresholds):
32
+ """Convert thresholds parameter to list."""
33
+ if thresholds is None:
34
+ return [0.5]
35
+ return [thresholds] if isinstance(thresholds, (int, float)) else thresholds
36
+
37
+
38
+ def _create_scoring_dict(scoring, metrics, threshold):
39
+ """Create scoring dictionary for GridSearchCV."""
40
+ if scoring is None:
41
+ return None
42
+
43
+ scoring_dict = {}
44
+ for metric in metrics:
45
+ if metric == "recall":
46
+ scoring_dict[metric] = make_scorer(
47
+ custom_recall, needs_proba=True, threshold=threshold
48
+ )
49
+ elif metric == "roc_auc":
50
+ scoring_dict[metric] = "roc_auc"
51
+ else:
52
+ scoring_dict[metric] = metric
53
+ return scoring_dict
54
+
55
+
56
+ @tags("sklearn", "model_performance")
57
+ @tasks("clustering", "classification")
16
58
  def HyperParametersTuning(
17
59
  model: VMModel,
18
60
  dataset: VMDataset,
19
- param_grid: Union[dict, None] = None,
20
- scoring: Union[str, None] = None,
61
+ param_grid: dict,
62
+ scoring: Union[str, List, Dict] = None,
63
+ thresholds: Union[float, List[float]] = None,
64
+ fit_params: dict = None,
21
65
  ):
22
66
  """
23
- Exerts exhaustive grid search to identify optimal hyperparameters for the model, improving performance.
24
-
25
- ### Purpose:
26
-
27
- The "HyperParametersTuning" metric aims to find the optimal set of hyperparameters for a given model. The test is
28
- designed to enhance the performance of the model by determining the best configuration of hyperparameters. The
29
- parameters that are being optimized are defined by the parameter grid provided to the metric.
30
-
31
- ### Test Mechanism:
32
-
33
- The HyperParametersTuning test employs a grid search mechanism using the GridSearchCV function from the
34
- scikit-learn library. The grid search algorithm systematically works through multiple combinations of parameter
35
- values, cross-validating to determine which combination gives the best model performance. The chosen model and the
36
- parameter grid passed for tuning are necessary inputs. Once the grid search is complete, the test caches and
37
- returns details of the best model and its associated parameters.
38
-
39
- ### Signs of High Risk:
40
-
41
- - The test raises a SkipTestError if the param_grid is not supplied, indicating a lack of specific parameters to
42
- optimize, which can be risky for certain model types reliant on parameter tuning.
43
- - Poorly chosen scoring metrics that do not align well with the specific model or problem at hand could reflect
44
- potential risks or failures in achieving optimal performance.
45
-
46
- ### Strengths:
47
-
48
- - Provides a comprehensive exploration mechanism to identify the best set of hyperparameters for the supplied
49
- model, thereby enhancing its performance.
50
- - Implements GridSearchCV, simplifying and automating the time-consuming task of hyperparameter tuning.
51
-
52
- ### Limitations:
53
-
54
- - The grid search algorithm can be computationally expensive, especially with large datasets or complex models, and
55
- can be time-consuming as it tests all possible combinations within the specified parameter grid.
56
- - The effectiveness of the tuning is heavily dependent on the quality of data and only accepts datasets with
57
- numerical or ordered categories.
58
- - Assumes that the same set of hyperparameters is optimal for all problem sets, which may not be true in every
59
- scenario.
60
- - There's a potential risk of overfitting the model if the training set is not representative of the data that the
61
- model will be applied to.
67
+ Performs exhaustive grid search over specified parameter ranges to find optimal model configurations
68
+ across different metrics and decision thresholds.
69
+
70
+ ### Purpose
71
+
72
+ The Hyperparameter Tuning test systematically explores the model's parameter space to identify optimal
73
+ configurations. It supports multiple optimization metrics and decision thresholds, providing a comprehensive
74
+ view of how different parameter combinations affect various aspects of model performance.
75
+
76
+ ### Test Mechanism
77
+
78
+ The test uses scikit-learn's GridSearchCV to perform cross-validation for each parameter combination.
79
+ For each specified threshold and optimization metric, it creates a scoring dictionary with
80
+ threshold-adjusted metrics, performs grid search with cross-validation, records best parameters and
81
+ corresponding scores, and combines results into a comparative table. This process is repeated for each
82
+ optimization metric to provide a comprehensive view of model performance under different configurations.
83
+
84
+ ### Signs of High Risk
85
+
86
+ - Large performance variations across different parameter combinations
87
+ - Significant discrepancies between different optimization metrics
88
+ - Best parameters at the edges of the parameter grid
89
+ - Unstable performance across different thresholds
90
+ - Overly complex model configurations (risk of overfitting)
91
+ - Very different optimal parameters for different metrics
92
+ - Cross-validation scores showing high variance
93
+ - Extreme parameter values in best configurations
94
+
95
+ ### Strengths
96
+
97
+ - Comprehensive exploration of parameter space
98
+ - Supports multiple optimization metrics
99
+ - Allows threshold optimization
100
+ - Provides comparative view across different configurations
101
+ - Uses cross-validation for robust evaluation
102
+ - Helps understand trade-offs between different metrics
103
+ - Enables systematic parameter selection
104
+ - Supports both classification and clustering tasks
105
+
106
+ ### Limitations
107
+
108
+ - Computationally expensive for large parameter grids
109
+ - May not find global optimum (limited to grid points)
110
+ - Cannot handle dependencies between parameters
111
+ - Memory intensive for large datasets
112
+ - Limited to scikit-learn compatible models
113
+ - Cross-validation splits may not preserve time series structure
114
+ - Grid search may miss optimal values between grid points
115
+ - Resource intensive for high-dimensional parameter spaces
62
116
  """
63
- if not param_grid:
64
- raise SkipTestError("'param_grid' dictionary must be provided to run this test")
65
-
66
- estimators = GridSearchCV(model.model, param_grid=param_grid, scoring=scoring)
67
- estimators.fit(dataset.x, dataset.y)
68
-
69
- return [
70
- {
71
- "Best Model": estimators.best_estimator_,
72
- "Best Parameters": estimators.best_params_,
73
- }
74
- ]
117
+ fit_params = fit_params or {}
118
+
119
+ # Simple case: no scoring and no thresholds
120
+ if scoring is None and thresholds is None:
121
+ estimators = GridSearchCV(model.model, param_grid=param_grid, scoring=None)
122
+ estimators.fit(dataset.x_df(), dataset.y, **fit_params)
123
+ return [
124
+ {
125
+ "Best Model": estimators.best_estimator_,
126
+ "Best Parameters": estimators.best_params_,
127
+ }
128
+ ]
129
+
130
+ # Complex case: with scoring or thresholds
131
+ results = []
132
+ metrics = _get_metrics(scoring)
133
+ thresholds = _get_thresholds(thresholds)
134
+
135
+ for threshold in thresholds:
136
+ scoring_dict = _create_scoring_dict(scoring, metrics, threshold)
137
+
138
+ for optimize_for in metrics:
139
+ estimators = GridSearchCV(
140
+ model.model,
141
+ param_grid=param_grid,
142
+ scoring=scoring_dict,
143
+ refit=optimize_for if scoring is not None else True,
144
+ )
145
+
146
+ estimators.fit(dataset.x_df(), dataset.y, **fit_params)
147
+
148
+ best_index = estimators.best_index_
149
+ row_result = {
150
+ "Optimized for": optimize_for,
151
+ "Threshold": threshold,
152
+ "Best Parameters": estimators.best_params_,
153
+ }
154
+
155
+ score_key = (
156
+ "mean_test_score" if scoring is None else f"mean_test_{optimize_for}"
157
+ )
158
+ row_result[optimize_for] = estimators.cv_results_[score_key][best_index]
159
+
160
+ results.append(row_result)
161
+
162
+ return results
@@ -0,0 +1,74 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import pandas as pd
6
+ from validmind import tags, tasks
7
+
8
+
9
+ @tags("model_training", "metadata")
10
+ @tasks("classification", "regression")
11
+ def ModelParameters(model, model_params=None):
12
+ """
13
+ Extracts and displays model parameters in a structured format for transparency and reproducibility.
14
+
15
+ ### Purpose
16
+
17
+ The Model Parameters test is designed to provide transparency into model configuration and ensure
18
+ reproducibility of machine learning models. It accomplishes this by extracting and presenting all
19
+ relevant parameters that define the model's behavior, making it easier to audit, validate, and
20
+ reproduce model training.
21
+
22
+ ### Test Mechanism
23
+
24
+ The test leverages scikit-learn's API convention of get_params() to extract model parameters. It
25
+ produces a structured DataFrame containing parameter names and their corresponding values. For models
26
+ that follow scikit-learn's API (including XGBoost, RandomForest, and other estimators), all
27
+ parameters are automatically extracted and displayed.
28
+
29
+ ### Signs of High Risk
30
+
31
+ - Missing crucial parameters that should be explicitly set
32
+ - Extreme parameter values that could indicate overfitting (e.g., unlimited tree depth)
33
+ - Inconsistent parameters across different versions of the same model type
34
+ - Parameter combinations known to cause instability or poor performance
35
+ - Default values used for critical parameters that should be tuned
36
+
37
+ ### Strengths
38
+
39
+ - Universal compatibility with scikit-learn API-compliant models
40
+ - Ensures transparency in model configuration
41
+ - Facilitates model reproducibility and version control
42
+ - Enables systematic parameter auditing
43
+ - Supports both classification and regression models
44
+ - Helps identify potential configuration issues
45
+
46
+ ### Limitations
47
+
48
+ - Only works with models implementing scikit-learn's get_params() method
49
+ - Cannot capture dynamic parameters set during model training
50
+ - Does not validate parameter values for model-specific appropriateness
51
+ - Parameter meanings and impacts may vary across different model types
52
+ - Cannot detect indirect parameter interactions or their effects on model performance
53
+ """
54
+ # Check if model implements get_params()
55
+ if not hasattr(model.model, "get_params"):
56
+ return pd.DataFrame()
57
+
58
+ # Get all model parameters
59
+ params = model.model.get_params()
60
+
61
+ # If model_params is None, use all parameters from get_params()
62
+ if model_params is None:
63
+ model_params = sorted(params.keys()) # Sort for consistent ordering
64
+
65
+ # Create DataFrame with parameters and their values
66
+ param_df = pd.DataFrame(
67
+ [
68
+ {"Parameter": param, "Value": str(params.get(param, "Not specified"))}
69
+ for param in model_params
70
+ if params.get(param) is not None
71
+ ]
72
+ )
73
+
74
+ return param_df
@@ -6,7 +6,7 @@ import numpy as np
6
6
  import plotly.graph_objects as go
7
7
  from sklearn.metrics import roc_auc_score, roc_curve
8
8
 
9
- from validmind import tags, tasks
9
+ from validmind import RawData, tags, tasks
10
10
  from validmind.errors import SkipTestError
11
11
  from validmind.vm_models import VMDataset, VMModel
12
12
 
@@ -77,28 +77,31 @@ def ROCCurve(model: VMModel, dataset: VMDataset):
77
77
  fpr, tpr, _ = roc_curve(y_true, y_prob, drop_intermediate=False)
78
78
  auc = roc_auc_score(y_true, y_prob)
79
79
 
80
- return go.Figure(
81
- data=[
82
- go.Scatter(
83
- x=fpr,
84
- y=tpr,
85
- mode="lines",
86
- name=f"ROC curve (AUC = {auc:.2f})",
87
- line=dict(color="#DE257E"),
80
+ return (
81
+ RawData(fpr=fpr, tpr=tpr, auc=auc),
82
+ go.Figure(
83
+ data=[
84
+ go.Scatter(
85
+ x=fpr,
86
+ y=tpr,
87
+ mode="lines",
88
+ name=f"ROC curve (AUC = {auc:.2f})",
89
+ line=dict(color="#DE257E"),
90
+ ),
91
+ go.Scatter(
92
+ x=[0, 1],
93
+ y=[0, 1],
94
+ mode="lines",
95
+ name="Random (AUC = 0.5)",
96
+ line=dict(color="grey", dash="dash"),
97
+ ),
98
+ ],
99
+ layout=go.Layout(
100
+ title=f"ROC Curve for {model.input_id} on {dataset.input_id}",
101
+ xaxis=dict(title="False Positive Rate"),
102
+ yaxis=dict(title="True Positive Rate"),
103
+ width=700,
104
+ height=500,
88
105
  ),
89
- go.Scatter(
90
- x=[0, 1],
91
- y=[0, 1],
92
- mode="lines",
93
- name="Random (AUC = 0.5)",
94
- line=dict(color="grey", dash="dash"),
95
- ),
96
- ],
97
- layout=go.Layout(
98
- title=f"ROC Curve for {model.input_id} on {dataset.input_id}",
99
- xaxis=dict(title="False Positive Rate"),
100
- yaxis=dict(title="True Positive Rate"),
101
- width=700,
102
- height=500,
103
106
  ),
104
107
  )