scratchkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. mlscratch/__init__.py +56 -0
  2. mlscratch/__main__.py +118 -0
  3. mlscratch/bayesian/__init__.py +53 -0
  4. mlscratch/bayesian/bayesian_linear_regression.py +171 -0
  5. mlscratch/bayesian/bayesian_network.py +248 -0
  6. mlscratch/bayesian/bayesian_nn.py +315 -0
  7. mlscratch/bayesian/gaussian_process.py +207 -0
  8. mlscratch/bayesian/hmm.py +277 -0
  9. mlscratch/bayesian/init.py +52 -0
  10. mlscratch/bayesian/kalman_filter.py +182 -0
  11. mlscratch/bayesian/naive_bayes.py +209 -0
  12. mlscratch/metrics/__init__.py +59 -0
  13. mlscratch/metrics/classification.py +365 -0
  14. mlscratch/metrics/regression.py +79 -0
  15. mlscratch/neural/__init__.py +121 -0
  16. mlscratch/neural/attention.py +420 -0
  17. mlscratch/neural/autoencoder.py +543 -0
  18. mlscratch/neural/boltzmann.py +231 -0
  19. mlscratch/neural/cnn.py +593 -0
  20. mlscratch/neural/cvnn.py +322 -0
  21. mlscratch/neural/gan.py +364 -0
  22. mlscratch/neural/hopfield.py +193 -0
  23. mlscratch/neural/perceptron.py +398 -0
  24. mlscratch/neural/rbf_network.py +230 -0
  25. mlscratch/neural/recurrent.py +569 -0
  26. mlscratch/preprocessing/__init__.py +38 -0
  27. mlscratch/preprocessing/encoders.py +140 -0
  28. mlscratch/preprocessing/model_selection.py +119 -0
  29. mlscratch/preprocessing/polynomial.py +105 -0
  30. mlscratch/preprocessing/scalers.py +220 -0
  31. mlscratch/py.typed +0 -0
  32. mlscratch/reinforcement/__init__.py +59 -0
  33. mlscratch/reinforcement/ddpg.py +363 -0
  34. mlscratch/reinforcement/dqn.py +319 -0
  35. mlscratch/reinforcement/ppo.py +452 -0
  36. mlscratch/reinforcement/q_learning.py +352 -0
  37. mlscratch/reinforcement/sac.py +382 -0
  38. mlscratch/reinforcement/utils.py +594 -0
  39. mlscratch/supervised/__init__.py +76 -0
  40. mlscratch/supervised/_validation.py +50 -0
  41. mlscratch/supervised/adaboost.py +255 -0
  42. mlscratch/supervised/decision_tree.py +495 -0
  43. mlscratch/supervised/gradient_boosting.py +354 -0
  44. mlscratch/supervised/knn.py +234 -0
  45. mlscratch/supervised/lasso_regression.py +125 -0
  46. mlscratch/supervised/linear_models.py +459 -0
  47. mlscratch/supervised/linear_regression.py +197 -0
  48. mlscratch/supervised/logistic_regression.py +119 -0
  49. mlscratch/supervised/naive_bayes.py +113 -0
  50. mlscratch/supervised/random_forest.py +321 -0
  51. mlscratch/supervised/ridge_regression.py +93 -0
  52. mlscratch/supervised/svm.py +356 -0
  53. mlscratch/unsupervised/__init__.py +39 -0
  54. mlscratch/unsupervised/apriori.py +178 -0
  55. mlscratch/unsupervised/dbscan.py +141 -0
  56. mlscratch/unsupervised/gmm.py +204 -0
  57. mlscratch/unsupervised/hierarchical_clustering.py +137 -0
  58. mlscratch/unsupervised/ica.py +167 -0
  59. mlscratch/unsupervised/kmeans.py +135 -0
  60. mlscratch/unsupervised/kmedoids.py +133 -0
  61. mlscratch/unsupervised/pca.py +103 -0
  62. mlscratch/unsupervised/tsne.py +200 -0
  63. scratchkit-0.2.0.dist-info/METADATA +241 -0
  64. scratchkit-0.2.0.dist-info/RECORD +68 -0
  65. scratchkit-0.2.0.dist-info/WHEEL +5 -0
  66. scratchkit-0.2.0.dist-info/entry_points.txt +2 -0
  67. scratchkit-0.2.0.dist-info/licenses/LICENSE +201 -0
  68. scratchkit-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,495 @@
1
+ r"""
2
+ Decision Trees
3
+ ==============
4
+ From-scratch CART (Classification And Regression Trees), pure numpy.
5
+
6
+ DecisionTreeClassifier
7
+ -----------------------
8
+ Binary or multiclass classification. At each node the split that
9
+ minimises the weighted child impurity is chosen, where impurity is
10
+ either Gini:
11
+
12
+ .. math::
13
+ G = 1 - \sum_{k=1}^K p_k^2
14
+
15
+ or Shannon entropy:
16
+
17
+ .. math::
18
+ H = -\sum_{k=1}^K p_k \log_2 p_k
19
+
20
+ DecisionTreeRegressor
21
+ -----------------------
22
+ Minimises weighted variance (mean squared error) of the target within
23
+ each child:
24
+
25
+ .. math::
26
+ \mathrm{MSE} = \frac{1}{W}\sum_i w_i (y_i - \bar y)^2
27
+
28
+ Both trees support per-sample weights (``sample_weight``), which is
29
+ what lets :mod:`mlscratch.supervised.adaboost` and
30
+ :mod:`mlscratch.supervised.random_forest` reuse the exact same split
31
+ logic instead of re-deriving it.
32
+
33
+ Algorithm
34
+ ---------
35
+ For each candidate feature the rows are sorted once (`O(n log n)`),
36
+ then a single left-to-right vectorised sweep evaluates every possible
37
+ split point in `O(n)` using running (weighted) cumulative sums — no
38
+ per-split re-scan of the data. Overall a node with `n` samples and
39
+ `d` features costs `O(n d log n)`.
40
+
41
+ Complexity
42
+ ----------
43
+ - Training : O(n d log n) per node, O(depth) nodes on the root path
44
+ - Inference: O(depth) per sample
45
+ - Space : O(n_nodes)
46
+ """
47
+
48
+ from __future__ import annotations
49
+
50
+ from dataclasses import dataclass
51
+
52
+ import numpy as np
53
+ from numpy.typing import ArrayLike, NDArray
54
+
55
+ from ._validation import validate_sample_weight, validate_x, validate_xy
56
+
57
+ FloatArray = NDArray[np.float64]
58
+ IntArray = NDArray[np.int64]
59
+
60
+ _EPS = 1e-12
61
+
62
+
63
+ # ──────────────────────────────────────────────────────────────────────────
64
+ # Shared node / tree-walking helpers
65
+ # ──────────────────────────────────────────────────────────────────────────
66
+
67
+
68
+ @dataclass
69
+ class _Node:
70
+ """A single node of a binary decision tree.
71
+
72
+ ``value`` holds the leaf prediction: a class-probability vector for
73
+ :class:`DecisionTreeClassifier`, or a scalar mean for
74
+ :class:`DecisionTreeRegressor`. Internal nodes additionally carry
75
+ ``feature_index`` / ``threshold`` and child references.
76
+ """
77
+
78
+ n_samples: int
79
+ weighted_n_samples: float
80
+ impurity: float
81
+ value: object
82
+ feature_index: int | None = None
83
+ threshold: float | None = None
84
+ left: _Node | None = None
85
+ right: _Node | None = None
86
+
87
+ @property
88
+ def is_leaf(self) -> bool:
89
+ return self.feature_index is None
90
+
91
+
92
+ def _apply(root: _Node, X: FloatArray) -> list[_Node]:
93
+ """Route every row of X to its terminal leaf node and return the leaves."""
94
+ leaves: list[_Node] = []
95
+ for row in X:
96
+ node = root
97
+ while not node.is_leaf:
98
+ node = node.left if row[node.feature_index] <= node.threshold else node.right
99
+ leaves.append(node)
100
+ return leaves
101
+
102
+
103
+ def group_by_leaf(leaves: list[_Node]) -> dict[int, tuple[_Node, list[int]]]:
104
+ """Group sample indices by which leaf object they were routed to.
105
+
106
+ Used internally by gradient boosting to re-fit leaf values after the
107
+ tree structure has been chosen by the (cheaper) variance criterion.
108
+ """
109
+ groups: dict[int, tuple[_Node, list[int]]] = {}
110
+ for i, leaf in enumerate(leaves):
111
+ key = id(leaf)
112
+ if key not in groups:
113
+ groups[key] = (leaf, [])
114
+ groups[key][1].append(i)
115
+ return groups
116
+
117
+
118
+ # ──────────────────────────────────────────────────────────────────────────
119
+ # DecisionTreeClassifier
120
+ # ──────────────────────────────────────────────────────────────────────────
121
+
122
+
123
+ class DecisionTreeClassifier:
124
+ """A binary or multiclass CART decision tree classifier.
125
+
126
+ Parameters
127
+ ----------
128
+ max_depth : int | None, default=None
129
+ Maximum tree depth. ``None`` grows nodes until they are pure or
130
+ too small to split.
131
+ min_samples_split : int, default=2
132
+ Minimum number of samples a node must have to be eligible for
133
+ splitting.
134
+ min_samples_leaf : int, default=1
135
+ Minimum number of samples required in each child of a split.
136
+ criterion : str, default='gini'
137
+ Split quality measure: ``'gini'`` or ``'entropy'``.
138
+ random_state : int | None, default=None
139
+ Unused by the splitting rule itself (which is deterministic);
140
+ accepted for API symmetry with ensembles that seed their trees.
141
+
142
+ Attributes
143
+ ----------
144
+ tree_ : the fitted root node
145
+ classes_ : sorted unique labels seen during fit
146
+ n_classes_ : number of classes
147
+ n_features_in_ : number of features seen during fit
148
+ feature_importances_ : impurity-decrease-based importances, sums to 1
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ max_depth: int | None = None,
154
+ min_samples_split: int = 2,
155
+ min_samples_leaf: int = 1,
156
+ criterion: str = "gini",
157
+ random_state: int | None = None,
158
+ ) -> None:
159
+ if criterion not in ("gini", "entropy"):
160
+ raise ValueError("criterion must be 'gini' or 'entropy'.")
161
+ if int(min_samples_split) < 2:
162
+ raise ValueError("min_samples_split must be >= 2.")
163
+ if int(min_samples_leaf) < 1:
164
+ raise ValueError("min_samples_leaf must be >= 1.")
165
+ self.max_depth = max_depth
166
+ self.min_samples_split = int(min_samples_split)
167
+ self.min_samples_leaf = int(min_samples_leaf)
168
+ self.criterion = criterion
169
+ self.random_state = random_state
170
+
171
+ self.tree_: _Node | None = None
172
+ self.classes_: IntArray | None = None
173
+ self.n_classes_: int | None = None
174
+ self.n_features_in_: int | None = None
175
+ self.feature_importances_: FloatArray | None = None
176
+
177
+ # -- public API ---------------------------------------------------------
178
+
179
+ def fit(
180
+ self, X: ArrayLike, y: ArrayLike, sample_weight: ArrayLike | None = None
181
+ ) -> DecisionTreeClassifier:
182
+ """Grow the decision tree from training data."""
183
+ X_arr, y_raw = validate_xy(X, y)
184
+ self.classes_, y_idx = np.unique(y_raw, return_inverse=True)
185
+ y_idx = y_idx.astype(np.int64)
186
+ self.n_classes_ = int(self.classes_.size)
187
+ self.n_features_in_ = X_arr.shape[1]
188
+ w = validate_sample_weight(sample_weight, X_arr.shape[0])
189
+
190
+ importances = np.zeros(self.n_features_in_, dtype=np.float64)
191
+ self.tree_ = self._grow(X_arr, y_idx, w, depth=0, importances=importances)
192
+ total = importances.sum()
193
+ self.feature_importances_ = importances / total if total > _EPS else importances
194
+ return self
195
+
196
+ def predict_proba(self, X: ArrayLike) -> FloatArray:
197
+ """Return class-probability estimates, columns ordered as ``classes_``."""
198
+ if self.tree_ is None:
199
+ raise RuntimeError("Call fit() before predict_proba().")
200
+ X_arr = validate_x(X)
201
+ leaves = _apply(self.tree_, X_arr)
202
+ return np.vstack([leaf.value for leaf in leaves])
203
+
204
+ def predict(self, X: ArrayLike) -> NDArray:
205
+ """Predict the most likely class label for each row of X."""
206
+ proba = self.predict_proba(X)
207
+ return self.classes_[np.argmax(proba, axis=1)]
208
+
209
+ def score(self, X: ArrayLike, y: ArrayLike, sample_weight: ArrayLike | None = None) -> float:
210
+ """Return (optionally weighted) classification accuracy."""
211
+ X_arr, y_arr = validate_xy(X, y)
212
+ w = validate_sample_weight(sample_weight, X_arr.shape[0])
213
+ preds = self.predict(X_arr)
214
+ return float(np.average(preds == y_arr, weights=w))
215
+
216
+ def apply(self, X: ArrayLike) -> list[_Node]:
217
+ """Return the terminal leaf node each row of X is routed to."""
218
+ if self.tree_ is None:
219
+ raise RuntimeError("Call fit() before apply().")
220
+ return _apply(self.tree_, validate_x(X))
221
+
222
+ # -- tree construction ----------------------------------------------------
223
+
224
+ def _impurity(self, weighted_counts: FloatArray, total_w: float) -> float:
225
+ if total_w <= _EPS:
226
+ return 0.0
227
+ p = weighted_counts / total_w
228
+ if self.criterion == "gini":
229
+ return float(1.0 - np.sum(p**2))
230
+ with np.errstate(divide="ignore", invalid="ignore"):
231
+ log_p = np.log2(np.where(p > 0, p, 1.0))
232
+ return float(-np.sum(np.where(p > 0, p * log_p, 0.0)))
233
+
234
+ def _grow(
235
+ self, X: FloatArray, y: IntArray, w: FloatArray, depth: int, importances: FloatArray
236
+ ) -> _Node:
237
+ weighted_counts = np.bincount(y, weights=w, minlength=self.n_classes_)
238
+ total_w = float(weighted_counts.sum())
239
+ if total_w > _EPS:
240
+ proba = weighted_counts / total_w
241
+ else:
242
+ proba = np.full(self.n_classes_, 1.0 / self.n_classes_)
243
+ impurity = self._impurity(weighted_counts, total_w)
244
+
245
+ node = _Node(
246
+ n_samples=X.shape[0], weighted_n_samples=total_w, impurity=impurity, value=proba
247
+ )
248
+
249
+ can_split = (
250
+ X.shape[0] >= self.min_samples_split
251
+ and impurity > _EPS
252
+ and self.n_classes_ > 1
253
+ and (self.max_depth is None or depth < self.max_depth)
254
+ )
255
+ if can_split:
256
+ feat_idx, threshold, gain = self._best_split(X, y, w, total_w, impurity)
257
+ if feat_idx is not None:
258
+ mask = X[:, feat_idx] <= threshold
259
+ importances[feat_idx] += gain * total_w
260
+ node.feature_index = feat_idx
261
+ node.threshold = threshold
262
+ node.left = self._grow(X[mask], y[mask], w[mask], depth + 1, importances)
263
+ node.right = self._grow(X[~mask], y[~mask], w[~mask], depth + 1, importances)
264
+ return node
265
+
266
+ def _best_split(
267
+ self, X: FloatArray, y: IntArray, w: FloatArray, total_w: float, parent_impurity: float
268
+ ) -> tuple[int | None, float | None, float]:
269
+ best_feat, best_thr, best_impurity = None, None, parent_impurity
270
+ for feat in range(self.n_features_in_):
271
+ thr, impurity = self._best_split_feature(X[:, feat], y, w)
272
+ if thr is not None and impurity < best_impurity - _EPS:
273
+ best_impurity, best_feat, best_thr = impurity, feat, thr
274
+ if best_feat is None:
275
+ return None, None, 0.0
276
+ return best_feat, best_thr, parent_impurity - best_impurity
277
+
278
+ def _best_split_feature(
279
+ self, col: FloatArray, y: IntArray, w: FloatArray
280
+ ) -> tuple[float | None, float]:
281
+ n = col.shape[0]
282
+ if n < 2:
283
+ return None, np.inf
284
+
285
+ order = np.argsort(col, kind="mergesort")
286
+ xs, ys, ws = col[order], y[order], w[order]
287
+
288
+ one_hot = np.zeros((n, self.n_classes_), dtype=np.float64)
289
+ one_hot[np.arange(n), ys] = ws
290
+ left_cum = np.cumsum(one_hot, axis=0)
291
+ W_left = np.cumsum(ws)
292
+ total_counts, W_total = left_cum[-1], W_left[-1]
293
+ right_cum = total_counts - left_cum
294
+ W_right = W_total - W_left
295
+
296
+ left_sizes = np.arange(1, n)
297
+ right_sizes = n - left_sizes
298
+ valid = (
299
+ (xs[1:] != xs[:-1])
300
+ & (left_sizes >= self.min_samples_leaf)
301
+ & (right_sizes >= self.min_samples_leaf)
302
+ )
303
+ if not np.any(valid):
304
+ return None, np.inf
305
+
306
+ Wl, Wr = W_left[:-1], W_right[:-1]
307
+ safe_Wl = np.where(Wl > _EPS, Wl, 1.0)
308
+ safe_Wr = np.where(Wr > _EPS, Wr, 1.0)
309
+ pl = left_cum[:-1] / safe_Wl[:, None]
310
+ pr = right_cum[:-1] / safe_Wr[:, None]
311
+
312
+ if self.criterion == "gini":
313
+ imp_l = 1.0 - np.sum(pl**2, axis=1)
314
+ imp_r = 1.0 - np.sum(pr**2, axis=1)
315
+ else:
316
+ with np.errstate(divide="ignore", invalid="ignore"):
317
+ log_pl = np.log2(np.where(pl > 0, pl, 1.0))
318
+ log_pr = np.log2(np.where(pr > 0, pr, 1.0))
319
+ imp_l = -np.sum(np.where(pl > 0, pl * log_pl, 0.0), axis=1)
320
+ imp_r = -np.sum(np.where(pr > 0, pr * log_pr, 0.0), axis=1)
321
+
322
+ weighted_impurity = np.where(valid, (Wl * imp_l + Wr * imp_r) / W_total, np.inf)
323
+ best_i = int(np.argmin(weighted_impurity))
324
+ if not np.isfinite(weighted_impurity[best_i]):
325
+ return None, np.inf
326
+ threshold = float((xs[best_i] + xs[best_i + 1]) / 2.0)
327
+ return threshold, float(weighted_impurity[best_i])
328
+
329
+
330
+ # ──────────────────────────────────────────────────────────────────────────
331
+ # DecisionTreeRegressor
332
+ # ──────────────────────────────────────────────────────────────────────────
333
+
334
+
335
+ class DecisionTreeRegressor:
336
+ """A CART decision tree regressor minimising weighted MSE.
337
+
338
+ Parameters
339
+ ----------
340
+ max_depth : int | None, default=None
341
+ min_samples_split : int, default=2
342
+ min_samples_leaf : int, default=1
343
+ random_state : int | None, default=None
344
+ Unused by the (deterministic) splitting rule; kept for API
345
+ symmetry with ensembles.
346
+
347
+ Attributes
348
+ ----------
349
+ tree_ : the fitted root node
350
+ n_features_in_ : number of features seen during fit
351
+ feature_importances_ : impurity-decrease-based importances, sums to 1
352
+ """
353
+
354
+ def __init__(
355
+ self,
356
+ max_depth: int | None = None,
357
+ min_samples_split: int = 2,
358
+ min_samples_leaf: int = 1,
359
+ random_state: int | None = None,
360
+ ) -> None:
361
+ if int(min_samples_split) < 2:
362
+ raise ValueError("min_samples_split must be >= 2.")
363
+ if int(min_samples_leaf) < 1:
364
+ raise ValueError("min_samples_leaf must be >= 1.")
365
+ self.max_depth = max_depth
366
+ self.min_samples_split = int(min_samples_split)
367
+ self.min_samples_leaf = int(min_samples_leaf)
368
+ self.random_state = random_state
369
+
370
+ self.tree_: _Node | None = None
371
+ self.n_features_in_: int | None = None
372
+ self.feature_importances_: FloatArray | None = None
373
+
374
+ # -- public API ---------------------------------------------------------
375
+
376
+ def fit(
377
+ self, X: ArrayLike, y: ArrayLike, sample_weight: ArrayLike | None = None
378
+ ) -> DecisionTreeRegressor:
379
+ """Grow the regression tree from training data."""
380
+ X_arr, y_arr = validate_xy(X, y)
381
+ y_arr = y_arr.astype(np.float64)
382
+ self.n_features_in_ = X_arr.shape[1]
383
+ w = validate_sample_weight(sample_weight, X_arr.shape[0])
384
+
385
+ importances = np.zeros(self.n_features_in_, dtype=np.float64)
386
+ self.tree_ = self._grow(X_arr, y_arr, w, depth=0, importances=importances)
387
+ total = importances.sum()
388
+ self.feature_importances_ = importances / total if total > _EPS else importances
389
+ return self
390
+
391
+ def predict(self, X: ArrayLike) -> FloatArray:
392
+ """Predict the target value for each row of X."""
393
+ if self.tree_ is None:
394
+ raise RuntimeError("Call fit() before predict().")
395
+ leaves = self.apply(X)
396
+ return np.array([leaf.value for leaf in leaves], dtype=np.float64)
397
+
398
+ def score(self, X: ArrayLike, y: ArrayLike, sample_weight: ArrayLike | None = None) -> float:
399
+ """Return the coefficient of determination R^2 of the prediction."""
400
+ X_arr, y_arr = validate_xy(X, y)
401
+ w = validate_sample_weight(sample_weight, X_arr.shape[0])
402
+ preds = self.predict(X_arr)
403
+ y_mean = float(np.average(y_arr, weights=w))
404
+ ss_res = float(np.sum(w * (y_arr - preds) ** 2))
405
+ ss_tot = float(np.sum(w * (y_arr - y_mean) ** 2))
406
+ return 1.0 - ss_res / ss_tot if ss_tot > _EPS else 0.0
407
+
408
+ def apply(self, X: ArrayLike) -> list[_Node]:
409
+ """Return the terminal leaf node each row of X is routed to."""
410
+ if self.tree_ is None:
411
+ raise RuntimeError("Call fit() before apply().")
412
+ return _apply(self.tree_, validate_x(X))
413
+
414
+ # -- tree construction ----------------------------------------------------
415
+
416
+ def _grow(
417
+ self, X: FloatArray, y: FloatArray, w: FloatArray, depth: int, importances: FloatArray
418
+ ) -> _Node:
419
+ total_w = float(w.sum())
420
+ mean = float(np.average(y, weights=w)) if total_w > _EPS else float(np.mean(y))
421
+ variance = float(np.average((y - mean) ** 2, weights=w)) if total_w > _EPS else 0.0
422
+
423
+ node = _Node(
424
+ n_samples=X.shape[0], weighted_n_samples=total_w, impurity=variance, value=mean
425
+ )
426
+
427
+ can_split = (
428
+ X.shape[0] >= self.min_samples_split
429
+ and variance > _EPS
430
+ and (self.max_depth is None or depth < self.max_depth)
431
+ )
432
+ if can_split:
433
+ feat_idx, threshold, gain = self._best_split(X, y, w, total_w, variance)
434
+ if feat_idx is not None:
435
+ mask = X[:, feat_idx] <= threshold
436
+ importances[feat_idx] += gain * total_w
437
+ node.feature_index = feat_idx
438
+ node.threshold = threshold
439
+ node.left = self._grow(X[mask], y[mask], w[mask], depth + 1, importances)
440
+ node.right = self._grow(X[~mask], y[~mask], w[~mask], depth + 1, importances)
441
+ return node
442
+
443
+ def _best_split(
444
+ self, X: FloatArray, y: FloatArray, w: FloatArray, total_w: float, parent_variance: float
445
+ ) -> tuple[int | None, float | None, float]:
446
+ best_feat, best_thr, best_var = None, None, parent_variance
447
+ for feat in range(self.n_features_in_):
448
+ thr, var = self._best_split_feature(X[:, feat], y, w)
449
+ if thr is not None and var < best_var - _EPS:
450
+ best_var, best_feat, best_thr = var, feat, thr
451
+ if best_feat is None:
452
+ return None, None, 0.0
453
+ return best_feat, best_thr, parent_variance - best_var
454
+
455
+ def _best_split_feature(
456
+ self, col: FloatArray, y: FloatArray, w: FloatArray
457
+ ) -> tuple[float | None, float]:
458
+ n = col.shape[0]
459
+ if n < 2:
460
+ return None, np.inf
461
+
462
+ order = np.argsort(col, kind="mergesort")
463
+ xs, ys, ws = col[order], y[order], w[order]
464
+
465
+ wy, wy2 = ys * ws, (ys**2) * ws
466
+ left_sum, left_sum2 = np.cumsum(wy), np.cumsum(wy2)
467
+ W_left = np.cumsum(ws)
468
+ total_sum, total_sum2, W_total = left_sum[-1], left_sum2[-1], W_left[-1]
469
+ right_sum, right_sum2 = total_sum - left_sum, total_sum2 - left_sum2
470
+ W_right = W_total - W_left
471
+
472
+ left_sizes = np.arange(1, n)
473
+ right_sizes = n - left_sizes
474
+ valid = (
475
+ (xs[1:] != xs[:-1])
476
+ & (left_sizes >= self.min_samples_leaf)
477
+ & (right_sizes >= self.min_samples_leaf)
478
+ )
479
+ if not np.any(valid):
480
+ return None, np.inf
481
+
482
+ Wl, Wr = W_left[:-1], W_right[:-1]
483
+ safe_Wl = np.where(Wl > _EPS, Wl, 1.0)
484
+ safe_Wr = np.where(Wr > _EPS, Wr, 1.0)
485
+ mean_l = left_sum[:-1] / safe_Wl
486
+ mean_r = right_sum[:-1] / safe_Wr
487
+ var_l = np.maximum(left_sum2[:-1] / safe_Wl - mean_l**2, 0.0)
488
+ var_r = np.maximum(right_sum2[:-1] / safe_Wr - mean_r**2, 0.0)
489
+
490
+ weighted_var = np.where(valid, (Wl * var_l + Wr * var_r) / W_total, np.inf)
491
+ best_i = int(np.argmin(weighted_var))
492
+ if not np.isfinite(weighted_var[best_i]):
493
+ return None, np.inf
494
+ threshold = float((xs[best_i] + xs[best_i + 1]) / 2.0)
495
+ return threshold, float(weighted_var[best_i])