edef 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edef/__init__.py +21 -0
- edef/_linear.py +680 -0
- edef/_numerical.py +292 -0
- edef/_results.py +225 -0
- edef/_torch.py +294 -0
- edef/_tree.py +385 -0
- edef-0.1.0.dist-info/METADATA +601 -0
- edef-0.1.0.dist-info/RECORD +11 -0
- edef-0.1.0.dist-info/WHEEL +5 -0
- edef-0.1.0.dist-info/licenses/LICENSE +11 -0
- edef-0.1.0.dist-info/top_level.txt +1 -0
edef/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from ._linear import (
|
|
2
|
+
LinearExplainer,
|
|
3
|
+
linear_logistic_components,
|
|
4
|
+
linear_multiclass_components,
|
|
5
|
+
linear_regression_components,
|
|
6
|
+
)
|
|
7
|
+
from ._results import EDEFExplanation
|
|
8
|
+
from ._torch import TorchExplainer
|
|
9
|
+
from ._tree import TreeExplainer
|
|
10
|
+
from ._numerical import NumericalExplainer
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"EDEFExplanation",
|
|
14
|
+
"LinearExplainer",
|
|
15
|
+
"TorchExplainer",
|
|
16
|
+
"TreeExplainer",
|
|
17
|
+
"NumericalExplainer",
|
|
18
|
+
"linear_logistic_components",
|
|
19
|
+
"linear_multiclass_components",
|
|
20
|
+
"linear_regression_components",
|
|
21
|
+
]
|
edef/_linear.py
ADDED
|
@@ -0,0 +1,680 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ._results import EDEFExplanation
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def linear_regression_components(
|
|
9
|
+
y,
|
|
10
|
+
components,
|
|
11
|
+
*,
|
|
12
|
+
feature_names=None,
|
|
13
|
+
check_additivity: bool = True,
|
|
14
|
+
atol: float = 1e-10,
|
|
15
|
+
) -> EDEFExplanation:
|
|
16
|
+
"""
|
|
17
|
+
Closed-form EDEF for linear regression with squared-error loss.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
y : array-like, shape (n_obs,)
|
|
22
|
+
Realized outcomes.
|
|
23
|
+
|
|
24
|
+
components : array-like, shape (n_obs, n_features)
|
|
25
|
+
Fitted linear signal components. For a linear model, this is typically
|
|
26
|
+
|
|
27
|
+
components[:, j] = X[:, j] * beta[j]
|
|
28
|
+
|
|
29
|
+
The fitted prediction, excluding intercept effects, is
|
|
30
|
+
|
|
31
|
+
y_hat = components.sum(axis=1)
|
|
32
|
+
|
|
33
|
+
feature_names : sequence of str, optional
|
|
34
|
+
Feature names.
|
|
35
|
+
|
|
36
|
+
check_additivity : bool, default=True
|
|
37
|
+
Whether to check that feature contributions add to total fit improvement.
|
|
38
|
+
|
|
39
|
+
atol : float, default=1e-10
|
|
40
|
+
Absolute tolerance for the additivity check.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
EDEFExplanation
|
|
45
|
+
EDEF result object.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
y = np.asarray(y, dtype=float).reshape(-1)
|
|
49
|
+
components = np.asarray(components, dtype=float)
|
|
50
|
+
|
|
51
|
+
if components.ndim != 2:
|
|
52
|
+
raise ValueError("components must have shape (n_obs, n_features).")
|
|
53
|
+
|
|
54
|
+
n_obs, n_features = components.shape
|
|
55
|
+
|
|
56
|
+
if y.shape[0] != n_obs:
|
|
57
|
+
raise ValueError("y and components must have the same number of observations.")
|
|
58
|
+
|
|
59
|
+
if n_obs < 2:
|
|
60
|
+
raise ValueError("At least two observations are required.")
|
|
61
|
+
|
|
62
|
+
if not np.all(np.isfinite(y)):
|
|
63
|
+
raise ValueError("y must contain only finite values.")
|
|
64
|
+
|
|
65
|
+
if not np.all(np.isfinite(components)):
|
|
66
|
+
raise ValueError("components must contain only finite values.")
|
|
67
|
+
|
|
68
|
+
if feature_names is None:
|
|
69
|
+
feature_names = [f"x{i}" for i in range(n_features)]
|
|
70
|
+
else:
|
|
71
|
+
feature_names = list(feature_names)
|
|
72
|
+
if len(feature_names) != n_features:
|
|
73
|
+
raise ValueError("feature_names must have length n_features.")
|
|
74
|
+
|
|
75
|
+
y_centered = y - y.mean()
|
|
76
|
+
|
|
77
|
+
components_centered = components - components.mean(axis=0)
|
|
78
|
+
prediction_centered = components_centered.sum(axis=1)
|
|
79
|
+
|
|
80
|
+
baseline_loss = np.mean(y_centered**2)
|
|
81
|
+
model_loss = np.mean((y_centered - prediction_centered) ** 2)
|
|
82
|
+
total = baseline_loss - model_loss
|
|
83
|
+
|
|
84
|
+
shared_term = 2.0 * y_centered - prediction_centered
|
|
85
|
+
|
|
86
|
+
observation_values = components_centered * shared_term[:, None]
|
|
87
|
+
values = observation_values.mean(axis=0)
|
|
88
|
+
|
|
89
|
+
standard_errors = observation_values.std(axis=0, ddof=1) / np.sqrt(n_obs)
|
|
90
|
+
|
|
91
|
+
additivity_error = values.sum() - total
|
|
92
|
+
|
|
93
|
+
if check_additivity and abs(additivity_error) > atol:
|
|
94
|
+
raise RuntimeError(
|
|
95
|
+
"EDEF contributions do not add to total fit improvement. "
|
|
96
|
+
f"Additivity error: {additivity_error}"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
return EDEFExplanation(
|
|
100
|
+
values=values,
|
|
101
|
+
observation_values=observation_values,
|
|
102
|
+
standard_errors=standard_errors,
|
|
103
|
+
total=total,
|
|
104
|
+
baseline_loss=baseline_loss,
|
|
105
|
+
model_loss=model_loss,
|
|
106
|
+
loss="squared_error",
|
|
107
|
+
model_type="linear_regression_components",
|
|
108
|
+
feature_names=feature_names,
|
|
109
|
+
n_obs=n_obs,
|
|
110
|
+
additivity_error=additivity_error,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
class LinearExplainer:
|
|
114
|
+
"""
|
|
115
|
+
SHAP-style EDEF explainer for linear models.
|
|
116
|
+
|
|
117
|
+
Version 1 supports fitted linear regression models with a 1D ``coef_``
|
|
118
|
+
attribute. The model intercept is absorbed into the intercept-only
|
|
119
|
+
baseline used by EDEF.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
model,
|
|
125
|
+
baseline=None,
|
|
126
|
+
*,
|
|
127
|
+
loss: str = "squared_error",
|
|
128
|
+
feature_names=None,
|
|
129
|
+
):
|
|
130
|
+
if loss not in {"squared_error", "log_loss"}:
|
|
131
|
+
raise ValueError("LinearExplainer supports squared_error and log_loss.")
|
|
132
|
+
|
|
133
|
+
self.model = model
|
|
134
|
+
self.baseline = baseline
|
|
135
|
+
self.loss = loss
|
|
136
|
+
self.feature_names = feature_names
|
|
137
|
+
|
|
138
|
+
self.coef_ = self._get_coef(model)
|
|
139
|
+
|
|
140
|
+
def __call__(
|
|
141
|
+
self,
|
|
142
|
+
X,
|
|
143
|
+
y,
|
|
144
|
+
*,
|
|
145
|
+
feature_names=None,
|
|
146
|
+
check_additivity: bool = True,
|
|
147
|
+
atol: float = 1e-10,
|
|
148
|
+
):
|
|
149
|
+
X = np.asarray(X, dtype=float)
|
|
150
|
+
|
|
151
|
+
if X.ndim != 2:
|
|
152
|
+
raise ValueError("X must have shape (n_obs, n_features).")
|
|
153
|
+
|
|
154
|
+
n_features = X.shape[1]
|
|
155
|
+
|
|
156
|
+
if self.coef_.ndim == 1:
|
|
157
|
+
coef_n_features = self.coef_.shape[0]
|
|
158
|
+
else:
|
|
159
|
+
coef_n_features = self.coef_.shape[1]
|
|
160
|
+
|
|
161
|
+
if coef_n_features != n_features:
|
|
162
|
+
raise ValueError(
|
|
163
|
+
"Model coefficient dimension does not match X. "
|
|
164
|
+
f"coef has {coef_n_features} features, "
|
|
165
|
+
f"but X has {n_features} columns."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
names = feature_names
|
|
169
|
+
if names is None:
|
|
170
|
+
names = self.feature_names
|
|
171
|
+
if names is None:
|
|
172
|
+
names = self._get_feature_names(X, n_features)
|
|
173
|
+
|
|
174
|
+
if self.loss == "squared_error":
|
|
175
|
+
if self.coef_.ndim != 1:
|
|
176
|
+
raise ValueError(
|
|
177
|
+
"squared_error requires a single-output linear model "
|
|
178
|
+
"with a 1D coef_ vector."
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
components = X * self.coef_[None, :]
|
|
182
|
+
|
|
183
|
+
return linear_regression_components(
|
|
184
|
+
y,
|
|
185
|
+
components,
|
|
186
|
+
feature_names=names,
|
|
187
|
+
check_additivity=check_additivity,
|
|
188
|
+
atol=atol,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
if self.loss == "log_loss":
|
|
192
|
+
eta = self._decision_function(X)
|
|
193
|
+
|
|
194
|
+
if self.coef_.ndim == 1:
|
|
195
|
+
components = X * self.coef_[None, :]
|
|
196
|
+
|
|
197
|
+
return linear_logistic_components(
|
|
198
|
+
y,
|
|
199
|
+
components,
|
|
200
|
+
eta=eta,
|
|
201
|
+
include_intercept_component=True,
|
|
202
|
+
feature_names=names,
|
|
203
|
+
check_additivity=check_additivity,
|
|
204
|
+
atol=atol,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
if self.coef_.ndim == 2:
|
|
208
|
+
components = X[:, None, :] * self.coef_[None, :, :]
|
|
209
|
+
|
|
210
|
+
return linear_multiclass_components(
|
|
211
|
+
y,
|
|
212
|
+
components,
|
|
213
|
+
eta=eta,
|
|
214
|
+
include_intercept_component=True,
|
|
215
|
+
feature_names=names,
|
|
216
|
+
check_additivity=check_additivity,
|
|
217
|
+
atol=atol,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
raise RuntimeError(f"Unexpected loss: {self.loss}")
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _decision_function(self, X) -> np.ndarray:
|
|
224
|
+
if hasattr(self.model, "decision_function"):
|
|
225
|
+
eta = self.model.decision_function(X)
|
|
226
|
+
elif hasattr(self.model, "intercept_"):
|
|
227
|
+
intercept = np.asarray(self.model.intercept_, dtype=float).reshape(-1)
|
|
228
|
+
|
|
229
|
+
if self.coef_.ndim == 1:
|
|
230
|
+
if intercept.size != 1:
|
|
231
|
+
raise ValueError(
|
|
232
|
+
"Binary classification requires a scalar intercept."
|
|
233
|
+
)
|
|
234
|
+
eta = X @ self.coef_ + intercept[0]
|
|
235
|
+
else:
|
|
236
|
+
if intercept.size != self.coef_.shape[0]:
|
|
237
|
+
raise ValueError(
|
|
238
|
+
"Multiclass classification requires one intercept per class."
|
|
239
|
+
)
|
|
240
|
+
eta = X @ self.coef_.T + intercept.reshape(1, -1)
|
|
241
|
+
else:
|
|
242
|
+
raise TypeError(
|
|
243
|
+
"log_loss requires a model with decision_function or intercept_."
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
eta = np.asarray(eta, dtype=float)
|
|
247
|
+
|
|
248
|
+
if self.coef_.ndim == 1:
|
|
249
|
+
eta = eta.reshape(-1)
|
|
250
|
+
if eta.shape[0] != X.shape[0]:
|
|
251
|
+
raise ValueError("decision_function output must have length n_obs.")
|
|
252
|
+
else:
|
|
253
|
+
expected_shape = (X.shape[0], self.coef_.shape[0])
|
|
254
|
+
if eta.shape != expected_shape:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
"decision_function output must have shape "
|
|
257
|
+
"(n_obs, n_classes)."
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
if not np.all(np.isfinite(eta)):
|
|
261
|
+
raise ValueError("decision_function output must contain only finite values.")
|
|
262
|
+
|
|
263
|
+
return eta
|
|
264
|
+
|
|
265
|
+
@staticmethod
|
|
266
|
+
def _get_coef(model) -> np.ndarray:
|
|
267
|
+
if not hasattr(model, "coef_"):
|
|
268
|
+
raise TypeError(
|
|
269
|
+
"LinearExplainer requires a fitted linear model with a coef_ attribute."
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
coef = np.asarray(model.coef_, dtype=float)
|
|
273
|
+
|
|
274
|
+
if coef.ndim == 2 and coef.shape[0] == 1:
|
|
275
|
+
coef = coef.reshape(-1)
|
|
276
|
+
|
|
277
|
+
if coef.ndim not in {1, 2}:
|
|
278
|
+
raise ValueError(
|
|
279
|
+
"LinearExplainer requires coef_ to be 1D for regression/binary "
|
|
280
|
+
"classification or 2D for multiclass classification."
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
if not np.all(np.isfinite(coef)):
|
|
284
|
+
raise ValueError("model.coef_ must contain only finite values.")
|
|
285
|
+
|
|
286
|
+
return coef
|
|
287
|
+
|
|
288
|
+
@staticmethod
|
|
289
|
+
def _get_feature_names(X, n_features: int) -> list[str]:
|
|
290
|
+
columns = getattr(X, "columns", None)
|
|
291
|
+
if columns is not None:
|
|
292
|
+
return list(columns)
|
|
293
|
+
return [f"x{i}" for i in range(n_features)]
|
|
294
|
+
|
|
295
|
+
def _sigmoid(z):
|
|
296
|
+
z = np.asarray(z, dtype=float)
|
|
297
|
+
return 1.0 / (1.0 + np.exp(-z))
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _logit(p):
|
|
301
|
+
p = np.asarray(p, dtype=float)
|
|
302
|
+
p = np.clip(p, 1e-12, 1.0 - 1e-12)
|
|
303
|
+
return np.log(p / (1.0 - p))
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _softplus(z):
|
|
307
|
+
return np.logaddexp(0.0, z)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _binary_log_loss(y, p):
|
|
311
|
+
p = np.clip(p, 1e-12, 1.0 - 1e-12)
|
|
312
|
+
return -(y * np.log(p) + (1.0 - y) * np.log1p(-p))
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def linear_logistic_components(
|
|
316
|
+
y,
|
|
317
|
+
components,
|
|
318
|
+
*,
|
|
319
|
+
eta=None,
|
|
320
|
+
intercept_component=None,
|
|
321
|
+
include_intercept_component: bool = False,
|
|
322
|
+
feature_names=None,
|
|
323
|
+
check_additivity: bool = True,
|
|
324
|
+
atol: float = 1e-10,
|
|
325
|
+
) -> EDEFExplanation:
|
|
326
|
+
"""
|
|
327
|
+
Closed-form EDEF for binary linear classification with log loss.
|
|
328
|
+
|
|
329
|
+
Parameters
|
|
330
|
+
----------
|
|
331
|
+
y : array-like, shape (n_obs,)
|
|
332
|
+
Binary labels in {0, 1}.
|
|
333
|
+
|
|
334
|
+
components : array-like, shape (n_obs, n_features)
|
|
335
|
+
Fitted score/logit components. For logistic regression, this is typically
|
|
336
|
+
|
|
337
|
+
components[:, j] = X[:, j] * beta[j]
|
|
338
|
+
|
|
339
|
+
eta : array-like, shape (n_obs,), optional
|
|
340
|
+
Full fitted score/logit. If omitted, the score is constructed as
|
|
341
|
+
|
|
342
|
+
eta = eta_bar + components.sum(axis=1)
|
|
343
|
+
|
|
344
|
+
where eta_bar is the baseline logit.
|
|
345
|
+
|
|
346
|
+
intercept_component : array-like, shape (n_obs,), optional
|
|
347
|
+
Additional score component to include, typically the difference between
|
|
348
|
+
the fitted intercept and the baseline logit.
|
|
349
|
+
|
|
350
|
+
include_intercept_component : bool, default=False
|
|
351
|
+
If True, append intercept_component as an additional attribution column.
|
|
352
|
+
|
|
353
|
+
feature_names : sequence of str, optional
|
|
354
|
+
Feature names.
|
|
355
|
+
|
|
356
|
+
check_additivity : bool, default=True
|
|
357
|
+
Whether to check that feature contributions add to total fit improvement.
|
|
358
|
+
|
|
359
|
+
atol : float, default=1e-10
|
|
360
|
+
Absolute tolerance for the additivity check.
|
|
361
|
+
|
|
362
|
+
Returns
|
|
363
|
+
-------
|
|
364
|
+
EDEFExplanation
|
|
365
|
+
EDEF result object.
|
|
366
|
+
"""
|
|
367
|
+
|
|
368
|
+
y = np.asarray(y, dtype=float).reshape(-1)
|
|
369
|
+
components = np.asarray(components, dtype=float)
|
|
370
|
+
|
|
371
|
+
if components.ndim != 2:
|
|
372
|
+
raise ValueError("components must have shape (n_obs, n_features).")
|
|
373
|
+
|
|
374
|
+
n_obs, n_features = components.shape
|
|
375
|
+
|
|
376
|
+
if y.shape[0] != n_obs:
|
|
377
|
+
raise ValueError("y and components must have the same number of observations.")
|
|
378
|
+
|
|
379
|
+
if n_obs < 2:
|
|
380
|
+
raise ValueError("At least two observations are required.")
|
|
381
|
+
|
|
382
|
+
if not np.all(np.isfinite(y)):
|
|
383
|
+
raise ValueError("y must contain only finite values.")
|
|
384
|
+
|
|
385
|
+
if not np.all((y == 0.0) | (y == 1.0)):
|
|
386
|
+
raise ValueError("y must contain only binary labels in {0, 1}.")
|
|
387
|
+
|
|
388
|
+
if not np.all(np.isfinite(components)):
|
|
389
|
+
raise ValueError("components must contain only finite values.")
|
|
390
|
+
|
|
391
|
+
p_bar = float(np.clip(y.mean(), 1e-12, 1.0 - 1e-12))
|
|
392
|
+
eta_bar = float(_logit(p_bar))
|
|
393
|
+
|
|
394
|
+
if eta is None:
|
|
395
|
+
eta = eta_bar + components.sum(axis=1)
|
|
396
|
+
else:
|
|
397
|
+
eta = np.asarray(eta, dtype=float).reshape(-1)
|
|
398
|
+
if eta.shape[0] != n_obs:
|
|
399
|
+
raise ValueError("eta must have length n_obs.")
|
|
400
|
+
if not np.all(np.isfinite(eta)):
|
|
401
|
+
raise ValueError("eta must contain only finite values.")
|
|
402
|
+
|
|
403
|
+
if feature_names is None:
|
|
404
|
+
feature_names = [f"x{i}" for i in range(n_features)]
|
|
405
|
+
else:
|
|
406
|
+
feature_names = list(feature_names)
|
|
407
|
+
if len(feature_names) != n_features:
|
|
408
|
+
raise ValueError("feature_names must have length n_features.")
|
|
409
|
+
|
|
410
|
+
if include_intercept_component:
|
|
411
|
+
if intercept_component is None:
|
|
412
|
+
intercept_component = eta - eta_bar - components.sum(axis=1)
|
|
413
|
+
else:
|
|
414
|
+
intercept_component = np.asarray(intercept_component, dtype=float).reshape(-1)
|
|
415
|
+
if intercept_component.shape[0] != n_obs:
|
|
416
|
+
raise ValueError("intercept_component must have length n_obs.")
|
|
417
|
+
if not np.all(np.isfinite(intercept_component)):
|
|
418
|
+
raise ValueError(
|
|
419
|
+
"intercept_component must contain only finite values."
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
components = np.column_stack([components, intercept_component])
|
|
423
|
+
feature_names = feature_names + ["__InterceptShift__"]
|
|
424
|
+
|
|
425
|
+
p_hat = _sigmoid(eta)
|
|
426
|
+
|
|
427
|
+
baseline_loss = float(np.mean(_binary_log_loss(y, p_bar)))
|
|
428
|
+
model_loss = float(np.mean(_binary_log_loss(y, p_hat)))
|
|
429
|
+
total = baseline_loss - model_loss
|
|
430
|
+
|
|
431
|
+
delta = eta - eta_bar
|
|
432
|
+
sp_eta = _softplus(eta)
|
|
433
|
+
sp_eta_bar = _softplus(eta_bar)
|
|
434
|
+
|
|
435
|
+
eps = 1e-12
|
|
436
|
+
path_weight = np.empty(n_obs, dtype=float)
|
|
437
|
+
|
|
438
|
+
mask = np.abs(delta) > eps
|
|
439
|
+
path_weight[mask] = y[mask] - (sp_eta[mask] - sp_eta_bar) / delta[mask]
|
|
440
|
+
path_weight[~mask] = y[~mask] - _sigmoid(eta_bar)
|
|
441
|
+
|
|
442
|
+
observation_values = components * path_weight[:, None]
|
|
443
|
+
values = observation_values.mean(axis=0)
|
|
444
|
+
|
|
445
|
+
standard_errors = observation_values.std(axis=0, ddof=1) / np.sqrt(n_obs)
|
|
446
|
+
|
|
447
|
+
additivity_error = values.sum() - total
|
|
448
|
+
|
|
449
|
+
if check_additivity and abs(additivity_error) > atol:
|
|
450
|
+
raise RuntimeError(
|
|
451
|
+
"EDEF contributions do not add to total fit improvement. "
|
|
452
|
+
f"Additivity error: {additivity_error}"
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
return EDEFExplanation(
|
|
456
|
+
values=values,
|
|
457
|
+
observation_values=observation_values,
|
|
458
|
+
standard_errors=standard_errors,
|
|
459
|
+
total=total,
|
|
460
|
+
baseline_loss=baseline_loss,
|
|
461
|
+
model_loss=model_loss,
|
|
462
|
+
loss="log_loss",
|
|
463
|
+
model_type="linear_logistic_components",
|
|
464
|
+
feature_names=feature_names,
|
|
465
|
+
n_obs=n_obs,
|
|
466
|
+
additivity_error=additivity_error,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
def _logsumexp(a, axis=None, keepdims=False):
|
|
470
|
+
a = np.asarray(a, dtype=float)
|
|
471
|
+
amax = np.max(a, axis=axis, keepdims=True)
|
|
472
|
+
out = amax + np.log(np.sum(np.exp(a - amax), axis=axis, keepdims=True))
|
|
473
|
+
|
|
474
|
+
if not keepdims:
|
|
475
|
+
out = np.squeeze(out, axis=axis)
|
|
476
|
+
|
|
477
|
+
return out
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def _softmax(eta):
|
|
481
|
+
log_denom = _logsumexp(eta, axis=1, keepdims=True)
|
|
482
|
+
return np.exp(eta - log_denom)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def linear_multiclass_components(
|
|
486
|
+
y,
|
|
487
|
+
components,
|
|
488
|
+
*,
|
|
489
|
+
eta=None,
|
|
490
|
+
intercept_component=None,
|
|
491
|
+
include_intercept_component: bool = False,
|
|
492
|
+
feature_names=None,
|
|
493
|
+
check_additivity: bool = True,
|
|
494
|
+
atol: float = 1e-10,
|
|
495
|
+
) -> EDEFExplanation:
|
|
496
|
+
"""
|
|
497
|
+
Closed-form EDEF for multiclass linear classification with log loss.
|
|
498
|
+
|
|
499
|
+
Parameters
|
|
500
|
+
----------
|
|
501
|
+
y : array-like, shape (n_obs,)
|
|
502
|
+
Integer class labels in {0, ..., n_classes - 1}.
|
|
503
|
+
|
|
504
|
+
components : array-like, shape (n_obs, n_classes, n_features)
|
|
505
|
+
Fitted class-score components. For multinomial logistic regression,
|
|
506
|
+
|
|
507
|
+
components[i, k, j] = X[i, j] * beta[k, j]
|
|
508
|
+
|
|
509
|
+
eta : array-like, shape (n_obs, n_classes), optional
|
|
510
|
+
Full fitted class scores. If omitted, the score is constructed as
|
|
511
|
+
|
|
512
|
+
eta = eta_bar + components.sum(axis=2)
|
|
513
|
+
|
|
514
|
+
where eta_bar is the baseline class-score vector.
|
|
515
|
+
|
|
516
|
+
intercept_component : array-like, shape (n_obs, n_classes), optional
|
|
517
|
+
Additional class-score component, typically the difference between the
|
|
518
|
+
fitted intercept vector and the baseline class-score vector.
|
|
519
|
+
|
|
520
|
+
include_intercept_component : bool, default=False
|
|
521
|
+
If True, append intercept_component as an additional attribution column.
|
|
522
|
+
|
|
523
|
+
feature_names : sequence of str, optional
|
|
524
|
+
Feature names.
|
|
525
|
+
|
|
526
|
+
Returns
|
|
527
|
+
-------
|
|
528
|
+
EDEFExplanation
|
|
529
|
+
Scalar log-loss EDEF result. Class dimensions are summed internally, so
|
|
530
|
+
observation_values has shape (n_obs, n_features).
|
|
531
|
+
"""
|
|
532
|
+
|
|
533
|
+
y = np.asarray(y).reshape(-1)
|
|
534
|
+
components = np.asarray(components, dtype=float)
|
|
535
|
+
|
|
536
|
+
if components.ndim != 3:
|
|
537
|
+
raise ValueError(
|
|
538
|
+
"components must have shape (n_obs, n_classes, n_features)."
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
n_obs, n_classes, n_features = components.shape
|
|
542
|
+
|
|
543
|
+
if y.shape[0] != n_obs:
|
|
544
|
+
raise ValueError("y and components must have the same number of observations.")
|
|
545
|
+
|
|
546
|
+
if n_obs < 2:
|
|
547
|
+
raise ValueError("At least two observations are required.")
|
|
548
|
+
|
|
549
|
+
if not np.all(np.isfinite(components)):
|
|
550
|
+
raise ValueError("components must contain only finite values.")
|
|
551
|
+
|
|
552
|
+
if not np.issubdtype(y.dtype, np.integer):
|
|
553
|
+
if np.all(np.equal(y, np.round(y))):
|
|
554
|
+
y = y.astype(int)
|
|
555
|
+
else:
|
|
556
|
+
raise ValueError("y must contain integer class labels.")
|
|
557
|
+
|
|
558
|
+
y = y.astype(int)
|
|
559
|
+
|
|
560
|
+
if np.any(y < 0) or np.any(y >= n_classes):
|
|
561
|
+
raise ValueError("y must contain class labels in {0, ..., n_classes - 1}.")
|
|
562
|
+
|
|
563
|
+
class_counts = np.bincount(y, minlength=n_classes).astype(float)
|
|
564
|
+
class_probs = np.clip(class_counts / n_obs, 1e-12, 1.0)
|
|
565
|
+
class_probs = class_probs / class_probs.sum()
|
|
566
|
+
|
|
567
|
+
eta_bar = np.log(class_probs)
|
|
568
|
+
|
|
569
|
+
if eta is None:
|
|
570
|
+
eta = eta_bar.reshape(1, -1) + components.sum(axis=2)
|
|
571
|
+
else:
|
|
572
|
+
eta = np.asarray(eta, dtype=float)
|
|
573
|
+
if eta.shape != (n_obs, n_classes):
|
|
574
|
+
raise ValueError("eta must have shape (n_obs, n_classes).")
|
|
575
|
+
if not np.all(np.isfinite(eta)):
|
|
576
|
+
raise ValueError("eta must contain only finite values.")
|
|
577
|
+
|
|
578
|
+
if feature_names is None:
|
|
579
|
+
feature_names = [f"x{i}" for i in range(n_features)]
|
|
580
|
+
else:
|
|
581
|
+
feature_names = list(feature_names)
|
|
582
|
+
if len(feature_names) != n_features:
|
|
583
|
+
raise ValueError("feature_names must have length n_features.")
|
|
584
|
+
|
|
585
|
+
if include_intercept_component:
|
|
586
|
+
if intercept_component is None:
|
|
587
|
+
intercept_component = (
|
|
588
|
+
eta
|
|
589
|
+
- eta_bar.reshape(1, -1)
|
|
590
|
+
- components.sum(axis=2)
|
|
591
|
+
)
|
|
592
|
+
else:
|
|
593
|
+
intercept_component = np.asarray(intercept_component, dtype=float)
|
|
594
|
+
if intercept_component.shape != (n_obs, n_classes):
|
|
595
|
+
raise ValueError(
|
|
596
|
+
"intercept_component must have shape (n_obs, n_classes)."
|
|
597
|
+
)
|
|
598
|
+
if not np.all(np.isfinite(intercept_component)):
|
|
599
|
+
raise ValueError(
|
|
600
|
+
"intercept_component must contain only finite values."
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
components = np.concatenate(
|
|
604
|
+
[components, intercept_component[:, :, None]],
|
|
605
|
+
axis=2,
|
|
606
|
+
)
|
|
607
|
+
feature_names = feature_names + ["__InterceptShift__"]
|
|
608
|
+
n_features = n_features + 1
|
|
609
|
+
|
|
610
|
+
baseline_loss = float(-np.mean(np.log(class_probs[y])))
|
|
611
|
+
|
|
612
|
+
log_probs = eta - _logsumexp(eta, axis=1, keepdims=True)
|
|
613
|
+
model_loss = float(-np.mean(log_probs[np.arange(n_obs), y]))
|
|
614
|
+
|
|
615
|
+
total = baseline_loss - model_loss
|
|
616
|
+
|
|
617
|
+
delta = eta - eta_bar.reshape(1, -1)
|
|
618
|
+
|
|
619
|
+
# For each observation and class, compute
|
|
620
|
+
#
|
|
621
|
+
# integral_0^1 softmax_k(eta_bar + t * delta_i) dt
|
|
622
|
+
#
|
|
623
|
+
# using Gauss-Legendre quadrature. There is no simple binary-style
|
|
624
|
+
# scalar softplus closed form for the multiclass softmax path.
|
|
625
|
+
nodes, weights = np.polynomial.legendre.leggauss(64)
|
|
626
|
+
nodes = 0.5 * (nodes + 1.0)
|
|
627
|
+
weights = 0.5 * weights
|
|
628
|
+
|
|
629
|
+
eta_all = (
|
|
630
|
+
eta_bar.reshape(1, 1, -1)
|
|
631
|
+
+ nodes.reshape(-1, 1, 1) * delta.reshape(1, n_obs, n_classes)
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
eta_all = eta_all - eta_all.max(axis=2, keepdims=True)
|
|
635
|
+
|
|
636
|
+
prob_all = np.exp(eta_all)
|
|
637
|
+
prob_all = prob_all / prob_all.sum(axis=2, keepdims=True)
|
|
638
|
+
|
|
639
|
+
avg_prob = np.sum(
|
|
640
|
+
weights.reshape(-1, 1, 1) * prob_all,
|
|
641
|
+
axis=0,
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
one_hot = np.zeros((n_obs, n_classes), dtype=float)
|
|
645
|
+
one_hot[np.arange(n_obs), y] = 1.0
|
|
646
|
+
|
|
647
|
+
path_weight = one_hot - avg_prob
|
|
648
|
+
|
|
649
|
+
# components: (n_obs, n_classes, n_features)
|
|
650
|
+
# path_weight: (n_obs, n_classes)
|
|
651
|
+
# observation_values: sum over classes -> (n_obs, n_features)
|
|
652
|
+
observation_values = np.sum(
|
|
653
|
+
components * path_weight[:, :, None],
|
|
654
|
+
axis=1,
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
values = observation_values.mean(axis=0)
|
|
658
|
+
standard_errors = observation_values.std(axis=0, ddof=1) / np.sqrt(n_obs)
|
|
659
|
+
|
|
660
|
+
additivity_error = values.sum() - total
|
|
661
|
+
|
|
662
|
+
if check_additivity and abs(additivity_error) > atol:
|
|
663
|
+
raise RuntimeError(
|
|
664
|
+
"EDEF contributions do not add to total fit improvement. "
|
|
665
|
+
f"Additivity error: {additivity_error}"
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
return EDEFExplanation(
|
|
669
|
+
values=values,
|
|
670
|
+
observation_values=observation_values,
|
|
671
|
+
standard_errors=standard_errors,
|
|
672
|
+
total=total,
|
|
673
|
+
baseline_loss=baseline_loss,
|
|
674
|
+
model_loss=model_loss,
|
|
675
|
+
loss="log_loss",
|
|
676
|
+
model_type="linear_multiclass_components",
|
|
677
|
+
feature_names=feature_names,
|
|
678
|
+
n_obs=n_obs,
|
|
679
|
+
additivity_error=additivity_error,
|
|
680
|
+
)
|