shaprpy 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
shaprpy/__init__.py ADDED
@@ -0,0 +1,51 @@
1
+ from importlib.metadata import version, PackageNotFoundError
2
+ from importlib import import_module
3
+
4
+ # Lightweight public re-export (no R dependency)
5
+ from . import datasets # noqa: F401
6
+
7
+ __all__ = ["explain", "datasets", "ensure_r_ready"]
8
+
9
+ try:
10
+ __version__ = version("shaprpy")
11
+ except PackageNotFoundError:
12
+ __version__ = "0.0.0+local"
13
+
14
+ _r_ready = False
15
+ _explain_impl = None
16
+
17
+
18
+ def ensure_r_ready() -> bool:
19
+ """Ensure rpy2 and the R package 'shapr' are available, then bind the real explain() (idempotent)."""
20
+ global _r_ready, _explain_impl
21
+ if _r_ready:
22
+ return True
23
+
24
+ try:
25
+ import rpy2.robjects as _ro # noqa: F401
26
+ from rpy2.robjects.packages import importr
27
+ except Exception as e:
28
+ raise ImportError(
29
+ "shaprpy requires rpy2 and a working R installation.\n"
30
+ "Install R and rpy2, and ensure R is on PATH/R_HOME. See README."
31
+ ) from e
32
+
33
+ try:
34
+ importr("shapr")
35
+ except Exception as e:
36
+ raise ImportError(
37
+ "The R package 'shapr' is not installed or not found.\n"
38
+ "In an R session, run: install.packages('shapr')"
39
+ ) from e
40
+
41
+ # Import the implementation from a private module to avoid name collision
42
+ _explain_mod = import_module(__name__ + "._explain")
43
+ _explain_impl = _explain_mod.explain
44
+ _r_ready = True
45
+ return True
46
+
47
+
48
+ def explain(*args, **kwargs):
49
+ """Lazily initialize R/shapr then call the real explain()."""
50
+ ensure_r_ready()
51
+ return _explain_impl(*args, **kwargs)
shaprpy/_explain.py ADDED
@@ -0,0 +1,630 @@
1
+ import warnings
2
+ import numpy as np
3
+ import pandas as pd
4
+ from typing import Callable
5
+ from datetime import datetime
6
+ import rpy2.robjects as ro
7
+ from rpy2.robjects.packages import importr
8
+ from rpy2.rinterface import NULL, NA
9
+ from shaprpy.utils import r2py, py2r, recurse_r_tree
10
+ from rpy2.robjects.vectors import StrVector, ListVector
11
+
12
+ data_table = importr('data.table')
13
+ shapr = importr('shapr')
14
+ utils = importr('utils')
15
+ base = importr('base')
16
+ stats = importr('stats')
17
+
18
+ def maybe_null(val):
19
+ return val if val is not None else NULL
20
+
21
+ def explain(
22
+ model,
23
+ x_explain: pd.DataFrame,
24
+ x_train: pd.DataFrame,
25
+ approach: str | list[str],
26
+ phi0: float,
27
+ iterative: bool | None = None,
28
+ max_n_coalitions: int | None = None,
29
+ group: dict | None = None,
30
+ n_MC_samples: int = 1000,
31
+ seed: int | None = None,
32
+ verbose: str | list[str] | None = "basic",
33
+ predict_model: Callable | None = None,
34
+ get_model_specs: Callable | None = None,
35
+ asymmetric: bool = False,
36
+ causal_ordering: dict | None = None,
37
+ confounding: bool | None = None,
38
+ extra_computation_args: dict | None = None,
39
+ iterative_args: dict | None = None,
40
+ output_args: dict | None = None,
41
+ **kwargs,
42
+ ):
43
+ """
44
+ Explain the output of machine learning models with more accurately estimated Shapley values.
45
+
46
+ Computes dependence-aware Shapley values for observations in `x_explain` from the specified
47
+ `model` by using the method specified in `approach` to estimate the conditional expectation.
48
+
49
+ Parameters
50
+ ----------
51
+ model: The model whose predictions we want to explain.
52
+ `shaprpy` natively supports `sklearn`, `xgboost` and `keras` models.
53
+ Unsupported models can still be explained by passing `predict_model` and (optionally) `get_model_specs`.
54
+ x_explain: pd.DataFrame
55
+ Contains the features whose predictions ought to be explained.
56
+ x_train: pd.DataFrame
57
+ Contains the data used to estimate the (conditional) distributions for the features
58
+ needed to properly estimate the conditional expectations in the Shapley formula.
59
+ approach: str or list[str]
60
+ The method(s) to estimate the conditional expectation. All elements should,
61
+ either be `"gaussian"`, `"copula"`, `"empirical"`, `"ctree"`, `"categorical"`, `"timeseries"`, `"independence"`,
62
+ `"regression_separate"`, or `"regression_surrogate"`.
63
+ phi0: float
64
+ The prediction value for unseen data, i.e. an estimate of the expected prediction without conditioning on any
65
+ features. Typically we set this value equal to the mean of the response variable in our training data, but other
66
+ choices such as the mean of the predictions in the training data are also reasonable.
67
+ iterative: bool or None, optional
68
+ If `None` (default), the argument is set to `True` if there are more than 5 features/groups, and `False` otherwise.
69
+ If `True`, the Shapley values are estimated iteratively in an iterative manner.
70
+ max_n_coalitions: int or None, optional
71
+ The upper limit on the number of unique feature/group coalitions to use in the iterative procedure
72
+ (if `iterative = True`). If `iterative = False` it represents the number of feature/group coalitions to use directly.
73
+ `max_n_coalitions = None` corresponds to `max_n_coalitions=2^n_features`.
74
+ group: dict or None, optional
75
+ If `None` regular feature wise Shapley values are computed.
76
+ If provided, group wise Shapley values are computed. `group` then contains lists of unique feature names with the
77
+ features included in each of the different groups.
78
+ n_MC_samples: int, optional
79
+ Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation.
80
+ seed: int or None, optional
81
+ Specifies the seed before any randomness based code is being run.
82
+ If `None` (default) the seed will be inherited from the calling environment.
83
+ verbose: str or list[str] or None, optional
84
+ Specifies the verbosity (printout detail level) through one or more of the strings `"basic"`, `"progress"`,
85
+ `"convergence"`, `"shapley"` and `"vS_details"`. `None` means no printout.
86
+ predict_model: Callable, optional
87
+ The prediction function used when `model` is not natively supported. The function must have two arguments, `model` and `newdata`
88
+ which specify, respectively, the model and a pandas.DataFrame to compute predictions for. The function must give the prediction as a numpy.Array.
89
+ get_model_specs: Callable, optional
90
+ An optional function for checking model/data consistency when `model` is not natively supported. The function takes `model` as argument
91
+ and provides a `dict` with 3 elements: `labels`, `classes`, and `factor_levels`.
92
+ asymmetric: bool, optional
93
+ If `False` (default), `explain` computes regular symmetric Shapley values. If `True`, then `explain` computes asymmetric Shapley values
94
+ based on the (partial) causal ordering given by `causal_ordering`.
95
+ causal_ordering: dict or None, optional
96
+ An unnamed list of vectors specifying the components of the partial causal ordering that the coalitions must respect.
97
+ confounding: bool or None, optional
98
+ A vector of logicals specifying whether confounding is assumed or not for each component in the `causal_ordering`.
99
+ extra_computation_args: dict or None, optional
100
+ Specifies extra arguments related to the computation of the Shapley values.
101
+ iterative_args: dict or None, optional
102
+ Specifies the arguments for the iterative procedure.
103
+ output_args: dict or None, optional
104
+ Specifies certain arguments related to the output of the function.
105
+ **kwargs: Further arguments passed to specific approaches.
106
+
107
+ Returns
108
+ -------
109
+ dict
110
+ A dictionary containing the following items:
111
+ - "shapley_values_est": pd.DataFrame with the estimated Shapley values.
112
+ - "shapley_values_sd": pd.DataFrame with the standard deviation of the Shapley values.
113
+ - "pred_explain": numpy.Array with the predictions for the explained observations.
114
+ - "MSEv": dict with the values of the MSEv evaluation criterion.
115
+ - "iterative_results": dict with the results of the iterative estimation.
116
+ - "saving_path": str with the path where intermediate results are stored.
117
+ - "internal": dict with the different parameters, data, functions and other output used internally.
118
+ - "timing": dict containing timing information for the different parts of the computation.
119
+ """
120
+
121
+ init_time = base.Sys_time() # datetime.now()
122
+
123
+
124
+ if seed is not None:
125
+ base.set_seed(seed)
126
+
127
+ # Gets and check feature specs from the model
128
+ rfeature_specs = get_feature_specs(get_model_specs, model)
129
+
130
+ # Fixes the conversion from dict to a named list of vectors in R
131
+ r_group = NULL if group is None else ListVector({key: StrVector(value) for key, value in group.items()})
132
+
133
+ # Fixes the conversion from dict to a named list of vectors in R
134
+ r_causal_ordering = NULL if causal_ordering is None else ListVector({key: StrVector(value) for key, value in causal_ordering.items()})
135
+
136
+
137
+ # Fixes method specific argument names by replacing first occurrence of "_" with "."
138
+ if len(kwargs) > 0:
139
+ kwargs = change_first_underscore_to_dot(kwargs)
140
+
141
+ # Convert from dict to a named list of vectors in R if `regression.vfold_cv_para` is provided by the user
142
+ if 'regression.vfold_cv_para' in kwargs:
143
+ kwargs['regression.vfold_cv_para'] = ListVector(kwargs['regression.vfold_cv_para'])
144
+
145
+ # Convert from None or dict to a named list in R
146
+ if iterative_args is None:
147
+ iterative_args = ro.ListVector({})
148
+ else:
149
+ iterative_args = ListVector(iterative_args)
150
+
151
+ if output_args is None:
152
+ output_args = ro.ListVector({})
153
+ else:
154
+ output_args = ListVector(output_args)
155
+
156
+ if extra_computation_args is None:
157
+ extra_computation_args = ro.ListVector({})
158
+ else:
159
+ extra_computation_args = ListVector(extra_computation_args)
160
+
161
+ model_class = f"{type(model).__module__}.{type(model).__name__}"
162
+
163
+ # Sets up and organizes input parameters
164
+ # Checks the input parameters and their compatability
165
+ # Checks data/model compatability
166
+
167
+ if isinstance(approach, str):
168
+ approach = [approach]
169
+
170
+ if isinstance(verbose, str):
171
+ verbose = [verbose]
172
+ if isinstance(verbose, list):
173
+ verbose = StrVector(verbose)
174
+ else:
175
+ verbose = maybe_null(verbose)
176
+
177
+
178
+ rinternal = shapr.setup(
179
+ x_train = py2r(x_train),
180
+ x_explain = py2r(x_explain),
181
+ approach = StrVector(approach),
182
+ phi0 = phi0,
183
+ max_n_coalitions = maybe_null(max_n_coalitions),
184
+ group = r_group,
185
+ n_MC_samples = n_MC_samples,
186
+ seed = maybe_null(seed),
187
+ feature_specs = rfeature_specs,
188
+ verbose = verbose,
189
+ iterative = maybe_null(iterative),
190
+ iterative_args = iterative_args,
191
+ asymmetric = asymmetric,
192
+ causal_ordering = r_causal_ordering,
193
+ confounding = maybe_null(confounding),
194
+ output_args = output_args,
195
+ extra_computation_args = extra_computation_args,
196
+ init_time = init_time,
197
+ is_python = True,
198
+ model_class = model_class,
199
+ **kwargs
200
+ )
201
+
202
+ # Gets predict_model (if not passed to explain) and checks that predict_model gives correct format
203
+ predict_model = get_predict_model(x_test=x_train.head(2), predict_model=predict_model, model=model)
204
+
205
+ rinternal.rx2['timing_list'].rx2['test_prediction'] = base.Sys_time()
206
+
207
+ rinternal = additional_regression_setup(
208
+ rinternal,
209
+ model,
210
+ predict_model,
211
+ x_train,
212
+ x_explain)
213
+
214
+ # Not called for approach %in% c("regression_surrogate","vaeac")
215
+ rinternal = shapr.setup_approach(internal = rinternal) # model and predict_model are not supported in Python
216
+
217
+ rinternal.rx2['main_timing_list'] = rinternal.rx2['timing_list']
218
+
219
+ converged = False
220
+ iter = len(rinternal.rx2('iter_list'))
221
+
222
+ if seed is not None:
223
+ base.set_seed(seed)
224
+
225
+ shapr.cli_startup(rinternal, verbose)
226
+
227
+ rinternal.rx2['iter_timing_list'] = ro.ListVector({})
228
+
229
+ while not converged:
230
+ shapr.cli_iter(verbose, rinternal, iter)
231
+
232
+ rinternal.rx2['timing_list'] = ro.ListVector({'init': base.Sys_time()})
233
+
234
+ # Setup the Shapley framework
235
+ rinternal = shapr.shapley_setup(rinternal)
236
+
237
+ # Only actually called for approach in ["regression_surrogate", "vaeac"]
238
+ rinternal = shapr.setup_approach(rinternal)
239
+
240
+ # Compute the vS
241
+ vS_list = compute_vS(rinternal, model, predict_model)
242
+
243
+ # Compute Shapley value estimates and bootstrapped standard deviations
244
+ rinternal = shapr.compute_estimates(rinternal, vS_list)
245
+
246
+ # Check convergence based on estimates and standard deviations (and thresholds)
247
+ rinternal = shapr.check_convergence(rinternal)
248
+
249
+ # Save intermediate results
250
+ shapr.save_results(rinternal)
251
+
252
+ # Preparing parameters for next iteration (does not do anything if already converged)
253
+ rinternal = shapr.prepare_next_iteration(rinternal)
254
+
255
+ # Printing iteration information
256
+ shapr.print_iter(rinternal)
257
+
258
+ # Setting globals to simplify the loop
259
+ converged = rinternal.rx2('iter_list')[iter-1].rx2('converged')[0]
260
+
261
+ rinternal.rx2['timing_list'] = ro.ListVector({**dict(rinternal.rx2['timing_list'].items()), 'postprocess_res': base.Sys_time()})
262
+
263
+ # Add the current timing_list to the iter_timing_list
264
+ rinternal.rx2['iter_timing_list'] = ro.ListVector({**dict(rinternal.rx2['iter_timing_list'].items()), f'element_{iter}': rinternal.rx2['timing_list']})
265
+
266
+ iter += 1
267
+
268
+ rinternal.rx2['main_timing_list'] = ro.ListVector({**dict(rinternal.rx2['main_timing_list'].items()), 'main_computation': base.Sys_time()})
269
+
270
+ # Rerun after convergence to get the same output format as for the non-iterative approach
271
+ routput = shapr.finalize_explanation(rinternal)
272
+
273
+ rinternal.rx2['main_timing_list'] = ro.ListVector({**dict(rinternal.rx2['main_timing_list'].items()), 'finalize_explanation': base.Sys_time()})
274
+
275
+
276
+ routput.rx2['timing'] = shapr.compute_time(rinternal)
277
+
278
+ # Some cleanup when doing testing
279
+ testing = rinternal.rx2('parameters').rx2('testing')[0]
280
+ if testing:
281
+ routput = shapr.testing_cleanup(routput)
282
+
283
+ # Convert R objects to Python objects
284
+ shapley_values_est = recurse_r_tree(routput.rx2('shapley_values_est'))
285
+ shapley_values_sd = recurse_r_tree(routput.rx2('shapley_values_sd'))
286
+ pred_explain = recurse_r_tree(routput.rx2('pred_explain'))
287
+ MSEv = recurse_r_tree(routput.rx2('MSEv'))
288
+ iterative_results = recurse_r_tree(routput.rx2('iterative_results'))
289
+ saving_path = recurse_r_tree(routput.rx2('saving_path'))
290
+ internal = recurse_r_tree(routput.rx2('internal'))
291
+ timing = recurse_r_tree(routput.rx2('timing'))
292
+
293
+ return {
294
+ "shapley_values_est": shapley_values_est,
295
+ "shapley_values_sd": shapley_values_sd,
296
+ "pred_explain": pred_explain,
297
+ "MSEv": MSEv,
298
+ "iterative_results": iterative_results,
299
+ "saving_path": saving_path,
300
+ "internal": internal,
301
+ "timing": timing,
302
+ }
303
+
304
+
305
+ def compute_vS(rinternal, model, predict_model):
306
+
307
+ iter = len(rinternal.rx2('iter_list'))
308
+
309
+ S_batch = rinternal.rx2('iter_list')[iter-1].rx2('S_batch')
310
+
311
+ # verbose
312
+ shapr.cli_compute_vS(rinternal)
313
+
314
+ stats.rnorm(1) # Perform a single sample to forward the RNG state one step. This is done to ensurie consistency with
315
+ # future.apply::future_lapply in R which does this to to guarantee consistency for parallellization.
316
+ # See ?future.apply::future_lapply for details
317
+
318
+ vS_list = ro.ListVector({})
319
+ for i, S in enumerate(S_batch):
320
+ vS_list.rx2[i+1] = batch_compute_vS(S=S, rinternal=rinternal, model=model, predict_model=predict_model)
321
+
322
+ #### Adds v_S output above to any vS_list already computed ####
323
+ vS_list = shapr.append_vS_list(vS_list,rinternal)
324
+
325
+ return vS_list
326
+
327
+
328
+ def batch_compute_vS(S, rinternal, model, predict_model):
329
+ regression = rinternal.rx2('parameters').rx2('regression')[0]
330
+
331
+ # Check if we are to use regression or Monte Carlo integration to compute the contribution function values
332
+ if regression:
333
+ dt_vS = shapr.batch_prepare_vS_regression(S=S, internal=rinternal)
334
+ else:
335
+ # dt_vS is either only dt_vS or a list containing dt_vS and dt if internal$parameters$output_args$keep_samp_for_vS = TRUE
336
+ dt_vS = batch_prepare_vS_MC(S=S, rinternal=rinternal, model=model, predict_model=predict_model)
337
+
338
+ return dt_vS
339
+
340
+
341
+ def batch_prepare_vS_MC_old(S, rinternal, model, predict_model):
342
+ keep_samp_for_vS = rinternal.rx2('parameters').rx2('keep_samp_for_vS')[0]
343
+ feature_names = list(rinternal.rx2('parameters').rx2('feature_names'))
344
+
345
+ dt = shapr.batch_prepare_vS_MC_auxiliary(S=S, internal=rinternal)
346
+
347
+ dt = compute_preds(dt=dt, feature_names=feature_names, predict_model=predict_model, model=model)
348
+
349
+ dt_vS = shapr.compute_MCint(dt)
350
+
351
+ if keep_samp_for_vS:
352
+ return ro.ListVector({'dt_vS':dt_vS, 'dt_samp_for_vS':dt})
353
+ else:
354
+ return dt_vS
355
+
356
+ def batch_prepare_vS_MC(S, rinternal, model, predict_model):
357
+ feature_names = list(rinternal.rx2('parameters').rx2('feature_names'))
358
+ keep_samp_for_vS = rinternal.rx2('parameters').rx2('output_args').rx2('keep_samp_for_vS')[0]
359
+ causal_sampling = rinternal.rx2('parameters').rx2('causal_sampling')[0]
360
+ output_size = int(rinternal.rx2('parameters').rx2('output_size')[0])
361
+
362
+ dt = shapr.batch_prepare_vS_MC_auxiliary(S=S, internal=rinternal, causal_sampling=causal_sampling)
363
+
364
+ pred_cols = [f"p_hat{i+1}" for i in range(output_size)]
365
+ type_ = rinternal.rx2('parameters').rx2('type')[0]
366
+
367
+ if type_ == "forecast":
368
+ horizon = rinternal.rx2('parameters').rx2('horizon')[0]
369
+ n_endo = rinternal.rx2('data').rx2('n_endo')[0]
370
+ explain_idx = rinternal.rx2('parameters').rx2('explain_idx')[0]
371
+ explain_lags = rinternal.rx2('parameters').rx2('explain_lags')[0]
372
+ y = rinternal.rx2('data').rx2('y')
373
+ xreg = rinternal.rx2('data').rx2('xreg')
374
+ dt = compute_preds(
375
+ dt=dt,
376
+ feature_names=feature_names,
377
+ predict_model=predict_model,
378
+ model=model,
379
+ type_=type_,
380
+ horizon=horizon,
381
+ n_endo=n_endo,
382
+ explain_idx=explain_idx,
383
+ explain_lags=explain_lags,
384
+ y=y,
385
+ xreg=xreg
386
+ )
387
+ else:
388
+ dt = compute_preds(
389
+ dt=dt,
390
+ feature_names=feature_names,
391
+ predict_model=predict_model,
392
+ model=model,
393
+ type_=type_
394
+ )
395
+
396
+ dt_vS = shapr.compute_MCint(dt)
397
+
398
+ if keep_samp_for_vS:
399
+ return ro.ListVector({'dt_vS': dt_vS, 'dt_samp_for_vS': dt})
400
+ else:
401
+ return dt_vS
402
+
403
+ def compute_preds(
404
+ dt,
405
+ feature_names,
406
+ predict_model,
407
+ model,
408
+ type_,
409
+ horizon=None,
410
+ n_endo=None,
411
+ explain_idx=None,
412
+ explain_lags=None,
413
+ y=None,
414
+ xreg=None
415
+ ):
416
+ # Predictions
417
+ if type_ == "forecast":
418
+ preds = predict_model(
419
+ model,
420
+ r2py(dt).loc[:,:n_endo],
421
+ r2py(dt).loc[:,n_endo:],
422
+ horizon,
423
+ explain_idx,
424
+ explain_lags,
425
+ y,
426
+ xreg
427
+ )
428
+
429
+ else:
430
+ preds = predict_model(
431
+ model,
432
+ r2py(dt).loc[:,feature_names]
433
+ )
434
+
435
+ return ro.r.cbind(dt, p_hat=ro.FloatVector(preds.tolist()))
436
+
437
+
438
+
439
+ def compute_preds_old(dt, feature_names, predict_model, model):
440
+ preds = predict_model(model, r2py(dt).loc[:,feature_names])
441
+ return ro.r.cbind(dt, p_hat=ro.FloatVector(preds.tolist()))
442
+
443
+
444
+
445
+ def get_feature_specs(get_model_specs, model):
446
+ model_class0 = type(model)
447
+
448
+ if (get_model_specs is not None) and (not callable(get_model_specs)):
449
+ raise ValueError('`get_model_specs` must be None or callable.')
450
+
451
+ if get_model_specs is None:
452
+ get_model_specs = prebuilt_get_model_specs(model)
453
+ if get_model_specs is None:
454
+ warnings.warn(f'No pre-built get_model_specs for model of type {type(model)}, disabling checks.')
455
+ return NULL
456
+
457
+ if callable(get_model_specs):
458
+ try:
459
+ feature_specs = get_model_specs(model)
460
+ except Exception as e:
461
+ raise RuntimeError(f'The get_model_specs function of class `{model_class0}` is invalid.\nA basic function test threw the following error:\n{e}')
462
+
463
+ if not isinstance(feature_specs, dict):
464
+ raise ValueError(f'`get_model_specs` returned an object of type `{type(feature_specs)}`, but it should be of type `dict`')
465
+ if set(feature_specs.keys()) != set(["labels","classes","factor_levels"]):
466
+ raise ValueError(f'`get_model_specs` should return a `dict` with keys ["labels","classes","factor_levels"], but found keys {list(feature_specs.keys())}')
467
+
468
+ if feature_specs is None:
469
+ rfeature_specs = NULL
470
+ else:
471
+ py2r_or_na = lambda v: py2r(v) if v is not None else NA
472
+ def strvec_or_na(v):
473
+ if v is None: return NA
474
+ strvec = StrVector(list(v.values()))
475
+ strvec.names = list(v.keys())
476
+ return strvec
477
+ def listvec_or_na(v):
478
+ if v is None: return NA
479
+ return ro.ListVector({k:list(val) for k,val in v.items()})
480
+
481
+ rfeature_specs = ro.ListVector({
482
+ 'labels': py2r_or_na(feature_specs['labels']),
483
+ 'classes': strvec_or_na(feature_specs['classes']),
484
+ 'factor_levels': listvec_or_na(feature_specs['factor_levels']),
485
+ })
486
+ return rfeature_specs
487
+
488
+
489
+ def get_predict_model(x_test, predict_model, model):
490
+
491
+ model_class0 = type(model)
492
+
493
+ if (predict_model is not None) and (not callable(predict_model)):
494
+ raise RuntimeError(f'The predict_model function of class `{model}` is invalid.\nA basic function test threw the following error:\n{e}')
495
+
496
+ if predict_model is None:
497
+ predict_model = prebuilt_predict_model(model)
498
+ if predict_model is None:
499
+ raise ValueError(f'No pre-built predict_model for model of type {type(model)}. Please pass a custom predict_model to shaprpy.explain(...).')
500
+
501
+ try:
502
+ tmp = py2r(predict_model(model, x_test))
503
+ except Exception as e:
504
+ raise RuntimeError(f'The predict_model function of class `{model_class0}` is invalid.\nA basic function test threw the following error:\n{e}')
505
+ if not all(base.is_numeric(tmp)):
506
+ raise RuntimeError('The output of predict_model is expected to be numeric.')
507
+ if not (len(tmp) == 2):
508
+ raise RuntimeError('The output of predict_model does not match the length of the input.')
509
+ return predict_model
510
+
511
+
512
+ def prebuilt_get_model_specs(model):
513
+
514
+ # Look for sklearn
515
+ try:
516
+ from sklearn.base import BaseEstimator
517
+ if isinstance(model, BaseEstimator):
518
+ return lambda m: {
519
+ 'labels': m.feature_names_in_,
520
+ 'classes': None, # Not available from model object
521
+ 'factor_levels': None, # Not available from model object
522
+ }
523
+ except:
524
+ pass
525
+
526
+ # Look for xgboost.core.Booster
527
+ try:
528
+ import xgboost as xgb
529
+ if isinstance(model, xgb.core.Booster):
530
+ return lambda m: {
531
+ 'labels': np.array(m.feature_names),
532
+ 'classes': None, # Not available from model object
533
+ 'factor_levels': None, # Not available from model object
534
+ }
535
+ except:
536
+ pass
537
+
538
+ return None
539
+
540
+
541
+ def prebuilt_predict_model(model):
542
+
543
+ # Look for sklearn
544
+ try:
545
+ from sklearn.base import is_classifier, is_regressor
546
+ if is_classifier(model): return lambda m, x: m.predict_proba(x)[:,1]
547
+ if is_regressor(model): return lambda m, x: m.predict(x).flatten()
548
+ except:
549
+ pass
550
+
551
+ # Look for xgboost.core.Booster
552
+ try:
553
+ import xgboost as xgb
554
+ if isinstance(model, xgb.core.Booster):
555
+ return lambda m, x: m.predict(xgb.DMatrix(x))
556
+ except:
557
+ pass
558
+
559
+ # Look for keras
560
+ try:
561
+ from keras.models import Model
562
+ if isinstance(model, Model):
563
+ def predict_fn(m,x):
564
+ pred = m.predict(x)
565
+ return pred.reshape(pred.shape[0],)
566
+ return predict_fn
567
+ except:
568
+ pass
569
+
570
+ return None
571
+
572
+
573
+ def compute_time(timing_list):
574
+
575
+ timing_secs = {
576
+ f'{key}': (timing_list[key] - timing_list[prev_key]).total_seconds()
577
+ for key, prev_key in zip(list(timing_list.keys())[1:], list(timing_list.keys())[:-1])
578
+ }
579
+ timing_output = {
580
+ 'init_time': timing_list['init_time'].strftime("%Y-%m-%d %H:%M:%S"),
581
+ 'total_time_secs': sum(timing_secs.values()),
582
+ 'timing_secs': timing_secs
583
+ }
584
+
585
+ return timing_output
586
+
587
+
588
+ def additional_regression_setup(rinternal, model, predict_model, x_train, x_explain):
589
+ # Add the predicted response of the training and explain data to the internal list for regression-based methods
590
+ regression = rinternal.rx2("parameters").rx2("regression")[0]
591
+ if regression:
592
+ rinternal = regression_get_y_hat(rinternal, model, predict_model, x_train, x_explain)
593
+
594
+ return rinternal
595
+
596
+
597
+ def regression_get_y_hat(rinternal, model, predict_model, x_train, x_explain):
598
+ x_train_y_hat = predict_model(model, x_train)
599
+ x_explain_y_hat = predict_model(model, x_explain)
600
+
601
+ # Extract data list, add the predicted responses, and then updated rinternal (direct assignment did not work)
602
+ data = rinternal.rx2['data']
603
+ data.rx2['x_train_y_hat'] = ro.FloatVector(x_train_y_hat.tolist())
604
+ data.rx2['x_explain_y_hat'] = ro.FloatVector(x_explain_y_hat.tolist())
605
+ rinternal.rx2['data'] = data
606
+
607
+ return rinternal
608
+
609
+
610
+ def regression_remove_objects(routput):
611
+ tmp_internal = routput.rx2("internal")
612
+ tmp_parameters = tmp_internal.rx2("parameters")
613
+ objects = StrVector(("regression", "regression.model", "regression.tune_values", "regression.vfold_cv_para",
614
+ "regression.recipe_func", "regression.tune", "regression.surrogate_n_comb"))
615
+ tmp_parameters.rx[objects] = NULL
616
+ tmp_internal.rx2["parameters"] = tmp_parameters
617
+ if tmp_parameters.rx2("approach")[0] == "regression_surrogate":
618
+ tmp_objects = tmp_internal.rx2("objects")
619
+ tmp_objects.rx["regression.surrogate_model"] = NULL
620
+ tmp_internal.rx2["objects"] = tmp_objects
621
+ routput.rx2["internal"] = tmp_internal
622
+ return routput
623
+
624
+
625
+ def change_first_underscore_to_dot(kwargs):
626
+ kwargs_tmp = {}
627
+ for k, v in kwargs.items():
628
+ kwargs_tmp[k.replace('_', '.', 1)] = v
629
+ return kwargs_tmp
630
+
shaprpy/datasets.py ADDED
@@ -0,0 +1,27 @@
1
+ import pandas as pd
2
+ from sklearn.datasets import fetch_openml, fetch_california_housing, load_iris
3
+ from sklearn.model_selection import train_test_split
4
+
5
+
6
+ def load_california_housing():
7
+ housing = fetch_california_housing()
8
+ dfx = pd.DataFrame(housing.data, columns=housing.feature_names)
9
+ dfy = pd.DataFrame({'target': housing.target})
10
+ dfx_train, dfx_test, dfy_train, dfy_test = train_test_split(dfx, dfy, test_size=0.99, random_state=42)
11
+ dfx_test, dfy_test = dfx_test[:5], dfy_test[:5] # To reduce computational load
12
+ return dfx_train, dfx_test, dfy_train, dfy_test
13
+
14
+
15
+ def load_binary_iris():
16
+ bcancer = load_iris()
17
+ dfx = pd.DataFrame(bcancer.data, columns=bcancer.feature_names).iloc[bcancer.target<2] # Turning it into a binary classification problem
18
+ dfy = pd.DataFrame({'target': bcancer.target}).iloc[bcancer.target<2] # Turning it into a binary classification problem
19
+ dfx_train, dfx_test, dfy_train, dfy_test = train_test_split(dfx, dfy, test_size=5, random_state=42)
20
+ return dfx_train, dfx_test, dfy_train, dfy_test
21
+
22
+
23
+ def load_adult():
24
+ dfx, y = fetch_openml(data_id=1590, return_X_y=True)
25
+ dfx = dfx.dropna(axis=1) # Drop columns with NAs for simplicity
26
+ dfy = pd.DataFrame({'target': y.factorize()[0]})
27
+ return train_test_split(dfx, dfy, test_size=5, random_state=42)
shaprpy/utils.py ADDED
@@ -0,0 +1,67 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ from rpy2.robjects.conversion import localconverter
4
+ from rpy2.robjects import default_converter, Formula
5
+ from rpy2.robjects.functions import SignatureTranslatedFunction
6
+ from rpy2.robjects.numpy2ri import converter as np_converter
7
+ from rpy2.robjects.pandas2ri import converter as pd_converter
8
+ from rpy2.robjects.pandas2ri import _to_pandas_factor
9
+ from rpy2.rinterface import NULL, NA
10
+ from rpy2.robjects.vectors import DataFrame, FloatVector, IntVector, BoolVector, StrVector, ListVector, FactorVector, FloatMatrix, Matrix, POSIXct
11
+ import warnings
12
+
13
+
14
+ @pd_converter.rpy2py.register(FactorVector)
15
+ def rpt2py_factorvector(obj):
16
+ return _to_pandas_factor(obj)
17
+
18
+
19
+ def py2r(obj):
20
+ with localconverter(default_converter + np_converter + pd_converter) as converter:
21
+ robj = converter.py2rpy(obj)
22
+ return robj
23
+
24
+
25
+ def r2py(robj):
26
+ converter = default_converter + np_converter + pd_converter
27
+ obj = converter.rpy2py(robj)
28
+ return obj
29
+
30
+
31
+ def recurse_r_tree(data):
32
+ if data == NULL:
33
+ return None
34
+ elif type(data) == DataFrame:
35
+ try:
36
+ return r2py(data)
37
+ except Exception as e:
38
+ # The column "features" in internal$objects$X is known to cause problems
39
+ d = {}
40
+ for col in data.names:
41
+ try:
42
+ d[col] = r2py(data.rx2(col))
43
+ except:
44
+ # We manually convert the elements of the column "features" in internal$objects$X
45
+ d[col] = [r2py(d) for d in data.rx2(col)]
46
+ return pd.DataFrame(d, index=data.rownames)
47
+ elif type(data) in [FloatVector, IntVector, BoolVector, FloatMatrix, Matrix]:
48
+ return np.array(data)
49
+ elif type(data) == FactorVector:
50
+ return _to_pandas_factor(data)
51
+ elif type(data) == POSIXct:
52
+ with warnings.catch_warnings():
53
+ warnings.simplefilter("ignore")
54
+ tmp = r2py(data).strftime("%Y-%m-%d %H:%M:%S")[0]
55
+ return tmp
56
+ elif type(data) == SignatureTranslatedFunction:
57
+ return str(data)
58
+ elif type(data) == Formula:
59
+ return str(data)
60
+ elif type(data) == ListVector:
61
+ if type(data.names) == type(NULL):
62
+ data.names = [f"element_{i+1}" for i in range(len(data))]
63
+ return dict(zip(data.names, [recurse_r_tree(d) for d in data]))
64
+ elif type(data) == StrVector:
65
+ return [recurse_r_tree(d) for d in data]
66
+ else:
67
+ return data # We reached the end of recursion (if not converted below, return the object as is)
@@ -0,0 +1,168 @@
1
+ Metadata-Version: 2.4
2
+ Name: shaprpy
3
+ Version: 0.3.0
4
+ Summary: Python wrapper for the R package shapr (via rpy2)
5
+ Author: Martin Jullum, Lars Henry Berge Olsen, Didrik Nielsen
6
+ License: YEAR: 2019
7
+ COPYRIGHT HOLDER: Norsk Regnesentral
8
+
9
+ Project-URL: Homepage, https://github.com/NorskRegnesentral/shapr
10
+ Project-URL: Documentation, https://norskregnesentral.github.io/shapr/shaprpy.html
11
+ Project-URL: Issues, https://github.com/NorskRegnesentral/shapr/issues
12
+ Project-URL: Changelog, https://github.com/NorskRegnesentral/shapr/blob/main/python/CHANGELOG.md
13
+ Keywords: explainable-ai,shapley-values,machine-learning,model-interpretability
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3 :: Only
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Programming Language :: Python :: 3.13
20
+ Classifier: License :: OSI Approved :: MIT License
21
+ Classifier: Operating System :: POSIX :: Linux
22
+ Classifier: Intended Audience :: Science/Research
23
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
24
+ Requires-Python: >3.10
25
+ Description-Content-Type: text/markdown
26
+ License-File: LICENSE
27
+ License-File: LICENSE.md
28
+ Requires-Dist: rpy2>=3.5.1
29
+ Requires-Dist: numpy>=1.22.3
30
+ Requires-Dist: pandas>=1.4.2
31
+ Requires-Dist: scikit-learn>=1.0.0
32
+ Requires-Dist: tabulate>=0.8.10
33
+ Provides-Extra: test
34
+ Requires-Dist: pytest>=7.0.0; extra == "test"
35
+ Requires-Dist: syrupy>=4.0.0; extra == "test"
36
+ Requires-Dist: xgboost>=1.5.0; extra == "test"
37
+ Dynamic: license-file
38
+
39
+ # shaprpy
40
+
41
+ Python wrapper for the R package [shapr](https://github.com/NorskRegnesentral/shapr).
42
+
43
+ NOTE: This wrapper is not as comprehensively tested as the `R`-package.
44
+
45
+ `shaprpy` relies heavily on the `rpy2` Python library for accessing R from within Python.
46
+ `rpy2` has limited support on Windows. `shaprpy` has only been tested on Linux.
47
+ The below instructions assumes a Linux environment.
48
+
49
+ # shaprpy
50
+
51
+ `shaprpy` is a Python wrapper for the R package [shapr](https://github.com/NorskRegnesentral/shapr),
52
+ using the [`rpy2`](https://rpy2.github.io/) Python library to access R from within Python.
53
+
54
+ > **Note:** This wrapper is **not** as comprehensively tested as the R package.
55
+ > `rpy2` has limited support on Windows, and the same therefore applies to `shaprpy`.
56
+ > `shaprpy` has only been tested on Linux (and WSL - Windows Subsystem for Linux), and the below instructions assume a Linux environment.
57
+ >
58
+ > **Requirement:** Python 3.10 or later is required to use `shaprpy`.
59
+
60
+ ## Changelog
61
+
62
+ For a list of changes and updates to the `shaprpy` package, see the [shaprpy CHANGELOG](https://norskregnesentral.github.io/shapr/py_changelog.html).
63
+
64
+ ---
65
+
66
+ ## Installation
67
+
68
+ These instructions assume you already have **pip** and **R** installed and available to the Python environment in which you want to run `shaprpy`.
69
+
70
+ - Official instructions for installing `pip` can be found [here](https://pip.pypa.io/en/stable/installation/).
71
+ - Official instructions for installing R can be found [here](https://cran.r-project.org/).
72
+
73
+ On Debian/Ubuntu-based systems, R can also be installed via:
74
+ ```bash
75
+ sudo apt update
76
+ sudo apt install r-base r-base-dev -y
77
+ ```
78
+
79
+ ### 1. Install the R package `shapr`
80
+
81
+ `shaprpy` requires the R package `shapr` (version 1.0.5 or newer).
82
+ In your R environment, install the latest version from CRAN using:
83
+
84
+ ```bash
85
+ Rscript -e 'install.packages("shapr", repos="https://cran.rstudio.com")'
86
+ ```
87
+
88
+ ### 2. Ensure R is discoverable (R_HOME and PATH)
89
+
90
+ Sometimes `shaprpy` cannot automatically locate your R installation. To ensure proper detection, verify that:
91
+ - R is available in your system `PATH`, **or**
92
+ - The `R_HOME` environment variable is set to your R installation directory.
93
+
94
+ Example:
95
+ ```bash
96
+ export R_HOME=$(R RHOME)
97
+ export PATH=$PATH:$(R RHOME)/bin
98
+ ```
99
+
100
+ ### 3. Install the Python wrapper
101
+
102
+ Install directly from PyPI with:
103
+
104
+ ```bash
105
+ pip install shaprpy
106
+ ```
107
+
108
+ #### Local development install (for contributors)
109
+ If you have cloned the repository and want to install in development mode for local changes, navigate to the `./python` directory and run:
110
+ ```bash
111
+ pip install -e .
112
+ ```
113
+ The `-e` flag installs in editable mode, allowing local code changes to be reflected immediately.
114
+
115
+ ---
116
+
117
+ ## Quick Demo
118
+
119
+ ```python
120
+ from sklearn.ensemble import RandomForestRegressor
121
+ from shaprpy import explain
122
+ from shaprpy.datasets import load_california_housing
123
+
124
+ # Load example data
125
+ dfx_train, dfx_test, dfy_train, dfy_test = load_california_housing()
126
+
127
+ # Fit a model
128
+ model = RandomForestRegressor()
129
+ model.fit(dfx_train, dfy_train.values.flatten())
130
+
131
+ # Explain predictions
132
+ explanation = explain(
133
+ model=model,
134
+ x_train=dfx_train,
135
+ x_explain=dfx_test,
136
+ approach="empirical",
137
+ phi0=dfy_train.mean().item(),
138
+ seed=1
139
+ )
140
+
141
+ print(explanation["shapley_values_est"])
142
+ ```
143
+
144
+ ---
145
+
146
+ ## Supported Models
147
+
148
+ `shaprpy` can explain predictions from models built with:
149
+ - [`scikit-learn`](https://scikit-learn.org/)
150
+ - [`keras`](https://keras.io/) (Sequential API)
151
+ - [`xgboost`](https://xgboost.readthedocs.io/)
152
+
153
+ For other model types, you can supply:
154
+ - A custom `predict_model` function
155
+ - (Optionally) a custom `get_model_specs` function
156
+ to `shaprpy.explain`.
157
+
158
+ ---
159
+
160
+ ## Examples
161
+
162
+ See the `/examples` folder for runnable examples, including:
163
+ - A custom PyTorch model
164
+ - The **regression paradigm** described in [Olsen et al. (2024)](https://link.springer.com/article/10.1007/s10618-024-01016-z),
165
+ which shows:
166
+ - How to specify the regression model
167
+ - How to enable automatic cross-validation of hyperparameters
168
+ - How to apply pre-processing steps before fitting regression models
@@ -0,0 +1,10 @@
1
+ shaprpy/__init__.py,sha256=Z7UA4ws5uuPUny9uRpnn6o3PGv5ONWKjhRzBOv-YJJc,1528
2
+ shaprpy/_explain.py,sha256=k_OH_wqewAB6iRodW2aWI_gprw0hra9vyI3Ir-jQkqE,23831
3
+ shaprpy/datasets.py,sha256=ptX-fcD9evlpEdqXrmCTXwcYTao0vTPv93k5C5PYS2g,1295
4
+ shaprpy/utils.py,sha256=_kD5ZW9bp3kyTXhtS-jz4zKAdql4pOH6KZehFJTxaRk,2461
5
+ shaprpy-0.3.0.dist-info/licenses/LICENSE,sha256=hABnnsrNduCr-dp2Cy4-XOroRGzhaw7LefBhxlY0gjQ,48
6
+ shaprpy-0.3.0.dist-info/licenses/LICENSE.md,sha256=9Gwz5ov3rmu9AK8fTEZUj4pPXL5ROOlzcNqukEKxNOY,1077
7
+ shaprpy-0.3.0.dist-info/METADATA,sha256=82ke02a2Al423a_GoQyg9UQthpGgIOBEAHhjPXPP9fU,5675
8
+ shaprpy-0.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
+ shaprpy-0.3.0.dist-info/top_level.txt,sha256=4LqZ5SreryN3N5zgxtxKjyAeDUo6ggO9w1s0eRTak78,8
10
+ shaprpy-0.3.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ YEAR: 2019
2
+ COPYRIGHT HOLDER: Norsk Regnesentral
@@ -0,0 +1,21 @@
1
+ # MIT License
2
+
3
+ Copyright (c) 2019 Norsk Regnesentral
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ shaprpy