lens-eval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lens_eval/__init__.py ADDED
@@ -0,0 +1,55 @@
1
+ """lens-eval: interpretable multi-dimension text quality scoring.
2
+
3
+ Public entrypoints:
4
+
5
+ from lens_eval import LENS
6
+ from lens_eval import semantic_score, nli_score, naturalness_score, emotion_score
7
+ from lens_eval import configure # device, paths, naturalness mode/centroid
8
+
9
+ See README.md for usage.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ __version__ = "0.1.0"
15
+
16
+ from .encoders import (
17
+ DIMENSIONS,
18
+ configure,
19
+ emotion_score,
20
+ featurize,
21
+ free,
22
+ naturalness_score,
23
+ nli_score,
24
+ semantic_score,
25
+ )
26
+ from .errors import (
27
+ AmbiguousTaskError,
28
+ CombinerBackendMissing,
29
+ DegenerateTargetError,
30
+ EncoderVersionMismatchError,
31
+ InsufficientDataError,
32
+ LensEvalError,
33
+ ReferenceModeError,
34
+ )
35
+ from .lens import LENS
36
+
37
+ __all__ = [
38
+ "LENS",
39
+ "DIMENSIONS",
40
+ "configure",
41
+ "featurize",
42
+ "free",
43
+ "semantic_score",
44
+ "nli_score",
45
+ "naturalness_score",
46
+ "emotion_score",
47
+ "LensEvalError",
48
+ "InsufficientDataError",
49
+ "DegenerateTargetError",
50
+ "AmbiguousTaskError",
51
+ "ReferenceModeError",
52
+ "EncoderVersionMismatchError",
53
+ "CombinerBackendMissing",
54
+ "__version__",
55
+ ]
lens_eval/_validate.py ADDED
@@ -0,0 +1,85 @@
1
+ """Centralized structural validation for fit-time inputs.
2
+
3
+ Both ``LENS.fit`` and the CLI route their inputs through this module so a
4
+ missing ``groups`` array, an ambiguous task channel, or a malformed pairs
5
+ matrix surfaces the *same* typed exception with the *same* error message
6
+ regardless of where the call originated. The gateway is intentionally
7
+ narrow — it does shape and channel checks only, no semantic validation.
8
+ That belongs in the combiner layer.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Any, Optional, Sequence, Tuple
14
+
15
+ import numpy as np
16
+
17
+ from .errors import AmbiguousTaskError
18
+
19
+
20
+ TASK_CHANNELS = ("scores", "pairs", "ranks")
21
+ TASK_TO_NAME = {"scores": "regression", "pairs": "pairwise", "ranks": "ranking"}
22
+
23
+
24
+ def validate_task_channels(
25
+ *,
26
+ scores: Optional[Any] = None,
27
+ pairs: Optional[Any] = None,
28
+ ranks: Optional[Any] = None,
29
+ groups: Optional[Any] = None,
30
+ ) -> Tuple[str, str]:
31
+ """Resolve the task channel and validate channel-specific structure.
32
+
33
+ Returns ``(channel_name, task)`` where:
34
+ * ``channel_name`` is one of ``"scores" | "pairs" | "ranks"``.
35
+ * ``task`` is the inferred ``LENS`` task —
36
+ ``"regression" | "pairwise" | "ranking"``.
37
+
38
+ Raises ``AmbiguousTaskError`` if zero or multiple channels are supplied.
39
+ Raises ``ValueError`` for shape / dtype / coverage issues with the chosen
40
+ channel (e.g. pairs must be ``(M, 2)``, ranking requires ``groups``).
41
+
42
+ Used by both ``LENS.fit`` and the CLI so the user-visible failure modes
43
+ are identical across entry points.
44
+ """
45
+ provided = [k for k, v in (("scores", scores), ("pairs", pairs), ("ranks", ranks))
46
+ if v is not None]
47
+ if len(provided) != 1:
48
+ # Zero ⇒ nothing to fit on. Two+ ⇒ which to optimise? Both are bad.
49
+ raise AmbiguousTaskError(
50
+ f"fit() / CLI needs exactly one of (scores, pairs, ranks); got {provided!r}"
51
+ )
52
+ channel = provided[0]
53
+ task = TASK_TO_NAME[channel]
54
+
55
+ if channel == "pairs":
56
+ arr = np.asarray(pairs)
57
+ if arr.ndim != 2 or arr.shape[1] != 2:
58
+ raise ValueError(
59
+ f"`pairs` must have shape (M, 2) of (winner_idx, loser_idx); got {arr.shape}"
60
+ )
61
+ if arr.size and not np.issubdtype(arr.dtype, np.integer):
62
+ # Allow float-typed indices that round to int, but reject strings /
63
+ # floats with fractional parts — those are user-error.
64
+ if not np.allclose(arr, np.round(arr)):
65
+ raise ValueError("`pairs` indices must be integer-valued.")
66
+ elif channel == "ranks":
67
+ # Ranking is the only channel that needs an auxiliary grouping array —
68
+ # lambdarank needs to know which rows belong to the same query. Without
69
+ # groups every row is treated as its own group, which silently
70
+ # collapses to "fit a regressor on rank labels" rather than learning
71
+ # within-query orderings. Reject upfront.
72
+ if groups is None:
73
+ raise ValueError(
74
+ "task='ranking' requires `groups` so lambdarank can isolate "
75
+ "within-query orderings — pass a per-row group-id array."
76
+ )
77
+ ranks_arr = np.asarray(ranks)
78
+ groups_arr = np.asarray(groups)
79
+ if ranks_arr.shape[0] != groups_arr.shape[0]:
80
+ raise ValueError(
81
+ f"`ranks` ({ranks_arr.shape[0]}) and `groups` ({groups_arr.shape[0]}) "
82
+ f"must align row-wise."
83
+ )
84
+
85
+ return channel, task
lens_eval/cli.py ADDED
@@ -0,0 +1,215 @@
1
+ """`lens-eval` CLI — thin wrapper around the Python API.
2
+
3
+ lens-eval fit \
4
+ --texts hyp.txt --refs ref.txt --scores scores.csv \
5
+ --output ./my-lens --task auto --verbose
6
+
7
+ lens-eval score \
8
+ --model ./my-lens \
9
+ --texts new_hyp.txt --refs new_ref.txt \
10
+ --output predictions.csv
11
+
12
+ lens-eval report ./my-lens --html report.html
13
+
14
+ All three accept ``--features path.csv`` to skip encoding entirely. The
15
+ features CSV must have one column per dimension; column order follows the
16
+ LENS ``dimensions_used_`` for ``score``, or the user-passed ``--dimensions``
17
+ for ``fit``.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import argparse
23
+ import csv
24
+ import sys
25
+ from pathlib import Path
26
+ from typing import List, Optional
27
+
28
+ import numpy as np
29
+
30
+
31
+ def _read_lines(path: str | Path) -> List[str]:
32
+ p = Path(path)
33
+ text = p.read_text(encoding="utf-8")
34
+ if p.suffix.lower() == ".csv":
35
+ return [row[0] for row in csv.reader(text.splitlines()) if row]
36
+ return [ln for ln in text.splitlines() if ln]
37
+
38
+
39
+ def _read_scalar_csv(path: str | Path) -> np.ndarray:
40
+ rows: List[float] = []
41
+ with open(path, newline="") as f:
42
+ for row in csv.reader(f):
43
+ if not row:
44
+ continue
45
+ try:
46
+ rows.append(float(row[0]))
47
+ except ValueError:
48
+ # Header row — skip silently.
49
+ continue
50
+ return np.asarray(rows, dtype=float)
51
+
52
+
53
+ def _read_features_csv(path: str | Path) -> np.ndarray:
54
+ arr = []
55
+ with open(path, newline="") as f:
56
+ for row in csv.reader(f):
57
+ if not row:
58
+ continue
59
+ try:
60
+ arr.append([float(x) for x in row])
61
+ except ValueError:
62
+ continue
63
+ return np.asarray(arr, dtype=float)
64
+
65
+
66
+ def _read_pairs_csv(path: str | Path) -> np.ndarray:
67
+ pairs = []
68
+ with open(path, newline="") as f:
69
+ for row in csv.reader(f):
70
+ if not row:
71
+ continue
72
+ try:
73
+ a, b = int(row[0]), int(row[1])
74
+ except (ValueError, IndexError):
75
+ continue
76
+ pairs.append((a, b))
77
+ if not pairs:
78
+ raise ValueError(f"no pair rows parsed from {path}")
79
+ return np.asarray(pairs, dtype=int)
80
+
81
+
82
+ def _read_int_column_csv(path: str | Path) -> np.ndarray:
83
+ vals: List[int] = []
84
+ with open(path, newline="") as f:
85
+ for row in csv.reader(f):
86
+ if not row:
87
+ continue
88
+ try:
89
+ vals.append(int(row[0]))
90
+ except ValueError:
91
+ continue
92
+ if not vals:
93
+ raise ValueError(f"no integer rows parsed from {path}")
94
+ return np.asarray(vals, dtype=int)
95
+
96
+
97
+ def _cmd_fit(args) -> int:
98
+ from . import encoders as enc
99
+ from ._validate import validate_task_channels
100
+ from .lens import LENS
101
+
102
+ if args.naturalness_mode:
103
+ enc.configure(naturalness_mode=args.naturalness_mode)
104
+
105
+ texts = _read_lines(args.texts) if args.texts else None
106
+ refs = _read_lines(args.refs) if args.refs else None
107
+ feats = _read_features_csv(args.features) if args.features else None
108
+ y = _read_scalar_csv(args.scores) if args.scores else None
109
+ pairs = _read_pairs_csv(args.pairs) if args.pairs else None
110
+ ranks = _read_int_column_csv(args.ranks) if args.ranks else None
111
+ groups = _read_int_column_csv(args.groups) if args.groups else None
112
+
113
+ try:
114
+ channel, _ = validate_task_channels(
115
+ scores=y, pairs=pairs, ranks=ranks, groups=groups,
116
+ )
117
+ except Exception as exc:
118
+ print(f"error: {exc}", file=sys.stderr)
119
+ return 2
120
+
121
+ lens = LENS()
122
+ lens.fit(
123
+ texts=texts,
124
+ references=refs,
125
+ features=feats,
126
+ scores=y, pairs=pairs, ranks=ranks, groups=groups,
127
+ task=args.task,
128
+ target_type=args.target_type,
129
+ selection=args.selection,
130
+ verbose=args.verbose,
131
+ )
132
+ lens.save(args.output)
133
+ if args.html:
134
+ lens.report_html(args.html)
135
+ print(f"saved fitted LENS to {args.output} ({channel} channel)")
136
+ return 0
137
+
138
+
139
+ def _cmd_score(args) -> int:
140
+ from .lens import LENS
141
+
142
+ lens = LENS.load(args.model)
143
+ texts = _read_lines(args.texts) if args.texts else None
144
+ refs = _read_lines(args.refs) if args.refs else None
145
+ feats = _read_features_csv(args.features) if args.features else None
146
+
147
+ if texts is None and feats is None:
148
+ print("error: provide --texts or --features", file=sys.stderr)
149
+ return 2
150
+
151
+ scores = lens.score(texts, references=refs, features=feats)
152
+
153
+ out = Path(args.output)
154
+ with open(out, "w", newline="") as f:
155
+ w = csv.writer(f)
156
+ w.writerow(["score"])
157
+ for s in scores:
158
+ w.writerow([float(s)])
159
+ print(f"wrote {len(scores)} scores to {out}")
160
+ return 0
161
+
162
+
163
+ def _cmd_report(args) -> int:
164
+ from .lens import LENS
165
+
166
+ lens = LENS.load(args.model)
167
+ lens.report()
168
+ if args.html:
169
+ lens.report_html(args.html)
170
+ print(f"wrote HTML report to {args.html}")
171
+ return 0
172
+
173
+
174
+ def main(argv: Optional[List[str]] = None) -> int:
175
+ p = argparse.ArgumentParser(prog="lens-eval", description="lens-eval CLI")
176
+ sub = p.add_subparsers(dest="command", required=True)
177
+
178
+ pf = sub.add_parser("fit", help="fit a LENS combiner")
179
+ pf.add_argument("--texts", help="path to candidate texts (txt or csv)")
180
+ pf.add_argument("--refs", help="path to references (txt or csv)")
181
+ pf.add_argument("--features", help="path to precomputed feature matrix CSV (D columns)")
182
+ pf.add_argument("--scores", help="scalar targets CSV (1 column)")
183
+ pf.add_argument("--pairs", help="pairwise CSV: 2 columns (winner_idx, loser_idx)")
184
+ pf.add_argument("--ranks", help="ranking CSV: 1 column of integer ranks")
185
+ pf.add_argument("--groups", help="ranking grouping CSV: 1 column of group ids — REQUIRED with --ranks")
186
+ pf.add_argument("--output", required=True, help="where to save the fitted LENS directory")
187
+ pf.add_argument("--task", default="auto", choices=("auto", "regression", "pairwise", "ranking"))
188
+ pf.add_argument("--target-type", default="auto",
189
+ choices=("auto", "bounded", "ordinal", "binary", "continuous"))
190
+ pf.add_argument("--selection", default="auto",
191
+ choices=("auto", "fast", "exhaustive", "glm", "glm_interactions", "ebm", "gbm"))
192
+ pf.add_argument("--naturalness-mode", default="centroid", choices=("centroid", "reference"))
193
+ pf.add_argument("--html", help="optional: also write an HTML report to this path")
194
+ pf.add_argument("--verbose", action="store_true")
195
+ pf.set_defaults(func=_cmd_fit)
196
+
197
+ ps = sub.add_parser("score", help="score new texts with a saved LENS")
198
+ ps.add_argument("--model", required=True, help="path to a saved LENS directory")
199
+ ps.add_argument("--texts", help="path to candidate texts (txt or csv)")
200
+ ps.add_argument("--refs", help="path to references (txt or csv)")
201
+ ps.add_argument("--features", help="path to precomputed feature matrix CSV (D columns)")
202
+ ps.add_argument("--output", required=True, help="output CSV path for predictions")
203
+ ps.set_defaults(func=_cmd_score)
204
+
205
+ pr = sub.add_parser("report", help="print/render the selection report from a saved LENS")
206
+ pr.add_argument("model", help="path to a saved LENS directory")
207
+ pr.add_argument("--html", help="optional: write an HTML report to this path too")
208
+ pr.set_defaults(func=_cmd_report)
209
+
210
+ args = p.parse_args(argv)
211
+ return int(args.func(args) or 0)
212
+
213
+
214
+ if __name__ == "__main__":
215
+ raise SystemExit(main())
@@ -0,0 +1,93 @@
1
+ """Combiner registry.
2
+
3
+ Each combiner implements the BaseCombiner protocol. The `AVAILABLE` dict maps
4
+ combiner type names → (factory, is_backend_available) so the selection layer can
5
+ filter to what's actually installable without having to import each module
6
+ eagerly.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Callable, Dict, Optional, Tuple
12
+
13
+ from .base import BaseCombiner, expand_pairs
14
+ from .glm import GLMCombiner
15
+ from .glm_interactions import GLMInteractionsCombiner
16
+
17
+
18
+ def _ebm_available() -> bool:
19
+ # ImportError → package missing. OSError → native lib missing (e.g. some
20
+ # interpret-ml builds without their native deps). Either way: not usable.
21
+ try:
22
+ import interpret.glassbox # noqa: F401
23
+ except Exception:
24
+ return False
25
+ return True
26
+
27
+
28
+ def _gbm_available() -> bool:
29
+ # lightgbm's ctypes load can OSError on macOS without libomp; treat as
30
+ # "not installed" rather than crashing the candidate filter.
31
+ try:
32
+ import lightgbm # noqa: F401
33
+ except Exception:
34
+ return False
35
+ return True
36
+
37
+
38
+ def _ebm_factory(**kwargs) -> BaseCombiner:
39
+ # Factory wraps the import so AVAILABLE can be built without importing
40
+ # interpret-ml — the module is only touched when the user actually
41
+ # picks this combiner.
42
+ from .ebm import EBMCombiner
43
+ return EBMCombiner(**kwargs)
44
+
45
+
46
+ def _gbm_factory(**kwargs) -> BaseCombiner:
47
+ # Same deferred-import pattern as the EBM factory.
48
+ from .gbm import GBMCombiner
49
+ return GBMCombiner(**kwargs)
50
+
51
+
52
+ # Registry: name → (factory, availability_check). The factory is called only
53
+ # when this combiner wins selection; the check runs during candidate-filter
54
+ # without any heavy imports.
55
+ AVAILABLE: Dict[str, Tuple[Callable[..., BaseCombiner], Callable[[], bool]]] = {
56
+ "glm": (lambda **kw: GLMCombiner(**kw), lambda: True),
57
+ "glm_interactions": (lambda **kw: GLMInteractionsCombiner(**kw), lambda: True),
58
+ "ebm": (_ebm_factory, _ebm_available),
59
+ "gbm": (_gbm_factory, _gbm_available),
60
+ }
61
+
62
+ # Capacity ordering: simpler first. Used by the 1-SE selection rule to break
63
+ # ties toward lower-capacity (more interpretable, less prone to overfit).
64
+ CAPACITY_ORDER: Tuple[str, ...] = ("glm", "glm_interactions", "ebm", "gbm")
65
+
66
+
67
+ def make(combiner_type: str, **kwargs) -> BaseCombiner:
68
+ """Build a combiner by name."""
69
+ if combiner_type not in AVAILABLE:
70
+ raise ValueError(
71
+ f"unknown combiner {combiner_type!r}; choose from {list(AVAILABLE)}"
72
+ )
73
+ factory, check = AVAILABLE[combiner_type]
74
+ # Double-check at construction time — selection layer should have
75
+ # filtered already, but guard against a direct `make()` call too.
76
+ if not check():
77
+ from ..errors import CombinerBackendMissing
78
+ raise CombinerBackendMissing(
79
+ f"combiner {combiner_type!r} requires an optional dependency. "
80
+ f"Install with: pip install 'lens-eval[{combiner_type}]'"
81
+ )
82
+ return factory(**kwargs)
83
+
84
+
85
+ __all__ = [
86
+ "BaseCombiner",
87
+ "GLMCombiner",
88
+ "GLMInteractionsCombiner",
89
+ "AVAILABLE",
90
+ "CAPACITY_ORDER",
91
+ "expand_pairs",
92
+ "make",
93
+ ]
@@ -0,0 +1,93 @@
1
+ """BaseCombiner: the protocol every combiner tier implements.
2
+
3
+ Invariants:
4
+
5
+ 1. ``predict(X)`` returns scores in target space (logit/latent/raw — the
6
+ combiner decides). LENS applies any inverse link before returning to the
7
+ user; ranking only needs ordering, so we don't enforce calibration here.
8
+ 2. ``contributions(X)`` returns an (N, D) attribution matrix. Linear combiners
9
+ return ``X * coef``; EBM/GBM return per-feature SHAP-style contributions.
10
+ 3. ``coefficients()`` returns a dict of interpretable parameters used by the
11
+ report layer.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from abc import ABC, abstractmethod
17
+ from dataclasses import dataclass, field
18
+ from typing import Any, Dict, Optional
19
+
20
+ import numpy as np
21
+
22
+
23
+ @dataclass
24
+ class BaseCombiner(ABC):
25
+ """Common state every combiner carries."""
26
+
27
+ task: str = "regression" # 'regression' | 'pairwise' | 'ranking'
28
+ target_type: str = "continuous" # 'bounded' | 'ordinal' | 'binary' | 'continuous'
29
+ link: str = "identity" # 'identity' | 'logit' | 'cumulative_logit'
30
+ random_state: int = 42
31
+ hyperparameters: Dict[str, Any] = field(default_factory=dict)
32
+ feature_names: Optional[list] = None
33
+
34
+ # Subclasses MUST set this to True at the end of fit().
35
+ is_fitted_: bool = field(default=False, init=False)
36
+
37
+ @abstractmethod
38
+ def fit(
39
+ self,
40
+ X: np.ndarray,
41
+ y: Optional[np.ndarray] = None,
42
+ *,
43
+ pairs: Optional[np.ndarray] = None,
44
+ ranks: Optional[np.ndarray] = None,
45
+ groups: Optional[np.ndarray] = None,
46
+ ) -> "BaseCombiner":
47
+ ...
48
+
49
+ @abstractmethod
50
+ def predict(self, X: np.ndarray) -> np.ndarray:
51
+ """Score in target space; higher = better quality."""
52
+
53
+ @abstractmethod
54
+ def contributions(self, X: np.ndarray) -> np.ndarray:
55
+ """Per-sample per-feature contribution matrix, shape (N, D)."""
56
+
57
+ def coefficients(self) -> Dict[str, Any]:
58
+ """Interpretable parameters. Override to expose more."""
59
+ return {}
60
+
61
+ def _check_fitted(self) -> None:
62
+ if not self.is_fitted_:
63
+ raise RuntimeError(
64
+ f"{type(self).__name__} is not fitted; call .fit() before prediction."
65
+ )
66
+
67
+ @property
68
+ def type_name(self) -> str:
69
+ return _TYPE_NAME.get(type(self).__name__, type(self).__name__.lower())
70
+
71
+
72
+ _TYPE_NAME = {
73
+ "GLMCombiner": "glm",
74
+ "GLMInteractionsCombiner": "glm_interactions",
75
+ "EBMCombiner": "ebm",
76
+ "GBMCombiner": "gbm",
77
+ }
78
+
79
+
80
+ def expand_pairs(X: np.ndarray, pairs: np.ndarray):
81
+ """Bradley-Terry antisymmetric expansion.
82
+
83
+ Each pair ``(a, b)`` (a beat b) becomes two training rows: ``(x_a - x_b, 1)``
84
+ and ``(x_b - x_a, 0)``. Returns ``(X_diff, y_bin)`` of length ``2 * len(pairs)``.
85
+ """
86
+ pairs = np.asarray(pairs, dtype=int)
87
+ diff = X[pairs[:, 0]] - X[pairs[:, 1]]
88
+ X_use = np.vstack([diff, -diff])
89
+ y_use = np.concatenate([
90
+ np.ones(len(pairs), dtype=int),
91
+ np.zeros(len(pairs), dtype=int),
92
+ ])
93
+ return X_use, y_use
@@ -0,0 +1,141 @@
1
+ """EBM combiner — wraps interpret-ml's Explainable Boosting Machine."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Dict, Optional
7
+
8
+ import numpy as np
9
+
10
+ from .base import BaseCombiner, expand_pairs
11
+
12
+
13
+ @dataclass
14
+ class EBMCombiner(BaseCombiner):
15
+ max_bins: int = 256
16
+ outer_bags: int = 8
17
+ interactions: Any = "auto" # 0, int, or "auto"
18
+
19
+ model_: Optional[Any] = field(default=None, init=False, repr=False)
20
+
21
+ def fit(
22
+ self,
23
+ X: np.ndarray,
24
+ y: Optional[np.ndarray] = None,
25
+ *,
26
+ pairs: Optional[np.ndarray] = None,
27
+ ranks: Optional[np.ndarray] = None,
28
+ groups: Optional[np.ndarray] = None,
29
+ ) -> "EBMCombiner":
30
+ # Deferred import so the module is importable without interpret-ml.
31
+ from interpret.glassbox import (
32
+ ExplainableBoostingClassifier,
33
+ ExplainableBoostingRegressor,
34
+ )
35
+
36
+ X = np.asarray(X, dtype=float)
37
+
38
+ # `interactions`: number of pairwise interactions EBM auto-discovers
39
+ # via FAST. n_features - 1 is a sensible default at D=4 (3 pairs).
40
+ inter = max(0, X.shape[1] - 1) if self.interactions == "auto" else self.interactions
41
+
42
+ common = dict(
43
+ max_bins=self.max_bins,
44
+ outer_bags=self.outer_bags,
45
+ interactions=inter,
46
+ random_state=self.random_state,
47
+ )
48
+ # Push real dim names through to interpret-ml so explain_global /
49
+ # explain_local don't return "feature_0000" — that breaks both
50
+ # LENS.feature_importance() (name-based lookup) and the text report
51
+ # (interactions render as "feature_0000 & feature_0001" otherwise).
52
+ if self.feature_names:
53
+ common["feature_names"] = list(self.feature_names)
54
+
55
+ if self.task == "pairwise":
56
+ if pairs is not None:
57
+ X_use, y_use = expand_pairs(X, pairs)
58
+ else:
59
+ X_use = X
60
+ y_use = np.asarray(y, dtype=int).ravel()
61
+ self.model_ = ExplainableBoostingClassifier(**common)
62
+ self.model_.fit(X_use, y_use)
63
+ elif self.target_type == "binary":
64
+ self.model_ = ExplainableBoostingClassifier(**common)
65
+ self.model_.fit(X, np.asarray(y, dtype=int).ravel())
66
+ else:
67
+ self.model_ = ExplainableBoostingRegressor(**common)
68
+ self.model_.fit(X, np.asarray(y, dtype=float).ravel())
69
+
70
+ self.is_fitted_ = True
71
+ return self
72
+
73
+ def predict(self, X: np.ndarray) -> np.ndarray:
74
+ self._check_fitted()
75
+ X = np.asarray(X, dtype=float)
76
+ # Classifier → positive-class probability, matching GLM.predict shape.
77
+ if self.task == "pairwise" or self.target_type == "binary":
78
+ return self.model_.predict_proba(X)[:, 1]
79
+ return self.model_.predict(X)
80
+
81
+ def contributions(self, X: np.ndarray) -> np.ndarray:
82
+ """Per-feature local contributions via interpret-ml's ``explain_local``.
83
+
84
+ Interaction-term contributions are split 50/50 between their parents so
85
+ the result stays (N, D_base). Falls back to finite differences when
86
+ ``explain_local`` raises (the API shifts across interpret-ml versions).
87
+ """
88
+ self._check_fitted()
89
+ X = np.asarray(X, dtype=float)
90
+ D = X.shape[1]
91
+ try:
92
+ explanation = self.model_.explain_local(X)
93
+ contribs = np.zeros((X.shape[0], D), dtype=float)
94
+ base_names = (list(self.model_.feature_names_in_)
95
+ if hasattr(self.model_, "feature_names_in_")
96
+ else [f"feature_{i:04d}" for i in range(D)])
97
+ name_to_idx = {n: i for i, n in enumerate(base_names)}
98
+ for row_idx in range(X.shape[0]):
99
+ row = explanation.data(row_idx)
100
+ if row is None:
101
+ continue
102
+ for n, s in zip(row["names"], row["scores"]):
103
+ n = str(n)
104
+ if " & " in n:
105
+ a, b = n.split(" & ", 1)
106
+ if a in name_to_idx and b in name_to_idx:
107
+ contribs[row_idx, name_to_idx[a]] += 0.5 * float(s)
108
+ contribs[row_idx, name_to_idx[b]] += 0.5 * float(s)
109
+ elif n in name_to_idx:
110
+ contribs[row_idx, name_to_idx[n]] += float(s)
111
+ return contribs
112
+ except Exception:
113
+ return _finite_diff_contribs(self.predict, X)
114
+
115
+ def coefficients(self) -> Dict[str, Any]:
116
+ out: Dict[str, Any] = {}
117
+ if self.model_ is None:
118
+ return out
119
+ try:
120
+ g = self.model_.explain_global().data()
121
+ out["feature_names"] = list(g.get("names", []))
122
+ out["feature_importances"] = list(g.get("scores", []))
123
+ except Exception:
124
+ pass
125
+ out["hyperparameters"] = {
126
+ "max_bins": self.max_bins,
127
+ "outer_bags": self.outer_bags,
128
+ "interactions": self.interactions,
129
+ }
130
+ return out
131
+
132
+
133
+ def _finite_diff_contribs(predict_fn, X: np.ndarray, eps: float = 1e-3) -> np.ndarray:
134
+ """Centred-difference fallback: ∂f/∂x_j × x_j (Taylor expansion around 0)."""
135
+ D = X.shape[1]
136
+ out = np.zeros_like(X, dtype=float)
137
+ for j in range(D):
138
+ Xp = X.copy(); Xp[:, j] += eps
139
+ Xm = X.copy(); Xm[:, j] -= eps
140
+ out[:, j] = (predict_fn(Xp) - predict_fn(Xm)) / (2 * eps) * X[:, j]
141
+ return out