fast-causal-shap 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fast_causal_shap/__init__.py +9 -0
- fast_causal_shap/core.py +444 -0
- fast_causal_shap-0.1.3.dist-info/METADATA +104 -0
- fast_causal_shap-0.1.3.dist-info/RECORD +7 -0
- fast_causal_shap-0.1.3.dist-info/WHEEL +5 -0
- fast_causal_shap-0.1.3.dist-info/licenses/LICENSE +21 -0
- fast_causal_shap-0.1.3.dist-info/top_level.txt +1 -0
fast_causal_shap/core.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from math import factorial
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import networkx as nx
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from sklearn.linear_model import LinearRegression
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FastCausalSHAP:
|
|
16
|
+
def __init__(self, data: pd.DataFrame, model: Any, target_variable: str) -> None:
|
|
17
|
+
"""
|
|
18
|
+
Initialize FastCausalSHAP with data, model, and target variable.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
data : pd.DataFrame
|
|
23
|
+
The dataset containing features and target variable.
|
|
24
|
+
Must not be empty.
|
|
25
|
+
model : Any
|
|
26
|
+
A fitted sklearn model with predict() method and feature_names_in_ attribute
|
|
27
|
+
Can be a classifier or regressor.
|
|
28
|
+
target_variable : str
|
|
29
|
+
The name of the target variable column in the data.
|
|
30
|
+
Must exist in data.columns.
|
|
31
|
+
|
|
32
|
+
Raises
|
|
33
|
+
------
|
|
34
|
+
TypeError
|
|
35
|
+
If data is not a pandas DataFrame.
|
|
36
|
+
ValueError
|
|
37
|
+
If data is empty or target_variable not in data columns.
|
|
38
|
+
AttributeError
|
|
39
|
+
If model doesn't have required methods/attributes.
|
|
40
|
+
|
|
41
|
+
Examples
|
|
42
|
+
--------
|
|
43
|
+
>>> from sklearn.ensemble import RandomForestRegressor
|
|
44
|
+
>>> import pandas as pd
|
|
45
|
+
>>>
|
|
46
|
+
>>> data = pd.DataFrame({'X1': [1, 2, 3], 'X2': [4, 5, 6], 'Y': [7, 8, 9]})
|
|
47
|
+
>>> model = RandomForestRegressor()
|
|
48
|
+
>>> model.fit(data[['X1', 'X2']], data['Y'])
|
|
49
|
+
>>>
|
|
50
|
+
>>> shap = FastCausalSHAP(data, model, 'Y')
|
|
51
|
+
"""
|
|
52
|
+
if not isinstance(data, pd.DataFrame):
|
|
53
|
+
raise TypeError("data must be a pandas DataFrame")
|
|
54
|
+
|
|
55
|
+
if data.empty:
|
|
56
|
+
raise ValueError("data must not be empty")
|
|
57
|
+
|
|
58
|
+
if target_variable not in data.columns:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"target_variable '{target_variable}' not found in data columns. "
|
|
61
|
+
f"Available columns: {list(data.columns)}"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
if not hasattr(model, "predict"):
|
|
65
|
+
raise AttributeError("model must have a predict method")
|
|
66
|
+
|
|
67
|
+
if not hasattr(model, "feature_names_in_"):
|
|
68
|
+
raise AttributeError(
|
|
69
|
+
"model must have 'feature_names_in_' attribute. "
|
|
70
|
+
"Ensure the model has been fitted before passing it."
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
self.data: pd.DataFrame = data
|
|
74
|
+
self.model: Any = model
|
|
75
|
+
self.gamma: Optional[Dict[str, float]] = None
|
|
76
|
+
self.target_variable: str = target_variable
|
|
77
|
+
self.ida_graph: Optional[nx.DiGraph] = None
|
|
78
|
+
self.regression_models: Dict[Tuple[str, Tuple[str, ...]], Tuple[Any, float]] = (
|
|
79
|
+
{}
|
|
80
|
+
)
|
|
81
|
+
self.feature_depths: Dict[str, int] = {}
|
|
82
|
+
self.path_cache: Dict[Any, float] = {}
|
|
83
|
+
self.causal_paths: Dict[str, List[List[str]]] = {}
|
|
84
|
+
|
|
85
|
+
def remove_cycles(self) -> List[Tuple[str, str, float]]:
|
|
86
|
+
"""
|
|
87
|
+
Detects cycles in the graph and removes edges causing cycles.
|
|
88
|
+
Returns a list of removed edges.
|
|
89
|
+
"""
|
|
90
|
+
if self.ida_graph is None:
|
|
91
|
+
return []
|
|
92
|
+
|
|
93
|
+
G = self.ida_graph.copy()
|
|
94
|
+
removed_edges = []
|
|
95
|
+
|
|
96
|
+
# Find all cycles in the graph
|
|
97
|
+
try:
|
|
98
|
+
cycles = list(nx.simple_cycles(G))
|
|
99
|
+
except nx.NetworkXNoCycle:
|
|
100
|
+
return [] # No cycles found
|
|
101
|
+
|
|
102
|
+
while cycles:
|
|
103
|
+
# Get the current cycle
|
|
104
|
+
cycle = cycles[0]
|
|
105
|
+
|
|
106
|
+
# Find the edge with the smallest weight in the cycle
|
|
107
|
+
min_weight = float("inf")
|
|
108
|
+
edge_to_remove = None
|
|
109
|
+
|
|
110
|
+
for i in range(len(cycle)):
|
|
111
|
+
source = cycle[i]
|
|
112
|
+
target = cycle[(i + 1) % len(cycle)]
|
|
113
|
+
|
|
114
|
+
if G.has_edge(source, target):
|
|
115
|
+
weight = abs(G[source][target]["weight"])
|
|
116
|
+
if weight < min_weight:
|
|
117
|
+
min_weight = weight
|
|
118
|
+
edge_to_remove = (source, target)
|
|
119
|
+
|
|
120
|
+
if edge_to_remove:
|
|
121
|
+
# Remove the edge with the smallest weight
|
|
122
|
+
G.remove_edge(*edge_to_remove)
|
|
123
|
+
removed_edges.append(
|
|
124
|
+
(
|
|
125
|
+
edge_to_remove[0],
|
|
126
|
+
edge_to_remove[1],
|
|
127
|
+
self.ida_graph[edge_to_remove[0]][edge_to_remove[1]]["weight"],
|
|
128
|
+
)
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Recalculate cycles after removing an edge
|
|
132
|
+
try:
|
|
133
|
+
cycles = list(nx.simple_cycles(G))
|
|
134
|
+
except nx.NetworkXNoCycle:
|
|
135
|
+
cycles = [] # No more cycles
|
|
136
|
+
else:
|
|
137
|
+
break
|
|
138
|
+
|
|
139
|
+
# Update the graph
|
|
140
|
+
self.ida_graph = G
|
|
141
|
+
return removed_edges
|
|
142
|
+
|
|
143
|
+
def _compute_causal_paths(self) -> None:
|
|
144
|
+
"""Compute and store all causal paths to target for each feature."""
|
|
145
|
+
features = [col for col in self.data.columns if col != self.target_variable]
|
|
146
|
+
for feature in features:
|
|
147
|
+
try:
|
|
148
|
+
# Store the actual paths instead of just the features
|
|
149
|
+
paths = list(
|
|
150
|
+
nx.all_simple_paths(self.ida_graph, feature, self.target_variable)
|
|
151
|
+
)
|
|
152
|
+
self.causal_paths[feature] = paths
|
|
153
|
+
except nx.NetworkXNoPath:
|
|
154
|
+
self.causal_paths[feature] = []
|
|
155
|
+
|
|
156
|
+
def load_causal_strengths(self, json_file_path: str) -> Dict[str, float]:
|
|
157
|
+
"""Load causal strengths from JSON file and compute gamma values."""
|
|
158
|
+
if not isinstance(json_file_path, str):
|
|
159
|
+
raise TypeError("json_file_path must be a string")
|
|
160
|
+
|
|
161
|
+
import os
|
|
162
|
+
|
|
163
|
+
if not os.path.isfile(json_file_path):
|
|
164
|
+
raise ValueError("json_file_path must be a valid file path")
|
|
165
|
+
|
|
166
|
+
try:
|
|
167
|
+
with open(json_file_path, "r") as f:
|
|
168
|
+
causal_effects_list = json.load(f)
|
|
169
|
+
except json.JSONDecodeError as e:
|
|
170
|
+
raise ValueError(f"Invalid JSON file: {json_file_path}. Error: {e}")
|
|
171
|
+
|
|
172
|
+
if not isinstance(causal_effects_list, list):
|
|
173
|
+
raise ValueError(
|
|
174
|
+
f"JSON file must has a list, got {type(causal_effects_list).__name__}"
|
|
175
|
+
)
|
|
176
|
+
if not causal_effects_list:
|
|
177
|
+
raise ValueError("JSON file contains an empty list")
|
|
178
|
+
|
|
179
|
+
G = nx.DiGraph()
|
|
180
|
+
nodes = list(self.data.columns)
|
|
181
|
+
G.add_nodes_from(nodes)
|
|
182
|
+
|
|
183
|
+
for item in causal_effects_list:
|
|
184
|
+
pair = item["Pair"]
|
|
185
|
+
mean_causal_effect = item["Mean_Causal_Effect"]
|
|
186
|
+
if mean_causal_effect is None:
|
|
187
|
+
continue
|
|
188
|
+
source, target = pair.split("->")
|
|
189
|
+
source = source.strip()
|
|
190
|
+
target = target.strip()
|
|
191
|
+
G.add_edge(source, target, weight=mean_causal_effect)
|
|
192
|
+
self.ida_graph = G.copy()
|
|
193
|
+
|
|
194
|
+
removed_edges = self.remove_cycles()
|
|
195
|
+
if removed_edges:
|
|
196
|
+
logger.info(
|
|
197
|
+
f"Removed {len(removed_edges)} edges to make the graph acyclic:"
|
|
198
|
+
)
|
|
199
|
+
for source, target, weight in removed_edges:
|
|
200
|
+
logger.info(f" {source} -> {target} (weight: {weight})")
|
|
201
|
+
|
|
202
|
+
self._compute_feature_depths()
|
|
203
|
+
self._compute_causal_paths()
|
|
204
|
+
features = self.data.columns.tolist()
|
|
205
|
+
beta_dict = {}
|
|
206
|
+
|
|
207
|
+
for feature in features:
|
|
208
|
+
if feature == self.target_variable:
|
|
209
|
+
continue
|
|
210
|
+
try:
|
|
211
|
+
paths = list(
|
|
212
|
+
nx.all_simple_paths(G, source=feature, target=self.target_variable)
|
|
213
|
+
)
|
|
214
|
+
except nx.NetworkXNoPath:
|
|
215
|
+
continue
|
|
216
|
+
total_effect = 0
|
|
217
|
+
for path in paths:
|
|
218
|
+
effect = 1
|
|
219
|
+
for i in range(len(path) - 1):
|
|
220
|
+
edge_weight = G[path[i]][path[i + 1]]["weight"]
|
|
221
|
+
effect *= edge_weight
|
|
222
|
+
total_effect += effect
|
|
223
|
+
if total_effect != 0:
|
|
224
|
+
beta_dict[feature] = total_effect
|
|
225
|
+
|
|
226
|
+
total_causal_effect = sum(abs(beta) for beta in beta_dict.values())
|
|
227
|
+
if total_causal_effect == 0:
|
|
228
|
+
self.gamma = {k: 0.0 for k in features}
|
|
229
|
+
else:
|
|
230
|
+
self.gamma = {
|
|
231
|
+
k: abs(beta_dict.get(k, 0.0)) / total_causal_effect for k in features
|
|
232
|
+
}
|
|
233
|
+
return self.gamma
|
|
234
|
+
|
|
235
|
+
def _compute_feature_depths(self) -> None:
|
|
236
|
+
"""Compute minimum depth of each feature to target in causal graph."""
|
|
237
|
+
features = [col for col in self.data.columns if col != self.target_variable]
|
|
238
|
+
for feature in features:
|
|
239
|
+
try:
|
|
240
|
+
all_paths = list(
|
|
241
|
+
nx.all_simple_paths(self.ida_graph, feature, self.target_variable)
|
|
242
|
+
)
|
|
243
|
+
if all_paths:
|
|
244
|
+
min_depth = min(len(path) - 1 for path in all_paths)
|
|
245
|
+
self.feature_depths[feature] = min_depth
|
|
246
|
+
except nx.NetworkXNoPath:
|
|
247
|
+
continue
|
|
248
|
+
|
|
249
|
+
def get_topological_order(self, S: List[str]) -> List[str]:
|
|
250
|
+
"""Returns the topological order of variables after intervening on subset S."""
|
|
251
|
+
if self.ida_graph is None:
|
|
252
|
+
return []
|
|
253
|
+
G_intervened = self.ida_graph.copy()
|
|
254
|
+
for feature in S:
|
|
255
|
+
G_intervened.remove_edges_from(list(G_intervened.in_edges(feature)))
|
|
256
|
+
missing_nodes = set(self.data.columns) - set(G_intervened.nodes)
|
|
257
|
+
G_intervened.add_nodes_from(missing_nodes)
|
|
258
|
+
|
|
259
|
+
try:
|
|
260
|
+
order = list(nx.topological_sort(G_intervened))
|
|
261
|
+
except nx.NetworkXUnfeasible:
|
|
262
|
+
raise ValueError("The causal graph contains cycles.")
|
|
263
|
+
|
|
264
|
+
return order
|
|
265
|
+
|
|
266
|
+
def get_parents(self, feature: str) -> List[str]:
|
|
267
|
+
"""Returns the parent features for a given feature in the causal graph."""
|
|
268
|
+
if self.ida_graph is None:
|
|
269
|
+
return []
|
|
270
|
+
return list(self.ida_graph.predecessors(feature))
|
|
271
|
+
|
|
272
|
+
def sample_marginal(self, feature: str) -> float:
|
|
273
|
+
"""Sample a value from the marginal distribution of the specified feature."""
|
|
274
|
+
return self.data[feature].sample(1).iloc[0]
|
|
275
|
+
|
|
276
|
+
def sample_conditional(
|
|
277
|
+
self, feature: str, parent_values: Dict[str, float]
|
|
278
|
+
) -> float:
|
|
279
|
+
"""Sample a value for a feature conditioned on its parent features."""
|
|
280
|
+
effective_parents = [
|
|
281
|
+
p for p in self.get_parents(feature) if p != self.target_variable
|
|
282
|
+
]
|
|
283
|
+
if not effective_parents:
|
|
284
|
+
return self.sample_marginal(feature)
|
|
285
|
+
model_key = (feature, tuple(sorted(effective_parents)))
|
|
286
|
+
if model_key not in self.regression_models:
|
|
287
|
+
X = self.data[effective_parents].values
|
|
288
|
+
y = self.data[feature].values
|
|
289
|
+
reg = LinearRegression()
|
|
290
|
+
reg.fit(X, y)
|
|
291
|
+
residuals = y - reg.predict(X)
|
|
292
|
+
std = residuals.std()
|
|
293
|
+
self.regression_models[model_key] = (reg, std)
|
|
294
|
+
reg, std = self.regression_models[model_key]
|
|
295
|
+
parent_values_array = np.array(
|
|
296
|
+
[parent_values[parent] for parent in effective_parents]
|
|
297
|
+
).reshape(1, -1)
|
|
298
|
+
mean = reg.predict(parent_values_array)[0]
|
|
299
|
+
sampled_value = np.random.normal(mean, std)
|
|
300
|
+
return sampled_value
|
|
301
|
+
|
|
302
|
+
def compute_v_do(
|
|
303
|
+
self, S: List[str], x_S: Dict[str, float], is_classifier: bool = False
|
|
304
|
+
) -> float:
|
|
305
|
+
"""Compute interventional expectations with caching."""
|
|
306
|
+
cache_key = (
|
|
307
|
+
frozenset(S),
|
|
308
|
+
tuple(sorted(x_S.items())) if len(x_S) > 0 else tuple(),
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
if cache_key in self.path_cache:
|
|
312
|
+
return self.path_cache[cache_key]
|
|
313
|
+
|
|
314
|
+
variables_order = self.get_topological_order(S)
|
|
315
|
+
|
|
316
|
+
sample = {}
|
|
317
|
+
for feature in S:
|
|
318
|
+
sample[feature] = x_S[feature]
|
|
319
|
+
for feature in variables_order:
|
|
320
|
+
if feature in S or feature == self.target_variable:
|
|
321
|
+
continue
|
|
322
|
+
parents = self.get_parents(feature)
|
|
323
|
+
parent_values = {
|
|
324
|
+
p: x_S[p] if p in S else sample[p]
|
|
325
|
+
for p in parents
|
|
326
|
+
if p != self.target_variable
|
|
327
|
+
}
|
|
328
|
+
if not parent_values:
|
|
329
|
+
sample[feature] = self.sample_marginal(feature)
|
|
330
|
+
else:
|
|
331
|
+
sample[feature] = self.sample_conditional(feature, parent_values)
|
|
332
|
+
|
|
333
|
+
intervened_data = pd.DataFrame([sample])
|
|
334
|
+
intervened_data = intervened_data[self.model.feature_names_in_]
|
|
335
|
+
if is_classifier:
|
|
336
|
+
probas = self.model.predict_proba(intervened_data)[:, 1]
|
|
337
|
+
else:
|
|
338
|
+
probas = self.model.predict(intervened_data)
|
|
339
|
+
|
|
340
|
+
result = float(np.mean(probas))
|
|
341
|
+
self.path_cache[cache_key] = result
|
|
342
|
+
return result
|
|
343
|
+
|
|
344
|
+
def is_on_causal_path(self, feature: str, target_feature: str) -> bool:
|
|
345
|
+
"""Check if feature is on any causal path from S to target_feature."""
|
|
346
|
+
if target_feature not in self.causal_paths:
|
|
347
|
+
return False
|
|
348
|
+
path_features = self.causal_paths[target_feature]
|
|
349
|
+
return feature in path_features
|
|
350
|
+
|
|
351
|
+
def compute_modified_shap_proba(
|
|
352
|
+
self, x: pd.Series, is_classifier: bool = False
|
|
353
|
+
) -> Dict[str, float]:
|
|
354
|
+
"""TreeSHAP-inspired computation using causal paths and dynamic programming."""
|
|
355
|
+
if self.gamma is None:
|
|
356
|
+
raise ValueError(
|
|
357
|
+
"Must call load_causal_strengths before computing SHAP values"
|
|
358
|
+
)
|
|
359
|
+
if not isinstance(x, pd.Series):
|
|
360
|
+
raise TypeError(f"x must be a pandas Series, got {type(x).__name__}")
|
|
361
|
+
|
|
362
|
+
# validate x contains required features
|
|
363
|
+
required_features = self.model.feature_names_in_
|
|
364
|
+
missing_features = set(required_features) - set(x.index)
|
|
365
|
+
if missing_features:
|
|
366
|
+
raise ValueError(
|
|
367
|
+
f"x is missing required features: {missing_features}. "
|
|
368
|
+
f"Required features: {list(required_features)}"
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
features = [col for col in self.data.columns if col != self.target_variable]
|
|
372
|
+
phi_causal = {feature: 0.0 for feature in features}
|
|
373
|
+
|
|
374
|
+
data_without_target = self.data.drop(columns=[self.target_variable])
|
|
375
|
+
if is_classifier:
|
|
376
|
+
E_fX = self.model.predict_proba(data_without_target)[:, 1].mean()
|
|
377
|
+
else:
|
|
378
|
+
E_fX = self.model.predict(data_without_target).mean()
|
|
379
|
+
|
|
380
|
+
x_ordered = x[self.model.feature_names_in_]
|
|
381
|
+
if is_classifier:
|
|
382
|
+
f_x = self.model.predict_proba(x_ordered.to_frame().T)[0][1]
|
|
383
|
+
else:
|
|
384
|
+
f_x = self.model.predict(x_ordered.to_frame().T)[0]
|
|
385
|
+
|
|
386
|
+
sorted_features = sorted(features, key=lambda f: self.feature_depths.get(f, 0))
|
|
387
|
+
max_path_length = max(self.feature_depths.values(), default=0)
|
|
388
|
+
shapley_weights = {}
|
|
389
|
+
for m in range(max_path_length + 1):
|
|
390
|
+
for d in range(m + 1, max_path_length + 1):
|
|
391
|
+
shapley_weights[(m, d)] = (
|
|
392
|
+
factorial(m) * factorial(d - m - 1)
|
|
393
|
+
) / factorial(d)
|
|
394
|
+
|
|
395
|
+
# Track contributions using dynamic programming (EXTEND-like logic in TreeSHAP)
|
|
396
|
+
# m_values will accumulate contributions from subsets (use combinatorial logic)
|
|
397
|
+
# Essentially, values in m_values[k] represent how many ways there are
|
|
398
|
+
# to select k nodes from the path seen so far.
|
|
399
|
+
for feature in sorted_features:
|
|
400
|
+
if feature not in self.causal_paths:
|
|
401
|
+
continue
|
|
402
|
+
for path in self.causal_paths[feature]:
|
|
403
|
+
path_features = [n for n in path if n != self.target_variable]
|
|
404
|
+
d = len(path_features)
|
|
405
|
+
m_values = defaultdict(float)
|
|
406
|
+
m_values[0] = 1.0
|
|
407
|
+
|
|
408
|
+
for node in path_features:
|
|
409
|
+
if node == feature:
|
|
410
|
+
continue
|
|
411
|
+
|
|
412
|
+
new_m_values: defaultdict[int, float] = defaultdict(float)
|
|
413
|
+
for m, val in m_values.items():
|
|
414
|
+
new_m_values[m + 1] += val
|
|
415
|
+
new_m_values[m] += val
|
|
416
|
+
m_values = new_m_values
|
|
417
|
+
|
|
418
|
+
for m in m_values:
|
|
419
|
+
weight = shapley_weights.get((m, d), 0) * self.gamma.get(feature, 0)
|
|
420
|
+
delta_v = self._compute_path_delta_v(
|
|
421
|
+
feature, path, m, x, is_classifier
|
|
422
|
+
)
|
|
423
|
+
phi_causal[feature] += weight * delta_v
|
|
424
|
+
|
|
425
|
+
sum_phi = sum(phi_causal.values())
|
|
426
|
+
if sum_phi != 0:
|
|
427
|
+
scaling_factor = (f_x - E_fX) / sum_phi
|
|
428
|
+
phi_causal = {k: v * scaling_factor for k, v in phi_causal.items()}
|
|
429
|
+
|
|
430
|
+
return phi_causal
|
|
431
|
+
|
|
432
|
+
def _compute_path_delta_v(
|
|
433
|
+
self, feature: str, path: List[str], m: int, x: pd.Series, is_classifier: bool
|
|
434
|
+
) -> float:
|
|
435
|
+
"""Compute Δv for a causal path using precomputed expectations."""
|
|
436
|
+
S = [n for n in path[:m] if n != feature]
|
|
437
|
+
x_S = {n: x[n] for n in S if n in x}
|
|
438
|
+
v_S = self.compute_v_do(S, x_S, is_classifier)
|
|
439
|
+
|
|
440
|
+
S_with_i = S + [feature]
|
|
441
|
+
x_Si = {**x_S, feature: x[feature]}
|
|
442
|
+
v_Si = self.compute_v_do(S_with_i, x_Si, is_classifier)
|
|
443
|
+
|
|
444
|
+
return v_Si - v_S
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: fast-causal-shap
|
|
3
|
+
Version: 0.1.3
|
|
4
|
+
Summary: A Python package for efficient causal SHAP computations
|
|
5
|
+
Author-email: woonyee28 <ngnwy289@gmail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/woonyee28/CausalSHAP
|
|
8
|
+
Project-URL: Issues, https://github.com/woonyee28/CausalSHAP/issues
|
|
9
|
+
Requires-Python: >=3.9
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Requires-Dist: pandas>=1.0.0
|
|
13
|
+
Requires-Dist: networkx>=2.0
|
|
14
|
+
Requires-Dist: numpy>=1.18.0
|
|
15
|
+
Requires-Dist: scikit-learn>=0.24.0
|
|
16
|
+
Provides-Extra: dev
|
|
17
|
+
Requires-Dist: pytest>=6.0; extra == "dev"
|
|
18
|
+
Requires-Dist: black>=21.0; extra == "dev"
|
|
19
|
+
Requires-Dist: flake8>=3.8; extra == "dev"
|
|
20
|
+
Requires-Dist: mypy>=0.800; extra == "dev"
|
|
21
|
+
Requires-Dist: isort>=5.0; extra == "dev"
|
|
22
|
+
Requires-Dist: pytest-cov>=2.0; extra == "dev"
|
|
23
|
+
Requires-Dist: pre-commit>=2.0; extra == "dev"
|
|
24
|
+
Dynamic: license-file
|
|
25
|
+
|
|
26
|
+
# Fast Causal SHAP
|
|
27
|
+
|
|
28
|
+
Fast Causal SHAP is a Python package designed for efficient and interpretable SHAP value computation in causal inference tasks. It integrates seamlessly with various causal inference frameworks and enables feature attribution with awareness of causal dependencies.
|
|
29
|
+
|
|
30
|
+
## Features
|
|
31
|
+
|
|
32
|
+
- Fast computation of SHAP values for causal models
|
|
33
|
+
- Support for multiple causal inference frameworks
|
|
34
|
+
|
|
35
|
+
## Installation
|
|
36
|
+
|
|
37
|
+
Install the stable version via PyPI:
|
|
38
|
+
|
|
39
|
+
```bash
|
|
40
|
+
pip install fast-causal-shap
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
Or, for the latest development version:
|
|
44
|
+
|
|
45
|
+
```bash
|
|
46
|
+
pip install git+https://github.com/woonyee28/CausalSHAP.git
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
## Usage
|
|
50
|
+
```
|
|
51
|
+
from fast_causal_shap.core import FastCausalSHAP
|
|
52
|
+
|
|
53
|
+
# Predict probabilities and assign to training data
|
|
54
|
+
predicted_probabilities = model.predict_proba(X_train)[:,1]
|
|
55
|
+
X_train['target'] = predicted_probabilities
|
|
56
|
+
|
|
57
|
+
# Initialize FastCausalInference
|
|
58
|
+
ci = FastCausalInference(data=X_train, model=model, target_variable='target')
|
|
59
|
+
|
|
60
|
+
# Load causal strengths (precomputed using R packages)
|
|
61
|
+
ci.load_causal_strengths(result_dir + 'Causal_Effect.json')
|
|
62
|
+
|
|
63
|
+
# Compute modified SHAP values for a single instance
|
|
64
|
+
x_instance = X_train.iloc[33]
|
|
65
|
+
|
|
66
|
+
print(ci.compute_modified_shap_proba(x_instance, is_classifier=True))
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
Format of the Causal_Effect.json:
|
|
70
|
+
```
|
|
71
|
+
[
|
|
72
|
+
{
|
|
73
|
+
"Pair": "Bacteroidia->Clostridia",
|
|
74
|
+
"Mean_Causal_Effect": 0.71292
|
|
75
|
+
},
|
|
76
|
+
{
|
|
77
|
+
"Pair": "Clostridia->Alphaproteobacteria",
|
|
78
|
+
"Mean_Causal_Effect": 0.37652
|
|
79
|
+
}, ......
|
|
80
|
+
]
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
Fast Causal SHAP supports integration with structural algorithms such as:
|
|
84
|
+
1. Peter-Clarke (PC) Algorithm
|
|
85
|
+
2. IDA Algorithm
|
|
86
|
+
3. Fast Causal Inference (FCI) Algorithm
|
|
87
|
+
You can find example R code for these integrations here: [FastCausalSHAP R code examples](https://github.com/woonyee28/CausalSHAP/tree/main/code/r)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
## Citation
|
|
91
|
+
If you use Fast Causal SHAP in your research, please cite:
|
|
92
|
+
```
|
|
93
|
+
@inproceedings{ng2025causal,
|
|
94
|
+
title={Causal SHAP: Feature Attribution with Dependency Awareness through Causal Discovery},
|
|
95
|
+
author={Ng, Woon Yee and Wang, Li Rong and Liu, Siyuan and Fan, Xiuyi},
|
|
96
|
+
booktitle={Proceedings of the International Joint Conference on Neural Networks (IJCNN)},
|
|
97
|
+
year={2025},
|
|
98
|
+
organization={IEEE}
|
|
99
|
+
}
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
## License
|
|
103
|
+
|
|
104
|
+
This project is licensed under the MIT License.
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
fast_causal_shap/__init__.py,sha256=n62hNTwd-c9-gHOpO5BGHOiq-jKhykeDPFBmydAmJy0,227
|
|
2
|
+
fast_causal_shap/core.py,sha256=XfUhKHd4pmLMe6BF9igLVT0KQDpoIbRVo1VjpnY6csE,17056
|
|
3
|
+
fast_causal_shap-0.1.3.dist-info/licenses/LICENSE,sha256=ACwmltkrXIz5VsEQcrqljq-fat6ZXAMepjXGoe40KtE,1069
|
|
4
|
+
fast_causal_shap-0.1.3.dist-info/METADATA,sha256=IyQJ_QjzMOSd8fp1EM6YSkS5jLlDGbdlqu0_CnoUG-0,3102
|
|
5
|
+
fast_causal_shap-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
6
|
+
fast_causal_shap-0.1.3.dist-info/top_level.txt,sha256=nAIqoFfVB4g6cJal-o9z4LmDYIX1lj1x15oJrlsT_4E,17
|
|
7
|
+
fast_causal_shap-0.1.3.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) [year] [fullname]
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
fast_causal_shap
|