validmind 2.7.2__py3-none-any.whl → 2.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. validmind/__version__.py +1 -1
  2. validmind/ai/test_descriptions.py +20 -4
  3. validmind/ai/test_result_description/user.jinja +5 -0
  4. validmind/datasets/credit_risk/lending_club.py +444 -14
  5. validmind/tests/data_validation/MutualInformation.py +129 -0
  6. validmind/tests/data_validation/ScoreBandDefaultRates.py +139 -0
  7. validmind/tests/data_validation/TooManyZeroValues.py +6 -5
  8. validmind/tests/data_validation/UniqueRows.py +3 -1
  9. validmind/tests/decorator.py +18 -16
  10. validmind/tests/load.py +4 -1
  11. validmind/tests/model_validation/sklearn/CalibrationCurve.py +116 -0
  12. validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +261 -0
  13. validmind/tests/model_validation/sklearn/ConfusionMatrix.py +1 -0
  14. validmind/tests/model_validation/sklearn/HyperParametersTuning.py +144 -56
  15. validmind/tests/model_validation/sklearn/ModelParameters.py +74 -0
  16. validmind/tests/model_validation/sklearn/ScoreProbabilityAlignment.py +130 -0
  17. validmind/tests/model_validation/statsmodels/CumulativePredictionProbabilities.py +5 -6
  18. validmind/tests/model_validation/statsmodels/PredictionProbabilitiesHistogram.py +2 -3
  19. validmind/tests/run.py +43 -72
  20. validmind/utils.py +23 -7
  21. validmind/vm_models/result/result.py +18 -17
  22. {validmind-2.7.2.dist-info → validmind-2.7.5.dist-info}/METADATA +2 -2
  23. {validmind-2.7.2.dist-info → validmind-2.7.5.dist-info}/RECORD +26 -20
  24. {validmind-2.7.2.dist-info → validmind-2.7.5.dist-info}/WHEEL +1 -1
  25. {validmind-2.7.2.dist-info → validmind-2.7.5.dist-info}/LICENSE +0 -0
  26. {validmind-2.7.2.dist-info → validmind-2.7.5.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,139 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import pandas as pd
6
+ import numpy as np
7
+ from validmind import tags, tasks
8
+ from validmind.vm_models import VMDataset, VMModel
9
+
10
+
11
+ @tags("visualization", "credit_risk", "scorecard")
12
+ @tasks("classification")
13
+ def ScoreBandDefaultRates(
14
+ dataset: VMDataset,
15
+ model: VMModel,
16
+ score_column: str = "score",
17
+ score_bands: list = None,
18
+ ):
19
+ """
20
+ Analyzes default rates and population distribution across credit score bands.
21
+
22
+ ### Purpose
23
+
24
+ The Score Band Default Rates test evaluates the discriminatory power of credit scores by analyzing
25
+ default rates across different score bands. This helps validate score effectiveness, supports
26
+ policy decisions, and provides insights into portfolio risk distribution.
27
+
28
+ ### Test Mechanism
29
+
30
+ The test segments the score distribution into bands and calculates key metrics for each band:
31
+ 1. Population count and percentage in each band
32
+ 2. Default rate within each band
33
+ 3. Cumulative statistics across bands
34
+ The results show how well the scores separate good and bad accounts.
35
+
36
+ ### Signs of High Risk
37
+
38
+ - Non-monotonic default rates across score bands
39
+ - Insufficient population in critical score bands
40
+ - Unexpected default rates for score ranges
41
+ - High concentration in specific score bands
42
+ - Similar default rates across adjacent bands
43
+ - Unstable default rates in key decision bands
44
+ - Extreme population skewness
45
+ - Poor risk separation between bands
46
+
47
+ ### Strengths
48
+
49
+ - Clear view of score effectiveness
50
+ - Supports policy threshold decisions
51
+ - Easy to interpret and communicate
52
+ - Directly links to business decisions
53
+ - Shows risk segmentation power
54
+ - Identifies potential score issues
55
+ - Helps validate scoring model
56
+ - Supports portfolio monitoring
57
+
58
+ ### Limitations
59
+
60
+ - Sensitive to band definition choices
61
+ - May mask within-band variations
62
+ - Requires sufficient data in each band
63
+ - Cannot capture non-linear patterns
64
+ - Point-in-time analysis only
65
+ - No temporal trend information
66
+ - Assumes band boundaries are appropriate
67
+ - May oversimplify risk patterns
68
+ """
69
+
70
+ if score_column not in dataset.df.columns:
71
+ raise ValueError(
72
+ f"The required column '{score_column}' is not present in the dataset with input_id {dataset.input_id}"
73
+ )
74
+
75
+ df = dataset._df.copy()
76
+
77
+ # Default score bands if none provided
78
+ if score_bands is None:
79
+ score_bands = [410, 440, 470]
80
+
81
+ # Create band labels
82
+ band_labels = [
83
+ f"{score_bands[i]}-{score_bands[i+1]}" for i in range(len(score_bands) - 1)
84
+ ]
85
+ band_labels.insert(0, f"<{score_bands[0]}")
86
+ band_labels.append(f">{score_bands[-1]}")
87
+
88
+ # Bin the scores with infinite upper bound
89
+ df["score_band"] = pd.cut(
90
+ df[score_column], bins=[-np.inf] + score_bands + [np.inf], labels=band_labels
91
+ )
92
+
93
+ # Calculate min and max scores for the total row
94
+ min_score = df[score_column].min()
95
+ max_score = df[score_column].max()
96
+
97
+ # Get predicted classes (0/1)
98
+ y_pred = dataset.y_pred(model)
99
+
100
+ # Calculate metrics by band using target_column name
101
+ results = []
102
+ for band in band_labels:
103
+ band_mask = df["score_band"] == band
104
+ population = band_mask.sum()
105
+ observed_defaults = df[band_mask][dataset.target_column].sum()
106
+ predicted_defaults = y_pred[
107
+ band_mask
108
+ ].sum() # Sum of 1s gives number of predicted defaults
109
+
110
+ results.append(
111
+ {
112
+ "Score Band": band,
113
+ "Population Count": population,
114
+ "Population (%)": population / len(df) * 100,
115
+ "Predicted Default Rate (%)": (
116
+ predicted_defaults / population * 100 if population > 0 else 0
117
+ ),
118
+ "Observed Default Rate (%)": (
119
+ observed_defaults / population * 100 if population > 0 else 0
120
+ ),
121
+ }
122
+ )
123
+
124
+ # Add total row
125
+ total_population = len(df)
126
+ total_observed = df[dataset.target_column].sum()
127
+ total_predicted = y_pred.sum() # Total number of predicted defaults
128
+
129
+ results.append(
130
+ {
131
+ "Score Band": f"Total ({min_score:.0f}-{max_score:.0f})",
132
+ "Population Count": total_population,
133
+ "Population (%)": sum(r["Population (%)"] for r in results),
134
+ "Predicted Default Rate (%)": total_predicted / total_population * 100,
135
+ "Observed Default Rate (%)": total_observed / total_population * 100,
136
+ }
137
+ )
138
+
139
+ return pd.DataFrame(results)
@@ -61,24 +61,25 @@ def TooManyZeroValues(dataset: VMDataset, max_percent_threshold: float = 0.03):
61
61
  issues.
62
62
  """
63
63
  df = dataset.df
64
-
65
64
  table = []
66
65
 
67
66
  for col in dataset.feature_columns_numeric:
68
67
  value_counts = df[col].value_counts()
68
+ row_count = df.shape[0]
69
69
 
70
70
  if 0 not in value_counts.index:
71
71
  continue
72
72
 
73
73
  n_zeros = value_counts[0]
74
- p_zeros = n_zeros / df.shape[0]
74
+ p_zeros = (n_zeros / row_count) * 100
75
75
 
76
76
  table.append(
77
77
  {
78
- "Column": col,
78
+ "Variable": col,
79
+ "Row Count": row_count,
79
80
  "Number of Zero Values": n_zeros,
80
- "Percentage of Zero Values (%)": p_zeros * 100,
81
- "Pass/Fail": "Pass" if p_zeros < max_percent_threshold else "Fail",
81
+ "Percentage of Zero Values (%)": p_zeros,
82
+ "Pass/Fail": ("Pass" if p_zeros < (max_percent_threshold) else "Fail"),
82
83
  }
83
84
  )
84
85
 
@@ -61,7 +61,9 @@ def UniqueRows(dataset: VMDataset, min_percent_threshold: float = 1):
61
61
  "Number of Unique Values": unique_rows[col],
62
62
  "Percentage of Unique Values (%)": unique_rows[col] / rows * 100,
63
63
  "Pass/Fail": (
64
- "Pass" if unique_rows[col] / rows >= min_percent_threshold else "Fail"
64
+ "Pass"
65
+ if (unique_rows[col] / rows * 100) >= min_percent_threshold
66
+ else "Fail"
65
67
  ),
66
68
  }
67
69
  for col in unique_rows.index
@@ -24,6 +24,11 @@ def _get_save_func(func, test_id):
24
24
  test library.
25
25
  """
26
26
 
27
+ # get og source before its wrapped by the test decorator
28
+ source = inspect.getsource(func)
29
+ # remove decorator line
30
+ source = source.split("\n", 1)[1]
31
+
27
32
  def save(root_folder=".", imports=None):
28
33
  parts = test_id.split(".")
29
34
 
@@ -41,35 +46,32 @@ def _get_save_func(func, test_id):
41
46
 
42
47
  full_path = os.path.join(path, f"{test_name}.py")
43
48
 
44
- source = inspect.getsource(func)
45
- # remove decorator line
46
- source = source.split("\n", 1)[1]
49
+ _source = source.replace(f"def {func.__name__}", f"def {test_name}")
50
+
47
51
  if imports:
48
52
  imports = "\n".join(imports)
49
- source = f"{imports}\n\n\n{source}"
53
+ _source = f"{imports}\n\n\n{_source}"
54
+
50
55
  # add comment to the top of the file
51
- source = f"""
56
+ _source = f"""
52
57
  # Saved from {func.__module__}.{func.__name__}
53
58
  # Original Test ID: {test_id}
54
59
  # New Test ID: {new_test_id}
55
60
 
56
- {source}
61
+ {_source}
57
62
  """
58
63
 
59
- # ensure that the function name matches the test name
60
- source = source.replace(f"def {func.__name__}", f"def {test_name}")
61
-
62
64
  # use black to format the code
63
65
  try:
64
66
  import black
65
67
 
66
- source = black.format_str(source, mode=black.FileMode())
68
+ _source = black.format_str(_source, mode=black.FileMode())
67
69
  except ImportError:
68
70
  # ignore if not available
69
71
  pass
70
72
 
71
73
  with open(full_path, "w") as file:
72
- file.writelines(source)
74
+ file.writelines(_source)
73
75
 
74
76
  logger.info(
75
77
  f"Saved to {os.path.abspath(full_path)}!"
@@ -119,12 +121,12 @@ def test(func_or_id):
119
121
  test_func = load_test(test_id, func, reload=True)
120
122
  test_store.register_test(test_id, test_func)
121
123
 
122
- @wraps(test_func)
123
- def wrapper(*args, **kwargs):
124
- return test_func(*args, **kwargs)
125
-
126
124
  # special function to allow the function to be saved to a file
127
- wrapper.save = _get_save_func(test_func, test_id)
125
+ save_func = _get_save_func(func, test_id)
126
+
127
+ wrapper = wraps(func)(test_func)
128
+ wrapper.test_id = test_id
129
+ wrapper.save = save_func
128
130
 
129
131
  return wrapper
130
132
 
validmind/tests/load.py CHANGED
@@ -191,7 +191,7 @@ def list_tags():
191
191
  return list(unique_tags)
192
192
 
193
193
 
194
- def list_tasks_and_tags():
194
+ def list_tasks_and_tags(as_json=False):
195
195
  """
196
196
  List all task types and their associated tags, with one row per task type and
197
197
  all tags for a task type in one row.
@@ -205,6 +205,9 @@ def list_tasks_and_tags():
205
205
  for task in test.__tasks__:
206
206
  task_tags_dict.setdefault(task, set()).update(test.__tags__)
207
207
 
208
+ if as_json:
209
+ return task_tags_dict
210
+
208
211
  return format_dataframe(
209
212
  pd.DataFrame(
210
213
  [
@@ -0,0 +1,116 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ from sklearn.calibration import calibration_curve
6
+ import plotly.graph_objects as go
7
+ from validmind import tags, tasks
8
+ from validmind.vm_models import VMModel, VMDataset
9
+ from validmind.vm_models.result import RawData
10
+
11
+
12
+ @tags("sklearn", "model_performance", "classification")
13
+ @tasks("classification")
14
+ def CalibrationCurve(model: VMModel, dataset: VMDataset, n_bins: int = 10):
15
+ """
16
+ Evaluates the calibration of probability estimates by comparing predicted probabilities against observed
17
+ frequencies.
18
+
19
+ ### Purpose
20
+
21
+ The Calibration Curve test assesses how well a model's predicted probabilities align with actual
22
+ observed frequencies. This is crucial for applications requiring accurate probability estimates,
23
+ such as risk assessment, decision-making systems, and cost-sensitive applications where probability
24
+ calibration directly impacts business decisions.
25
+
26
+ ### Test Mechanism
27
+
28
+ The test uses sklearn's calibration_curve function to:
29
+ 1. Sort predictions into bins based on predicted probabilities
30
+ 2. Calculate the mean predicted probability in each bin
31
+ 3. Compare against the observed frequency of positive cases
32
+ 4. Plot the results against the perfect calibration line (y=x)
33
+ The resulting curve shows how well the predicted probabilities match empirical probabilities.
34
+
35
+ ### Signs of High Risk
36
+
37
+ - Significant deviation from the perfect calibration line
38
+ - Systematic overconfidence (predictions too close to 0 or 1)
39
+ - Systematic underconfidence (predictions clustered around 0.5)
40
+ - Empty or sparse bins indicating poor probability coverage
41
+ - Sharp discontinuities in the calibration curve
42
+ - Different calibration patterns across different probability ranges
43
+ - Consistent over/under estimation in critical probability regions
44
+ - Large confidence intervals in certain probability ranges
45
+
46
+ ### Strengths
47
+
48
+ - Visual and intuitive interpretation of probability quality
49
+ - Identifies systematic biases in probability estimates
50
+ - Supports probability threshold selection
51
+ - Helps understand model confidence patterns
52
+ - Applicable across different classification models
53
+ - Enables comparison between different models
54
+ - Guides potential need for recalibration
55
+ - Critical for risk-sensitive applications
56
+
57
+ ### Limitations
58
+
59
+ - Sensitive to the number of bins chosen
60
+ - Requires sufficient samples in each bin for reliable estimates
61
+ - May mask local calibration issues within bins
62
+ - Does not account for feature-dependent calibration issues
63
+ - Limited to binary classification problems
64
+ - Cannot detect all forms of miscalibration
65
+ - Assumes bin boundaries are appropriate for the problem
66
+ - May be affected by class imbalance
67
+ """
68
+ prob_true, prob_pred = calibration_curve(
69
+ dataset.y, dataset.y_prob(model), n_bins=n_bins
70
+ )
71
+
72
+ # Create DataFrame for raw data
73
+ raw_data = RawData(
74
+ mean_predicted_probability=prob_pred, observed_frequency=prob_true
75
+ )
76
+
77
+ # Create Plotly figure
78
+ fig = go.Figure()
79
+
80
+ # Add perfect calibration line
81
+ fig.add_trace(
82
+ go.Scatter(
83
+ x=[0, 1],
84
+ y=[0, 1],
85
+ mode="lines",
86
+ name="Perfect Calibration",
87
+ line=dict(dash="dash", color="gray"),
88
+ )
89
+ )
90
+
91
+ # Add calibration curve
92
+ fig.add_trace(
93
+ go.Scatter(
94
+ x=prob_pred,
95
+ y=prob_true,
96
+ mode="lines+markers",
97
+ name="Model Calibration",
98
+ line=dict(color="blue"),
99
+ marker=dict(size=8),
100
+ )
101
+ )
102
+
103
+ # Update layout
104
+ fig.update_layout(
105
+ title="Calibration Curve",
106
+ xaxis_title="Mean Predicted Probability",
107
+ yaxis_title="Observed Frequency",
108
+ xaxis=dict(range=[0, 1]),
109
+ yaxis=dict(range=[0, 1]),
110
+ width=800,
111
+ height=600,
112
+ showlegend=True,
113
+ template="plotly_white",
114
+ )
115
+
116
+ return raw_data, fig
@@ -0,0 +1,261 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import plotly.graph_objects as go
8
+ from plotly.subplots import make_subplots
9
+ from sklearn.metrics import (
10
+ roc_curve,
11
+ precision_recall_curve,
12
+ confusion_matrix,
13
+ )
14
+ from validmind import tags, tasks
15
+ from validmind.vm_models import VMDataset, VMModel
16
+
17
+
18
+ def find_optimal_threshold(y_true, y_prob, method="youden", target_recall=None):
19
+ """
20
+ Find the optimal classification threshold using various methods.
21
+
22
+ Args:
23
+ y_true: True binary labels
24
+ y_prob: Predicted probabilities
25
+ method: Method to use for finding optimal threshold
26
+ target_recall: Required if method='target_recall'
27
+
28
+ Returns:
29
+ dict: Dictionary containing threshold and metrics
30
+ """
31
+ # Get ROC and PR curve points
32
+ fpr, tpr, thresholds_roc = roc_curve(y_true, y_prob)
33
+ precision, recall, thresholds_pr = precision_recall_curve(y_true, y_prob)
34
+
35
+ # Find optimal threshold based on method
36
+ if method == "naive":
37
+ optimal_threshold = 0.5
38
+ elif method == "youden":
39
+ j_scores = tpr - fpr
40
+ best_idx = np.argmax(j_scores)
41
+ optimal_threshold = thresholds_roc[best_idx]
42
+ elif method == "f1":
43
+ f1_scores = 2 * (precision * recall) / (precision + recall)
44
+ best_idx = np.argmax(f1_scores)
45
+ optimal_threshold = (
46
+ thresholds_pr[best_idx] if best_idx < len(thresholds_pr) else 1.0
47
+ )
48
+ elif method == "precision_recall":
49
+ diff = abs(precision - recall)
50
+ best_idx = np.argmin(diff)
51
+ optimal_threshold = (
52
+ thresholds_pr[best_idx] if best_idx < len(thresholds_pr) else 1.0
53
+ )
54
+ elif method == "target_recall":
55
+ if target_recall is None:
56
+ raise ValueError(
57
+ "target_recall must be specified when method='target_recall'"
58
+ )
59
+ idx = np.argmin(abs(recall - target_recall))
60
+ optimal_threshold = thresholds_pr[idx] if idx < len(thresholds_pr) else 1.0
61
+ else:
62
+ raise ValueError(f"Unknown method: {method}")
63
+
64
+ # Calculate predictions with optimal threshold
65
+ y_pred = (y_prob >= optimal_threshold).astype(int)
66
+
67
+ # Calculate confusion matrix
68
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
69
+
70
+ # Calculate metrics directly
71
+ metrics = {
72
+ "method": method,
73
+ "threshold": optimal_threshold,
74
+ "precision": tp / (tp + fp) if (tp + fp) > 0 else 0,
75
+ "recall": tp / (tp + fn) if (tp + fn) > 0 else 0,
76
+ "f1_score": 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0,
77
+ "accuracy": (tp + tn) / (tp + tn + fp + fn),
78
+ }
79
+
80
+ return metrics
81
+
82
+
83
+ @tags("model_validation", "threshold_optimization", "classification_metrics")
84
+ @tasks("classification")
85
+ def ClassifierThresholdOptimization(
86
+ dataset: VMDataset, model: VMModel, methods=None, target_recall=None
87
+ ):
88
+ """
89
+ Analyzes and visualizes different threshold optimization methods for binary classification models.
90
+
91
+ ### Purpose
92
+
93
+ The Classifier Threshold Optimization test identifies optimal decision thresholds using various
94
+ methods to balance different performance metrics. This helps adapt the model's decision boundary
95
+ to specific business requirements, such as minimizing false positives in fraud detection or
96
+ achieving target recall in medical diagnosis.
97
+
98
+ ### Test Mechanism
99
+
100
+ The test implements multiple threshold optimization methods:
101
+ 1. Youden's J statistic (maximizing sensitivity + specificity - 1)
102
+ 2. F1-score optimization (balancing precision and recall)
103
+ 3. Precision-Recall equality point
104
+ 4. Target recall achievement
105
+ 5. Naive (0.5) threshold
106
+ For each method, it computes ROC and PR curves, identifies optimal points, and provides
107
+ comprehensive performance metrics at each threshold.
108
+
109
+ ### Signs of High Risk
110
+
111
+ - Large discrepancies between different optimization methods
112
+ - Optimal thresholds far from the default 0.5
113
+ - Poor performance metrics across all thresholds
114
+ - Significant gap between achieved and target recall
115
+ - Unstable thresholds across different methods
116
+ - Extreme trade-offs between precision and recall
117
+ - Threshold optimization showing minimal impact
118
+ - Business metrics not improving with optimization
119
+
120
+ ### Strengths
121
+
122
+ - Multiple optimization strategies for different needs
123
+ - Visual and numerical results for comparison
124
+ - Support for business-driven optimization (target recall)
125
+ - Comprehensive performance metrics at each threshold
126
+ - Integration with ROC and PR curves
127
+ - Handles class imbalance through various metrics
128
+ - Enables informed threshold selection
129
+ - Supports cost-sensitive decision making
130
+
131
+ ### Limitations
132
+
133
+ - Assumes cost of false positives/negatives are known
134
+ - May need adjustment for highly imbalanced datasets
135
+ - Threshold might not be stable across different samples
136
+ - Cannot handle multi-class problems directly
137
+ - Optimization methods may conflict with business needs
138
+ - Requires sufficient validation data
139
+ - May not capture temporal changes in optimal threshold
140
+ - Single threshold may not be optimal for all subgroups
141
+
142
+ Args:
143
+ dataset: VMDataset containing features and target
144
+ model: VMModel containing predictions
145
+ methods: List of methods to compare (default: ['youden', 'f1', 'precision_recall'])
146
+ target_recall: Target recall value if using 'target_recall' method
147
+
148
+ Returns:
149
+ Dictionary containing:
150
+ - table: DataFrame comparing different threshold optimization methods
151
+ (using weighted averages for precision, recall, and f1)
152
+ - figure: Plotly figure showing ROC and PR curves with optimal thresholds
153
+ """
154
+ # Verify binary classification
155
+ unique_values = np.unique(dataset.y)
156
+ if len(unique_values) != 2:
157
+ raise ValueError("Target variable must be binary")
158
+
159
+ if methods is None:
160
+ methods = ["naive", "youden", "f1", "precision_recall"]
161
+ if target_recall is not None:
162
+ methods.append("target_recall")
163
+
164
+ y_true = dataset.y
165
+ y_prob = dataset.y_prob(model)
166
+
167
+ # Get curve points for plotting
168
+ fpr, tpr, thresholds_roc = roc_curve(y_true, y_prob)
169
+ precision, recall, thresholds_pr = precision_recall_curve(y_true, y_prob)
170
+
171
+ # Calculate optimal thresholds and metrics
172
+ results = []
173
+ optimal_points = {}
174
+
175
+ for method in methods:
176
+ metrics = find_optimal_threshold(y_true, y_prob, method, target_recall)
177
+ results.append(metrics)
178
+
179
+ # Store optimal points for plotting
180
+ if method == "youden":
181
+ idx = np.argmax(tpr - fpr)
182
+ optimal_points[method] = {
183
+ "x": fpr[idx],
184
+ "y": tpr[idx],
185
+ "threshold": thresholds_roc[idx],
186
+ }
187
+ elif method in ["f1", "precision_recall", "target_recall"]:
188
+ idx = np.argmin(abs(thresholds_pr - metrics["threshold"]))
189
+ optimal_points[method] = {
190
+ "x": recall[idx],
191
+ "y": precision[idx],
192
+ "threshold": metrics["threshold"],
193
+ }
194
+
195
+ # Create visualization
196
+ fig = make_subplots(
197
+ rows=1, cols=2, subplot_titles=("ROC Curve", "Precision-Recall Curve")
198
+ )
199
+
200
+ # Plot ROC curve
201
+ fig.add_trace(
202
+ go.Scatter(x=fpr, y=tpr, name="ROC Curve", line=dict(color="blue")),
203
+ row=1,
204
+ col=1,
205
+ )
206
+
207
+ # Plot PR curve
208
+ fig.add_trace(
209
+ go.Scatter(x=recall, y=precision, name="PR Curve", line=dict(color="green")),
210
+ row=1,
211
+ col=2,
212
+ )
213
+
214
+ # Add optimal points
215
+ colors = {
216
+ "youden": "red",
217
+ "f1": "orange",
218
+ "precision_recall": "purple",
219
+ "target_recall": "brown",
220
+ }
221
+
222
+ for method, points in optimal_points.items():
223
+ if method == "youden":
224
+ fig.add_trace(
225
+ go.Scatter(
226
+ x=[points["x"]],
227
+ y=[points["y"]],
228
+ name=f'{method} (t={points["threshold"]:.2f})',
229
+ mode="markers",
230
+ marker=dict(size=10, color=colors[method]),
231
+ ),
232
+ row=1,
233
+ col=1,
234
+ )
235
+ else:
236
+ fig.add_trace(
237
+ go.Scatter(
238
+ x=[points["x"]],
239
+ y=[points["y"]],
240
+ name=f'{method} (t={points["threshold"]:.2f})',
241
+ mode="markers",
242
+ marker=dict(size=10, color=colors[method]),
243
+ ),
244
+ row=1,
245
+ col=2,
246
+ )
247
+
248
+ # Update layout
249
+ fig.update_layout(
250
+ height=500, title_text="Threshold Optimization Analysis", showlegend=True
251
+ )
252
+
253
+ fig.update_xaxes(title_text="False Positive Rate", row=1, col=1)
254
+ fig.update_xaxes(title_text="Recall", row=1, col=2)
255
+ fig.update_yaxes(title_text="True Positive Rate", row=1, col=1)
256
+ fig.update_yaxes(title_text="Precision", row=1, col=2)
257
+
258
+ # Create results table and sort by threshold descending
259
+ table = pd.DataFrame(results).sort_values("threshold", ascending=False)
260
+
261
+ return fig, table
@@ -106,6 +106,7 @@ def ConfusionMatrix(dataset: VMDataset, model: VMModel):
106
106
  autosize=False,
107
107
  width=600,
108
108
  height=600,
109
+ title_text="Confusion Matrix",
109
110
  )
110
111
 
111
112
  fig.add_annotation(