explainiverse 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,120 +1,194 @@
1
1
  # src/explainiverse/engine/suite.py
2
+ """
3
+ ExplanationSuite - Multi-explainer comparison and evaluation.
4
+
5
+ Provides utilities for running multiple explainers on the same instances
6
+ and comparing their outputs.
7
+ """
8
+
9
+ from typing import Dict, List, Optional, Any, Tuple
10
+ import numpy as np
2
11
 
3
- from explainiverse.core.explanation import Explanation
4
- from explainiverse.explainers.attribution.lime_wrapper import LimeExplainer
5
- from explainiverse.explainers.attribution.shap_wrapper import ShapExplainer
6
- from explainiverse.evaluation.metrics import compute_roar
7
- from sklearn.metrics import accuracy_score
8
- from sklearn.linear_model import LogisticRegression
9
12
 
10
13
  class ExplanationSuite:
11
14
  """
12
- Runs multiple explainers on a single instance and compares their outputs.
15
+ Run and compare multiple explainers on the same instances.
16
+
17
+ This class provides a unified interface for:
18
+ - Running multiple explainers on a single instance
19
+ - Comparing attribution scores side-by-side
20
+ - Suggesting the best explainer based on model/task characteristics
21
+ - Evaluating explainers using ROAR (Remove And Retrain)
22
+
23
+ Example:
24
+ >>> from explainiverse import ExplanationSuite, SklearnAdapter
25
+ >>> suite = ExplanationSuite(
26
+ ... model=adapter,
27
+ ... explainer_configs=[
28
+ ... ("lime", {"training_data": X_train, "feature_names": fnames, "class_names": cnames}),
29
+ ... ("shap", {"background_data": X_train[:50], "feature_names": fnames, "class_names": cnames}),
30
+ ... ]
31
+ ... )
32
+ >>> results = suite.run(X_test[0])
33
+ >>> suite.compare()
13
34
  """
14
35
 
15
- def __init__(self, model, explainer_configs, data_meta=None):
36
+ def __init__(
37
+ self,
38
+ model,
39
+ explainer_configs: List[Tuple[str, Dict[str, Any]]],
40
+ data_meta: Optional[Dict[str, Any]] = None
41
+ ):
16
42
  """
43
+ Initialize the ExplanationSuite.
44
+
17
45
  Args:
18
- model: a model adapter (e.g., SklearnAdapter)
19
- explainer_configs: list of (name, kwargs) tuples for explainers
20
- data_meta: optional metadata about the task, scope, or preference
46
+ model: A model adapter (e.g., SklearnAdapter, PyTorchAdapter)
47
+ explainer_configs: List of (explainer_name, kwargs) tuples.
48
+ The explainer_name should match a registered explainer in
49
+ the default_registry (e.g., "lime", "shap", "treeshap").
50
+ data_meta: Optional metadata about the task, scope, or preference.
51
+ Can include "task" ("classification" or "regression").
21
52
  """
22
53
  self.model = model
23
54
  self.configs = explainer_configs
24
55
  self.data_meta = data_meta or {}
25
- self.explanations = {}
56
+ self.explanations: Dict[str, Any] = {}
57
+ self._registry = None
58
+
59
+ def _get_registry(self):
60
+ """Lazy load the registry to avoid circular imports."""
61
+ if self._registry is None:
62
+ from explainiverse.core.registry import default_registry
63
+ self._registry = default_registry
64
+ return self._registry
26
65
 
27
- def run(self, instance):
66
+ def run(self, instance: np.ndarray) -> Dict[str, Any]:
28
67
  """
29
68
  Run all configured explainers on a single instance.
69
+
70
+ Args:
71
+ instance: Input instance to explain (1D numpy array)
72
+
73
+ Returns:
74
+ Dictionary mapping explainer names to Explanation objects
30
75
  """
76
+ instance = np.asarray(instance)
77
+ registry = self._get_registry()
78
+
31
79
  for name, params in self.configs:
32
- explainer = self._load_explainer(name, **params)
33
- explanation = explainer.explain(instance)
34
- self.explanations[name] = explanation
80
+ try:
81
+ explainer = registry.create(name, model=self.model, **params)
82
+ explanation = explainer.explain(instance)
83
+ self.explanations[name] = explanation
84
+ except Exception as e:
85
+ print(f"[ExplanationSuite] Warning: Failed to run {name}: {e}")
86
+ continue
87
+
35
88
  return self.explanations
36
89
 
37
- def compare(self):
90
+ def compare(self) -> None:
38
91
  """
39
- Print attribution scores side-by-side.
92
+ Print attribution scores side-by-side for comparison.
40
93
  """
41
- keys = set()
94
+ if not self.explanations:
95
+ print("No explanations to compare. Run suite.run(instance) first.")
96
+ return
97
+
98
+ # Collect all feature names across explanations
99
+ all_keys = set()
42
100
  for explanation in self.explanations.values():
43
- keys.update(explanation.explanation_data.get("feature_attributions", {}).keys())
101
+ attrs = explanation.explanation_data.get("feature_attributions", {})
102
+ all_keys.update(attrs.keys())
44
103
 
45
104
  print("\nSide-by-Side Comparison:")
46
- for key in sorted(keys):
47
- row = [f"{key}"]
105
+ print("-" * 60)
106
+
107
+ # Header
108
+ header = ["Feature"] + list(self.explanations.keys())
109
+ print(" | ".join(f"{h:>15}" for h in header))
110
+ print("-" * 60)
111
+
112
+ # Rows
113
+ for key in sorted(all_keys):
114
+ row = [f"{key:>15}"]
48
115
  for name in self.explanations:
49
- value = self.explanations[name].explanation_data.get("feature_attributions", {}).get(key, "—")
50
- row.append(f"{name}: {value:.4f}" if isinstance(value, float) else f"{name}: {value}")
116
+ value = self.explanations[name].explanation_data.get(
117
+ "feature_attributions", {}
118
+ ).get(key, None)
119
+ if value is not None:
120
+ row.append(f"{value:>15.4f}")
121
+ else:
122
+ row.append(f"{'—':>15}")
51
123
  print(" | ".join(row))
52
124
 
53
- def suggest_best(self):
125
+ def suggest_best(self) -> str:
54
126
  """
55
- Suggest the best explainer based on model type, output structure, and task metadata.
127
+ Suggest the best explainer based on model type and task characteristics.
128
+
129
+ Returns:
130
+ Name of the suggested explainer
56
131
  """
57
- if "task" in self.data_meta:
58
- task = self.data_meta["task"]
59
- else:
60
- task = "unknown"
61
-
62
- model = self.model.model
132
+ task = self.data_meta.get("task", "unknown")
133
+ model = self.model.model if hasattr(self.model, 'model') else self.model
63
134
 
64
135
  # 1. Regression: SHAP preferred due to consistent output
65
136
  if task == "regression":
66
137
  return "shap"
67
138
 
68
- # 2. Model with `predict_proba` → SHAP handles probabilistic outputs well
139
+ # 2. Model with predict_proba → SHAP handles probabilistic outputs well
69
140
  if hasattr(model, "predict_proba"):
70
141
  try:
71
- output = self.model.predict([[0] * model.n_features_in_])
72
- if output.shape[1] > 2:
73
- return "shap" # Multi-class, SHAP more stable
74
- else:
75
- return "lime" # Binary, both are okay
142
+ # Check output dimensions
143
+ if hasattr(model, 'n_features_in_'):
144
+ test_input = np.zeros((1, model.n_features_in_))
145
+ output = self.model.predict(test_input)
146
+ if output.shape[1] > 2:
147
+ return "shap" # Multi-class, SHAP more stable
148
+ else:
149
+ return "lime" # Binary, both are okay
76
150
  except Exception:
77
151
  return "shap"
78
152
 
79
- # 3. Tree-based models → prefer SHAP (TreeSHAP if available)
80
- if "tree" in str(type(model)).lower():
81
- return "shap"
153
+ # 3. Tree-based models → prefer TreeSHAP
154
+ model_type_str = str(type(model)).lower()
155
+ if any(tree_type in model_type_str for tree_type in ['tree', 'forest', 'xgb', 'lgbm', 'catboost']):
156
+ return "treeshap"
82
157
 
83
- # 4. Default fallback
84
- return "lime"
85
-
86
- def _load_explainer(self, name, **kwargs):
87
- if name == "lime":
88
- return LimeExplainer(model=self.model, **kwargs)
89
- elif name == "shap":
90
- return ShapExplainer(model=self.model, **kwargs)
91
- else:
92
- raise ValueError(f"Unknown explainer: {name}")
93
-
158
+ # 4. Neural networks → prefer gradient methods
159
+ if 'torch' in model_type_str or 'keras' in model_type_str or 'tensorflow' in model_type_str:
160
+ return "integrated_gradients"
94
161
 
162
+ # 5. Default fallback
163
+ return "lime"
95
164
 
96
165
  def evaluate_roar(
97
166
  self,
98
- X_train,
99
- y_train,
100
- X_test,
101
- y_test,
167
+ X_train: np.ndarray,
168
+ y_train: np.ndarray,
169
+ X_test: np.ndarray,
170
+ y_test: np.ndarray,
102
171
  top_k: int = 2,
103
172
  model_class=None,
104
- model_kwargs: dict = None
105
- ):
173
+ model_kwargs: Optional[Dict] = None
174
+ ) -> Dict[str, float]:
106
175
  """
107
176
  Evaluate each explainer using ROAR (Remove And Retrain).
108
177
 
178
+ ROAR measures explanation quality by retraining the model after
179
+ removing the top-k important features identified by each explainer.
180
+ A larger accuracy drop indicates more faithful explanations.
181
+
109
182
  Args:
110
- X_train, y_train: training data
111
- X_test, y_test: test data
112
- top_k: number of features to mask
113
- model_class: model constructor with .fit() and .predict() (default: same as current model)
114
- model_kwargs: optional keyword args for new model instance
183
+ X_train, y_train: Training data
184
+ X_test, y_test: Test data
185
+ top_k: Number of features to mask
186
+ model_class: Model constructor with .fit() and .predict()
187
+ If None, uses the same type as self.model.model
188
+ model_kwargs: Optional keyword args for new model instance
115
189
 
116
190
  Returns:
117
- Dict of {explainer_name: accuracy drop (baseline - retrained)}
191
+ Dict mapping explainer names to accuracy drops
118
192
  """
119
193
  from explainiverse.evaluation.metrics import compute_roar
120
194
 
@@ -122,22 +196,57 @@ class ExplanationSuite:
122
196
 
123
197
  # Default to type(self.model.model) if not provided
124
198
  if model_class is None:
125
- model_class = type(self.model.model)
199
+ raw_model = self.model.model if hasattr(self.model, 'model') else self.model
200
+ model_class = type(raw_model)
126
201
 
127
202
  roar_scores = {}
128
203
 
129
204
  for name, explanation in self.explanations.items():
130
205
  print(f"[ROAR] Evaluating explainer: {name}")
131
- roar = compute_roar(
132
- model_class=model_class,
133
- X_train=X_train,
134
- y_train=y_train,
135
- X_test=X_test,
136
- y_test=y_test,
137
- explanations=[explanation], # single-instance for now
138
- top_k=top_k,
139
- model_kwargs=model_kwargs
140
- )
141
- roar_scores[name] = roar
142
-
143
- return roar_scores
206
+ try:
207
+ roar = compute_roar(
208
+ model_class=model_class,
209
+ X_train=X_train,
210
+ y_train=y_train,
211
+ X_test=X_test,
212
+ y_test=y_test,
213
+ explanations=[explanation],
214
+ top_k=top_k,
215
+ model_kwargs=model_kwargs
216
+ )
217
+ roar_scores[name] = roar
218
+ except Exception as e:
219
+ print(f"[ROAR] Failed for {name}: {e}")
220
+ roar_scores[name] = 0.0
221
+
222
+ return roar_scores
223
+
224
+ def get_explanation(self, name: str):
225
+ """
226
+ Get a specific explanation by explainer name.
227
+
228
+ Args:
229
+ name: Name of the explainer
230
+
231
+ Returns:
232
+ Explanation object or None if not found
233
+ """
234
+ return self.explanations.get(name)
235
+
236
+ def list_explainers(self) -> List[str]:
237
+ """
238
+ List all configured explainer names.
239
+
240
+ Returns:
241
+ List of explainer names
242
+ """
243
+ return [name for name, _ in self.configs]
244
+
245
+ def list_completed(self) -> List[str]:
246
+ """
247
+ List explainers that have been run successfully.
248
+
249
+ Returns:
250
+ List of explainer names with results
251
+ """
252
+ return list(self.explanations.keys())