shaprpy 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shaprpy/__init__.py +51 -0
- shaprpy/_explain.py +630 -0
- shaprpy/datasets.py +27 -0
- shaprpy/utils.py +67 -0
- shaprpy-0.3.0.dist-info/METADATA +168 -0
- shaprpy-0.3.0.dist-info/RECORD +10 -0
- shaprpy-0.3.0.dist-info/WHEEL +5 -0
- shaprpy-0.3.0.dist-info/licenses/LICENSE +2 -0
- shaprpy-0.3.0.dist-info/licenses/LICENSE.md +21 -0
- shaprpy-0.3.0.dist-info/top_level.txt +1 -0
shaprpy/__init__.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from importlib.metadata import version, PackageNotFoundError
|
|
2
|
+
from importlib import import_module
|
|
3
|
+
|
|
4
|
+
# Lightweight public re-export (no R dependency)
|
|
5
|
+
from . import datasets # noqa: F401
|
|
6
|
+
|
|
7
|
+
__all__ = ["explain", "datasets", "ensure_r_ready"]
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
__version__ = version("shaprpy")
|
|
11
|
+
except PackageNotFoundError:
|
|
12
|
+
__version__ = "0.0.0+local"
|
|
13
|
+
|
|
14
|
+
_r_ready = False
|
|
15
|
+
_explain_impl = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def ensure_r_ready() -> bool:
|
|
19
|
+
"""Ensure rpy2 and the R package 'shapr' are available, then bind the real explain() (idempotent)."""
|
|
20
|
+
global _r_ready, _explain_impl
|
|
21
|
+
if _r_ready:
|
|
22
|
+
return True
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
import rpy2.robjects as _ro # noqa: F401
|
|
26
|
+
from rpy2.robjects.packages import importr
|
|
27
|
+
except Exception as e:
|
|
28
|
+
raise ImportError(
|
|
29
|
+
"shaprpy requires rpy2 and a working R installation.\n"
|
|
30
|
+
"Install R and rpy2, and ensure R is on PATH/R_HOME. See README."
|
|
31
|
+
) from e
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
importr("shapr")
|
|
35
|
+
except Exception as e:
|
|
36
|
+
raise ImportError(
|
|
37
|
+
"The R package 'shapr' is not installed or not found.\n"
|
|
38
|
+
"In an R session, run: install.packages('shapr')"
|
|
39
|
+
) from e
|
|
40
|
+
|
|
41
|
+
# Import the implementation from a private module to avoid name collision
|
|
42
|
+
_explain_mod = import_module(__name__ + "._explain")
|
|
43
|
+
_explain_impl = _explain_mod.explain
|
|
44
|
+
_r_ready = True
|
|
45
|
+
return True
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def explain(*args, **kwargs):
|
|
49
|
+
"""Lazily initialize R/shapr then call the real explain()."""
|
|
50
|
+
ensure_r_ready()
|
|
51
|
+
return _explain_impl(*args, **kwargs)
|
shaprpy/_explain.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from typing import Callable
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
import rpy2.robjects as ro
|
|
7
|
+
from rpy2.robjects.packages import importr
|
|
8
|
+
from rpy2.rinterface import NULL, NA
|
|
9
|
+
from shaprpy.utils import r2py, py2r, recurse_r_tree
|
|
10
|
+
from rpy2.robjects.vectors import StrVector, ListVector
|
|
11
|
+
|
|
12
|
+
data_table = importr('data.table')
|
|
13
|
+
shapr = importr('shapr')
|
|
14
|
+
utils = importr('utils')
|
|
15
|
+
base = importr('base')
|
|
16
|
+
stats = importr('stats')
|
|
17
|
+
|
|
18
|
+
def maybe_null(val):
|
|
19
|
+
return val if val is not None else NULL
|
|
20
|
+
|
|
21
|
+
def explain(
|
|
22
|
+
model,
|
|
23
|
+
x_explain: pd.DataFrame,
|
|
24
|
+
x_train: pd.DataFrame,
|
|
25
|
+
approach: str | list[str],
|
|
26
|
+
phi0: float,
|
|
27
|
+
iterative: bool | None = None,
|
|
28
|
+
max_n_coalitions: int | None = None,
|
|
29
|
+
group: dict | None = None,
|
|
30
|
+
n_MC_samples: int = 1000,
|
|
31
|
+
seed: int | None = None,
|
|
32
|
+
verbose: str | list[str] | None = "basic",
|
|
33
|
+
predict_model: Callable | None = None,
|
|
34
|
+
get_model_specs: Callable | None = None,
|
|
35
|
+
asymmetric: bool = False,
|
|
36
|
+
causal_ordering: dict | None = None,
|
|
37
|
+
confounding: bool | None = None,
|
|
38
|
+
extra_computation_args: dict | None = None,
|
|
39
|
+
iterative_args: dict | None = None,
|
|
40
|
+
output_args: dict | None = None,
|
|
41
|
+
**kwargs,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Explain the output of machine learning models with more accurately estimated Shapley values.
|
|
45
|
+
|
|
46
|
+
Computes dependence-aware Shapley values for observations in `x_explain` from the specified
|
|
47
|
+
`model` by using the method specified in `approach` to estimate the conditional expectation.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
model: The model whose predictions we want to explain.
|
|
52
|
+
`shaprpy` natively supports `sklearn`, `xgboost` and `keras` models.
|
|
53
|
+
Unsupported models can still be explained by passing `predict_model` and (optionally) `get_model_specs`.
|
|
54
|
+
x_explain: pd.DataFrame
|
|
55
|
+
Contains the features whose predictions ought to be explained.
|
|
56
|
+
x_train: pd.DataFrame
|
|
57
|
+
Contains the data used to estimate the (conditional) distributions for the features
|
|
58
|
+
needed to properly estimate the conditional expectations in the Shapley formula.
|
|
59
|
+
approach: str or list[str]
|
|
60
|
+
The method(s) to estimate the conditional expectation. All elements should,
|
|
61
|
+
either be `"gaussian"`, `"copula"`, `"empirical"`, `"ctree"`, `"categorical"`, `"timeseries"`, `"independence"`,
|
|
62
|
+
`"regression_separate"`, or `"regression_surrogate"`.
|
|
63
|
+
phi0: float
|
|
64
|
+
The prediction value for unseen data, i.e. an estimate of the expected prediction without conditioning on any
|
|
65
|
+
features. Typically we set this value equal to the mean of the response variable in our training data, but other
|
|
66
|
+
choices such as the mean of the predictions in the training data are also reasonable.
|
|
67
|
+
iterative: bool or None, optional
|
|
68
|
+
If `None` (default), the argument is set to `True` if there are more than 5 features/groups, and `False` otherwise.
|
|
69
|
+
If `True`, the Shapley values are estimated iteratively in an iterative manner.
|
|
70
|
+
max_n_coalitions: int or None, optional
|
|
71
|
+
The upper limit on the number of unique feature/group coalitions to use in the iterative procedure
|
|
72
|
+
(if `iterative = True`). If `iterative = False` it represents the number of feature/group coalitions to use directly.
|
|
73
|
+
`max_n_coalitions = None` corresponds to `max_n_coalitions=2^n_features`.
|
|
74
|
+
group: dict or None, optional
|
|
75
|
+
If `None` regular feature wise Shapley values are computed.
|
|
76
|
+
If provided, group wise Shapley values are computed. `group` then contains lists of unique feature names with the
|
|
77
|
+
features included in each of the different groups.
|
|
78
|
+
n_MC_samples: int, optional
|
|
79
|
+
Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation.
|
|
80
|
+
seed: int or None, optional
|
|
81
|
+
Specifies the seed before any randomness based code is being run.
|
|
82
|
+
If `None` (default) the seed will be inherited from the calling environment.
|
|
83
|
+
verbose: str or list[str] or None, optional
|
|
84
|
+
Specifies the verbosity (printout detail level) through one or more of the strings `"basic"`, `"progress"`,
|
|
85
|
+
`"convergence"`, `"shapley"` and `"vS_details"`. `None` means no printout.
|
|
86
|
+
predict_model: Callable, optional
|
|
87
|
+
The prediction function used when `model` is not natively supported. The function must have two arguments, `model` and `newdata`
|
|
88
|
+
which specify, respectively, the model and a pandas.DataFrame to compute predictions for. The function must give the prediction as a numpy.Array.
|
|
89
|
+
get_model_specs: Callable, optional
|
|
90
|
+
An optional function for checking model/data consistency when `model` is not natively supported. The function takes `model` as argument
|
|
91
|
+
and provides a `dict` with 3 elements: `labels`, `classes`, and `factor_levels`.
|
|
92
|
+
asymmetric: bool, optional
|
|
93
|
+
If `False` (default), `explain` computes regular symmetric Shapley values. If `True`, then `explain` computes asymmetric Shapley values
|
|
94
|
+
based on the (partial) causal ordering given by `causal_ordering`.
|
|
95
|
+
causal_ordering: dict or None, optional
|
|
96
|
+
An unnamed list of vectors specifying the components of the partial causal ordering that the coalitions must respect.
|
|
97
|
+
confounding: bool or None, optional
|
|
98
|
+
A vector of logicals specifying whether confounding is assumed or not for each component in the `causal_ordering`.
|
|
99
|
+
extra_computation_args: dict or None, optional
|
|
100
|
+
Specifies extra arguments related to the computation of the Shapley values.
|
|
101
|
+
iterative_args: dict or None, optional
|
|
102
|
+
Specifies the arguments for the iterative procedure.
|
|
103
|
+
output_args: dict or None, optional
|
|
104
|
+
Specifies certain arguments related to the output of the function.
|
|
105
|
+
**kwargs: Further arguments passed to specific approaches.
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
dict
|
|
110
|
+
A dictionary containing the following items:
|
|
111
|
+
- "shapley_values_est": pd.DataFrame with the estimated Shapley values.
|
|
112
|
+
- "shapley_values_sd": pd.DataFrame with the standard deviation of the Shapley values.
|
|
113
|
+
- "pred_explain": numpy.Array with the predictions for the explained observations.
|
|
114
|
+
- "MSEv": dict with the values of the MSEv evaluation criterion.
|
|
115
|
+
- "iterative_results": dict with the results of the iterative estimation.
|
|
116
|
+
- "saving_path": str with the path where intermediate results are stored.
|
|
117
|
+
- "internal": dict with the different parameters, data, functions and other output used internally.
|
|
118
|
+
- "timing": dict containing timing information for the different parts of the computation.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
init_time = base.Sys_time() # datetime.now()
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
if seed is not None:
|
|
125
|
+
base.set_seed(seed)
|
|
126
|
+
|
|
127
|
+
# Gets and check feature specs from the model
|
|
128
|
+
rfeature_specs = get_feature_specs(get_model_specs, model)
|
|
129
|
+
|
|
130
|
+
# Fixes the conversion from dict to a named list of vectors in R
|
|
131
|
+
r_group = NULL if group is None else ListVector({key: StrVector(value) for key, value in group.items()})
|
|
132
|
+
|
|
133
|
+
# Fixes the conversion from dict to a named list of vectors in R
|
|
134
|
+
r_causal_ordering = NULL if causal_ordering is None else ListVector({key: StrVector(value) for key, value in causal_ordering.items()})
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
# Fixes method specific argument names by replacing first occurrence of "_" with "."
|
|
138
|
+
if len(kwargs) > 0:
|
|
139
|
+
kwargs = change_first_underscore_to_dot(kwargs)
|
|
140
|
+
|
|
141
|
+
# Convert from dict to a named list of vectors in R if `regression.vfold_cv_para` is provided by the user
|
|
142
|
+
if 'regression.vfold_cv_para' in kwargs:
|
|
143
|
+
kwargs['regression.vfold_cv_para'] = ListVector(kwargs['regression.vfold_cv_para'])
|
|
144
|
+
|
|
145
|
+
# Convert from None or dict to a named list in R
|
|
146
|
+
if iterative_args is None:
|
|
147
|
+
iterative_args = ro.ListVector({})
|
|
148
|
+
else:
|
|
149
|
+
iterative_args = ListVector(iterative_args)
|
|
150
|
+
|
|
151
|
+
if output_args is None:
|
|
152
|
+
output_args = ro.ListVector({})
|
|
153
|
+
else:
|
|
154
|
+
output_args = ListVector(output_args)
|
|
155
|
+
|
|
156
|
+
if extra_computation_args is None:
|
|
157
|
+
extra_computation_args = ro.ListVector({})
|
|
158
|
+
else:
|
|
159
|
+
extra_computation_args = ListVector(extra_computation_args)
|
|
160
|
+
|
|
161
|
+
model_class = f"{type(model).__module__}.{type(model).__name__}"
|
|
162
|
+
|
|
163
|
+
# Sets up and organizes input parameters
|
|
164
|
+
# Checks the input parameters and their compatability
|
|
165
|
+
# Checks data/model compatability
|
|
166
|
+
|
|
167
|
+
if isinstance(approach, str):
|
|
168
|
+
approach = [approach]
|
|
169
|
+
|
|
170
|
+
if isinstance(verbose, str):
|
|
171
|
+
verbose = [verbose]
|
|
172
|
+
if isinstance(verbose, list):
|
|
173
|
+
verbose = StrVector(verbose)
|
|
174
|
+
else:
|
|
175
|
+
verbose = maybe_null(verbose)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
rinternal = shapr.setup(
|
|
179
|
+
x_train = py2r(x_train),
|
|
180
|
+
x_explain = py2r(x_explain),
|
|
181
|
+
approach = StrVector(approach),
|
|
182
|
+
phi0 = phi0,
|
|
183
|
+
max_n_coalitions = maybe_null(max_n_coalitions),
|
|
184
|
+
group = r_group,
|
|
185
|
+
n_MC_samples = n_MC_samples,
|
|
186
|
+
seed = maybe_null(seed),
|
|
187
|
+
feature_specs = rfeature_specs,
|
|
188
|
+
verbose = verbose,
|
|
189
|
+
iterative = maybe_null(iterative),
|
|
190
|
+
iterative_args = iterative_args,
|
|
191
|
+
asymmetric = asymmetric,
|
|
192
|
+
causal_ordering = r_causal_ordering,
|
|
193
|
+
confounding = maybe_null(confounding),
|
|
194
|
+
output_args = output_args,
|
|
195
|
+
extra_computation_args = extra_computation_args,
|
|
196
|
+
init_time = init_time,
|
|
197
|
+
is_python = True,
|
|
198
|
+
model_class = model_class,
|
|
199
|
+
**kwargs
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Gets predict_model (if not passed to explain) and checks that predict_model gives correct format
|
|
203
|
+
predict_model = get_predict_model(x_test=x_train.head(2), predict_model=predict_model, model=model)
|
|
204
|
+
|
|
205
|
+
rinternal.rx2['timing_list'].rx2['test_prediction'] = base.Sys_time()
|
|
206
|
+
|
|
207
|
+
rinternal = additional_regression_setup(
|
|
208
|
+
rinternal,
|
|
209
|
+
model,
|
|
210
|
+
predict_model,
|
|
211
|
+
x_train,
|
|
212
|
+
x_explain)
|
|
213
|
+
|
|
214
|
+
# Not called for approach %in% c("regression_surrogate","vaeac")
|
|
215
|
+
rinternal = shapr.setup_approach(internal = rinternal) # model and predict_model are not supported in Python
|
|
216
|
+
|
|
217
|
+
rinternal.rx2['main_timing_list'] = rinternal.rx2['timing_list']
|
|
218
|
+
|
|
219
|
+
converged = False
|
|
220
|
+
iter = len(rinternal.rx2('iter_list'))
|
|
221
|
+
|
|
222
|
+
if seed is not None:
|
|
223
|
+
base.set_seed(seed)
|
|
224
|
+
|
|
225
|
+
shapr.cli_startup(rinternal, verbose)
|
|
226
|
+
|
|
227
|
+
rinternal.rx2['iter_timing_list'] = ro.ListVector({})
|
|
228
|
+
|
|
229
|
+
while not converged:
|
|
230
|
+
shapr.cli_iter(verbose, rinternal, iter)
|
|
231
|
+
|
|
232
|
+
rinternal.rx2['timing_list'] = ro.ListVector({'init': base.Sys_time()})
|
|
233
|
+
|
|
234
|
+
# Setup the Shapley framework
|
|
235
|
+
rinternal = shapr.shapley_setup(rinternal)
|
|
236
|
+
|
|
237
|
+
# Only actually called for approach in ["regression_surrogate", "vaeac"]
|
|
238
|
+
rinternal = shapr.setup_approach(rinternal)
|
|
239
|
+
|
|
240
|
+
# Compute the vS
|
|
241
|
+
vS_list = compute_vS(rinternal, model, predict_model)
|
|
242
|
+
|
|
243
|
+
# Compute Shapley value estimates and bootstrapped standard deviations
|
|
244
|
+
rinternal = shapr.compute_estimates(rinternal, vS_list)
|
|
245
|
+
|
|
246
|
+
# Check convergence based on estimates and standard deviations (and thresholds)
|
|
247
|
+
rinternal = shapr.check_convergence(rinternal)
|
|
248
|
+
|
|
249
|
+
# Save intermediate results
|
|
250
|
+
shapr.save_results(rinternal)
|
|
251
|
+
|
|
252
|
+
# Preparing parameters for next iteration (does not do anything if already converged)
|
|
253
|
+
rinternal = shapr.prepare_next_iteration(rinternal)
|
|
254
|
+
|
|
255
|
+
# Printing iteration information
|
|
256
|
+
shapr.print_iter(rinternal)
|
|
257
|
+
|
|
258
|
+
# Setting globals to simplify the loop
|
|
259
|
+
converged = rinternal.rx2('iter_list')[iter-1].rx2('converged')[0]
|
|
260
|
+
|
|
261
|
+
rinternal.rx2['timing_list'] = ro.ListVector({**dict(rinternal.rx2['timing_list'].items()), 'postprocess_res': base.Sys_time()})
|
|
262
|
+
|
|
263
|
+
# Add the current timing_list to the iter_timing_list
|
|
264
|
+
rinternal.rx2['iter_timing_list'] = ro.ListVector({**dict(rinternal.rx2['iter_timing_list'].items()), f'element_{iter}': rinternal.rx2['timing_list']})
|
|
265
|
+
|
|
266
|
+
iter += 1
|
|
267
|
+
|
|
268
|
+
rinternal.rx2['main_timing_list'] = ro.ListVector({**dict(rinternal.rx2['main_timing_list'].items()), 'main_computation': base.Sys_time()})
|
|
269
|
+
|
|
270
|
+
# Rerun after convergence to get the same output format as for the non-iterative approach
|
|
271
|
+
routput = shapr.finalize_explanation(rinternal)
|
|
272
|
+
|
|
273
|
+
rinternal.rx2['main_timing_list'] = ro.ListVector({**dict(rinternal.rx2['main_timing_list'].items()), 'finalize_explanation': base.Sys_time()})
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
routput.rx2['timing'] = shapr.compute_time(rinternal)
|
|
277
|
+
|
|
278
|
+
# Some cleanup when doing testing
|
|
279
|
+
testing = rinternal.rx2('parameters').rx2('testing')[0]
|
|
280
|
+
if testing:
|
|
281
|
+
routput = shapr.testing_cleanup(routput)
|
|
282
|
+
|
|
283
|
+
# Convert R objects to Python objects
|
|
284
|
+
shapley_values_est = recurse_r_tree(routput.rx2('shapley_values_est'))
|
|
285
|
+
shapley_values_sd = recurse_r_tree(routput.rx2('shapley_values_sd'))
|
|
286
|
+
pred_explain = recurse_r_tree(routput.rx2('pred_explain'))
|
|
287
|
+
MSEv = recurse_r_tree(routput.rx2('MSEv'))
|
|
288
|
+
iterative_results = recurse_r_tree(routput.rx2('iterative_results'))
|
|
289
|
+
saving_path = recurse_r_tree(routput.rx2('saving_path'))
|
|
290
|
+
internal = recurse_r_tree(routput.rx2('internal'))
|
|
291
|
+
timing = recurse_r_tree(routput.rx2('timing'))
|
|
292
|
+
|
|
293
|
+
return {
|
|
294
|
+
"shapley_values_est": shapley_values_est,
|
|
295
|
+
"shapley_values_sd": shapley_values_sd,
|
|
296
|
+
"pred_explain": pred_explain,
|
|
297
|
+
"MSEv": MSEv,
|
|
298
|
+
"iterative_results": iterative_results,
|
|
299
|
+
"saving_path": saving_path,
|
|
300
|
+
"internal": internal,
|
|
301
|
+
"timing": timing,
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def compute_vS(rinternal, model, predict_model):
|
|
306
|
+
|
|
307
|
+
iter = len(rinternal.rx2('iter_list'))
|
|
308
|
+
|
|
309
|
+
S_batch = rinternal.rx2('iter_list')[iter-1].rx2('S_batch')
|
|
310
|
+
|
|
311
|
+
# verbose
|
|
312
|
+
shapr.cli_compute_vS(rinternal)
|
|
313
|
+
|
|
314
|
+
stats.rnorm(1) # Perform a single sample to forward the RNG state one step. This is done to ensurie consistency with
|
|
315
|
+
# future.apply::future_lapply in R which does this to to guarantee consistency for parallellization.
|
|
316
|
+
# See ?future.apply::future_lapply for details
|
|
317
|
+
|
|
318
|
+
vS_list = ro.ListVector({})
|
|
319
|
+
for i, S in enumerate(S_batch):
|
|
320
|
+
vS_list.rx2[i+1] = batch_compute_vS(S=S, rinternal=rinternal, model=model, predict_model=predict_model)
|
|
321
|
+
|
|
322
|
+
#### Adds v_S output above to any vS_list already computed ####
|
|
323
|
+
vS_list = shapr.append_vS_list(vS_list,rinternal)
|
|
324
|
+
|
|
325
|
+
return vS_list
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def batch_compute_vS(S, rinternal, model, predict_model):
|
|
329
|
+
regression = rinternal.rx2('parameters').rx2('regression')[0]
|
|
330
|
+
|
|
331
|
+
# Check if we are to use regression or Monte Carlo integration to compute the contribution function values
|
|
332
|
+
if regression:
|
|
333
|
+
dt_vS = shapr.batch_prepare_vS_regression(S=S, internal=rinternal)
|
|
334
|
+
else:
|
|
335
|
+
# dt_vS is either only dt_vS or a list containing dt_vS and dt if internal$parameters$output_args$keep_samp_for_vS = TRUE
|
|
336
|
+
dt_vS = batch_prepare_vS_MC(S=S, rinternal=rinternal, model=model, predict_model=predict_model)
|
|
337
|
+
|
|
338
|
+
return dt_vS
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def batch_prepare_vS_MC_old(S, rinternal, model, predict_model):
|
|
342
|
+
keep_samp_for_vS = rinternal.rx2('parameters').rx2('keep_samp_for_vS')[0]
|
|
343
|
+
feature_names = list(rinternal.rx2('parameters').rx2('feature_names'))
|
|
344
|
+
|
|
345
|
+
dt = shapr.batch_prepare_vS_MC_auxiliary(S=S, internal=rinternal)
|
|
346
|
+
|
|
347
|
+
dt = compute_preds(dt=dt, feature_names=feature_names, predict_model=predict_model, model=model)
|
|
348
|
+
|
|
349
|
+
dt_vS = shapr.compute_MCint(dt)
|
|
350
|
+
|
|
351
|
+
if keep_samp_for_vS:
|
|
352
|
+
return ro.ListVector({'dt_vS':dt_vS, 'dt_samp_for_vS':dt})
|
|
353
|
+
else:
|
|
354
|
+
return dt_vS
|
|
355
|
+
|
|
356
|
+
def batch_prepare_vS_MC(S, rinternal, model, predict_model):
|
|
357
|
+
feature_names = list(rinternal.rx2('parameters').rx2('feature_names'))
|
|
358
|
+
keep_samp_for_vS = rinternal.rx2('parameters').rx2('output_args').rx2('keep_samp_for_vS')[0]
|
|
359
|
+
causal_sampling = rinternal.rx2('parameters').rx2('causal_sampling')[0]
|
|
360
|
+
output_size = int(rinternal.rx2('parameters').rx2('output_size')[0])
|
|
361
|
+
|
|
362
|
+
dt = shapr.batch_prepare_vS_MC_auxiliary(S=S, internal=rinternal, causal_sampling=causal_sampling)
|
|
363
|
+
|
|
364
|
+
pred_cols = [f"p_hat{i+1}" for i in range(output_size)]
|
|
365
|
+
type_ = rinternal.rx2('parameters').rx2('type')[0]
|
|
366
|
+
|
|
367
|
+
if type_ == "forecast":
|
|
368
|
+
horizon = rinternal.rx2('parameters').rx2('horizon')[0]
|
|
369
|
+
n_endo = rinternal.rx2('data').rx2('n_endo')[0]
|
|
370
|
+
explain_idx = rinternal.rx2('parameters').rx2('explain_idx')[0]
|
|
371
|
+
explain_lags = rinternal.rx2('parameters').rx2('explain_lags')[0]
|
|
372
|
+
y = rinternal.rx2('data').rx2('y')
|
|
373
|
+
xreg = rinternal.rx2('data').rx2('xreg')
|
|
374
|
+
dt = compute_preds(
|
|
375
|
+
dt=dt,
|
|
376
|
+
feature_names=feature_names,
|
|
377
|
+
predict_model=predict_model,
|
|
378
|
+
model=model,
|
|
379
|
+
type_=type_,
|
|
380
|
+
horizon=horizon,
|
|
381
|
+
n_endo=n_endo,
|
|
382
|
+
explain_idx=explain_idx,
|
|
383
|
+
explain_lags=explain_lags,
|
|
384
|
+
y=y,
|
|
385
|
+
xreg=xreg
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
dt = compute_preds(
|
|
389
|
+
dt=dt,
|
|
390
|
+
feature_names=feature_names,
|
|
391
|
+
predict_model=predict_model,
|
|
392
|
+
model=model,
|
|
393
|
+
type_=type_
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
dt_vS = shapr.compute_MCint(dt)
|
|
397
|
+
|
|
398
|
+
if keep_samp_for_vS:
|
|
399
|
+
return ro.ListVector({'dt_vS': dt_vS, 'dt_samp_for_vS': dt})
|
|
400
|
+
else:
|
|
401
|
+
return dt_vS
|
|
402
|
+
|
|
403
|
+
def compute_preds(
|
|
404
|
+
dt,
|
|
405
|
+
feature_names,
|
|
406
|
+
predict_model,
|
|
407
|
+
model,
|
|
408
|
+
type_,
|
|
409
|
+
horizon=None,
|
|
410
|
+
n_endo=None,
|
|
411
|
+
explain_idx=None,
|
|
412
|
+
explain_lags=None,
|
|
413
|
+
y=None,
|
|
414
|
+
xreg=None
|
|
415
|
+
):
|
|
416
|
+
# Predictions
|
|
417
|
+
if type_ == "forecast":
|
|
418
|
+
preds = predict_model(
|
|
419
|
+
model,
|
|
420
|
+
r2py(dt).loc[:,:n_endo],
|
|
421
|
+
r2py(dt).loc[:,n_endo:],
|
|
422
|
+
horizon,
|
|
423
|
+
explain_idx,
|
|
424
|
+
explain_lags,
|
|
425
|
+
y,
|
|
426
|
+
xreg
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
else:
|
|
430
|
+
preds = predict_model(
|
|
431
|
+
model,
|
|
432
|
+
r2py(dt).loc[:,feature_names]
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
return ro.r.cbind(dt, p_hat=ro.FloatVector(preds.tolist()))
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def compute_preds_old(dt, feature_names, predict_model, model):
|
|
440
|
+
preds = predict_model(model, r2py(dt).loc[:,feature_names])
|
|
441
|
+
return ro.r.cbind(dt, p_hat=ro.FloatVector(preds.tolist()))
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def get_feature_specs(get_model_specs, model):
|
|
446
|
+
model_class0 = type(model)
|
|
447
|
+
|
|
448
|
+
if (get_model_specs is not None) and (not callable(get_model_specs)):
|
|
449
|
+
raise ValueError('`get_model_specs` must be None or callable.')
|
|
450
|
+
|
|
451
|
+
if get_model_specs is None:
|
|
452
|
+
get_model_specs = prebuilt_get_model_specs(model)
|
|
453
|
+
if get_model_specs is None:
|
|
454
|
+
warnings.warn(f'No pre-built get_model_specs for model of type {type(model)}, disabling checks.')
|
|
455
|
+
return NULL
|
|
456
|
+
|
|
457
|
+
if callable(get_model_specs):
|
|
458
|
+
try:
|
|
459
|
+
feature_specs = get_model_specs(model)
|
|
460
|
+
except Exception as e:
|
|
461
|
+
raise RuntimeError(f'The get_model_specs function of class `{model_class0}` is invalid.\nA basic function test threw the following error:\n{e}')
|
|
462
|
+
|
|
463
|
+
if not isinstance(feature_specs, dict):
|
|
464
|
+
raise ValueError(f'`get_model_specs` returned an object of type `{type(feature_specs)}`, but it should be of type `dict`')
|
|
465
|
+
if set(feature_specs.keys()) != set(["labels","classes","factor_levels"]):
|
|
466
|
+
raise ValueError(f'`get_model_specs` should return a `dict` with keys ["labels","classes","factor_levels"], but found keys {list(feature_specs.keys())}')
|
|
467
|
+
|
|
468
|
+
if feature_specs is None:
|
|
469
|
+
rfeature_specs = NULL
|
|
470
|
+
else:
|
|
471
|
+
py2r_or_na = lambda v: py2r(v) if v is not None else NA
|
|
472
|
+
def strvec_or_na(v):
|
|
473
|
+
if v is None: return NA
|
|
474
|
+
strvec = StrVector(list(v.values()))
|
|
475
|
+
strvec.names = list(v.keys())
|
|
476
|
+
return strvec
|
|
477
|
+
def listvec_or_na(v):
|
|
478
|
+
if v is None: return NA
|
|
479
|
+
return ro.ListVector({k:list(val) for k,val in v.items()})
|
|
480
|
+
|
|
481
|
+
rfeature_specs = ro.ListVector({
|
|
482
|
+
'labels': py2r_or_na(feature_specs['labels']),
|
|
483
|
+
'classes': strvec_or_na(feature_specs['classes']),
|
|
484
|
+
'factor_levels': listvec_or_na(feature_specs['factor_levels']),
|
|
485
|
+
})
|
|
486
|
+
return rfeature_specs
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def get_predict_model(x_test, predict_model, model):
|
|
490
|
+
|
|
491
|
+
model_class0 = type(model)
|
|
492
|
+
|
|
493
|
+
if (predict_model is not None) and (not callable(predict_model)):
|
|
494
|
+
raise RuntimeError(f'The predict_model function of class `{model}` is invalid.\nA basic function test threw the following error:\n{e}')
|
|
495
|
+
|
|
496
|
+
if predict_model is None:
|
|
497
|
+
predict_model = prebuilt_predict_model(model)
|
|
498
|
+
if predict_model is None:
|
|
499
|
+
raise ValueError(f'No pre-built predict_model for model of type {type(model)}. Please pass a custom predict_model to shaprpy.explain(...).')
|
|
500
|
+
|
|
501
|
+
try:
|
|
502
|
+
tmp = py2r(predict_model(model, x_test))
|
|
503
|
+
except Exception as e:
|
|
504
|
+
raise RuntimeError(f'The predict_model function of class `{model_class0}` is invalid.\nA basic function test threw the following error:\n{e}')
|
|
505
|
+
if not all(base.is_numeric(tmp)):
|
|
506
|
+
raise RuntimeError('The output of predict_model is expected to be numeric.')
|
|
507
|
+
if not (len(tmp) == 2):
|
|
508
|
+
raise RuntimeError('The output of predict_model does not match the length of the input.')
|
|
509
|
+
return predict_model
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def prebuilt_get_model_specs(model):
|
|
513
|
+
|
|
514
|
+
# Look for sklearn
|
|
515
|
+
try:
|
|
516
|
+
from sklearn.base import BaseEstimator
|
|
517
|
+
if isinstance(model, BaseEstimator):
|
|
518
|
+
return lambda m: {
|
|
519
|
+
'labels': m.feature_names_in_,
|
|
520
|
+
'classes': None, # Not available from model object
|
|
521
|
+
'factor_levels': None, # Not available from model object
|
|
522
|
+
}
|
|
523
|
+
except:
|
|
524
|
+
pass
|
|
525
|
+
|
|
526
|
+
# Look for xgboost.core.Booster
|
|
527
|
+
try:
|
|
528
|
+
import xgboost as xgb
|
|
529
|
+
if isinstance(model, xgb.core.Booster):
|
|
530
|
+
return lambda m: {
|
|
531
|
+
'labels': np.array(m.feature_names),
|
|
532
|
+
'classes': None, # Not available from model object
|
|
533
|
+
'factor_levels': None, # Not available from model object
|
|
534
|
+
}
|
|
535
|
+
except:
|
|
536
|
+
pass
|
|
537
|
+
|
|
538
|
+
return None
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def prebuilt_predict_model(model):
|
|
542
|
+
|
|
543
|
+
# Look for sklearn
|
|
544
|
+
try:
|
|
545
|
+
from sklearn.base import is_classifier, is_regressor
|
|
546
|
+
if is_classifier(model): return lambda m, x: m.predict_proba(x)[:,1]
|
|
547
|
+
if is_regressor(model): return lambda m, x: m.predict(x).flatten()
|
|
548
|
+
except:
|
|
549
|
+
pass
|
|
550
|
+
|
|
551
|
+
# Look for xgboost.core.Booster
|
|
552
|
+
try:
|
|
553
|
+
import xgboost as xgb
|
|
554
|
+
if isinstance(model, xgb.core.Booster):
|
|
555
|
+
return lambda m, x: m.predict(xgb.DMatrix(x))
|
|
556
|
+
except:
|
|
557
|
+
pass
|
|
558
|
+
|
|
559
|
+
# Look for keras
|
|
560
|
+
try:
|
|
561
|
+
from keras.models import Model
|
|
562
|
+
if isinstance(model, Model):
|
|
563
|
+
def predict_fn(m,x):
|
|
564
|
+
pred = m.predict(x)
|
|
565
|
+
return pred.reshape(pred.shape[0],)
|
|
566
|
+
return predict_fn
|
|
567
|
+
except:
|
|
568
|
+
pass
|
|
569
|
+
|
|
570
|
+
return None
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def compute_time(timing_list):
|
|
574
|
+
|
|
575
|
+
timing_secs = {
|
|
576
|
+
f'{key}': (timing_list[key] - timing_list[prev_key]).total_seconds()
|
|
577
|
+
for key, prev_key in zip(list(timing_list.keys())[1:], list(timing_list.keys())[:-1])
|
|
578
|
+
}
|
|
579
|
+
timing_output = {
|
|
580
|
+
'init_time': timing_list['init_time'].strftime("%Y-%m-%d %H:%M:%S"),
|
|
581
|
+
'total_time_secs': sum(timing_secs.values()),
|
|
582
|
+
'timing_secs': timing_secs
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
return timing_output
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
def additional_regression_setup(rinternal, model, predict_model, x_train, x_explain):
|
|
589
|
+
# Add the predicted response of the training and explain data to the internal list for regression-based methods
|
|
590
|
+
regression = rinternal.rx2("parameters").rx2("regression")[0]
|
|
591
|
+
if regression:
|
|
592
|
+
rinternal = regression_get_y_hat(rinternal, model, predict_model, x_train, x_explain)
|
|
593
|
+
|
|
594
|
+
return rinternal
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def regression_get_y_hat(rinternal, model, predict_model, x_train, x_explain):
|
|
598
|
+
x_train_y_hat = predict_model(model, x_train)
|
|
599
|
+
x_explain_y_hat = predict_model(model, x_explain)
|
|
600
|
+
|
|
601
|
+
# Extract data list, add the predicted responses, and then updated rinternal (direct assignment did not work)
|
|
602
|
+
data = rinternal.rx2['data']
|
|
603
|
+
data.rx2['x_train_y_hat'] = ro.FloatVector(x_train_y_hat.tolist())
|
|
604
|
+
data.rx2['x_explain_y_hat'] = ro.FloatVector(x_explain_y_hat.tolist())
|
|
605
|
+
rinternal.rx2['data'] = data
|
|
606
|
+
|
|
607
|
+
return rinternal
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def regression_remove_objects(routput):
|
|
611
|
+
tmp_internal = routput.rx2("internal")
|
|
612
|
+
tmp_parameters = tmp_internal.rx2("parameters")
|
|
613
|
+
objects = StrVector(("regression", "regression.model", "regression.tune_values", "regression.vfold_cv_para",
|
|
614
|
+
"regression.recipe_func", "regression.tune", "regression.surrogate_n_comb"))
|
|
615
|
+
tmp_parameters.rx[objects] = NULL
|
|
616
|
+
tmp_internal.rx2["parameters"] = tmp_parameters
|
|
617
|
+
if tmp_parameters.rx2("approach")[0] == "regression_surrogate":
|
|
618
|
+
tmp_objects = tmp_internal.rx2("objects")
|
|
619
|
+
tmp_objects.rx["regression.surrogate_model"] = NULL
|
|
620
|
+
tmp_internal.rx2["objects"] = tmp_objects
|
|
621
|
+
routput.rx2["internal"] = tmp_internal
|
|
622
|
+
return routput
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
def change_first_underscore_to_dot(kwargs):
|
|
626
|
+
kwargs_tmp = {}
|
|
627
|
+
for k, v in kwargs.items():
|
|
628
|
+
kwargs_tmp[k.replace('_', '.', 1)] = v
|
|
629
|
+
return kwargs_tmp
|
|
630
|
+
|
shaprpy/datasets.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from sklearn.datasets import fetch_openml, fetch_california_housing, load_iris
|
|
3
|
+
from sklearn.model_selection import train_test_split
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def load_california_housing():
|
|
7
|
+
housing = fetch_california_housing()
|
|
8
|
+
dfx = pd.DataFrame(housing.data, columns=housing.feature_names)
|
|
9
|
+
dfy = pd.DataFrame({'target': housing.target})
|
|
10
|
+
dfx_train, dfx_test, dfy_train, dfy_test = train_test_split(dfx, dfy, test_size=0.99, random_state=42)
|
|
11
|
+
dfx_test, dfy_test = dfx_test[:5], dfy_test[:5] # To reduce computational load
|
|
12
|
+
return dfx_train, dfx_test, dfy_train, dfy_test
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def load_binary_iris():
|
|
16
|
+
bcancer = load_iris()
|
|
17
|
+
dfx = pd.DataFrame(bcancer.data, columns=bcancer.feature_names).iloc[bcancer.target<2] # Turning it into a binary classification problem
|
|
18
|
+
dfy = pd.DataFrame({'target': bcancer.target}).iloc[bcancer.target<2] # Turning it into a binary classification problem
|
|
19
|
+
dfx_train, dfx_test, dfy_train, dfy_test = train_test_split(dfx, dfy, test_size=5, random_state=42)
|
|
20
|
+
return dfx_train, dfx_test, dfy_train, dfy_test
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def load_adult():
|
|
24
|
+
dfx, y = fetch_openml(data_id=1590, return_X_y=True)
|
|
25
|
+
dfx = dfx.dropna(axis=1) # Drop columns with NAs for simplicity
|
|
26
|
+
dfy = pd.DataFrame({'target': y.factorize()[0]})
|
|
27
|
+
return train_test_split(dfx, dfy, test_size=5, random_state=42)
|
shaprpy/utils.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from rpy2.robjects.conversion import localconverter
|
|
4
|
+
from rpy2.robjects import default_converter, Formula
|
|
5
|
+
from rpy2.robjects.functions import SignatureTranslatedFunction
|
|
6
|
+
from rpy2.robjects.numpy2ri import converter as np_converter
|
|
7
|
+
from rpy2.robjects.pandas2ri import converter as pd_converter
|
|
8
|
+
from rpy2.robjects.pandas2ri import _to_pandas_factor
|
|
9
|
+
from rpy2.rinterface import NULL, NA
|
|
10
|
+
from rpy2.robjects.vectors import DataFrame, FloatVector, IntVector, BoolVector, StrVector, ListVector, FactorVector, FloatMatrix, Matrix, POSIXct
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@pd_converter.rpy2py.register(FactorVector)
|
|
15
|
+
def rpt2py_factorvector(obj):
|
|
16
|
+
return _to_pandas_factor(obj)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def py2r(obj):
|
|
20
|
+
with localconverter(default_converter + np_converter + pd_converter) as converter:
|
|
21
|
+
robj = converter.py2rpy(obj)
|
|
22
|
+
return robj
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def r2py(robj):
|
|
26
|
+
converter = default_converter + np_converter + pd_converter
|
|
27
|
+
obj = converter.rpy2py(robj)
|
|
28
|
+
return obj
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def recurse_r_tree(data):
|
|
32
|
+
if data == NULL:
|
|
33
|
+
return None
|
|
34
|
+
elif type(data) == DataFrame:
|
|
35
|
+
try:
|
|
36
|
+
return r2py(data)
|
|
37
|
+
except Exception as e:
|
|
38
|
+
# The column "features" in internal$objects$X is known to cause problems
|
|
39
|
+
d = {}
|
|
40
|
+
for col in data.names:
|
|
41
|
+
try:
|
|
42
|
+
d[col] = r2py(data.rx2(col))
|
|
43
|
+
except:
|
|
44
|
+
# We manually convert the elements of the column "features" in internal$objects$X
|
|
45
|
+
d[col] = [r2py(d) for d in data.rx2(col)]
|
|
46
|
+
return pd.DataFrame(d, index=data.rownames)
|
|
47
|
+
elif type(data) in [FloatVector, IntVector, BoolVector, FloatMatrix, Matrix]:
|
|
48
|
+
return np.array(data)
|
|
49
|
+
elif type(data) == FactorVector:
|
|
50
|
+
return _to_pandas_factor(data)
|
|
51
|
+
elif type(data) == POSIXct:
|
|
52
|
+
with warnings.catch_warnings():
|
|
53
|
+
warnings.simplefilter("ignore")
|
|
54
|
+
tmp = r2py(data).strftime("%Y-%m-%d %H:%M:%S")[0]
|
|
55
|
+
return tmp
|
|
56
|
+
elif type(data) == SignatureTranslatedFunction:
|
|
57
|
+
return str(data)
|
|
58
|
+
elif type(data) == Formula:
|
|
59
|
+
return str(data)
|
|
60
|
+
elif type(data) == ListVector:
|
|
61
|
+
if type(data.names) == type(NULL):
|
|
62
|
+
data.names = [f"element_{i+1}" for i in range(len(data))]
|
|
63
|
+
return dict(zip(data.names, [recurse_r_tree(d) for d in data]))
|
|
64
|
+
elif type(data) == StrVector:
|
|
65
|
+
return [recurse_r_tree(d) for d in data]
|
|
66
|
+
else:
|
|
67
|
+
return data # We reached the end of recursion (if not converted below, return the object as is)
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: shaprpy
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: Python wrapper for the R package shapr (via rpy2)
|
|
5
|
+
Author: Martin Jullum, Lars Henry Berge Olsen, Didrik Nielsen
|
|
6
|
+
License: YEAR: 2019
|
|
7
|
+
COPYRIGHT HOLDER: Norsk Regnesentral
|
|
8
|
+
|
|
9
|
+
Project-URL: Homepage, https://github.com/NorskRegnesentral/shapr
|
|
10
|
+
Project-URL: Documentation, https://norskregnesentral.github.io/shapr/shaprpy.html
|
|
11
|
+
Project-URL: Issues, https://github.com/NorskRegnesentral/shapr/issues
|
|
12
|
+
Project-URL: Changelog, https://github.com/NorskRegnesentral/shapr/blob/main/python/CHANGELOG.md
|
|
13
|
+
Keywords: explainable-ai,shapley-values,machine-learning,model-interpretability
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
21
|
+
Classifier: Operating System :: POSIX :: Linux
|
|
22
|
+
Classifier: Intended Audience :: Science/Research
|
|
23
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
24
|
+
Requires-Python: >3.10
|
|
25
|
+
Description-Content-Type: text/markdown
|
|
26
|
+
License-File: LICENSE
|
|
27
|
+
License-File: LICENSE.md
|
|
28
|
+
Requires-Dist: rpy2>=3.5.1
|
|
29
|
+
Requires-Dist: numpy>=1.22.3
|
|
30
|
+
Requires-Dist: pandas>=1.4.2
|
|
31
|
+
Requires-Dist: scikit-learn>=1.0.0
|
|
32
|
+
Requires-Dist: tabulate>=0.8.10
|
|
33
|
+
Provides-Extra: test
|
|
34
|
+
Requires-Dist: pytest>=7.0.0; extra == "test"
|
|
35
|
+
Requires-Dist: syrupy>=4.0.0; extra == "test"
|
|
36
|
+
Requires-Dist: xgboost>=1.5.0; extra == "test"
|
|
37
|
+
Dynamic: license-file
|
|
38
|
+
|
|
39
|
+
# shaprpy
|
|
40
|
+
|
|
41
|
+
Python wrapper for the R package [shapr](https://github.com/NorskRegnesentral/shapr).
|
|
42
|
+
|
|
43
|
+
NOTE: This wrapper is not as comprehensively tested as the `R`-package.
|
|
44
|
+
|
|
45
|
+
`shaprpy` relies heavily on the `rpy2` Python library for accessing R from within Python.
|
|
46
|
+
`rpy2` has limited support on Windows. `shaprpy` has only been tested on Linux.
|
|
47
|
+
The below instructions assumes a Linux environment.
|
|
48
|
+
|
|
49
|
+
# shaprpy
|
|
50
|
+
|
|
51
|
+
`shaprpy` is a Python wrapper for the R package [shapr](https://github.com/NorskRegnesentral/shapr),
|
|
52
|
+
using the [`rpy2`](https://rpy2.github.io/) Python library to access R from within Python.
|
|
53
|
+
|
|
54
|
+
> **Note:** This wrapper is **not** as comprehensively tested as the R package.
|
|
55
|
+
> `rpy2` has limited support on Windows, and the same therefore applies to `shaprpy`.
|
|
56
|
+
> `shaprpy` has only been tested on Linux (and WSL - Windows Subsystem for Linux), and the below instructions assume a Linux environment.
|
|
57
|
+
>
|
|
58
|
+
> **Requirement:** Python 3.10 or later is required to use `shaprpy`.
|
|
59
|
+
|
|
60
|
+
## Changelog
|
|
61
|
+
|
|
62
|
+
For a list of changes and updates to the `shaprpy` package, see the [shaprpy CHANGELOG](https://norskregnesentral.github.io/shapr/py_changelog.html).
|
|
63
|
+
|
|
64
|
+
---
|
|
65
|
+
|
|
66
|
+
## Installation
|
|
67
|
+
|
|
68
|
+
These instructions assume you already have **pip** and **R** installed and available to the Python environment in which you want to run `shaprpy`.
|
|
69
|
+
|
|
70
|
+
- Official instructions for installing `pip` can be found [here](https://pip.pypa.io/en/stable/installation/).
|
|
71
|
+
- Official instructions for installing R can be found [here](https://cran.r-project.org/).
|
|
72
|
+
|
|
73
|
+
On Debian/Ubuntu-based systems, R can also be installed via:
|
|
74
|
+
```bash
|
|
75
|
+
sudo apt update
|
|
76
|
+
sudo apt install r-base r-base-dev -y
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
### 1. Install the R package `shapr`
|
|
80
|
+
|
|
81
|
+
`shaprpy` requires the R package `shapr` (version 1.0.5 or newer).
|
|
82
|
+
In your R environment, install the latest version from CRAN using:
|
|
83
|
+
|
|
84
|
+
```bash
|
|
85
|
+
Rscript -e 'install.packages("shapr", repos="https://cran.rstudio.com")'
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
### 2. Ensure R is discoverable (R_HOME and PATH)
|
|
89
|
+
|
|
90
|
+
Sometimes `shaprpy` cannot automatically locate your R installation. To ensure proper detection, verify that:
|
|
91
|
+
- R is available in your system `PATH`, **or**
|
|
92
|
+
- The `R_HOME` environment variable is set to your R installation directory.
|
|
93
|
+
|
|
94
|
+
Example:
|
|
95
|
+
```bash
|
|
96
|
+
export R_HOME=$(R RHOME)
|
|
97
|
+
export PATH=$PATH:$(R RHOME)/bin
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
### 3. Install the Python wrapper
|
|
101
|
+
|
|
102
|
+
Install directly from PyPI with:
|
|
103
|
+
|
|
104
|
+
```bash
|
|
105
|
+
pip install shaprpy
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
#### Local development install (for contributors)
|
|
109
|
+
If you have cloned the repository and want to install in development mode for local changes, navigate to the `./python` directory and run:
|
|
110
|
+
```bash
|
|
111
|
+
pip install -e .
|
|
112
|
+
```
|
|
113
|
+
The `-e` flag installs in editable mode, allowing local code changes to be reflected immediately.
|
|
114
|
+
|
|
115
|
+
---
|
|
116
|
+
|
|
117
|
+
## Quick Demo
|
|
118
|
+
|
|
119
|
+
```python
|
|
120
|
+
from sklearn.ensemble import RandomForestRegressor
|
|
121
|
+
from shaprpy import explain
|
|
122
|
+
from shaprpy.datasets import load_california_housing
|
|
123
|
+
|
|
124
|
+
# Load example data
|
|
125
|
+
dfx_train, dfx_test, dfy_train, dfy_test = load_california_housing()
|
|
126
|
+
|
|
127
|
+
# Fit a model
|
|
128
|
+
model = RandomForestRegressor()
|
|
129
|
+
model.fit(dfx_train, dfy_train.values.flatten())
|
|
130
|
+
|
|
131
|
+
# Explain predictions
|
|
132
|
+
explanation = explain(
|
|
133
|
+
model=model,
|
|
134
|
+
x_train=dfx_train,
|
|
135
|
+
x_explain=dfx_test,
|
|
136
|
+
approach="empirical",
|
|
137
|
+
phi0=dfy_train.mean().item(),
|
|
138
|
+
seed=1
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
print(explanation["shapley_values_est"])
|
|
142
|
+
```
|
|
143
|
+
|
|
144
|
+
---
|
|
145
|
+
|
|
146
|
+
## Supported Models
|
|
147
|
+
|
|
148
|
+
`shaprpy` can explain predictions from models built with:
|
|
149
|
+
- [`scikit-learn`](https://scikit-learn.org/)
|
|
150
|
+
- [`keras`](https://keras.io/) (Sequential API)
|
|
151
|
+
- [`xgboost`](https://xgboost.readthedocs.io/)
|
|
152
|
+
|
|
153
|
+
For other model types, you can supply:
|
|
154
|
+
- A custom `predict_model` function
|
|
155
|
+
- (Optionally) a custom `get_model_specs` function
|
|
156
|
+
to `shaprpy.explain`.
|
|
157
|
+
|
|
158
|
+
---
|
|
159
|
+
|
|
160
|
+
## Examples
|
|
161
|
+
|
|
162
|
+
See the `/examples` folder for runnable examples, including:
|
|
163
|
+
- A custom PyTorch model
|
|
164
|
+
- The **regression paradigm** described in [Olsen et al. (2024)](https://link.springer.com/article/10.1007/s10618-024-01016-z),
|
|
165
|
+
which shows:
|
|
166
|
+
- How to specify the regression model
|
|
167
|
+
- How to enable automatic cross-validation of hyperparameters
|
|
168
|
+
- How to apply pre-processing steps before fitting regression models
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
shaprpy/__init__.py,sha256=Z7UA4ws5uuPUny9uRpnn6o3PGv5ONWKjhRzBOv-YJJc,1528
|
|
2
|
+
shaprpy/_explain.py,sha256=k_OH_wqewAB6iRodW2aWI_gprw0hra9vyI3Ir-jQkqE,23831
|
|
3
|
+
shaprpy/datasets.py,sha256=ptX-fcD9evlpEdqXrmCTXwcYTao0vTPv93k5C5PYS2g,1295
|
|
4
|
+
shaprpy/utils.py,sha256=_kD5ZW9bp3kyTXhtS-jz4zKAdql4pOH6KZehFJTxaRk,2461
|
|
5
|
+
shaprpy-0.3.0.dist-info/licenses/LICENSE,sha256=hABnnsrNduCr-dp2Cy4-XOroRGzhaw7LefBhxlY0gjQ,48
|
|
6
|
+
shaprpy-0.3.0.dist-info/licenses/LICENSE.md,sha256=9Gwz5ov3rmu9AK8fTEZUj4pPXL5ROOlzcNqukEKxNOY,1077
|
|
7
|
+
shaprpy-0.3.0.dist-info/METADATA,sha256=82ke02a2Al423a_GoQyg9UQthpGgIOBEAHhjPXPP9fU,5675
|
|
8
|
+
shaprpy-0.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
+
shaprpy-0.3.0.dist-info/top_level.txt,sha256=4LqZ5SreryN3N5zgxtxKjyAeDUo6ggO9w1s0eRTak78,8
|
|
10
|
+
shaprpy-0.3.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2019 Norsk Regnesentral
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
shaprpy
|