slatex 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
slatex/__init__.py ADDED
@@ -0,0 +1,18 @@
1
+ """slatex — Sparse Lightweight Additive Threshold Ensemble.
2
+
3
+ A small, fast, interpretable classifier for constrained hardware (edge devices,
4
+ microcontrollers, TinyML). SLATE is an additive model (a GAM) built from simple
5
+ threshold rules of the form "1 if a feature is at or below a learned cut, else 0",
6
+ trained by budgeted, L1-regularized Newton boosting.
7
+
8
+ Basic usage
9
+ -----------
10
+ >>> from slatex import SlateClassifier
11
+ >>> clf = SlateClassifier(budget=32).fit(X_train, y_train)
12
+ >>> clf.predict(X_test)
13
+ >>> clf.predict_proba(X_test)
14
+ """
15
+ from .classifier import NotFittedError, SlateClassifier
16
+
17
+ __all__ = ["SlateClassifier", "NotFittedError", "__version__"]
18
+ __version__ = "0.1.0"
slatex/_core.py ADDED
@@ -0,0 +1,233 @@
1
+ """Binary SLATE engine: budgeted, fully-corrective, L1-regularized Newton
2
+ boosting over axis-aligned threshold indicator atoms ``h_{j,t}(x) = 1[x_j <= t]``.
3
+
4
+ This module is internal. Users should use :class:`slatex.SlateClassifier`, which
5
+ wraps this engine with label encoding, input validation, and multiclass support.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import numpy as np
10
+
11
+ __all__ = ["_BinarySlate", "_sigmoid"]
12
+
13
+
14
+ def _sigmoid(z):
15
+ """Numerically stable logistic sigmoid."""
16
+ z = np.asarray(z, dtype=np.float64)
17
+ out = np.empty_like(z)
18
+ pos = z >= 0
19
+ out[pos] = 1.0 / (1.0 + np.exp(-z[pos]))
20
+ ez = np.exp(z[~pos])
21
+ out[~pos] = ez / (1.0 + ez)
22
+ return out
23
+
24
+
25
+ class _BinarySlate:
26
+ """Binary additive threshold ensemble trained by budgeted Newton boosting.
27
+
28
+ Operates on a pre-validated, finite, float64 design matrix ``X`` and a label
29
+ vector ``y`` with values in ``{0, 1}``. All user-facing concerns (label
30
+ encoding, validation, multiclass) are handled one layer up.
31
+ """
32
+
33
+ def __init__(self, budget=64, n_bins=32, max_iter=400, learning_rate=0.5,
34
+ l2=2.0, l1=1e-3, corrective_every=5, corrective_passes=2,
35
+ tol=1e-7):
36
+ self.budget = budget
37
+ self.n_bins = n_bins
38
+ self.max_iter = max_iter
39
+ self.learning_rate = learning_rate
40
+ self.l2 = l2
41
+ self.l1 = l1
42
+ self.corrective_every = corrective_every
43
+ self.corrective_passes = corrective_passes
44
+ self.tol = tol
45
+
46
+ # ------------------------------------------------------------------ #
47
+ def _bin_features(self, X):
48
+ """Quantile-bin every feature; store per-feature threshold grids."""
49
+ n, d = X.shape
50
+ self.thresholds_ = []
51
+ codes = np.empty((n, d), dtype=np.int32)
52
+ qs = np.linspace(0, 1, self.n_bins + 1)[1:-1]
53
+ for j in range(d):
54
+ col = X[:, j]
55
+ if qs.size:
56
+ t = np.unique(np.quantile(col, qs))
57
+ else:
58
+ t = np.empty(0, dtype=np.float64)
59
+ # drop degenerate thresholds (>= max never splits anything off)
60
+ if t.size and t[-1] >= col.max():
61
+ t = t[t < col.max()]
62
+ self.thresholds_.append(t.astype(np.float64))
63
+ # searchsorted maps each value to its bin index in [0, len(t)]
64
+ codes[:, j] = np.searchsorted(t, col, side="left")
65
+ return codes
66
+
67
+ # ------------------------------------------------------------------ #
68
+ def fit(self, X, y):
69
+ n, d = X.shape
70
+ codes = self._bin_features(X)
71
+ nthr = np.array([t.size for t in self.thresholds_], dtype=np.int64)
72
+
73
+ # model state -------------------------------------------------- #
74
+ p0 = float(np.clip(y.mean(), 1e-6, 1 - 1e-6))
75
+ self.intercept_ = float(np.log(p0 / (1 - p0)))
76
+ F = np.full(n, self.intercept_)
77
+ atoms: dict[tuple[int, int], float] = {} # (j, bin) -> alpha
78
+ masks: dict[tuple[int, int], np.ndarray] = {} # cached indicator vectors
79
+
80
+ ll_prev = np.inf
81
+ for it in range(self.max_iter):
82
+ p = _sigmoid(F)
83
+ g = p - y # gradient of logloss
84
+ h = np.maximum(p * (1 - p), 1e-12) # Hessian diagonal
85
+
86
+ # --- greedy atom selection via per-feature histograms ----- #
87
+ best_gain, best = -1.0, None
88
+ for j in range(d):
89
+ k = int(nthr[j])
90
+ if k == 0:
91
+ continue
92
+ Gh = np.bincount(codes[:, j], weights=g, minlength=k + 1)
93
+ Hh = np.bincount(codes[:, j], weights=h, minlength=k + 1)
94
+ Gc = np.cumsum(Gh)[:k] # sum of g over {x_j <= t_b}
95
+ Hc = np.cumsum(Hh)[:k]
96
+ gains = Gc * Gc / (Hc + self.l2)
97
+ b = int(np.argmax(gains))
98
+ if gains[b] > best_gain:
99
+ best_gain, best = float(gains[b]), (j, b, Gc[b], Hc[b])
100
+
101
+ if best is None or best_gain < self.tol:
102
+ break
103
+ j, b, G, H = best
104
+ key = (j, b)
105
+ if key not in atoms and len(atoms) >= self.budget:
106
+ # budget full: restrict the greedy step to existing atoms
107
+ key, G, H = self._best_existing(masks, g, h)
108
+ if key is None:
109
+ break
110
+ if key not in masks:
111
+ masks[key] = (codes[:, key[0]] <= key[1])
112
+ atoms.setdefault(key, 0.0)
113
+ G = g[masks[key]].sum()
114
+ H = h[masks[key]].sum()
115
+ step = -self.learning_rate * G / (H + self.l2)
116
+ atoms[key] += step
117
+ F += step * masks[key]
118
+
119
+ # --- intercept Newton update ------------------------------ #
120
+ p = _sigmoid(F)
121
+ db = -(p - y).sum() / (np.maximum(p * (1 - p), 1e-12).sum() + self.l2)
122
+ self.intercept_ += db
123
+ F += db
124
+
125
+ # --- fully-corrective proximal phase ---------------------- #
126
+ if (it + 1) % self.corrective_every == 0:
127
+ F = self._corrective(F, y, atoms, masks)
128
+ ll = self._logloss(F, y)
129
+ if ll_prev - ll < 1e-6 and len(atoms) >= self.budget:
130
+ break
131
+ ll_prev = ll
132
+
133
+ # final corrective polish
134
+ F = self._corrective(F, y, atoms, masks)
135
+ self._pack(atoms)
136
+ return self
137
+
138
+ # ------------------------------------------------------------------ #
139
+ def _best_existing(self, masks, g, h):
140
+ best_gain, best_key, bG, bH = -1.0, None, 0.0, 0.0
141
+ for key, m in masks.items():
142
+ G = g[m].sum()
143
+ H = h[m].sum()
144
+ gain = G * G / (H + self.l2)
145
+ if gain > best_gain:
146
+ best_gain, best_key, bG, bH = gain, key, G, H
147
+ return best_key, bG, bH
148
+
149
+ def _corrective(self, F, y, atoms, masks):
150
+ """Cyclic Newton + soft-threshold (prox of l1) on the active set."""
151
+ for _ in range(self.corrective_passes):
152
+ for key in list(atoms.keys()):
153
+ m = masks[key]
154
+ p = _sigmoid(F)
155
+ g = p - y
156
+ h = np.maximum(p * (1 - p), 1e-12)
157
+ G = g[m].sum()
158
+ H = h[m].sum() + self.l2
159
+ anew = atoms[key] - G / H
160
+ anew = np.sign(anew) * max(abs(anew) - self.l1 / H, 0.0) # prox_{l1/H}
161
+ dF = anew - atoms[key]
162
+ if dF != 0.0:
163
+ F = F + dF * m
164
+ atoms[key] = anew
165
+ if atoms[key] == 0.0:
166
+ del atoms[key]
167
+ del masks[key]
168
+ # intercept
169
+ p = _sigmoid(F)
170
+ db = -(p - y).sum() / (np.maximum(p * (1 - p), 1e-12).sum() + self.l2)
171
+ self.intercept_ += db
172
+ F = F + db
173
+ return F
174
+
175
+ @staticmethod
176
+ def _logloss(F, y):
177
+ if F.size == 0:
178
+ return 0.0
179
+ return float(np.mean(np.log1p(np.exp(-np.where(y == 1, F, -F)))))
180
+
181
+ def _pack(self, atoms):
182
+ """Pack model into flat arrays for O(B) vectorized inference.
183
+
184
+ After packing, the per-feature quantile grids in ``thresholds_`` are no
185
+ longer needed (every inference and interpretability path uses the packed
186
+ ``atom_threshold_`` array instead), so they are dropped to keep the
187
+ serialized model as small as the live inference footprint.
188
+ """
189
+ keys = sorted(atoms.keys())
190
+ self.atom_feature_ = np.array([k[0] for k in keys], dtype=np.int32)
191
+ self.atom_threshold_ = np.array(
192
+ [self.thresholds_[k[0]][k[1]] for k in keys], dtype=np.float64)
193
+ self.atom_coef_ = np.array([atoms[k] for k in keys], dtype=np.float64)
194
+ self.n_atoms_ = len(keys)
195
+ # free training-only scaffolding so a pickled model matches its true
196
+ # deployment footprint (~1 KB instead of tens of KB)
197
+ del self.thresholds_
198
+
199
+ # ------------------------------------------------------------------ #
200
+ def decision_function(self, X):
201
+ """Raw additive margin F(x). Vectorized: O(n * B)."""
202
+ n = X.shape[0]
203
+ if self.n_atoms_ == 0:
204
+ return np.full(n, self.intercept_)
205
+ # (n, B) indicator matrix @ (B,) coefficients
206
+ ind = X[:, self.atom_feature_] <= self.atom_threshold_
207
+ return self.intercept_ + ind @ self.atom_coef_
208
+
209
+ def predict_proba_pos(self, X):
210
+ """P(y = 1 | x) as a 1-D array."""
211
+ return _sigmoid(self.decision_function(X))
212
+
213
+ # ----------------------- interpretability ------------------------- #
214
+ def shape_function(self, j, grid):
215
+ """Exact additive contribution of feature ``j`` over a grid of values."""
216
+ grid = np.asarray(grid, dtype=np.float64)
217
+ out = np.zeros_like(grid)
218
+ sel = self.atom_feature_ == j
219
+ for t, a in zip(self.atom_threshold_[sel], self.atom_coef_[sel]):
220
+ out += a * (grid <= t)
221
+ return out
222
+
223
+ def shapley_values(self, X, X_background):
224
+ """Exact Shapley attributions (closed form for additive models):
225
+ ``phi_j(x) = f_j(x_j) - E[f_j(X_j)]`` against a background set."""
226
+ X = np.asarray(X, dtype=np.float64)
227
+ B = np.asarray(X_background, dtype=np.float64)
228
+ phi = np.zeros((X.shape[0], X.shape[1]))
229
+ for jf, t, a in zip(self.atom_feature_, self.atom_threshold_,
230
+ self.atom_coef_):
231
+ mean_bg = a * (B[:, jf] <= t).mean()
232
+ phi[:, jf] += a * (X[:, jf] <= t) - mean_bg
233
+ return phi
slatex/classifier.py ADDED
@@ -0,0 +1,355 @@
1
+ """Public SLATE estimator: a sparse, lightweight, interpretable classifier."""
2
+ from __future__ import annotations
3
+
4
+ import inspect
5
+
6
+ import numpy as np
7
+
8
+ from ._core import _BinarySlate, _sigmoid
9
+
10
+ __all__ = ["SlateClassifier", "NotFittedError"]
11
+
12
+
13
+ class NotFittedError(ValueError, AttributeError):
14
+ """Raised when prediction is attempted on an unfitted estimator."""
15
+
16
+
17
+ def _check_X(X, *, ensure_2d=True):
18
+ """Validate a design matrix: finite float64, 2-D, non-empty."""
19
+ X = np.asarray(X, dtype=np.float64)
20
+ if ensure_2d:
21
+ if X.ndim == 1:
22
+ raise ValueError(
23
+ "Expected a 2-D array, got 1-D. Reshape your data with "
24
+ "X.reshape(-1, 1) for a single feature or X.reshape(1, -1) "
25
+ "for a single sample."
26
+ )
27
+ if X.ndim != 2:
28
+ raise ValueError(f"Expected a 2-D array, got {X.ndim}-D.")
29
+ if X.size == 0:
30
+ raise ValueError("Found an empty array with 0 sample(s) or feature(s).")
31
+ if not np.all(np.isfinite(X)):
32
+ raise ValueError(
33
+ "Input X contains NaN or infinity. SLATE requires finite inputs; "
34
+ "impute or clean your data first (e.g. sklearn SimpleImputer)."
35
+ )
36
+ return X
37
+
38
+
39
+ class SlateClassifier:
40
+ """Sparse Lightweight Additive Threshold Ensemble.
41
+
42
+ A small, fast, interpretable classifier built from simple threshold rules
43
+ of the form ``1 if a feature is at or below a learned cut, else 0``. It is
44
+ a Generalized Additive Model trained by budgeted, L1-regularized Newton
45
+ boosting, designed for constrained hardware (edge devices, microcontrollers,
46
+ TinyML).
47
+
48
+ Binary and multiclass targets are both supported; multiclass is handled
49
+ internally via one-vs-rest. Labels may be of any type (ints, strings, etc.)
50
+ and are encoded automatically.
51
+
52
+ Parameters
53
+ ----------
54
+ budget : int, default=64
55
+ Hard cap ``B`` on the number of distinct threshold atoms per binary
56
+ model (controls model size and inference cost).
57
+ n_bins : int, default=32
58
+ Maximum number of quantile bins per feature (dictionary granularity).
59
+ max_iter : int, default=400
60
+ Maximum number of boosting iterations per binary model.
61
+ learning_rate : float, default=0.5
62
+ Shrinkage applied to each greedy Newton step.
63
+ l2 : float, default=2.0
64
+ Hessian ridge (Newton damping) per atom.
65
+ l1 : float, default=1e-3
66
+ Soft-threshold level in the fully-corrective proximal passes (drives
67
+ exact sparsity / prunes weak atoms).
68
+ corrective_every : int, default=5
69
+ Run a fully-corrective cyclic Newton + prox pass every ``k`` iterations.
70
+ corrective_passes : int, default=2
71
+ Number of cyclic passes per corrective phase.
72
+ tol : float, default=1e-7
73
+ Stop when the best available Newton gain falls below ``tol``.
74
+ random_state : int or None, default=0
75
+ Accepted for API compatibility. Training is deterministic, so this has
76
+ no effect on results.
77
+
78
+ Attributes
79
+ ----------
80
+ classes_ : ndarray
81
+ The class labels seen during :meth:`fit`, in sorted order.
82
+ n_features_in_ : int
83
+ Number of features seen during :meth:`fit`.
84
+ n_atoms_ : int
85
+ Total number of threshold atoms across all internal binary models.
86
+ n_parameters_ : int
87
+ Total parameter count (coefficients + thresholds + intercepts).
88
+ memory_bytes_ : int
89
+ Approximate size of the packed model in bytes.
90
+
91
+ Examples
92
+ --------
93
+ >>> from slatex import SlateClassifier
94
+ >>> import numpy as np
95
+ >>> X = np.random.RandomState(0).randn(200, 5)
96
+ >>> y = (X[:, 0] + X[:, 1] > 0).astype(int)
97
+ >>> clf = SlateClassifier(budget=16).fit(X, y)
98
+ >>> clf.predict(X[:3])
99
+ array([...])
100
+ """
101
+
102
+ _estimator_type = "classifier"
103
+
104
+ def __init__(self, budget=64, n_bins=32, max_iter=400, learning_rate=0.5,
105
+ l2=2.0, l1=1e-3, corrective_every=5, corrective_passes=2,
106
+ tol=1e-7, random_state=0):
107
+ self.budget = budget
108
+ self.n_bins = n_bins
109
+ self.max_iter = max_iter
110
+ self.learning_rate = learning_rate
111
+ self.l2 = l2
112
+ self.l1 = l1
113
+ self.corrective_every = corrective_every
114
+ self.corrective_passes = corrective_passes
115
+ self.tol = tol
116
+ self.random_state = random_state
117
+
118
+ # ----------------------- sklearn-style params --------------------- #
119
+ @classmethod
120
+ def _param_names(cls):
121
+ sig = inspect.signature(cls.__init__)
122
+ return sorted(p for p in sig.parameters if p != "self")
123
+
124
+ def __sklearn_tags__(self):
125
+ """Lazy scikit-learn (>=1.6) tag hook.
126
+
127
+ Defined so the estimator integrates with cross-validation, pipelines,
128
+ and grid search on modern scikit-learn, while keeping scikit-learn an
129
+ *optional* dependency (the import only happens if sklearn calls this).
130
+ Construction is delegated to scikit-learn's own ``BaseEstimator`` so it
131
+ stays correct across versions.
132
+ """
133
+ from sklearn.base import BaseEstimator
134
+ from sklearn.utils import ClassifierTags
135
+ tags = BaseEstimator.__sklearn_tags__(self)
136
+ tags.estimator_type = "classifier"
137
+ tags.classifier_tags = ClassifierTags()
138
+ return tags
139
+
140
+ def get_params(self, deep=True):
141
+ """Return estimator parameters (scikit-learn compatible)."""
142
+ return {k: getattr(self, k) for k in self._param_names()}
143
+
144
+ def set_params(self, **params):
145
+ """Set estimator parameters (scikit-learn compatible)."""
146
+ valid = set(self._param_names())
147
+ for k, v in params.items():
148
+ if k not in valid:
149
+ raise ValueError(
150
+ f"Invalid parameter {k!r} for SlateClassifier. "
151
+ f"Valid parameters are: {sorted(valid)}."
152
+ )
153
+ setattr(self, k, v)
154
+ return self
155
+
156
+ # --------------------------- validation --------------------------- #
157
+ def _validate_params(self):
158
+ if not (isinstance(self.budget, (int, np.integer)) and self.budget >= 1):
159
+ raise ValueError(f"budget must be a positive int, got {self.budget!r}.")
160
+ if not (isinstance(self.n_bins, (int, np.integer)) and self.n_bins >= 2):
161
+ raise ValueError(f"n_bins must be an int >= 2, got {self.n_bins!r}.")
162
+ if not (isinstance(self.max_iter, (int, np.integer)) and self.max_iter >= 1):
163
+ raise ValueError(f"max_iter must be a positive int, got {self.max_iter!r}.")
164
+ if not (self.learning_rate > 0):
165
+ raise ValueError(
166
+ f"learning_rate must be > 0, got {self.learning_rate!r}.")
167
+ if self.l2 < 0 or self.l1 < 0:
168
+ raise ValueError("l1 and l2 must be non-negative.")
169
+ if not (isinstance(self.corrective_every, (int, np.integer))
170
+ and self.corrective_every >= 1):
171
+ raise ValueError("corrective_every must be a positive int.")
172
+ if not (isinstance(self.corrective_passes, (int, np.integer))
173
+ and self.corrective_passes >= 0):
174
+ raise ValueError("corrective_passes must be a non-negative int.")
175
+
176
+ def _core_kwargs(self):
177
+ return dict(budget=int(self.budget), n_bins=int(self.n_bins),
178
+ max_iter=int(self.max_iter),
179
+ learning_rate=float(self.learning_rate),
180
+ l2=float(self.l2), l1=float(self.l1),
181
+ corrective_every=int(self.corrective_every),
182
+ corrective_passes=int(self.corrective_passes),
183
+ tol=float(self.tol))
184
+
185
+ def _check_is_fitted(self):
186
+ if not hasattr(self, "estimators_"):
187
+ raise NotFittedError(
188
+ "This SlateClassifier instance is not fitted yet. Call 'fit' "
189
+ "with appropriate arguments before using this estimator."
190
+ )
191
+
192
+ # ------------------------------ fit ------------------------------- #
193
+ def fit(self, X, y):
194
+ """Fit the model.
195
+
196
+ Parameters
197
+ ----------
198
+ X : array-like of shape (n_samples, n_features)
199
+ Finite numeric training data.
200
+ y : array-like of shape (n_samples,)
201
+ Target labels (any hashable type; encoded automatically).
202
+ """
203
+ self._validate_params()
204
+ X = _check_X(X)
205
+ y = np.asarray(y).ravel()
206
+ if y.shape[0] != X.shape[0]:
207
+ raise ValueError(
208
+ f"X and y have inconsistent lengths: {X.shape[0]} vs {y.shape[0]}.")
209
+
210
+ self.classes_, y_enc = np.unique(y, return_inverse=True)
211
+ if self.classes_.shape[0] < 2:
212
+ raise ValueError(
213
+ "Classifier can't train when only one class is present. "
214
+ f"Found classes: {self.classes_.tolist()}."
215
+ )
216
+ self.n_features_in_ = X.shape[1]
217
+ kw = self._core_kwargs()
218
+
219
+ if self.classes_.shape[0] == 2:
220
+ target = (y_enc == 1).astype(np.float64)
221
+ self.estimators_ = [_BinarySlate(**kw).fit(X, target)]
222
+ self._multiclass = False
223
+ else:
224
+ self.estimators_ = [
225
+ _BinarySlate(**kw).fit(X, (y_enc == c).astype(np.float64))
226
+ for c in range(self.classes_.shape[0])
227
+ ]
228
+ self._multiclass = True
229
+ return self
230
+
231
+ # --------------------------- inference ---------------------------- #
232
+ def _check_predict_X(self, X):
233
+ self._check_is_fitted()
234
+ X = _check_X(X)
235
+ if X.shape[1] != self.n_features_in_:
236
+ raise ValueError(
237
+ f"X has {X.shape[1]} features, but SlateClassifier was fitted "
238
+ f"with {self.n_features_in_} features."
239
+ )
240
+ return X
241
+
242
+ def decision_function(self, X):
243
+ """Confidence scores.
244
+
245
+ Returns a 1-D margin array for binary problems, or a 2-D array of
246
+ per-class margins of shape ``(n_samples, n_classes)`` for multiclass.
247
+ """
248
+ X = self._check_predict_X(X)
249
+ if not self._multiclass:
250
+ return self.estimators_[0].decision_function(X)
251
+ return np.column_stack([e.decision_function(X) for e in self.estimators_])
252
+
253
+ def predict_proba(self, X):
254
+ """Probability estimates of shape ``(n_samples, n_classes)``."""
255
+ X = self._check_predict_X(X)
256
+ if not self._multiclass:
257
+ p = self.estimators_[0].predict_proba_pos(X)
258
+ return np.column_stack([1.0 - p, p])
259
+ P = np.column_stack([e.predict_proba_pos(X) for e in self.estimators_])
260
+ # one-vs-rest normalization into a proper distribution
261
+ denom = np.clip(P.sum(axis=1, keepdims=True), 1e-12, None)
262
+ return P / denom
263
+
264
+ def predict_log_proba(self, X):
265
+ """Log of :meth:`predict_proba`."""
266
+ return np.log(np.clip(self.predict_proba(X), 1e-12, None))
267
+
268
+ def predict(self, X):
269
+ """Predict class labels for ``X``."""
270
+ proba = self.predict_proba(X)
271
+ idx = np.argmax(proba, axis=1)
272
+ return self.classes_[idx]
273
+
274
+ def score(self, X, y):
275
+ """Mean accuracy on the given test data and labels."""
276
+ y = np.asarray(y).ravel()
277
+ return float(np.mean(self.predict(X) == y))
278
+
279
+ # ----------------------- interpretability ------------------------- #
280
+ def _resolve_target(self, target):
281
+ """Map a user-supplied class label/index to an internal estimator index."""
282
+ if not self._multiclass:
283
+ if target is not None and target not in (None, 1, self.classes_[1]):
284
+ raise ValueError(
285
+ "For binary problems, interpretability is reported for the "
286
+ "positive class; leave target=None."
287
+ )
288
+ return 0
289
+ if target is None:
290
+ raise ValueError(
291
+ "target is required for multiclass models. Pass one of "
292
+ f"classes_={self.classes_.tolist()}."
293
+ )
294
+ matches = np.where(self.classes_ == target)[0]
295
+ if matches.size == 0:
296
+ raise ValueError(
297
+ f"Unknown target {target!r}. Valid classes: {self.classes_.tolist()}.")
298
+ return int(matches[0])
299
+
300
+ def shape_function(self, feature, grid, target=None):
301
+ """Exact additive contribution of one feature over a grid of values.
302
+
303
+ Parameters
304
+ ----------
305
+ feature : int
306
+ Feature index.
307
+ grid : array-like
308
+ Values of the feature to evaluate.
309
+ target : label, optional
310
+ Required for multiclass: which class's additive model to inspect.
311
+ Ignored for binary (the positive class is used).
312
+ """
313
+ self._check_is_fitted()
314
+ if not (0 <= feature < self.n_features_in_):
315
+ raise ValueError(
316
+ f"feature must be in [0, {self.n_features_in_}), got {feature}.")
317
+ est = self.estimators_[self._resolve_target(target)]
318
+ return est.shape_function(int(feature), grid)
319
+
320
+ def shapley_values(self, X, X_background, target=None):
321
+ """Exact per-feature Shapley attributions against a background set.
322
+
323
+ For multiclass, pass ``target`` to select the class model; the returned
324
+ array has shape ``(n_samples, n_features)``.
325
+ """
326
+ X = self._check_predict_X(X)
327
+ X_background = _check_X(X_background)
328
+ if X_background.shape[1] != self.n_features_in_:
329
+ raise ValueError(
330
+ f"X_background has {X_background.shape[1]} features, expected "
331
+ f"{self.n_features_in_}.")
332
+ est = self.estimators_[self._resolve_target(target)]
333
+ return est.shapley_values(X, X_background)
334
+
335
+ # ------------------------- model size ----------------------------- #
336
+ @property
337
+ def n_atoms_(self):
338
+ self._check_is_fitted()
339
+ return int(sum(e.n_atoms_ for e in self.estimators_))
340
+
341
+ @property
342
+ def n_parameters_(self):
343
+ self._check_is_fitted()
344
+ # per binary model: 2 * n_atoms (coef + threshold) + 1 intercept
345
+ return int(sum(2 * e.n_atoms_ + 1 for e in self.estimators_))
346
+
347
+ @property
348
+ def memory_bytes_(self):
349
+ self._check_is_fitted()
350
+ # per atom: 8 (coef) + 8 (threshold) + 4 (feature id); + 8 per intercept
351
+ return int(sum(e.n_atoms_ * (8 + 8 + 4) + 8 for e in self.estimators_))
352
+
353
+ def __repr__(self):
354
+ params = ", ".join(f"{k}={getattr(self, k)!r}" for k in self._param_names())
355
+ return f"SlateClassifier({params})"
@@ -0,0 +1,174 @@
1
+ Metadata-Version: 2.4
2
+ Name: slatex
3
+ Version: 0.1.0
4
+ Summary: Sparse Lightweight Additive Threshold Ensemble — a small, fast, interpretable classifier for edge devices and TinyML.
5
+ Project-URL: Homepage, https://github.com/saikirangogineni/slatex
6
+ Project-URL: Repository, https://github.com/saikirangogineni/slatex
7
+ Project-URL: Issues, https://github.com/saikirangogineni/slatex/issues
8
+ Author-email: Saikiran Gogineni <goginenisaikiran31677@gmail.com>
9
+ Maintainer-email: Saikiran Gogineni <goginenisaikiran31677@gmail.com>
10
+ License: MIT License
11
+
12
+ Copyright (c) 2026 Saikiran Gogineni
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ The above copyright notice and this permission notice shall be included in all
22
+ copies or substantial portions of the Software.
23
+
24
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ SOFTWARE.
31
+ License-File: LICENSE
32
+ Keywords: boosting,classifier,edge-ai,explainable-ai,gam,generalized-additive-model,interpretable-ml,machine-learning,sparse-models,tinyml
33
+ Classifier: Development Status :: 4 - Beta
34
+ Classifier: Intended Audience :: Developers
35
+ Classifier: Intended Audience :: Science/Research
36
+ Classifier: License :: OSI Approved :: MIT License
37
+ Classifier: Operating System :: OS Independent
38
+ Classifier: Programming Language :: Python :: 3
39
+ Classifier: Programming Language :: Python :: 3.9
40
+ Classifier: Programming Language :: Python :: 3.10
41
+ Classifier: Programming Language :: Python :: 3.11
42
+ Classifier: Programming Language :: Python :: 3.12
43
+ Classifier: Programming Language :: Python :: 3.13
44
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
45
+ Requires-Python: >=3.9
46
+ Requires-Dist: numpy>=1.21
47
+ Provides-Extra: dev
48
+ Requires-Dist: build; extra == 'dev'
49
+ Requires-Dist: pytest>=7.0; extra == 'dev'
50
+ Requires-Dist: scikit-learn>=1.0; extra == 'dev'
51
+ Requires-Dist: twine; extra == 'dev'
52
+ Provides-Extra: test
53
+ Requires-Dist: pytest>=7.0; extra == 'test'
54
+ Requires-Dist: scikit-learn>=1.0; extra == 'test'
55
+ Description-Content-Type: text/markdown
56
+
57
+ # slatex
58
+
59
+ **Sparse Lightweight Additive Threshold Ensemble** — a small, fast, interpretable
60
+ classifier for constrained hardware (edge devices, microcontrollers, TinyML).
61
+
62
+ SLATE is an additive model (a Generalized Additive Model) built from simple
63
+ threshold rules of the form *"1 if a feature is at or below a learned cut, else 0"*.
64
+ It is trained by L1-regularized Newton boosting: rules are selected one at a time,
65
+ then all active rules are refit together in a corrective pass, while a hard budget
66
+ keeps the model tiny.
67
+
68
+ - **Tiny.** A hard budget caps the number of rules, so a trained model is on the
69
+ order of a kilobyte and inference is a handful of comparisons and adds. The
70
+ fitted model drops its training-time scaffolding, so a pickled/joblib-saved
71
+ estimator is just ~1-2 KB — the same as its live inference footprint.
72
+ - **Interpretable.** Because the model is additive, you get exact per-feature shape
73
+ functions and exact Shapley attributions in closed form.
74
+ - **Lightweight to install.** Pure NumPy at runtime — no heavy ML stack required.
75
+ - **Familiar API.** `fit` / `predict` / `predict_proba`, drops into scikit-learn
76
+ pipelines and grid search. Binary *and* multiclass are supported.
77
+
78
+ ## Installation
79
+
80
+ ```bash
81
+ pip install slatex
82
+ ```
83
+
84
+ Requires Python 3.9+ and NumPy. To run the test suite you also need
85
+ scikit-learn (`pip install "slatex[test]"`).
86
+
87
+ ## Quickstart
88
+
89
+ ```python
90
+ import numpy as np
91
+ from slatex import SlateClassifier
92
+
93
+ rng = np.random.RandomState(0)
94
+ X = rng.randn(500, 8)
95
+ y = (X[:, 0] + 0.5 * X[:, 3] > 0).astype(int)
96
+
97
+ clf = SlateClassifier(budget=32).fit(X, y)
98
+
99
+ clf.predict(X[:5]) # class labels
100
+ clf.predict_proba(X[:5]) # (n_samples, n_classes) probabilities
101
+ clf.score(X, y) # mean accuracy
102
+
103
+ print(clf.n_atoms_) # number of threshold rules used
104
+ print(clf.memory_bytes_) # approximate packed model size in bytes
105
+ ```
106
+
107
+ Multiclass works the same way (handled internally via one-vs-rest), and labels
108
+ can be any type — integers, strings, etc.:
109
+
110
+ ```python
111
+ y = np.array(["low", "mid", "high"])[rng.randint(0, 3, size=500)]
112
+ clf = SlateClassifier(budget=24).fit(X, y)
113
+ clf.classes_ # array(['high', 'low', 'mid'], dtype='<U4')
114
+ ```
115
+
116
+ ## Interpretability
117
+
118
+ Because SLATE is additive, the contribution of each feature is exact and cheap to
119
+ compute.
120
+
121
+ ```python
122
+ # Shape function: how feature 0 contributes to the score across a range of values
123
+ grid = np.linspace(-3, 3, 50)
124
+ contribution = clf.shape_function(feature=0, grid=grid) # binary
125
+ # contribution = clf.shape_function(0, grid, target="high") # multiclass
126
+
127
+ # Exact Shapley attributions against a background sample
128
+ phi = clf.shapley_values(X[:10], X_background=X) # (10, n_features)
129
+ ```
130
+
131
+ ## Hyperparameters
132
+
133
+ | Parameter | Default | Meaning |
134
+ |---|---|---|
135
+ | `budget` | 64 | Hard cap on the number of threshold rules per binary model |
136
+ | `n_bins` | 32 | Max quantile bins per feature (candidate-cut granularity) |
137
+ | `max_iter` | 400 | Max boosting iterations |
138
+ | `learning_rate` | 0.5 | Shrinkage on each Newton step |
139
+ | `l2` | 2.0 | Newton damping (ridge on the Hessian) |
140
+ | `l1` | 1e-3 | Soft-threshold level for pruning weak rules |
141
+ | `corrective_every` | 5 | Run a fully-corrective refit pass every *k* iterations |
142
+ | `corrective_passes` | 2 | Cyclic passes per corrective phase |
143
+ | `tol` | 1e-7 | Stop when the best Newton gain falls below this |
144
+
145
+ Smaller `budget` → smaller, faster, more interpretable model. Increase `n_bins`
146
+ for finer cuts on continuous features.
147
+
148
+ ## scikit-learn compatibility
149
+
150
+ `SlateClassifier` implements `get_params` / `set_params` and follows the standard
151
+ estimator API, so it composes with scikit-learn tools:
152
+
153
+ ```python
154
+ from sklearn.pipeline import make_pipeline
155
+ from sklearn.preprocessing import StandardScaler
156
+ from sklearn.model_selection import GridSearchCV
157
+
158
+ pipe = make_pipeline(StandardScaler(), SlateClassifier())
159
+ grid = GridSearchCV(pipe, {"slateclassifier__budget": [16, 32, 64]}, cv=3)
160
+ grid.fit(X, y)
161
+ ```
162
+
163
+ (scikit-learn is optional and only needed if you use these helpers.)
164
+
165
+ ## Notes and requirements
166
+
167
+ - Inputs must be **finite numeric arrays** (no NaN/inf). Impute or clean first.
168
+ - Training is **deterministic**; `random_state` is accepted for API compatibility
169
+ but does not change results.
170
+ - This is research-grade software released under the MIT License.
171
+
172
+ ## License
173
+
174
+ MIT © Saikiran Gogineni
@@ -0,0 +1,7 @@
1
+ slatex/__init__.py,sha256=--f_5NDQFC9O2PfAoUwxavN7rPsPYAJIRMAHTbfgnI0,682
2
+ slatex/_core.py,sha256=IzA1Mw49IuTSANeAYARNLTn9jEVwLRcRYCzGQlMiipQ,9583
3
+ slatex/classifier.py,sha256=YFqxgrAmWim7KBJOp_YfSdKiA76-BYNLbrO0cuFFsFM,14438
4
+ slatex-0.1.0.dist-info/METADATA,sha256=0ghTU0HFepQTfataIladqW8-YfYgyu1rN5AaEbegtNY,7311
5
+ slatex-0.1.0.dist-info/WHEEL,sha256=mffPy8wBnZQn2VnJUU5jE99KsxaSfiyMHV9Yt0aLVxs,87
6
+ slatex-0.1.0.dist-info/licenses/LICENSE,sha256=62RL162QqdY58OxqUmFhBwG-drDX48i8ZlrSj3RWnNk,1074
7
+ slatex-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.30.1
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Saikiran Gogineni
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.