scratchkit 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlscratch/__init__.py +56 -0
- mlscratch/__main__.py +118 -0
- mlscratch/bayesian/__init__.py +53 -0
- mlscratch/bayesian/bayesian_linear_regression.py +171 -0
- mlscratch/bayesian/bayesian_network.py +248 -0
- mlscratch/bayesian/bayesian_nn.py +315 -0
- mlscratch/bayesian/gaussian_process.py +207 -0
- mlscratch/bayesian/hmm.py +277 -0
- mlscratch/bayesian/init.py +52 -0
- mlscratch/bayesian/kalman_filter.py +182 -0
- mlscratch/bayesian/naive_bayes.py +209 -0
- mlscratch/metrics/__init__.py +59 -0
- mlscratch/metrics/classification.py +365 -0
- mlscratch/metrics/regression.py +79 -0
- mlscratch/neural/__init__.py +121 -0
- mlscratch/neural/attention.py +420 -0
- mlscratch/neural/autoencoder.py +543 -0
- mlscratch/neural/boltzmann.py +231 -0
- mlscratch/neural/cnn.py +593 -0
- mlscratch/neural/cvnn.py +322 -0
- mlscratch/neural/gan.py +364 -0
- mlscratch/neural/hopfield.py +193 -0
- mlscratch/neural/perceptron.py +398 -0
- mlscratch/neural/rbf_network.py +230 -0
- mlscratch/neural/recurrent.py +569 -0
- mlscratch/preprocessing/__init__.py +38 -0
- mlscratch/preprocessing/encoders.py +140 -0
- mlscratch/preprocessing/model_selection.py +119 -0
- mlscratch/preprocessing/polynomial.py +105 -0
- mlscratch/preprocessing/scalers.py +220 -0
- mlscratch/py.typed +0 -0
- mlscratch/reinforcement/__init__.py +59 -0
- mlscratch/reinforcement/ddpg.py +363 -0
- mlscratch/reinforcement/dqn.py +319 -0
- mlscratch/reinforcement/ppo.py +452 -0
- mlscratch/reinforcement/q_learning.py +352 -0
- mlscratch/reinforcement/sac.py +382 -0
- mlscratch/reinforcement/utils.py +594 -0
- mlscratch/supervised/__init__.py +76 -0
- mlscratch/supervised/_validation.py +50 -0
- mlscratch/supervised/adaboost.py +255 -0
- mlscratch/supervised/decision_tree.py +495 -0
- mlscratch/supervised/gradient_boosting.py +354 -0
- mlscratch/supervised/knn.py +234 -0
- mlscratch/supervised/lasso_regression.py +125 -0
- mlscratch/supervised/linear_models.py +459 -0
- mlscratch/supervised/linear_regression.py +197 -0
- mlscratch/supervised/logistic_regression.py +119 -0
- mlscratch/supervised/naive_bayes.py +113 -0
- mlscratch/supervised/random_forest.py +321 -0
- mlscratch/supervised/ridge_regression.py +93 -0
- mlscratch/supervised/svm.py +356 -0
- mlscratch/unsupervised/__init__.py +39 -0
- mlscratch/unsupervised/apriori.py +178 -0
- mlscratch/unsupervised/dbscan.py +141 -0
- mlscratch/unsupervised/gmm.py +204 -0
- mlscratch/unsupervised/hierarchical_clustering.py +137 -0
- mlscratch/unsupervised/ica.py +167 -0
- mlscratch/unsupervised/kmeans.py +135 -0
- mlscratch/unsupervised/kmedoids.py +133 -0
- mlscratch/unsupervised/pca.py +103 -0
- mlscratch/unsupervised/tsne.py +200 -0
- scratchkit-0.2.0.dist-info/METADATA +241 -0
- scratchkit-0.2.0.dist-info/RECORD +68 -0
- scratchkit-0.2.0.dist-info/WHEEL +5 -0
- scratchkit-0.2.0.dist-info/entry_points.txt +2 -0
- scratchkit-0.2.0.dist-info/licenses/LICENSE +201 -0
- scratchkit-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
r"""
|
|
2
|
+
Decision Trees
|
|
3
|
+
==============
|
|
4
|
+
From-scratch CART (Classification And Regression Trees), pure numpy.
|
|
5
|
+
|
|
6
|
+
DecisionTreeClassifier
|
|
7
|
+
-----------------------
|
|
8
|
+
Binary or multiclass classification. At each node the split that
|
|
9
|
+
minimises the weighted child impurity is chosen, where impurity is
|
|
10
|
+
either Gini:
|
|
11
|
+
|
|
12
|
+
.. math::
|
|
13
|
+
G = 1 - \sum_{k=1}^K p_k^2
|
|
14
|
+
|
|
15
|
+
or Shannon entropy:
|
|
16
|
+
|
|
17
|
+
.. math::
|
|
18
|
+
H = -\sum_{k=1}^K p_k \log_2 p_k
|
|
19
|
+
|
|
20
|
+
DecisionTreeRegressor
|
|
21
|
+
-----------------------
|
|
22
|
+
Minimises weighted variance (mean squared error) of the target within
|
|
23
|
+
each child:
|
|
24
|
+
|
|
25
|
+
.. math::
|
|
26
|
+
\mathrm{MSE} = \frac{1}{W}\sum_i w_i (y_i - \bar y)^2
|
|
27
|
+
|
|
28
|
+
Both trees support per-sample weights (``sample_weight``), which is
|
|
29
|
+
what lets :mod:`mlscratch.supervised.adaboost` and
|
|
30
|
+
:mod:`mlscratch.supervised.random_forest` reuse the exact same split
|
|
31
|
+
logic instead of re-deriving it.
|
|
32
|
+
|
|
33
|
+
Algorithm
|
|
34
|
+
---------
|
|
35
|
+
For each candidate feature the rows are sorted once (`O(n log n)`),
|
|
36
|
+
then a single left-to-right vectorised sweep evaluates every possible
|
|
37
|
+
split point in `O(n)` using running (weighted) cumulative sums — no
|
|
38
|
+
per-split re-scan of the data. Overall a node with `n` samples and
|
|
39
|
+
`d` features costs `O(n d log n)`.
|
|
40
|
+
|
|
41
|
+
Complexity
|
|
42
|
+
----------
|
|
43
|
+
- Training : O(n d log n) per node, O(depth) nodes on the root path
|
|
44
|
+
- Inference: O(depth) per sample
|
|
45
|
+
- Space : O(n_nodes)
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
from __future__ import annotations
|
|
49
|
+
|
|
50
|
+
from dataclasses import dataclass
|
|
51
|
+
|
|
52
|
+
import numpy as np
|
|
53
|
+
from numpy.typing import ArrayLike, NDArray
|
|
54
|
+
|
|
55
|
+
from ._validation import validate_sample_weight, validate_x, validate_xy
|
|
56
|
+
|
|
57
|
+
FloatArray = NDArray[np.float64]
|
|
58
|
+
IntArray = NDArray[np.int64]
|
|
59
|
+
|
|
60
|
+
_EPS = 1e-12
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# ──────────────────────────────────────────────────────────────────────────
|
|
64
|
+
# Shared node / tree-walking helpers
|
|
65
|
+
# ──────────────────────────────────────────────────────────────────────────
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class _Node:
|
|
70
|
+
"""A single node of a binary decision tree.
|
|
71
|
+
|
|
72
|
+
``value`` holds the leaf prediction: a class-probability vector for
|
|
73
|
+
:class:`DecisionTreeClassifier`, or a scalar mean for
|
|
74
|
+
:class:`DecisionTreeRegressor`. Internal nodes additionally carry
|
|
75
|
+
``feature_index`` / ``threshold`` and child references.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
n_samples: int
|
|
79
|
+
weighted_n_samples: float
|
|
80
|
+
impurity: float
|
|
81
|
+
value: object
|
|
82
|
+
feature_index: int | None = None
|
|
83
|
+
threshold: float | None = None
|
|
84
|
+
left: _Node | None = None
|
|
85
|
+
right: _Node | None = None
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def is_leaf(self) -> bool:
|
|
89
|
+
return self.feature_index is None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _apply(root: _Node, X: FloatArray) -> list[_Node]:
|
|
93
|
+
"""Route every row of X to its terminal leaf node and return the leaves."""
|
|
94
|
+
leaves: list[_Node] = []
|
|
95
|
+
for row in X:
|
|
96
|
+
node = root
|
|
97
|
+
while not node.is_leaf:
|
|
98
|
+
node = node.left if row[node.feature_index] <= node.threshold else node.right
|
|
99
|
+
leaves.append(node)
|
|
100
|
+
return leaves
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def group_by_leaf(leaves: list[_Node]) -> dict[int, tuple[_Node, list[int]]]:
|
|
104
|
+
"""Group sample indices by which leaf object they were routed to.
|
|
105
|
+
|
|
106
|
+
Used internally by gradient boosting to re-fit leaf values after the
|
|
107
|
+
tree structure has been chosen by the (cheaper) variance criterion.
|
|
108
|
+
"""
|
|
109
|
+
groups: dict[int, tuple[_Node, list[int]]] = {}
|
|
110
|
+
for i, leaf in enumerate(leaves):
|
|
111
|
+
key = id(leaf)
|
|
112
|
+
if key not in groups:
|
|
113
|
+
groups[key] = (leaf, [])
|
|
114
|
+
groups[key][1].append(i)
|
|
115
|
+
return groups
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# ──────────────────────────────────────────────────────────────────────────
|
|
119
|
+
# DecisionTreeClassifier
|
|
120
|
+
# ──────────────────────────────────────────────────────────────────────────
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class DecisionTreeClassifier:
|
|
124
|
+
"""A binary or multiclass CART decision tree classifier.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
max_depth : int | None, default=None
|
|
129
|
+
Maximum tree depth. ``None`` grows nodes until they are pure or
|
|
130
|
+
too small to split.
|
|
131
|
+
min_samples_split : int, default=2
|
|
132
|
+
Minimum number of samples a node must have to be eligible for
|
|
133
|
+
splitting.
|
|
134
|
+
min_samples_leaf : int, default=1
|
|
135
|
+
Minimum number of samples required in each child of a split.
|
|
136
|
+
criterion : str, default='gini'
|
|
137
|
+
Split quality measure: ``'gini'`` or ``'entropy'``.
|
|
138
|
+
random_state : int | None, default=None
|
|
139
|
+
Unused by the splitting rule itself (which is deterministic);
|
|
140
|
+
accepted for API symmetry with ensembles that seed their trees.
|
|
141
|
+
|
|
142
|
+
Attributes
|
|
143
|
+
----------
|
|
144
|
+
tree_ : the fitted root node
|
|
145
|
+
classes_ : sorted unique labels seen during fit
|
|
146
|
+
n_classes_ : number of classes
|
|
147
|
+
n_features_in_ : number of features seen during fit
|
|
148
|
+
feature_importances_ : impurity-decrease-based importances, sums to 1
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
def __init__(
|
|
152
|
+
self,
|
|
153
|
+
max_depth: int | None = None,
|
|
154
|
+
min_samples_split: int = 2,
|
|
155
|
+
min_samples_leaf: int = 1,
|
|
156
|
+
criterion: str = "gini",
|
|
157
|
+
random_state: int | None = None,
|
|
158
|
+
) -> None:
|
|
159
|
+
if criterion not in ("gini", "entropy"):
|
|
160
|
+
raise ValueError("criterion must be 'gini' or 'entropy'.")
|
|
161
|
+
if int(min_samples_split) < 2:
|
|
162
|
+
raise ValueError("min_samples_split must be >= 2.")
|
|
163
|
+
if int(min_samples_leaf) < 1:
|
|
164
|
+
raise ValueError("min_samples_leaf must be >= 1.")
|
|
165
|
+
self.max_depth = max_depth
|
|
166
|
+
self.min_samples_split = int(min_samples_split)
|
|
167
|
+
self.min_samples_leaf = int(min_samples_leaf)
|
|
168
|
+
self.criterion = criterion
|
|
169
|
+
self.random_state = random_state
|
|
170
|
+
|
|
171
|
+
self.tree_: _Node | None = None
|
|
172
|
+
self.classes_: IntArray | None = None
|
|
173
|
+
self.n_classes_: int | None = None
|
|
174
|
+
self.n_features_in_: int | None = None
|
|
175
|
+
self.feature_importances_: FloatArray | None = None
|
|
176
|
+
|
|
177
|
+
# -- public API ---------------------------------------------------------
|
|
178
|
+
|
|
179
|
+
def fit(
|
|
180
|
+
self, X: ArrayLike, y: ArrayLike, sample_weight: ArrayLike | None = None
|
|
181
|
+
) -> DecisionTreeClassifier:
|
|
182
|
+
"""Grow the decision tree from training data."""
|
|
183
|
+
X_arr, y_raw = validate_xy(X, y)
|
|
184
|
+
self.classes_, y_idx = np.unique(y_raw, return_inverse=True)
|
|
185
|
+
y_idx = y_idx.astype(np.int64)
|
|
186
|
+
self.n_classes_ = int(self.classes_.size)
|
|
187
|
+
self.n_features_in_ = X_arr.shape[1]
|
|
188
|
+
w = validate_sample_weight(sample_weight, X_arr.shape[0])
|
|
189
|
+
|
|
190
|
+
importances = np.zeros(self.n_features_in_, dtype=np.float64)
|
|
191
|
+
self.tree_ = self._grow(X_arr, y_idx, w, depth=0, importances=importances)
|
|
192
|
+
total = importances.sum()
|
|
193
|
+
self.feature_importances_ = importances / total if total > _EPS else importances
|
|
194
|
+
return self
|
|
195
|
+
|
|
196
|
+
def predict_proba(self, X: ArrayLike) -> FloatArray:
|
|
197
|
+
"""Return class-probability estimates, columns ordered as ``classes_``."""
|
|
198
|
+
if self.tree_ is None:
|
|
199
|
+
raise RuntimeError("Call fit() before predict_proba().")
|
|
200
|
+
X_arr = validate_x(X)
|
|
201
|
+
leaves = _apply(self.tree_, X_arr)
|
|
202
|
+
return np.vstack([leaf.value for leaf in leaves])
|
|
203
|
+
|
|
204
|
+
def predict(self, X: ArrayLike) -> NDArray:
|
|
205
|
+
"""Predict the most likely class label for each row of X."""
|
|
206
|
+
proba = self.predict_proba(X)
|
|
207
|
+
return self.classes_[np.argmax(proba, axis=1)]
|
|
208
|
+
|
|
209
|
+
def score(self, X: ArrayLike, y: ArrayLike, sample_weight: ArrayLike | None = None) -> float:
|
|
210
|
+
"""Return (optionally weighted) classification accuracy."""
|
|
211
|
+
X_arr, y_arr = validate_xy(X, y)
|
|
212
|
+
w = validate_sample_weight(sample_weight, X_arr.shape[0])
|
|
213
|
+
preds = self.predict(X_arr)
|
|
214
|
+
return float(np.average(preds == y_arr, weights=w))
|
|
215
|
+
|
|
216
|
+
def apply(self, X: ArrayLike) -> list[_Node]:
|
|
217
|
+
"""Return the terminal leaf node each row of X is routed to."""
|
|
218
|
+
if self.tree_ is None:
|
|
219
|
+
raise RuntimeError("Call fit() before apply().")
|
|
220
|
+
return _apply(self.tree_, validate_x(X))
|
|
221
|
+
|
|
222
|
+
# -- tree construction ----------------------------------------------------
|
|
223
|
+
|
|
224
|
+
def _impurity(self, weighted_counts: FloatArray, total_w: float) -> float:
|
|
225
|
+
if total_w <= _EPS:
|
|
226
|
+
return 0.0
|
|
227
|
+
p = weighted_counts / total_w
|
|
228
|
+
if self.criterion == "gini":
|
|
229
|
+
return float(1.0 - np.sum(p**2))
|
|
230
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
231
|
+
log_p = np.log2(np.where(p > 0, p, 1.0))
|
|
232
|
+
return float(-np.sum(np.where(p > 0, p * log_p, 0.0)))
|
|
233
|
+
|
|
234
|
+
def _grow(
|
|
235
|
+
self, X: FloatArray, y: IntArray, w: FloatArray, depth: int, importances: FloatArray
|
|
236
|
+
) -> _Node:
|
|
237
|
+
weighted_counts = np.bincount(y, weights=w, minlength=self.n_classes_)
|
|
238
|
+
total_w = float(weighted_counts.sum())
|
|
239
|
+
if total_w > _EPS:
|
|
240
|
+
proba = weighted_counts / total_w
|
|
241
|
+
else:
|
|
242
|
+
proba = np.full(self.n_classes_, 1.0 / self.n_classes_)
|
|
243
|
+
impurity = self._impurity(weighted_counts, total_w)
|
|
244
|
+
|
|
245
|
+
node = _Node(
|
|
246
|
+
n_samples=X.shape[0], weighted_n_samples=total_w, impurity=impurity, value=proba
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
can_split = (
|
|
250
|
+
X.shape[0] >= self.min_samples_split
|
|
251
|
+
and impurity > _EPS
|
|
252
|
+
and self.n_classes_ > 1
|
|
253
|
+
and (self.max_depth is None or depth < self.max_depth)
|
|
254
|
+
)
|
|
255
|
+
if can_split:
|
|
256
|
+
feat_idx, threshold, gain = self._best_split(X, y, w, total_w, impurity)
|
|
257
|
+
if feat_idx is not None:
|
|
258
|
+
mask = X[:, feat_idx] <= threshold
|
|
259
|
+
importances[feat_idx] += gain * total_w
|
|
260
|
+
node.feature_index = feat_idx
|
|
261
|
+
node.threshold = threshold
|
|
262
|
+
node.left = self._grow(X[mask], y[mask], w[mask], depth + 1, importances)
|
|
263
|
+
node.right = self._grow(X[~mask], y[~mask], w[~mask], depth + 1, importances)
|
|
264
|
+
return node
|
|
265
|
+
|
|
266
|
+
def _best_split(
|
|
267
|
+
self, X: FloatArray, y: IntArray, w: FloatArray, total_w: float, parent_impurity: float
|
|
268
|
+
) -> tuple[int | None, float | None, float]:
|
|
269
|
+
best_feat, best_thr, best_impurity = None, None, parent_impurity
|
|
270
|
+
for feat in range(self.n_features_in_):
|
|
271
|
+
thr, impurity = self._best_split_feature(X[:, feat], y, w)
|
|
272
|
+
if thr is not None and impurity < best_impurity - _EPS:
|
|
273
|
+
best_impurity, best_feat, best_thr = impurity, feat, thr
|
|
274
|
+
if best_feat is None:
|
|
275
|
+
return None, None, 0.0
|
|
276
|
+
return best_feat, best_thr, parent_impurity - best_impurity
|
|
277
|
+
|
|
278
|
+
def _best_split_feature(
|
|
279
|
+
self, col: FloatArray, y: IntArray, w: FloatArray
|
|
280
|
+
) -> tuple[float | None, float]:
|
|
281
|
+
n = col.shape[0]
|
|
282
|
+
if n < 2:
|
|
283
|
+
return None, np.inf
|
|
284
|
+
|
|
285
|
+
order = np.argsort(col, kind="mergesort")
|
|
286
|
+
xs, ys, ws = col[order], y[order], w[order]
|
|
287
|
+
|
|
288
|
+
one_hot = np.zeros((n, self.n_classes_), dtype=np.float64)
|
|
289
|
+
one_hot[np.arange(n), ys] = ws
|
|
290
|
+
left_cum = np.cumsum(one_hot, axis=0)
|
|
291
|
+
W_left = np.cumsum(ws)
|
|
292
|
+
total_counts, W_total = left_cum[-1], W_left[-1]
|
|
293
|
+
right_cum = total_counts - left_cum
|
|
294
|
+
W_right = W_total - W_left
|
|
295
|
+
|
|
296
|
+
left_sizes = np.arange(1, n)
|
|
297
|
+
right_sizes = n - left_sizes
|
|
298
|
+
valid = (
|
|
299
|
+
(xs[1:] != xs[:-1])
|
|
300
|
+
& (left_sizes >= self.min_samples_leaf)
|
|
301
|
+
& (right_sizes >= self.min_samples_leaf)
|
|
302
|
+
)
|
|
303
|
+
if not np.any(valid):
|
|
304
|
+
return None, np.inf
|
|
305
|
+
|
|
306
|
+
Wl, Wr = W_left[:-1], W_right[:-1]
|
|
307
|
+
safe_Wl = np.where(Wl > _EPS, Wl, 1.0)
|
|
308
|
+
safe_Wr = np.where(Wr > _EPS, Wr, 1.0)
|
|
309
|
+
pl = left_cum[:-1] / safe_Wl[:, None]
|
|
310
|
+
pr = right_cum[:-1] / safe_Wr[:, None]
|
|
311
|
+
|
|
312
|
+
if self.criterion == "gini":
|
|
313
|
+
imp_l = 1.0 - np.sum(pl**2, axis=1)
|
|
314
|
+
imp_r = 1.0 - np.sum(pr**2, axis=1)
|
|
315
|
+
else:
|
|
316
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
317
|
+
log_pl = np.log2(np.where(pl > 0, pl, 1.0))
|
|
318
|
+
log_pr = np.log2(np.where(pr > 0, pr, 1.0))
|
|
319
|
+
imp_l = -np.sum(np.where(pl > 0, pl * log_pl, 0.0), axis=1)
|
|
320
|
+
imp_r = -np.sum(np.where(pr > 0, pr * log_pr, 0.0), axis=1)
|
|
321
|
+
|
|
322
|
+
weighted_impurity = np.where(valid, (Wl * imp_l + Wr * imp_r) / W_total, np.inf)
|
|
323
|
+
best_i = int(np.argmin(weighted_impurity))
|
|
324
|
+
if not np.isfinite(weighted_impurity[best_i]):
|
|
325
|
+
return None, np.inf
|
|
326
|
+
threshold = float((xs[best_i] + xs[best_i + 1]) / 2.0)
|
|
327
|
+
return threshold, float(weighted_impurity[best_i])
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
# ──────────────────────────────────────────────────────────────────────────
|
|
331
|
+
# DecisionTreeRegressor
|
|
332
|
+
# ──────────────────────────────────────────────────────────────────────────
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class DecisionTreeRegressor:
|
|
336
|
+
"""A CART decision tree regressor minimising weighted MSE.
|
|
337
|
+
|
|
338
|
+
Parameters
|
|
339
|
+
----------
|
|
340
|
+
max_depth : int | None, default=None
|
|
341
|
+
min_samples_split : int, default=2
|
|
342
|
+
min_samples_leaf : int, default=1
|
|
343
|
+
random_state : int | None, default=None
|
|
344
|
+
Unused by the (deterministic) splitting rule; kept for API
|
|
345
|
+
symmetry with ensembles.
|
|
346
|
+
|
|
347
|
+
Attributes
|
|
348
|
+
----------
|
|
349
|
+
tree_ : the fitted root node
|
|
350
|
+
n_features_in_ : number of features seen during fit
|
|
351
|
+
feature_importances_ : impurity-decrease-based importances, sums to 1
|
|
352
|
+
"""
|
|
353
|
+
|
|
354
|
+
def __init__(
|
|
355
|
+
self,
|
|
356
|
+
max_depth: int | None = None,
|
|
357
|
+
min_samples_split: int = 2,
|
|
358
|
+
min_samples_leaf: int = 1,
|
|
359
|
+
random_state: int | None = None,
|
|
360
|
+
) -> None:
|
|
361
|
+
if int(min_samples_split) < 2:
|
|
362
|
+
raise ValueError("min_samples_split must be >= 2.")
|
|
363
|
+
if int(min_samples_leaf) < 1:
|
|
364
|
+
raise ValueError("min_samples_leaf must be >= 1.")
|
|
365
|
+
self.max_depth = max_depth
|
|
366
|
+
self.min_samples_split = int(min_samples_split)
|
|
367
|
+
self.min_samples_leaf = int(min_samples_leaf)
|
|
368
|
+
self.random_state = random_state
|
|
369
|
+
|
|
370
|
+
self.tree_: _Node | None = None
|
|
371
|
+
self.n_features_in_: int | None = None
|
|
372
|
+
self.feature_importances_: FloatArray | None = None
|
|
373
|
+
|
|
374
|
+
# -- public API ---------------------------------------------------------
|
|
375
|
+
|
|
376
|
+
def fit(
|
|
377
|
+
self, X: ArrayLike, y: ArrayLike, sample_weight: ArrayLike | None = None
|
|
378
|
+
) -> DecisionTreeRegressor:
|
|
379
|
+
"""Grow the regression tree from training data."""
|
|
380
|
+
X_arr, y_arr = validate_xy(X, y)
|
|
381
|
+
y_arr = y_arr.astype(np.float64)
|
|
382
|
+
self.n_features_in_ = X_arr.shape[1]
|
|
383
|
+
w = validate_sample_weight(sample_weight, X_arr.shape[0])
|
|
384
|
+
|
|
385
|
+
importances = np.zeros(self.n_features_in_, dtype=np.float64)
|
|
386
|
+
self.tree_ = self._grow(X_arr, y_arr, w, depth=0, importances=importances)
|
|
387
|
+
total = importances.sum()
|
|
388
|
+
self.feature_importances_ = importances / total if total > _EPS else importances
|
|
389
|
+
return self
|
|
390
|
+
|
|
391
|
+
def predict(self, X: ArrayLike) -> FloatArray:
|
|
392
|
+
"""Predict the target value for each row of X."""
|
|
393
|
+
if self.tree_ is None:
|
|
394
|
+
raise RuntimeError("Call fit() before predict().")
|
|
395
|
+
leaves = self.apply(X)
|
|
396
|
+
return np.array([leaf.value for leaf in leaves], dtype=np.float64)
|
|
397
|
+
|
|
398
|
+
def score(self, X: ArrayLike, y: ArrayLike, sample_weight: ArrayLike | None = None) -> float:
|
|
399
|
+
"""Return the coefficient of determination R^2 of the prediction."""
|
|
400
|
+
X_arr, y_arr = validate_xy(X, y)
|
|
401
|
+
w = validate_sample_weight(sample_weight, X_arr.shape[0])
|
|
402
|
+
preds = self.predict(X_arr)
|
|
403
|
+
y_mean = float(np.average(y_arr, weights=w))
|
|
404
|
+
ss_res = float(np.sum(w * (y_arr - preds) ** 2))
|
|
405
|
+
ss_tot = float(np.sum(w * (y_arr - y_mean) ** 2))
|
|
406
|
+
return 1.0 - ss_res / ss_tot if ss_tot > _EPS else 0.0
|
|
407
|
+
|
|
408
|
+
def apply(self, X: ArrayLike) -> list[_Node]:
|
|
409
|
+
"""Return the terminal leaf node each row of X is routed to."""
|
|
410
|
+
if self.tree_ is None:
|
|
411
|
+
raise RuntimeError("Call fit() before apply().")
|
|
412
|
+
return _apply(self.tree_, validate_x(X))
|
|
413
|
+
|
|
414
|
+
# -- tree construction ----------------------------------------------------
|
|
415
|
+
|
|
416
|
+
def _grow(
|
|
417
|
+
self, X: FloatArray, y: FloatArray, w: FloatArray, depth: int, importances: FloatArray
|
|
418
|
+
) -> _Node:
|
|
419
|
+
total_w = float(w.sum())
|
|
420
|
+
mean = float(np.average(y, weights=w)) if total_w > _EPS else float(np.mean(y))
|
|
421
|
+
variance = float(np.average((y - mean) ** 2, weights=w)) if total_w > _EPS else 0.0
|
|
422
|
+
|
|
423
|
+
node = _Node(
|
|
424
|
+
n_samples=X.shape[0], weighted_n_samples=total_w, impurity=variance, value=mean
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
can_split = (
|
|
428
|
+
X.shape[0] >= self.min_samples_split
|
|
429
|
+
and variance > _EPS
|
|
430
|
+
and (self.max_depth is None or depth < self.max_depth)
|
|
431
|
+
)
|
|
432
|
+
if can_split:
|
|
433
|
+
feat_idx, threshold, gain = self._best_split(X, y, w, total_w, variance)
|
|
434
|
+
if feat_idx is not None:
|
|
435
|
+
mask = X[:, feat_idx] <= threshold
|
|
436
|
+
importances[feat_idx] += gain * total_w
|
|
437
|
+
node.feature_index = feat_idx
|
|
438
|
+
node.threshold = threshold
|
|
439
|
+
node.left = self._grow(X[mask], y[mask], w[mask], depth + 1, importances)
|
|
440
|
+
node.right = self._grow(X[~mask], y[~mask], w[~mask], depth + 1, importances)
|
|
441
|
+
return node
|
|
442
|
+
|
|
443
|
+
def _best_split(
|
|
444
|
+
self, X: FloatArray, y: FloatArray, w: FloatArray, total_w: float, parent_variance: float
|
|
445
|
+
) -> tuple[int | None, float | None, float]:
|
|
446
|
+
best_feat, best_thr, best_var = None, None, parent_variance
|
|
447
|
+
for feat in range(self.n_features_in_):
|
|
448
|
+
thr, var = self._best_split_feature(X[:, feat], y, w)
|
|
449
|
+
if thr is not None and var < best_var - _EPS:
|
|
450
|
+
best_var, best_feat, best_thr = var, feat, thr
|
|
451
|
+
if best_feat is None:
|
|
452
|
+
return None, None, 0.0
|
|
453
|
+
return best_feat, best_thr, parent_variance - best_var
|
|
454
|
+
|
|
455
|
+
def _best_split_feature(
|
|
456
|
+
self, col: FloatArray, y: FloatArray, w: FloatArray
|
|
457
|
+
) -> tuple[float | None, float]:
|
|
458
|
+
n = col.shape[0]
|
|
459
|
+
if n < 2:
|
|
460
|
+
return None, np.inf
|
|
461
|
+
|
|
462
|
+
order = np.argsort(col, kind="mergesort")
|
|
463
|
+
xs, ys, ws = col[order], y[order], w[order]
|
|
464
|
+
|
|
465
|
+
wy, wy2 = ys * ws, (ys**2) * ws
|
|
466
|
+
left_sum, left_sum2 = np.cumsum(wy), np.cumsum(wy2)
|
|
467
|
+
W_left = np.cumsum(ws)
|
|
468
|
+
total_sum, total_sum2, W_total = left_sum[-1], left_sum2[-1], W_left[-1]
|
|
469
|
+
right_sum, right_sum2 = total_sum - left_sum, total_sum2 - left_sum2
|
|
470
|
+
W_right = W_total - W_left
|
|
471
|
+
|
|
472
|
+
left_sizes = np.arange(1, n)
|
|
473
|
+
right_sizes = n - left_sizes
|
|
474
|
+
valid = (
|
|
475
|
+
(xs[1:] != xs[:-1])
|
|
476
|
+
& (left_sizes >= self.min_samples_leaf)
|
|
477
|
+
& (right_sizes >= self.min_samples_leaf)
|
|
478
|
+
)
|
|
479
|
+
if not np.any(valid):
|
|
480
|
+
return None, np.inf
|
|
481
|
+
|
|
482
|
+
Wl, Wr = W_left[:-1], W_right[:-1]
|
|
483
|
+
safe_Wl = np.where(Wl > _EPS, Wl, 1.0)
|
|
484
|
+
safe_Wr = np.where(Wr > _EPS, Wr, 1.0)
|
|
485
|
+
mean_l = left_sum[:-1] / safe_Wl
|
|
486
|
+
mean_r = right_sum[:-1] / safe_Wr
|
|
487
|
+
var_l = np.maximum(left_sum2[:-1] / safe_Wl - mean_l**2, 0.0)
|
|
488
|
+
var_r = np.maximum(right_sum2[:-1] / safe_Wr - mean_r**2, 0.0)
|
|
489
|
+
|
|
490
|
+
weighted_var = np.where(valid, (Wl * var_l + Wr * var_r) / W_total, np.inf)
|
|
491
|
+
best_i = int(np.argmin(weighted_var))
|
|
492
|
+
if not np.isfinite(weighted_var[best_i]):
|
|
493
|
+
return None, np.inf
|
|
494
|
+
threshold = float((xs[best_i] + xs[best_i + 1]) / 2.0)
|
|
495
|
+
return threshold, float(weighted_var[best_i])
|