explainiverse 0.2.5__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {explainiverse-0.2.5 → explainiverse-0.4.0}/PKG-INFO +2 -1
- {explainiverse-0.2.5 → explainiverse-0.4.0}/pyproject.toml +2 -1
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/__init__.py +1 -1
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/core/registry.py +22 -0
- explainiverse-0.4.0/src/explainiverse/evaluation/__init__.py +60 -0
- explainiverse-0.4.0/src/explainiverse/evaluation/_utils.py +325 -0
- explainiverse-0.4.0/src/explainiverse/evaluation/faithfulness.py +428 -0
- explainiverse-0.4.0/src/explainiverse/evaluation/stability.py +379 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/__init__.py +8 -0
- explainiverse-0.4.0/src/explainiverse/explainers/example_based/__init__.py +18 -0
- explainiverse-0.4.0/src/explainiverse/explainers/example_based/protodash.py +826 -0
- explainiverse-0.2.5/src/explainiverse/evaluation/__init__.py +0 -8
- {explainiverse-0.2.5 → explainiverse-0.4.0}/LICENSE +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/README.md +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/adapters/__init__.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/adapters/base_adapter.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/adapters/pytorch_adapter.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/adapters/sklearn_adapter.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/core/__init__.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/core/explainer.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/core/explanation.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/engine/__init__.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/engine/suite.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/evaluation/metrics.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/__init__.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/lime_wrapper.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/shap_wrapper.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/treeshap_wrapper.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/counterfactual/__init__.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/counterfactual/dice_wrapper.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/__init__.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/ale.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/partial_dependence.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/permutation_importance.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/sage.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/__init__.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/deeplift.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/gradcam.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/integrated_gradients.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/rule_based/__init__.py +0 -0
- {explainiverse-0.2.5 → explainiverse-0.4.0}/src/explainiverse/explainers/rule_based/anchors_wrapper.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: explainiverse
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more
|
|
5
5
|
Home-page: https://github.com/jemsbhai/explainiverse
|
|
6
6
|
License: MIT
|
|
@@ -20,6 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
20
20
|
Provides-Extra: torch
|
|
21
21
|
Requires-Dist: lime (>=0.2.0.1,<0.3.0.0)
|
|
22
22
|
Requires-Dist: numpy (>=1.24,<2.0)
|
|
23
|
+
Requires-Dist: pandas (>=1.5,<3.0)
|
|
23
24
|
Requires-Dist: scikit-learn (>=1.1,<1.6)
|
|
24
25
|
Requires-Dist: scipy (>=1.10,<2.0)
|
|
25
26
|
Requires-Dist: shap (>=0.48.0,<0.49.0)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "explainiverse"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.4.0"
|
|
4
4
|
description = "Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more"
|
|
5
5
|
authors = ["Muntaser Syed <jemsbhai@gmail.com>"]
|
|
6
6
|
license = "MIT"
|
|
@@ -27,6 +27,7 @@ numpy = ">=1.24,<2.0"
|
|
|
27
27
|
lime = "^0.2.0.1"
|
|
28
28
|
scikit-learn = ">=1.1,<1.6"
|
|
29
29
|
shap = "^0.48.0"
|
|
30
|
+
pandas = ">=1.5,<3.0"
|
|
30
31
|
scipy = ">=1.10,<2.0"
|
|
31
32
|
xgboost = ">=1.7,<3.0"
|
|
32
33
|
torch = { version = ">=2.0", optional = true }
|
|
@@ -372,6 +372,7 @@ def _create_default_registry() -> ExplainerRegistry:
|
|
|
372
372
|
from explainiverse.explainers.gradient.integrated_gradients import IntegratedGradientsExplainer
|
|
373
373
|
from explainiverse.explainers.gradient.gradcam import GradCAMExplainer
|
|
374
374
|
from explainiverse.explainers.gradient.deeplift import DeepLIFTExplainer, DeepLIFTShapExplainer
|
|
375
|
+
from explainiverse.explainers.example_based.protodash import ProtoDashExplainer
|
|
375
376
|
|
|
376
377
|
registry = ExplainerRegistry()
|
|
377
378
|
|
|
@@ -604,6 +605,27 @@ def _create_default_registry() -> ExplainerRegistry:
|
|
|
604
605
|
)
|
|
605
606
|
)
|
|
606
607
|
|
|
608
|
+
# =========================================================================
|
|
609
|
+
# Example-Based Explainers
|
|
610
|
+
# =========================================================================
|
|
611
|
+
|
|
612
|
+
# Register ProtoDash
|
|
613
|
+
registry.register(
|
|
614
|
+
name="protodash",
|
|
615
|
+
explainer_class=ProtoDashExplainer,
|
|
616
|
+
meta=ExplainerMeta(
|
|
617
|
+
scope="local",
|
|
618
|
+
model_types=["any"],
|
|
619
|
+
data_types=["tabular"],
|
|
620
|
+
task_types=["classification", "regression"],
|
|
621
|
+
description="ProtoDash - prototype selection with importance weights for example-based explanations",
|
|
622
|
+
paper_reference="Gurumoorthy et al., 2019 - 'Efficient Data Representation by Selecting Prototypes' (ICDM)",
|
|
623
|
+
complexity="O(n_prototypes * n_samples^2)",
|
|
624
|
+
requires_training_data=True,
|
|
625
|
+
supports_batching=True
|
|
626
|
+
)
|
|
627
|
+
)
|
|
628
|
+
|
|
607
629
|
return registry
|
|
608
630
|
|
|
609
631
|
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# src/explainiverse/evaluation/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Evaluation metrics for explanation quality.
|
|
4
|
+
|
|
5
|
+
Includes:
|
|
6
|
+
- Faithfulness metrics (PGI, PGU, Comprehensiveness, Sufficiency)
|
|
7
|
+
- Stability metrics (RIS, ROS, Lipschitz)
|
|
8
|
+
- Perturbation metrics (AOPC, ROAR)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from explainiverse.evaluation.metrics import (
|
|
12
|
+
compute_aopc,
|
|
13
|
+
compute_batch_aopc,
|
|
14
|
+
compute_roar,
|
|
15
|
+
compute_roar_curve,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from explainiverse.evaluation.faithfulness import (
|
|
19
|
+
compute_pgi,
|
|
20
|
+
compute_pgu,
|
|
21
|
+
compute_faithfulness_score,
|
|
22
|
+
compute_comprehensiveness,
|
|
23
|
+
compute_sufficiency,
|
|
24
|
+
compute_faithfulness_correlation,
|
|
25
|
+
compare_explainer_faithfulness,
|
|
26
|
+
compute_batch_faithfulness,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from explainiverse.evaluation.stability import (
|
|
30
|
+
compute_ris,
|
|
31
|
+
compute_ros,
|
|
32
|
+
compute_lipschitz_estimate,
|
|
33
|
+
compute_stability_metrics,
|
|
34
|
+
compute_batch_stability,
|
|
35
|
+
compare_explainer_stability,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
# Perturbation metrics (existing)
|
|
40
|
+
"compute_aopc",
|
|
41
|
+
"compute_batch_aopc",
|
|
42
|
+
"compute_roar",
|
|
43
|
+
"compute_roar_curve",
|
|
44
|
+
# Faithfulness metrics (new)
|
|
45
|
+
"compute_pgi",
|
|
46
|
+
"compute_pgu",
|
|
47
|
+
"compute_faithfulness_score",
|
|
48
|
+
"compute_comprehensiveness",
|
|
49
|
+
"compute_sufficiency",
|
|
50
|
+
"compute_faithfulness_correlation",
|
|
51
|
+
"compare_explainer_faithfulness",
|
|
52
|
+
"compute_batch_faithfulness",
|
|
53
|
+
# Stability metrics (new)
|
|
54
|
+
"compute_ris",
|
|
55
|
+
"compute_ros",
|
|
56
|
+
"compute_lipschitz_estimate",
|
|
57
|
+
"compute_stability_metrics",
|
|
58
|
+
"compute_batch_stability",
|
|
59
|
+
"compare_explainer_stability",
|
|
60
|
+
]
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
# src/explainiverse/evaluation/_utils.py
|
|
2
|
+
"""
|
|
3
|
+
Shared utility functions for evaluation metrics.
|
|
4
|
+
"""
|
|
5
|
+
import numpy as np
|
|
6
|
+
import re
|
|
7
|
+
from typing import Union, Callable, List, Tuple
|
|
8
|
+
from explainiverse.core.explanation import Explanation
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _extract_base_feature_name(feature_str: str) -> str:
|
|
12
|
+
"""
|
|
13
|
+
Extract the base feature name from LIME-style feature strings.
|
|
14
|
+
|
|
15
|
+
LIME returns strings like "petal width (cm) <= 0.80" or "feature_2 > 3.5".
|
|
16
|
+
This extracts just the feature name part.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
feature_str: Feature string possibly with conditions
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Base feature name
|
|
23
|
+
"""
|
|
24
|
+
# Remove comparison operators and values
|
|
25
|
+
# Pattern matches: name <= value, name < value, name >= value, name > value, name = value
|
|
26
|
+
patterns = [
|
|
27
|
+
r'^(.+?)\s*<=\s*[\d\.\-]+$',
|
|
28
|
+
r'^(.+?)\s*>=\s*[\d\.\-]+$',
|
|
29
|
+
r'^(.+?)\s*<\s*[\d\.\-]+$',
|
|
30
|
+
r'^(.+?)\s*>\s*[\d\.\-]+$',
|
|
31
|
+
r'^(.+?)\s*=\s*[\d\.\-]+$',
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
for pattern in patterns:
|
|
35
|
+
match = re.match(pattern, feature_str.strip())
|
|
36
|
+
if match:
|
|
37
|
+
return match.group(1).strip()
|
|
38
|
+
|
|
39
|
+
# No match found, return as-is
|
|
40
|
+
return feature_str.strip()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _match_feature_to_index(
|
|
44
|
+
feature_key: str,
|
|
45
|
+
feature_names: List[str]
|
|
46
|
+
) -> int:
|
|
47
|
+
"""
|
|
48
|
+
Match a feature key (possibly with LIME conditions) to its index.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
feature_key: Feature name from explanation (may include conditions)
|
|
52
|
+
feature_names: List of original feature names
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Index of the matching feature, or -1 if not found
|
|
56
|
+
"""
|
|
57
|
+
# Try exact match first
|
|
58
|
+
if feature_key in feature_names:
|
|
59
|
+
return feature_names.index(feature_key)
|
|
60
|
+
|
|
61
|
+
# Try extracting base name
|
|
62
|
+
base_name = _extract_base_feature_name(feature_key)
|
|
63
|
+
if base_name in feature_names:
|
|
64
|
+
return feature_names.index(base_name)
|
|
65
|
+
|
|
66
|
+
# Try partial matching (feature name is contained in key)
|
|
67
|
+
for i, fname in enumerate(feature_names):
|
|
68
|
+
if fname in feature_key:
|
|
69
|
+
return i
|
|
70
|
+
|
|
71
|
+
# Try index extraction from patterns like "feature_2" or "f2" or "feat_2"
|
|
72
|
+
patterns = [
|
|
73
|
+
r'feature[_\s]*(\d+)',
|
|
74
|
+
r'feat[_\s]*(\d+)',
|
|
75
|
+
r'^f(\d+)$',
|
|
76
|
+
r'^x(\d+)$',
|
|
77
|
+
]
|
|
78
|
+
for pattern in patterns:
|
|
79
|
+
match = re.search(pattern, feature_key, re.IGNORECASE)
|
|
80
|
+
if match:
|
|
81
|
+
idx = int(match.group(1))
|
|
82
|
+
if 0 <= idx < len(feature_names):
|
|
83
|
+
return idx
|
|
84
|
+
|
|
85
|
+
return -1
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_sorted_feature_indices(
|
|
89
|
+
explanation: Explanation,
|
|
90
|
+
descending: bool = True
|
|
91
|
+
) -> List[int]:
|
|
92
|
+
"""
|
|
93
|
+
Extract feature indices sorted by absolute attribution value.
|
|
94
|
+
|
|
95
|
+
Handles various feature naming conventions:
|
|
96
|
+
- Clean names: "sepal length", "feature_0"
|
|
97
|
+
- LIME-style: "sepal length <= 5.0", "feature_0 > 2.3"
|
|
98
|
+
- Indexed: "f0", "x1", "feat_2"
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
explanation: Explanation object with feature_attributions
|
|
102
|
+
descending: If True, sort from most to least important
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
List of feature indices sorted by importance
|
|
106
|
+
"""
|
|
107
|
+
attributions = explanation.explanation_data.get("feature_attributions", {})
|
|
108
|
+
|
|
109
|
+
if not attributions:
|
|
110
|
+
raise ValueError("No feature attributions found in explanation.")
|
|
111
|
+
|
|
112
|
+
# Sort features by absolute importance
|
|
113
|
+
sorted_features = sorted(
|
|
114
|
+
attributions.items(),
|
|
115
|
+
key=lambda x: abs(x[1]),
|
|
116
|
+
reverse=descending
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Map feature names to indices
|
|
120
|
+
feature_indices = []
|
|
121
|
+
feature_names = getattr(explanation, 'feature_names', None)
|
|
122
|
+
|
|
123
|
+
for i, (fname, _) in enumerate(sorted_features):
|
|
124
|
+
if feature_names is not None:
|
|
125
|
+
idx = _match_feature_to_index(fname, feature_names)
|
|
126
|
+
if idx >= 0:
|
|
127
|
+
feature_indices.append(idx)
|
|
128
|
+
else:
|
|
129
|
+
# Fallback: use position in sorted list
|
|
130
|
+
feature_indices.append(i % len(feature_names))
|
|
131
|
+
else:
|
|
132
|
+
# No feature_names available - try to extract index from name
|
|
133
|
+
patterns = [
|
|
134
|
+
r'feature[_\s]*(\d+)',
|
|
135
|
+
r'feat[_\s]*(\d+)',
|
|
136
|
+
r'^f(\d+)',
|
|
137
|
+
r'^x(\d+)',
|
|
138
|
+
]
|
|
139
|
+
found = False
|
|
140
|
+
for pattern in patterns:
|
|
141
|
+
match = re.search(pattern, fname, re.IGNORECASE)
|
|
142
|
+
if match:
|
|
143
|
+
feature_indices.append(int(match.group(1)))
|
|
144
|
+
found = True
|
|
145
|
+
break
|
|
146
|
+
if not found:
|
|
147
|
+
feature_indices.append(i)
|
|
148
|
+
|
|
149
|
+
return feature_indices
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def compute_baseline_values(
|
|
153
|
+
baseline: Union[str, float, np.ndarray, Callable],
|
|
154
|
+
background_data: np.ndarray = None,
|
|
155
|
+
n_features: int = None
|
|
156
|
+
) -> np.ndarray:
|
|
157
|
+
"""
|
|
158
|
+
Compute per-feature baseline values for perturbation.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
baseline: Baseline specification - one of:
|
|
162
|
+
- "mean": Use mean of background_data
|
|
163
|
+
- "median": Use median of background_data
|
|
164
|
+
- float/int: Use this value for all features
|
|
165
|
+
- np.ndarray: Use these values directly (must match n_features)
|
|
166
|
+
- Callable: Function that takes background_data and returns baseline array
|
|
167
|
+
background_data: Reference data for computing statistics (required for "mean"/"median")
|
|
168
|
+
n_features: Number of features (required if baseline is scalar)
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
1D numpy array of baseline values, one per feature
|
|
172
|
+
"""
|
|
173
|
+
if isinstance(baseline, str):
|
|
174
|
+
if background_data is None:
|
|
175
|
+
raise ValueError(f"background_data required for baseline='{baseline}'")
|
|
176
|
+
if baseline == "mean":
|
|
177
|
+
return np.mean(background_data, axis=0)
|
|
178
|
+
elif baseline == "median":
|
|
179
|
+
return np.median(background_data, axis=0)
|
|
180
|
+
else:
|
|
181
|
+
raise ValueError(f"Unsupported string baseline: {baseline}")
|
|
182
|
+
|
|
183
|
+
elif callable(baseline):
|
|
184
|
+
if background_data is None:
|
|
185
|
+
raise ValueError("background_data required for callable baseline")
|
|
186
|
+
result = baseline(background_data)
|
|
187
|
+
return np.asarray(result)
|
|
188
|
+
|
|
189
|
+
elif isinstance(baseline, np.ndarray):
|
|
190
|
+
return baseline
|
|
191
|
+
|
|
192
|
+
elif isinstance(baseline, (float, int, np.number)):
|
|
193
|
+
if n_features is None:
|
|
194
|
+
raise ValueError("n_features required for scalar baseline")
|
|
195
|
+
return np.full(n_features, baseline)
|
|
196
|
+
|
|
197
|
+
else:
|
|
198
|
+
raise ValueError(f"Invalid baseline type: {type(baseline)}")
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def apply_feature_mask(
|
|
202
|
+
instance: np.ndarray,
|
|
203
|
+
feature_indices: List[int],
|
|
204
|
+
baseline_values: np.ndarray
|
|
205
|
+
) -> np.ndarray:
|
|
206
|
+
"""
|
|
207
|
+
Replace specified features with baseline values.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
instance: Original instance (1D array)
|
|
211
|
+
feature_indices: Indices of features to replace
|
|
212
|
+
baseline_values: Per-feature baseline values
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Modified instance with specified features replaced
|
|
216
|
+
"""
|
|
217
|
+
modified = instance.copy()
|
|
218
|
+
for idx in feature_indices:
|
|
219
|
+
if idx < len(modified) and idx < len(baseline_values):
|
|
220
|
+
modified[idx] = baseline_values[idx]
|
|
221
|
+
return modified
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def resolve_k(k: Union[int, float], n_features: int) -> int:
|
|
225
|
+
"""
|
|
226
|
+
Resolve k to an integer number of features.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
k: Either an integer count or a float fraction (0-1)
|
|
230
|
+
n_features: Total number of features
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
Integer number of features
|
|
234
|
+
"""
|
|
235
|
+
if isinstance(k, float) and 0 < k <= 1:
|
|
236
|
+
return max(1, int(k * n_features))
|
|
237
|
+
elif isinstance(k, int) and k > 0:
|
|
238
|
+
return min(k, n_features)
|
|
239
|
+
else:
|
|
240
|
+
raise ValueError(f"k must be positive int or float in (0, 1], got {k}")
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def get_prediction_value(
|
|
244
|
+
model,
|
|
245
|
+
instance: np.ndarray,
|
|
246
|
+
output_type: str = "probability"
|
|
247
|
+
) -> float:
|
|
248
|
+
"""
|
|
249
|
+
Get a scalar prediction value from a model.
|
|
250
|
+
|
|
251
|
+
Works with both raw sklearn models and explainiverse adapters.
|
|
252
|
+
For adapters, .predict() typically returns probabilities.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
model: Model adapter with predict/predict_proba methods
|
|
256
|
+
instance: Single instance (1D array)
|
|
257
|
+
output_type: "probability" (max prob) or "class" (predicted class)
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Scalar prediction value
|
|
261
|
+
"""
|
|
262
|
+
instance_2d = instance.reshape(1, -1)
|
|
263
|
+
|
|
264
|
+
if output_type == "probability":
|
|
265
|
+
# Try predict_proba first (raw sklearn model)
|
|
266
|
+
if hasattr(model, 'predict_proba'):
|
|
267
|
+
proba = model.predict_proba(instance_2d)
|
|
268
|
+
if isinstance(proba, np.ndarray):
|
|
269
|
+
if proba.ndim == 2:
|
|
270
|
+
return float(np.max(proba[0]))
|
|
271
|
+
return float(np.max(proba))
|
|
272
|
+
return float(np.max(proba[0]))
|
|
273
|
+
|
|
274
|
+
# Fall back to predict (adapter returns probs from predict)
|
|
275
|
+
pred = model.predict(instance_2d)
|
|
276
|
+
if isinstance(pred, np.ndarray):
|
|
277
|
+
if pred.ndim == 2:
|
|
278
|
+
return float(np.max(pred[0]))
|
|
279
|
+
elif pred.ndim == 1:
|
|
280
|
+
return float(np.max(pred))
|
|
281
|
+
return float(pred[0]) if hasattr(pred, '__getitem__') else float(pred)
|
|
282
|
+
|
|
283
|
+
elif output_type == "class":
|
|
284
|
+
# For class prediction, use argmax of probabilities
|
|
285
|
+
if hasattr(model, 'predict_proba'):
|
|
286
|
+
proba = model.predict_proba(instance_2d)
|
|
287
|
+
return float(np.argmax(proba[0]))
|
|
288
|
+
pred = model.predict(instance_2d)
|
|
289
|
+
if isinstance(pred, np.ndarray) and pred.ndim == 2:
|
|
290
|
+
return float(np.argmax(pred[0]))
|
|
291
|
+
return float(pred[0]) if hasattr(pred, '__getitem__') else float(pred)
|
|
292
|
+
|
|
293
|
+
else:
|
|
294
|
+
raise ValueError(f"Unknown output_type: {output_type}")
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def compute_prediction_change(
|
|
298
|
+
model,
|
|
299
|
+
original: np.ndarray,
|
|
300
|
+
perturbed: np.ndarray,
|
|
301
|
+
metric: str = "absolute"
|
|
302
|
+
) -> float:
|
|
303
|
+
"""
|
|
304
|
+
Compute the change in prediction between original and perturbed instances.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
model: Model adapter
|
|
308
|
+
original: Original instance
|
|
309
|
+
perturbed: Perturbed instance
|
|
310
|
+
metric: "absolute" for |p1 - p2|, "relative" for |p1 - p2| / p1
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
Prediction change value
|
|
314
|
+
"""
|
|
315
|
+
orig_pred = get_prediction_value(model, original)
|
|
316
|
+
pert_pred = get_prediction_value(model, perturbed)
|
|
317
|
+
|
|
318
|
+
if metric == "absolute":
|
|
319
|
+
return abs(orig_pred - pert_pred)
|
|
320
|
+
elif metric == "relative":
|
|
321
|
+
if abs(orig_pred) < 1e-10:
|
|
322
|
+
return abs(pert_pred)
|
|
323
|
+
return abs(orig_pred - pert_pred) / abs(orig_pred)
|
|
324
|
+
else:
|
|
325
|
+
raise ValueError(f"Unknown metric: {metric}")
|