edef 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edef/__init__.py ADDED
@@ -0,0 +1,21 @@
1
+ from ._linear import (
2
+ LinearExplainer,
3
+ linear_logistic_components,
4
+ linear_multiclass_components,
5
+ linear_regression_components,
6
+ )
7
+ from ._results import EDEFExplanation
8
+ from ._torch import TorchExplainer
9
+ from ._tree import TreeExplainer
10
+ from ._numerical import NumericalExplainer
11
+
12
+ __all__ = [
13
+ "EDEFExplanation",
14
+ "LinearExplainer",
15
+ "TorchExplainer",
16
+ "TreeExplainer",
17
+ "NumericalExplainer",
18
+ "linear_logistic_components",
19
+ "linear_multiclass_components",
20
+ "linear_regression_components",
21
+ ]
edef/_linear.py ADDED
@@ -0,0 +1,680 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from ._results import EDEFExplanation
6
+
7
+
8
+ def linear_regression_components(
9
+ y,
10
+ components,
11
+ *,
12
+ feature_names=None,
13
+ check_additivity: bool = True,
14
+ atol: float = 1e-10,
15
+ ) -> EDEFExplanation:
16
+ """
17
+ Closed-form EDEF for linear regression with squared-error loss.
18
+
19
+ Parameters
20
+ ----------
21
+ y : array-like, shape (n_obs,)
22
+ Realized outcomes.
23
+
24
+ components : array-like, shape (n_obs, n_features)
25
+ Fitted linear signal components. For a linear model, this is typically
26
+
27
+ components[:, j] = X[:, j] * beta[j]
28
+
29
+ The fitted prediction, excluding intercept effects, is
30
+
31
+ y_hat = components.sum(axis=1)
32
+
33
+ feature_names : sequence of str, optional
34
+ Feature names.
35
+
36
+ check_additivity : bool, default=True
37
+ Whether to check that feature contributions add to total fit improvement.
38
+
39
+ atol : float, default=1e-10
40
+ Absolute tolerance for the additivity check.
41
+
42
+ Returns
43
+ -------
44
+ EDEFExplanation
45
+ EDEF result object.
46
+ """
47
+
48
+ y = np.asarray(y, dtype=float).reshape(-1)
49
+ components = np.asarray(components, dtype=float)
50
+
51
+ if components.ndim != 2:
52
+ raise ValueError("components must have shape (n_obs, n_features).")
53
+
54
+ n_obs, n_features = components.shape
55
+
56
+ if y.shape[0] != n_obs:
57
+ raise ValueError("y and components must have the same number of observations.")
58
+
59
+ if n_obs < 2:
60
+ raise ValueError("At least two observations are required.")
61
+
62
+ if not np.all(np.isfinite(y)):
63
+ raise ValueError("y must contain only finite values.")
64
+
65
+ if not np.all(np.isfinite(components)):
66
+ raise ValueError("components must contain only finite values.")
67
+
68
+ if feature_names is None:
69
+ feature_names = [f"x{i}" for i in range(n_features)]
70
+ else:
71
+ feature_names = list(feature_names)
72
+ if len(feature_names) != n_features:
73
+ raise ValueError("feature_names must have length n_features.")
74
+
75
+ y_centered = y - y.mean()
76
+
77
+ components_centered = components - components.mean(axis=0)
78
+ prediction_centered = components_centered.sum(axis=1)
79
+
80
+ baseline_loss = np.mean(y_centered**2)
81
+ model_loss = np.mean((y_centered - prediction_centered) ** 2)
82
+ total = baseline_loss - model_loss
83
+
84
+ shared_term = 2.0 * y_centered - prediction_centered
85
+
86
+ observation_values = components_centered * shared_term[:, None]
87
+ values = observation_values.mean(axis=0)
88
+
89
+ standard_errors = observation_values.std(axis=0, ddof=1) / np.sqrt(n_obs)
90
+
91
+ additivity_error = values.sum() - total
92
+
93
+ if check_additivity and abs(additivity_error) > atol:
94
+ raise RuntimeError(
95
+ "EDEF contributions do not add to total fit improvement. "
96
+ f"Additivity error: {additivity_error}"
97
+ )
98
+
99
+ return EDEFExplanation(
100
+ values=values,
101
+ observation_values=observation_values,
102
+ standard_errors=standard_errors,
103
+ total=total,
104
+ baseline_loss=baseline_loss,
105
+ model_loss=model_loss,
106
+ loss="squared_error",
107
+ model_type="linear_regression_components",
108
+ feature_names=feature_names,
109
+ n_obs=n_obs,
110
+ additivity_error=additivity_error,
111
+ )
112
+
113
+ class LinearExplainer:
114
+ """
115
+ SHAP-style EDEF explainer for linear models.
116
+
117
+ Version 1 supports fitted linear regression models with a 1D ``coef_``
118
+ attribute. The model intercept is absorbed into the intercept-only
119
+ baseline used by EDEF.
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ model,
125
+ baseline=None,
126
+ *,
127
+ loss: str = "squared_error",
128
+ feature_names=None,
129
+ ):
130
+ if loss not in {"squared_error", "log_loss"}:
131
+ raise ValueError("LinearExplainer supports squared_error and log_loss.")
132
+
133
+ self.model = model
134
+ self.baseline = baseline
135
+ self.loss = loss
136
+ self.feature_names = feature_names
137
+
138
+ self.coef_ = self._get_coef(model)
139
+
140
+ def __call__(
141
+ self,
142
+ X,
143
+ y,
144
+ *,
145
+ feature_names=None,
146
+ check_additivity: bool = True,
147
+ atol: float = 1e-10,
148
+ ):
149
+ X = np.asarray(X, dtype=float)
150
+
151
+ if X.ndim != 2:
152
+ raise ValueError("X must have shape (n_obs, n_features).")
153
+
154
+ n_features = X.shape[1]
155
+
156
+ if self.coef_.ndim == 1:
157
+ coef_n_features = self.coef_.shape[0]
158
+ else:
159
+ coef_n_features = self.coef_.shape[1]
160
+
161
+ if coef_n_features != n_features:
162
+ raise ValueError(
163
+ "Model coefficient dimension does not match X. "
164
+ f"coef has {coef_n_features} features, "
165
+ f"but X has {n_features} columns."
166
+ )
167
+
168
+ names = feature_names
169
+ if names is None:
170
+ names = self.feature_names
171
+ if names is None:
172
+ names = self._get_feature_names(X, n_features)
173
+
174
+ if self.loss == "squared_error":
175
+ if self.coef_.ndim != 1:
176
+ raise ValueError(
177
+ "squared_error requires a single-output linear model "
178
+ "with a 1D coef_ vector."
179
+ )
180
+
181
+ components = X * self.coef_[None, :]
182
+
183
+ return linear_regression_components(
184
+ y,
185
+ components,
186
+ feature_names=names,
187
+ check_additivity=check_additivity,
188
+ atol=atol,
189
+ )
190
+
191
+ if self.loss == "log_loss":
192
+ eta = self._decision_function(X)
193
+
194
+ if self.coef_.ndim == 1:
195
+ components = X * self.coef_[None, :]
196
+
197
+ return linear_logistic_components(
198
+ y,
199
+ components,
200
+ eta=eta,
201
+ include_intercept_component=True,
202
+ feature_names=names,
203
+ check_additivity=check_additivity,
204
+ atol=atol,
205
+ )
206
+
207
+ if self.coef_.ndim == 2:
208
+ components = X[:, None, :] * self.coef_[None, :, :]
209
+
210
+ return linear_multiclass_components(
211
+ y,
212
+ components,
213
+ eta=eta,
214
+ include_intercept_component=True,
215
+ feature_names=names,
216
+ check_additivity=check_additivity,
217
+ atol=atol,
218
+ )
219
+
220
+ raise RuntimeError(f"Unexpected loss: {self.loss}")
221
+
222
+
223
+ def _decision_function(self, X) -> np.ndarray:
224
+ if hasattr(self.model, "decision_function"):
225
+ eta = self.model.decision_function(X)
226
+ elif hasattr(self.model, "intercept_"):
227
+ intercept = np.asarray(self.model.intercept_, dtype=float).reshape(-1)
228
+
229
+ if self.coef_.ndim == 1:
230
+ if intercept.size != 1:
231
+ raise ValueError(
232
+ "Binary classification requires a scalar intercept."
233
+ )
234
+ eta = X @ self.coef_ + intercept[0]
235
+ else:
236
+ if intercept.size != self.coef_.shape[0]:
237
+ raise ValueError(
238
+ "Multiclass classification requires one intercept per class."
239
+ )
240
+ eta = X @ self.coef_.T + intercept.reshape(1, -1)
241
+ else:
242
+ raise TypeError(
243
+ "log_loss requires a model with decision_function or intercept_."
244
+ )
245
+
246
+ eta = np.asarray(eta, dtype=float)
247
+
248
+ if self.coef_.ndim == 1:
249
+ eta = eta.reshape(-1)
250
+ if eta.shape[0] != X.shape[0]:
251
+ raise ValueError("decision_function output must have length n_obs.")
252
+ else:
253
+ expected_shape = (X.shape[0], self.coef_.shape[0])
254
+ if eta.shape != expected_shape:
255
+ raise ValueError(
256
+ "decision_function output must have shape "
257
+ "(n_obs, n_classes)."
258
+ )
259
+
260
+ if not np.all(np.isfinite(eta)):
261
+ raise ValueError("decision_function output must contain only finite values.")
262
+
263
+ return eta
264
+
265
+ @staticmethod
266
+ def _get_coef(model) -> np.ndarray:
267
+ if not hasattr(model, "coef_"):
268
+ raise TypeError(
269
+ "LinearExplainer requires a fitted linear model with a coef_ attribute."
270
+ )
271
+
272
+ coef = np.asarray(model.coef_, dtype=float)
273
+
274
+ if coef.ndim == 2 and coef.shape[0] == 1:
275
+ coef = coef.reshape(-1)
276
+
277
+ if coef.ndim not in {1, 2}:
278
+ raise ValueError(
279
+ "LinearExplainer requires coef_ to be 1D for regression/binary "
280
+ "classification or 2D for multiclass classification."
281
+ )
282
+
283
+ if not np.all(np.isfinite(coef)):
284
+ raise ValueError("model.coef_ must contain only finite values.")
285
+
286
+ return coef
287
+
288
+ @staticmethod
289
+ def _get_feature_names(X, n_features: int) -> list[str]:
290
+ columns = getattr(X, "columns", None)
291
+ if columns is not None:
292
+ return list(columns)
293
+ return [f"x{i}" for i in range(n_features)]
294
+
295
+ def _sigmoid(z):
296
+ z = np.asarray(z, dtype=float)
297
+ return 1.0 / (1.0 + np.exp(-z))
298
+
299
+
300
+ def _logit(p):
301
+ p = np.asarray(p, dtype=float)
302
+ p = np.clip(p, 1e-12, 1.0 - 1e-12)
303
+ return np.log(p / (1.0 - p))
304
+
305
+
306
+ def _softplus(z):
307
+ return np.logaddexp(0.0, z)
308
+
309
+
310
+ def _binary_log_loss(y, p):
311
+ p = np.clip(p, 1e-12, 1.0 - 1e-12)
312
+ return -(y * np.log(p) + (1.0 - y) * np.log1p(-p))
313
+
314
+
315
+ def linear_logistic_components(
316
+ y,
317
+ components,
318
+ *,
319
+ eta=None,
320
+ intercept_component=None,
321
+ include_intercept_component: bool = False,
322
+ feature_names=None,
323
+ check_additivity: bool = True,
324
+ atol: float = 1e-10,
325
+ ) -> EDEFExplanation:
326
+ """
327
+ Closed-form EDEF for binary linear classification with log loss.
328
+
329
+ Parameters
330
+ ----------
331
+ y : array-like, shape (n_obs,)
332
+ Binary labels in {0, 1}.
333
+
334
+ components : array-like, shape (n_obs, n_features)
335
+ Fitted score/logit components. For logistic regression, this is typically
336
+
337
+ components[:, j] = X[:, j] * beta[j]
338
+
339
+ eta : array-like, shape (n_obs,), optional
340
+ Full fitted score/logit. If omitted, the score is constructed as
341
+
342
+ eta = eta_bar + components.sum(axis=1)
343
+
344
+ where eta_bar is the baseline logit.
345
+
346
+ intercept_component : array-like, shape (n_obs,), optional
347
+ Additional score component to include, typically the difference between
348
+ the fitted intercept and the baseline logit.
349
+
350
+ include_intercept_component : bool, default=False
351
+ If True, append intercept_component as an additional attribution column.
352
+
353
+ feature_names : sequence of str, optional
354
+ Feature names.
355
+
356
+ check_additivity : bool, default=True
357
+ Whether to check that feature contributions add to total fit improvement.
358
+
359
+ atol : float, default=1e-10
360
+ Absolute tolerance for the additivity check.
361
+
362
+ Returns
363
+ -------
364
+ EDEFExplanation
365
+ EDEF result object.
366
+ """
367
+
368
+ y = np.asarray(y, dtype=float).reshape(-1)
369
+ components = np.asarray(components, dtype=float)
370
+
371
+ if components.ndim != 2:
372
+ raise ValueError("components must have shape (n_obs, n_features).")
373
+
374
+ n_obs, n_features = components.shape
375
+
376
+ if y.shape[0] != n_obs:
377
+ raise ValueError("y and components must have the same number of observations.")
378
+
379
+ if n_obs < 2:
380
+ raise ValueError("At least two observations are required.")
381
+
382
+ if not np.all(np.isfinite(y)):
383
+ raise ValueError("y must contain only finite values.")
384
+
385
+ if not np.all((y == 0.0) | (y == 1.0)):
386
+ raise ValueError("y must contain only binary labels in {0, 1}.")
387
+
388
+ if not np.all(np.isfinite(components)):
389
+ raise ValueError("components must contain only finite values.")
390
+
391
+ p_bar = float(np.clip(y.mean(), 1e-12, 1.0 - 1e-12))
392
+ eta_bar = float(_logit(p_bar))
393
+
394
+ if eta is None:
395
+ eta = eta_bar + components.sum(axis=1)
396
+ else:
397
+ eta = np.asarray(eta, dtype=float).reshape(-1)
398
+ if eta.shape[0] != n_obs:
399
+ raise ValueError("eta must have length n_obs.")
400
+ if not np.all(np.isfinite(eta)):
401
+ raise ValueError("eta must contain only finite values.")
402
+
403
+ if feature_names is None:
404
+ feature_names = [f"x{i}" for i in range(n_features)]
405
+ else:
406
+ feature_names = list(feature_names)
407
+ if len(feature_names) != n_features:
408
+ raise ValueError("feature_names must have length n_features.")
409
+
410
+ if include_intercept_component:
411
+ if intercept_component is None:
412
+ intercept_component = eta - eta_bar - components.sum(axis=1)
413
+ else:
414
+ intercept_component = np.asarray(intercept_component, dtype=float).reshape(-1)
415
+ if intercept_component.shape[0] != n_obs:
416
+ raise ValueError("intercept_component must have length n_obs.")
417
+ if not np.all(np.isfinite(intercept_component)):
418
+ raise ValueError(
419
+ "intercept_component must contain only finite values."
420
+ )
421
+
422
+ components = np.column_stack([components, intercept_component])
423
+ feature_names = feature_names + ["__InterceptShift__"]
424
+
425
+ p_hat = _sigmoid(eta)
426
+
427
+ baseline_loss = float(np.mean(_binary_log_loss(y, p_bar)))
428
+ model_loss = float(np.mean(_binary_log_loss(y, p_hat)))
429
+ total = baseline_loss - model_loss
430
+
431
+ delta = eta - eta_bar
432
+ sp_eta = _softplus(eta)
433
+ sp_eta_bar = _softplus(eta_bar)
434
+
435
+ eps = 1e-12
436
+ path_weight = np.empty(n_obs, dtype=float)
437
+
438
+ mask = np.abs(delta) > eps
439
+ path_weight[mask] = y[mask] - (sp_eta[mask] - sp_eta_bar) / delta[mask]
440
+ path_weight[~mask] = y[~mask] - _sigmoid(eta_bar)
441
+
442
+ observation_values = components * path_weight[:, None]
443
+ values = observation_values.mean(axis=0)
444
+
445
+ standard_errors = observation_values.std(axis=0, ddof=1) / np.sqrt(n_obs)
446
+
447
+ additivity_error = values.sum() - total
448
+
449
+ if check_additivity and abs(additivity_error) > atol:
450
+ raise RuntimeError(
451
+ "EDEF contributions do not add to total fit improvement. "
452
+ f"Additivity error: {additivity_error}"
453
+ )
454
+
455
+ return EDEFExplanation(
456
+ values=values,
457
+ observation_values=observation_values,
458
+ standard_errors=standard_errors,
459
+ total=total,
460
+ baseline_loss=baseline_loss,
461
+ model_loss=model_loss,
462
+ loss="log_loss",
463
+ model_type="linear_logistic_components",
464
+ feature_names=feature_names,
465
+ n_obs=n_obs,
466
+ additivity_error=additivity_error,
467
+ )
468
+
469
+ def _logsumexp(a, axis=None, keepdims=False):
470
+ a = np.asarray(a, dtype=float)
471
+ amax = np.max(a, axis=axis, keepdims=True)
472
+ out = amax + np.log(np.sum(np.exp(a - amax), axis=axis, keepdims=True))
473
+
474
+ if not keepdims:
475
+ out = np.squeeze(out, axis=axis)
476
+
477
+ return out
478
+
479
+
480
+ def _softmax(eta):
481
+ log_denom = _logsumexp(eta, axis=1, keepdims=True)
482
+ return np.exp(eta - log_denom)
483
+
484
+
485
+ def linear_multiclass_components(
486
+ y,
487
+ components,
488
+ *,
489
+ eta=None,
490
+ intercept_component=None,
491
+ include_intercept_component: bool = False,
492
+ feature_names=None,
493
+ check_additivity: bool = True,
494
+ atol: float = 1e-10,
495
+ ) -> EDEFExplanation:
496
+ """
497
+ Closed-form EDEF for multiclass linear classification with log loss.
498
+
499
+ Parameters
500
+ ----------
501
+ y : array-like, shape (n_obs,)
502
+ Integer class labels in {0, ..., n_classes - 1}.
503
+
504
+ components : array-like, shape (n_obs, n_classes, n_features)
505
+ Fitted class-score components. For multinomial logistic regression,
506
+
507
+ components[i, k, j] = X[i, j] * beta[k, j]
508
+
509
+ eta : array-like, shape (n_obs, n_classes), optional
510
+ Full fitted class scores. If omitted, the score is constructed as
511
+
512
+ eta = eta_bar + components.sum(axis=2)
513
+
514
+ where eta_bar is the baseline class-score vector.
515
+
516
+ intercept_component : array-like, shape (n_obs, n_classes), optional
517
+ Additional class-score component, typically the difference between the
518
+ fitted intercept vector and the baseline class-score vector.
519
+
520
+ include_intercept_component : bool, default=False
521
+ If True, append intercept_component as an additional attribution column.
522
+
523
+ feature_names : sequence of str, optional
524
+ Feature names.
525
+
526
+ Returns
527
+ -------
528
+ EDEFExplanation
529
+ Scalar log-loss EDEF result. Class dimensions are summed internally, so
530
+ observation_values has shape (n_obs, n_features).
531
+ """
532
+
533
+ y = np.asarray(y).reshape(-1)
534
+ components = np.asarray(components, dtype=float)
535
+
536
+ if components.ndim != 3:
537
+ raise ValueError(
538
+ "components must have shape (n_obs, n_classes, n_features)."
539
+ )
540
+
541
+ n_obs, n_classes, n_features = components.shape
542
+
543
+ if y.shape[0] != n_obs:
544
+ raise ValueError("y and components must have the same number of observations.")
545
+
546
+ if n_obs < 2:
547
+ raise ValueError("At least two observations are required.")
548
+
549
+ if not np.all(np.isfinite(components)):
550
+ raise ValueError("components must contain only finite values.")
551
+
552
+ if not np.issubdtype(y.dtype, np.integer):
553
+ if np.all(np.equal(y, np.round(y))):
554
+ y = y.astype(int)
555
+ else:
556
+ raise ValueError("y must contain integer class labels.")
557
+
558
+ y = y.astype(int)
559
+
560
+ if np.any(y < 0) or np.any(y >= n_classes):
561
+ raise ValueError("y must contain class labels in {0, ..., n_classes - 1}.")
562
+
563
+ class_counts = np.bincount(y, minlength=n_classes).astype(float)
564
+ class_probs = np.clip(class_counts / n_obs, 1e-12, 1.0)
565
+ class_probs = class_probs / class_probs.sum()
566
+
567
+ eta_bar = np.log(class_probs)
568
+
569
+ if eta is None:
570
+ eta = eta_bar.reshape(1, -1) + components.sum(axis=2)
571
+ else:
572
+ eta = np.asarray(eta, dtype=float)
573
+ if eta.shape != (n_obs, n_classes):
574
+ raise ValueError("eta must have shape (n_obs, n_classes).")
575
+ if not np.all(np.isfinite(eta)):
576
+ raise ValueError("eta must contain only finite values.")
577
+
578
+ if feature_names is None:
579
+ feature_names = [f"x{i}" for i in range(n_features)]
580
+ else:
581
+ feature_names = list(feature_names)
582
+ if len(feature_names) != n_features:
583
+ raise ValueError("feature_names must have length n_features.")
584
+
585
+ if include_intercept_component:
586
+ if intercept_component is None:
587
+ intercept_component = (
588
+ eta
589
+ - eta_bar.reshape(1, -1)
590
+ - components.sum(axis=2)
591
+ )
592
+ else:
593
+ intercept_component = np.asarray(intercept_component, dtype=float)
594
+ if intercept_component.shape != (n_obs, n_classes):
595
+ raise ValueError(
596
+ "intercept_component must have shape (n_obs, n_classes)."
597
+ )
598
+ if not np.all(np.isfinite(intercept_component)):
599
+ raise ValueError(
600
+ "intercept_component must contain only finite values."
601
+ )
602
+
603
+ components = np.concatenate(
604
+ [components, intercept_component[:, :, None]],
605
+ axis=2,
606
+ )
607
+ feature_names = feature_names + ["__InterceptShift__"]
608
+ n_features = n_features + 1
609
+
610
+ baseline_loss = float(-np.mean(np.log(class_probs[y])))
611
+
612
+ log_probs = eta - _logsumexp(eta, axis=1, keepdims=True)
613
+ model_loss = float(-np.mean(log_probs[np.arange(n_obs), y]))
614
+
615
+ total = baseline_loss - model_loss
616
+
617
+ delta = eta - eta_bar.reshape(1, -1)
618
+
619
+ # For each observation and class, compute
620
+ #
621
+ # integral_0^1 softmax_k(eta_bar + t * delta_i) dt
622
+ #
623
+ # using Gauss-Legendre quadrature. There is no simple binary-style
624
+ # scalar softplus closed form for the multiclass softmax path.
625
+ nodes, weights = np.polynomial.legendre.leggauss(64)
626
+ nodes = 0.5 * (nodes + 1.0)
627
+ weights = 0.5 * weights
628
+
629
+ eta_all = (
630
+ eta_bar.reshape(1, 1, -1)
631
+ + nodes.reshape(-1, 1, 1) * delta.reshape(1, n_obs, n_classes)
632
+ )
633
+
634
+ eta_all = eta_all - eta_all.max(axis=2, keepdims=True)
635
+
636
+ prob_all = np.exp(eta_all)
637
+ prob_all = prob_all / prob_all.sum(axis=2, keepdims=True)
638
+
639
+ avg_prob = np.sum(
640
+ weights.reshape(-1, 1, 1) * prob_all,
641
+ axis=0,
642
+ )
643
+
644
+ one_hot = np.zeros((n_obs, n_classes), dtype=float)
645
+ one_hot[np.arange(n_obs), y] = 1.0
646
+
647
+ path_weight = one_hot - avg_prob
648
+
649
+ # components: (n_obs, n_classes, n_features)
650
+ # path_weight: (n_obs, n_classes)
651
+ # observation_values: sum over classes -> (n_obs, n_features)
652
+ observation_values = np.sum(
653
+ components * path_weight[:, :, None],
654
+ axis=1,
655
+ )
656
+
657
+ values = observation_values.mean(axis=0)
658
+ standard_errors = observation_values.std(axis=0, ddof=1) / np.sqrt(n_obs)
659
+
660
+ additivity_error = values.sum() - total
661
+
662
+ if check_additivity and abs(additivity_error) > atol:
663
+ raise RuntimeError(
664
+ "EDEF contributions do not add to total fit improvement. "
665
+ f"Additivity error: {additivity_error}"
666
+ )
667
+
668
+ return EDEFExplanation(
669
+ values=values,
670
+ observation_values=observation_values,
671
+ standard_errors=standard_errors,
672
+ total=total,
673
+ baseline_loss=baseline_loss,
674
+ model_loss=model_loss,
675
+ loss="log_loss",
676
+ model_type="linear_multiclass_components",
677
+ feature_names=feature_names,
678
+ n_obs=n_obs,
679
+ additivity_error=additivity_error,
680
+ )