warpgbm 0.1.19__tar.gz → 0.1.21__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {warpgbm-0.1.19/warpgbm.egg-info → warpgbm-0.1.21}/PKG-INFO +7 -1
- {warpgbm-0.1.19 → warpgbm-0.1.21}/README.md +6 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21}/pyproject.toml +1 -1
- {warpgbm-0.1.19 → warpgbm-0.1.21}/setup.py +1 -0
- warpgbm-0.1.21/version.txt +1 -0
- warpgbm-0.1.21/warpgbm/core.py +341 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm/cuda/node_kernel.cpp +8 -0
- warpgbm-0.1.21/warpgbm/cuda/predict.cu +77 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21/warpgbm.egg-info}/PKG-INFO +7 -1
- {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm.egg-info/SOURCES.txt +2 -1
- warpgbm-0.1.19/version.txt +0 -1
- warpgbm-0.1.19/warpgbm/core.py +0 -552
- {warpgbm-0.1.19 → warpgbm-0.1.21}/LICENSE +0 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21}/MANIFEST.in +0 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21}/setup.cfg +0 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21}/tests/__init__.py +0 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21}/tests/test_fit_predict_corr.py +0 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm/__init__.py +0 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm/cuda/__init__.py +0 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm/cuda/best_split_kernel.cu +0 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm/cuda/binner.cu +0 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm/cuda/histogram_kernel.cu +0 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm.egg-info/dependency_links.txt +0 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm.egg-info/requires.txt +0 -0
- {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: warpgbm
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.21
|
4
4
|
Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
|
5
5
|
License: GNU GENERAL PUBLIC LICENSE
|
6
6
|
Version 3, 29 June 2007
|
@@ -897,3 +897,9 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
|
|
897
897
|
|
898
898
|
---
|
899
899
|
|
900
|
+
## Version Notes
|
901
|
+
|
902
|
+
### v0.1.21
|
903
|
+
|
904
|
+
- Vectorized predict function replaced with CUDA kernel (`warpgbm/cuda/predict.cu`), parallelizing per sample, per tree.
|
905
|
+
|
@@ -209,3 +209,9 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
|
|
209
209
|
|
210
210
|
---
|
211
211
|
|
212
|
+
## Version Notes
|
213
|
+
|
214
|
+
### v0.1.21
|
215
|
+
|
216
|
+
- Vectorized predict function replaced with CUDA kernel (`warpgbm/cuda/predict.cu`), parallelizing per sample, per tree.
|
217
|
+
|
@@ -0,0 +1 @@
|
|
1
|
+
0.1.21
|
@@ -0,0 +1,341 @@
|
|
1
|
+
import torch
|
2
|
+
import numpy as np
|
3
|
+
from sklearn.base import BaseEstimator, RegressorMixin
|
4
|
+
from warpgbm.cuda import node_kernel
|
5
|
+
from tqdm import tqdm
|
6
|
+
from typing import Tuple
|
7
|
+
from torch import Tensor
|
8
|
+
|
9
|
+
histogram_kernels = {
|
10
|
+
'hist1': node_kernel.compute_histogram,
|
11
|
+
'hist2': node_kernel.compute_histogram2,
|
12
|
+
'hist3': node_kernel.compute_histogram3
|
13
|
+
}
|
14
|
+
|
15
|
+
class WarpGBM(BaseEstimator, RegressorMixin):
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
num_bins=10,
|
19
|
+
max_depth=3,
|
20
|
+
learning_rate=0.1,
|
21
|
+
n_estimators=100,
|
22
|
+
min_child_weight=20,
|
23
|
+
min_split_gain=0.0,
|
24
|
+
verbosity=True,
|
25
|
+
histogram_computer='hist3',
|
26
|
+
threads_per_block=64,
|
27
|
+
rows_per_thread=4,
|
28
|
+
L2_reg=1e-6,
|
29
|
+
L1_reg=0.0,
|
30
|
+
device='cuda'
|
31
|
+
):
|
32
|
+
# Validate arguments
|
33
|
+
self._validate_hyperparams(
|
34
|
+
num_bins=num_bins,
|
35
|
+
max_depth=max_depth,
|
36
|
+
learning_rate=learning_rate,
|
37
|
+
n_estimators=n_estimators,
|
38
|
+
min_child_weight=min_child_weight,
|
39
|
+
min_split_gain=min_split_gain,
|
40
|
+
histogram_computer=histogram_computer,
|
41
|
+
threads_per_block=threads_per_block,
|
42
|
+
rows_per_thread=rows_per_thread,
|
43
|
+
L2_reg=L2_reg,
|
44
|
+
L1_reg=L1_reg
|
45
|
+
)
|
46
|
+
|
47
|
+
self.num_bins = num_bins
|
48
|
+
self.max_depth = max_depth
|
49
|
+
self.learning_rate = learning_rate
|
50
|
+
self.n_estimators = n_estimators
|
51
|
+
self.forest = None
|
52
|
+
self.bin_edges = None
|
53
|
+
self.base_prediction = None
|
54
|
+
self.unique_eras = None
|
55
|
+
self.device = device
|
56
|
+
self.root_gradient_histogram = None
|
57
|
+
self.root_hessian_histogram = None
|
58
|
+
self.gradients = None
|
59
|
+
self.root_node_indices = None
|
60
|
+
self.bin_indices = None
|
61
|
+
self.Y_gpu = None
|
62
|
+
self.num_features = None
|
63
|
+
self.num_samples = None
|
64
|
+
self.min_child_weight = min_child_weight
|
65
|
+
self.min_split_gain = min_split_gain
|
66
|
+
self.best_bin = torch.tensor([-1], dtype=torch.int32, device=self.device)
|
67
|
+
self.compute_histogram = histogram_kernels[histogram_computer]
|
68
|
+
self.threads_per_block = threads_per_block
|
69
|
+
self.rows_per_thread = rows_per_thread
|
70
|
+
self.L2_reg = L2_reg
|
71
|
+
self.L1_reg = L1_reg
|
72
|
+
|
73
|
+
def _validate_hyperparams(self, **kwargs):
|
74
|
+
# Type checks
|
75
|
+
int_params = [
|
76
|
+
"num_bins", "max_depth", "n_estimators", "min_child_weight",
|
77
|
+
"threads_per_block", "rows_per_thread"
|
78
|
+
]
|
79
|
+
float_params = [
|
80
|
+
"learning_rate", "min_split_gain", "L2_reg", "L1_reg"
|
81
|
+
]
|
82
|
+
|
83
|
+
for param in int_params:
|
84
|
+
if not isinstance(kwargs[param], int):
|
85
|
+
raise TypeError(f"{param} must be an integer, got {type(kwargs[param])}.")
|
86
|
+
|
87
|
+
for param in float_params:
|
88
|
+
if not isinstance(kwargs[param], (float, int)): # Accept ints as valid floats
|
89
|
+
raise TypeError(f"{param} must be a float, got {type(kwargs[param])}.")
|
90
|
+
|
91
|
+
if not ( 2 <= kwargs["num_bins"] <= 127 ):
|
92
|
+
raise ValueError("num_bins must be between 2 and 127 inclusive.")
|
93
|
+
if kwargs["max_depth"] < 1:
|
94
|
+
raise ValueError("max_depth must be at least 1.")
|
95
|
+
if not (0.0 < kwargs["learning_rate"] <= 1.0):
|
96
|
+
raise ValueError("learning_rate must be in (0.0, 1.0].")
|
97
|
+
if kwargs["n_estimators"] <= 0:
|
98
|
+
raise ValueError("n_estimators must be positive.")
|
99
|
+
if kwargs["min_child_weight"] < 1:
|
100
|
+
raise ValueError("min_child_weight must be a positive integer.")
|
101
|
+
if kwargs["min_split_gain"] < 0:
|
102
|
+
raise ValueError("min_split_gain must be non-negative.")
|
103
|
+
if kwargs["threads_per_block"] <= 0 or kwargs["threads_per_block"] % 32 != 0:
|
104
|
+
raise ValueError("threads_per_block should be a positive multiple of 32 (warp size).")
|
105
|
+
if not ( 1 <= kwargs["rows_per_thread"] <= 16 ):
|
106
|
+
raise ValueError("rows_per_thread must be positive between 1 and 16 inclusive.")
|
107
|
+
if kwargs["L2_reg"] < 0 or kwargs["L1_reg"] < 0:
|
108
|
+
raise ValueError("L2_reg and L1_reg must be non-negative.")
|
109
|
+
if kwargs["histogram_computer"] not in histogram_kernels:
|
110
|
+
raise ValueError(f"Invalid histogram_computer: {kwargs['histogram_computer']}. Choose from {list(histogram_kernels.keys())}.")
|
111
|
+
|
112
|
+
def fit(self, X, y, era_id=None):
|
113
|
+
if era_id is None:
|
114
|
+
era_id = np.ones(X.shape[0], dtype='int32')
|
115
|
+
self.bin_indices, era_indices, self.bin_edges, self.unique_eras, self.Y_gpu = self.preprocess_gpu_data(X, y, era_id)
|
116
|
+
self.num_samples, self.num_features = X.shape
|
117
|
+
self.gradients = torch.zeros_like(self.Y_gpu)
|
118
|
+
self.root_node_indices = torch.arange(self.num_samples, device=self.device)
|
119
|
+
self.base_prediction = self.Y_gpu.mean().item()
|
120
|
+
self.gradients += self.base_prediction
|
121
|
+
self.best_gains = torch.zeros(self.num_features, device=self.device)
|
122
|
+
self.best_bins = torch.zeros(self.num_features, device=self.device, dtype=torch.int32)
|
123
|
+
with torch.no_grad():
|
124
|
+
self.forest = self.grow_forest()
|
125
|
+
return self
|
126
|
+
|
127
|
+
def preprocess_gpu_data(self, X_np, Y_np, era_id_np):
|
128
|
+
with torch.no_grad():
|
129
|
+
self.num_samples, self.num_features = X_np.shape
|
130
|
+
Y_gpu = torch.from_numpy(Y_np).type(torch.float32).to(self.device)
|
131
|
+
era_id_gpu = torch.from_numpy(era_id_np).type(torch.int32).to(self.device)
|
132
|
+
is_integer_type = np.issubdtype(X_np.dtype, np.integer)
|
133
|
+
if is_integer_type:
|
134
|
+
max_vals = X_np.max(axis=0)
|
135
|
+
if np.all(max_vals < self.num_bins):
|
136
|
+
print("Detected pre-binned integer input — skipping quantile binning.")
|
137
|
+
bin_indices = torch.from_numpy(X_np).to(self.device).contiguous().to(torch.int8)
|
138
|
+
|
139
|
+
# We'll store None or an empty tensor in self.bin_edges
|
140
|
+
# to indicate that we skip binning at predict-time
|
141
|
+
bin_edges = torch.arange(1, self.num_bins, dtype=torch.float32).repeat(self.num_features, 1)
|
142
|
+
bin_edges = bin_edges.to(self.device)
|
143
|
+
unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
|
144
|
+
return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
|
145
|
+
else:
|
146
|
+
print("Integer input detected, but values exceed num_bins — falling back to quantile binning.")
|
147
|
+
|
148
|
+
bin_indices = torch.empty((self.num_samples, self.num_features), dtype=torch.int8, device='cuda')
|
149
|
+
bin_edges = torch.empty((self.num_features, self.num_bins - 1), dtype=torch.float32, device='cuda')
|
150
|
+
|
151
|
+
X_np = torch.from_numpy(X_np).to(torch.float32).pin_memory()
|
152
|
+
|
153
|
+
for f in range(self.num_features):
|
154
|
+
X_f = X_np[:, f].to('cuda', non_blocking=True)
|
155
|
+
quantiles = torch.linspace(0, 1, self.num_bins + 1, device='cuda', dtype=X_f.dtype)[1:-1]
|
156
|
+
bin_edges_f = torch.quantile(X_f, quantiles, dim=0).contiguous() # shape: [B-1] for 1D input
|
157
|
+
bin_indices_f = bin_indices[:, f].contiguous() # view into output
|
158
|
+
node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
|
159
|
+
bin_indices[:,f] = bin_indices_f
|
160
|
+
bin_edges[f,:] = bin_edges_f
|
161
|
+
|
162
|
+
unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
|
163
|
+
return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
|
164
|
+
|
165
|
+
def compute_histograms(self, bin_indices_sub, gradients):
|
166
|
+
grad_hist = torch.zeros((self.num_features, self.num_bins), device=self.device, dtype=torch.float32)
|
167
|
+
hess_hist = torch.zeros((self.num_features, self.num_bins), device=self.device, dtype=torch.float32)
|
168
|
+
|
169
|
+
self.compute_histogram(
|
170
|
+
bin_indices_sub,
|
171
|
+
gradients,
|
172
|
+
grad_hist,
|
173
|
+
hess_hist,
|
174
|
+
self.num_bins,
|
175
|
+
self.threads_per_block,
|
176
|
+
self.rows_per_thread
|
177
|
+
)
|
178
|
+
return grad_hist, hess_hist
|
179
|
+
|
180
|
+
def find_best_split(self, gradient_histogram, hessian_histogram):
|
181
|
+
node_kernel.compute_split(
|
182
|
+
gradient_histogram,
|
183
|
+
hessian_histogram,
|
184
|
+
self.min_split_gain,
|
185
|
+
self.min_child_weight,
|
186
|
+
self.L2_reg,
|
187
|
+
self.best_gains,
|
188
|
+
self.best_bins,
|
189
|
+
self.threads_per_block
|
190
|
+
)
|
191
|
+
|
192
|
+
if torch.all(self.best_bins == -1):
|
193
|
+
return -1, -1 # No valid split found
|
194
|
+
|
195
|
+
f = torch.argmax(self.best_gains).item()
|
196
|
+
b = self.best_bins[f].item()
|
197
|
+
|
198
|
+
return f, b
|
199
|
+
|
200
|
+
def grow_tree(self, gradient_histogram, hessian_histogram, node_indices, depth):
|
201
|
+
if depth == self.max_depth:
|
202
|
+
leaf_value = self.residual[node_indices].mean()
|
203
|
+
self.gradients[node_indices] += self.learning_rate * leaf_value
|
204
|
+
return {"leaf_value": leaf_value.item(), "samples": node_indices.numel()}
|
205
|
+
|
206
|
+
parent_size = node_indices.numel()
|
207
|
+
best_feature, best_bin = self.find_best_split(gradient_histogram, hessian_histogram)
|
208
|
+
|
209
|
+
if best_feature == -1:
|
210
|
+
leaf_value = self.residual[node_indices].mean()
|
211
|
+
self.gradients[node_indices] += self.learning_rate * leaf_value
|
212
|
+
return {"leaf_value": leaf_value.item(), "samples": parent_size}
|
213
|
+
|
214
|
+
split_mask = (self.bin_indices[node_indices, best_feature] <= best_bin)
|
215
|
+
left_indices = node_indices[split_mask]
|
216
|
+
right_indices = node_indices[~split_mask]
|
217
|
+
|
218
|
+
left_size = left_indices.numel()
|
219
|
+
right_size = right_indices.numel()
|
220
|
+
|
221
|
+
|
222
|
+
if left_size <= right_size:
|
223
|
+
grad_hist_left, hess_hist_left = self.compute_histograms( self.bin_indices[left_indices], self.residual[left_indices] )
|
224
|
+
grad_hist_right = gradient_histogram - grad_hist_left
|
225
|
+
hess_hist_right = hessian_histogram - hess_hist_left
|
226
|
+
else:
|
227
|
+
grad_hist_right, hess_hist_right = self.compute_histograms( self.bin_indices[right_indices], self.residual[right_indices] )
|
228
|
+
grad_hist_left = gradient_histogram - grad_hist_right
|
229
|
+
hess_hist_left = hessian_histogram - hess_hist_right
|
230
|
+
|
231
|
+
new_depth = depth + 1
|
232
|
+
left_child = self.grow_tree(grad_hist_left, hess_hist_left, left_indices, new_depth)
|
233
|
+
right_child = self.grow_tree(grad_hist_right, hess_hist_right, right_indices, new_depth)
|
234
|
+
|
235
|
+
return { "feature": best_feature, "bin": best_bin, "left": left_child, "right": right_child }
|
236
|
+
|
237
|
+
def grow_forest(self):
|
238
|
+
forest = [{} for _ in range(self.n_estimators)]
|
239
|
+
self.training_loss = []
|
240
|
+
|
241
|
+
for i in tqdm( range(self.n_estimators) ):
|
242
|
+
self.residual = self.Y_gpu - self.gradients
|
243
|
+
|
244
|
+
self.root_gradient_histogram, self.root_hessian_histogram = \
|
245
|
+
self.compute_histograms(self.bin_indices, self.residual)
|
246
|
+
|
247
|
+
tree = self.grow_tree(
|
248
|
+
self.root_gradient_histogram,
|
249
|
+
self.root_hessian_histogram,
|
250
|
+
self.root_node_indices,
|
251
|
+
depth=0
|
252
|
+
)
|
253
|
+
forest[i] = tree
|
254
|
+
# loss = ((self.Y_gpu - self.gradients) ** 2).mean().item()
|
255
|
+
# self.training_loss.append(loss)
|
256
|
+
# print(f"🌲 Tree {i+1}/{self.n_estimators} - MSE: {loss:.6f}")
|
257
|
+
|
258
|
+
print("Finished training forest.")
|
259
|
+
return forest
|
260
|
+
|
261
|
+
def predict(self, X_np):
|
262
|
+
X_tensor = torch.from_numpy(X_np).to(torch.float32).pin_memory()
|
263
|
+
num_samples = X_tensor.size(0)
|
264
|
+
bin_indices = torch.zeros((num_samples, self.num_features), dtype=torch.int8, device=self.device)
|
265
|
+
|
266
|
+
with torch.no_grad():
|
267
|
+
for f in range(self.num_features):
|
268
|
+
X_f = X_tensor[:, f].to(self.device, non_blocking=True)
|
269
|
+
bin_edges_f = self.bin_edges[f]
|
270
|
+
bin_indices_f = bin_indices[:, f].contiguous()
|
271
|
+
node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
|
272
|
+
bin_indices[:, f] = bin_indices_f
|
273
|
+
|
274
|
+
tree_tensor = torch.stack([
|
275
|
+
self.flatten_tree(tree, max_nodes=2**(self.max_depth + 1))
|
276
|
+
for tree in self.forest
|
277
|
+
]).to(self.device)
|
278
|
+
|
279
|
+
out = torch.zeros(num_samples, device=self.device)
|
280
|
+
node_kernel.predict_forest(
|
281
|
+
bin_indices.contiguous(),
|
282
|
+
tree_tensor.contiguous(),
|
283
|
+
self.learning_rate,
|
284
|
+
out
|
285
|
+
)
|
286
|
+
|
287
|
+
return out.cpu().numpy()
|
288
|
+
|
289
|
+
def flatten_tree(self, tree, max_nodes):
|
290
|
+
"""
|
291
|
+
Convert a recursive tree structure into a flat matrix format.
|
292
|
+
|
293
|
+
Each row in the output represents a node:
|
294
|
+
- Columns: [feature, bin, left_id, right_id, is_leaf, value]
|
295
|
+
- Internal nodes fill columns 0–3 and set is_leaf = 0
|
296
|
+
- Leaf nodes fill only value and set is_leaf = 1
|
297
|
+
|
298
|
+
Args:
|
299
|
+
tree (list): A list containing a single root node (recursive dict form).
|
300
|
+
max_nodes (int): Max number of nodes to allocate in the flat matrix.
|
301
|
+
|
302
|
+
Returns:
|
303
|
+
torch.Tensor: [max_nodes x 6] matrix representing the flattened tree.
|
304
|
+
"""
|
305
|
+
flat = torch.full((max_nodes, 6), float('nan'), dtype=torch.float32)
|
306
|
+
node_counter = [0]
|
307
|
+
node_list = []
|
308
|
+
|
309
|
+
def walk(node):
|
310
|
+
curr_id = node_counter[0]
|
311
|
+
node_counter[0] += 1
|
312
|
+
|
313
|
+
new_node = {'node_id': curr_id}
|
314
|
+
if 'leaf_value' in node:
|
315
|
+
new_node['leaf_value'] = float(node['leaf_value'])
|
316
|
+
else:
|
317
|
+
new_node['best_feature'] = float(node['feature'])
|
318
|
+
new_node['split_bin'] = float(node['bin'])
|
319
|
+
new_node['left_id'] = node_counter[0]
|
320
|
+
walk(node['left'])
|
321
|
+
new_node['right_id'] = node_counter[0]
|
322
|
+
walk(node['right'])
|
323
|
+
|
324
|
+
node_list.append(new_node)
|
325
|
+
return new_node
|
326
|
+
|
327
|
+
walk(tree)
|
328
|
+
|
329
|
+
for node in node_list:
|
330
|
+
i = node['node_id']
|
331
|
+
if 'leaf_value' in node:
|
332
|
+
flat[i, 4] = 1.0
|
333
|
+
flat[i, 5] = node['leaf_value']
|
334
|
+
else:
|
335
|
+
flat[i, 0] = node['best_feature']
|
336
|
+
flat[i, 1] = node['split_bin']
|
337
|
+
flat[i, 2] = node['left_id']
|
338
|
+
flat[i, 3] = node['right_id']
|
339
|
+
flat[i, 4] = 0.0
|
340
|
+
|
341
|
+
return flat
|
@@ -44,6 +44,13 @@ void launch_bin_column_kernel(
|
|
44
44
|
at::Tensor bin_edges,
|
45
45
|
at::Tensor bin_indices);
|
46
46
|
|
47
|
+
void predict_with_forest(
|
48
|
+
const at::Tensor &bin_indices, // [N x F], int8
|
49
|
+
const at::Tensor &tree_tensor, // [T x max_nodes x 6], float32
|
50
|
+
float learning_rate,
|
51
|
+
at::Tensor &out // [N], float32
|
52
|
+
);
|
53
|
+
|
47
54
|
// Bindings
|
48
55
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
49
56
|
{
|
@@ -52,4 +59,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|
52
59
|
m.def("compute_histogram3", &launch_histogram_kernel_cuda_configurable, "Histogram Feature Shared Mem");
|
53
60
|
m.def("compute_split", &launch_best_split_kernel_cuda, "Best Split (CUDA)");
|
54
61
|
m.def("custom_cuda_binner", &launch_bin_column_kernel, "Custom CUDA binning kernel");
|
62
|
+
m.def("predict_forest", &predict_with_forest, "CUDA Predictions");
|
55
63
|
}
|
@@ -0,0 +1,77 @@
|
|
1
|
+
#include <cuda.h>
|
2
|
+
#include <cuda_runtime.h>
|
3
|
+
#include <torch/extension.h>
|
4
|
+
|
5
|
+
__global__ void predict_forest_kernel(
|
6
|
+
const int8_t *__restrict__ bin_indices, // [N x F]
|
7
|
+
const float *__restrict__ tree_tensor, // [T x max_nodes x 6]
|
8
|
+
int N, int F, int T, int max_nodes,
|
9
|
+
float learning_rate,
|
10
|
+
float *__restrict__ out // [N]
|
11
|
+
)
|
12
|
+
{
|
13
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
14
|
+
int total_jobs = N * T;
|
15
|
+
if (idx >= total_jobs)
|
16
|
+
return;
|
17
|
+
|
18
|
+
int i = idx % N; // sample index
|
19
|
+
int t = idx / N; // tree index
|
20
|
+
|
21
|
+
// if (i == 0 && t == 0)
|
22
|
+
// {
|
23
|
+
// printf("[DEBUG] Thread (i=%d, t=%d): starting prediction\n", i, t);
|
24
|
+
// }
|
25
|
+
|
26
|
+
const float *tree = tree_tensor + t * max_nodes * 6;
|
27
|
+
|
28
|
+
int node_id = 0;
|
29
|
+
while (true)
|
30
|
+
{
|
31
|
+
float is_leaf = tree[node_id * 6 + 4];
|
32
|
+
if (is_leaf > 0.5f)
|
33
|
+
{
|
34
|
+
float val = tree[node_id * 6 + 5];
|
35
|
+
atomicAdd(&out[i], learning_rate * val);
|
36
|
+
return;
|
37
|
+
}
|
38
|
+
int feat = static_cast<int>(tree[node_id * 6 + 0]);
|
39
|
+
int split_bin = static_cast<int>(tree[node_id * 6 + 1]);
|
40
|
+
int left_id = static_cast<int>(tree[node_id * 6 + 2]);
|
41
|
+
int right_id = static_cast<int>(tree[node_id * 6 + 3]);
|
42
|
+
|
43
|
+
int8_t bin = bin_indices[i * F + feat];
|
44
|
+
node_id = (bin <= split_bin) ? left_id : right_id;
|
45
|
+
// printf("sample %d, tree %d, feat %d, bin %d, split %d → %s\n", i, t, feat, bin, split_bin, (bin <= split_bin ? "L" : "R"));
|
46
|
+
}
|
47
|
+
}
|
48
|
+
|
49
|
+
void predict_with_forest(
|
50
|
+
const at::Tensor &bin_indices, // [N x F], int8
|
51
|
+
const at::Tensor &tree_tensor, // [T x max_nodes x 6], float32
|
52
|
+
float learning_rate,
|
53
|
+
at::Tensor &out // [N], float32
|
54
|
+
)
|
55
|
+
{
|
56
|
+
int N = bin_indices.size(0);
|
57
|
+
int F = bin_indices.size(1);
|
58
|
+
int T = tree_tensor.size(0);
|
59
|
+
int max_nodes = tree_tensor.size(1);
|
60
|
+
|
61
|
+
int total_jobs = N * T;
|
62
|
+
int threads_per_block = 256;
|
63
|
+
int blocks = (total_jobs + threads_per_block - 1) / threads_per_block;
|
64
|
+
|
65
|
+
predict_forest_kernel<<<blocks, threads_per_block>>>(
|
66
|
+
bin_indices.data_ptr<int8_t>(),
|
67
|
+
tree_tensor.data_ptr<float>(),
|
68
|
+
N, F, T, max_nodes,
|
69
|
+
learning_rate,
|
70
|
+
out.data_ptr<float>());
|
71
|
+
|
72
|
+
cudaError_t err = cudaGetLastError();
|
73
|
+
if (err != cudaSuccess)
|
74
|
+
{
|
75
|
+
printf("CUDA predict kernel failed: %s\n", cudaGetErrorString(err));
|
76
|
+
}
|
77
|
+
}
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: warpgbm
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.21
|
4
4
|
Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
|
5
5
|
License: GNU GENERAL PUBLIC LICENSE
|
6
6
|
Version 3, 29 June 2007
|
@@ -897,3 +897,9 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
|
|
897
897
|
|
898
898
|
---
|
899
899
|
|
900
|
+
## Version Notes
|
901
|
+
|
902
|
+
### v0.1.21
|
903
|
+
|
904
|
+
- Vectorized predict function replaced with CUDA kernel (`warpgbm/cuda/predict.cu`), parallelizing per sample, per tree.
|
905
|
+
|
warpgbm-0.1.19/version.txt
DELETED
@@ -1 +0,0 @@
|
|
1
|
-
0.1.19
|
warpgbm-0.1.19/warpgbm/core.py
DELETED
@@ -1,552 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import numpy as np
|
3
|
-
from sklearn.base import BaseEstimator, RegressorMixin
|
4
|
-
from warpgbm.cuda import node_kernel
|
5
|
-
from tqdm import tqdm
|
6
|
-
from typing import Tuple
|
7
|
-
from torch import Tensor
|
8
|
-
|
9
|
-
histogram_kernels = {
|
10
|
-
'hist1': node_kernel.compute_histogram,
|
11
|
-
'hist2': node_kernel.compute_histogram2,
|
12
|
-
'hist3': node_kernel.compute_histogram3
|
13
|
-
}
|
14
|
-
|
15
|
-
class WarpGBM(BaseEstimator, RegressorMixin):
|
16
|
-
def __init__(
|
17
|
-
self,
|
18
|
-
num_bins=10,
|
19
|
-
max_depth=3,
|
20
|
-
learning_rate=0.1,
|
21
|
-
n_estimators=100,
|
22
|
-
min_child_weight=20,
|
23
|
-
min_split_gain=0.0,
|
24
|
-
verbosity=True,
|
25
|
-
histogram_computer='hist3',
|
26
|
-
threads_per_block=64,
|
27
|
-
rows_per_thread=4,
|
28
|
-
L2_reg=1e-6,
|
29
|
-
L1_reg=0.0,
|
30
|
-
device='cuda'
|
31
|
-
):
|
32
|
-
# Validate arguments
|
33
|
-
self._validate_hyperparams(
|
34
|
-
num_bins=num_bins,
|
35
|
-
max_depth=max_depth,
|
36
|
-
learning_rate=learning_rate,
|
37
|
-
n_estimators=n_estimators,
|
38
|
-
min_child_weight=min_child_weight,
|
39
|
-
min_split_gain=min_split_gain,
|
40
|
-
histogram_computer=histogram_computer,
|
41
|
-
threads_per_block=threads_per_block,
|
42
|
-
rows_per_thread=rows_per_thread,
|
43
|
-
L2_reg=L2_reg,
|
44
|
-
L1_reg=L1_reg
|
45
|
-
)
|
46
|
-
|
47
|
-
self.num_bins = num_bins
|
48
|
-
self.max_depth = max_depth
|
49
|
-
self.learning_rate = learning_rate
|
50
|
-
self.n_estimators = n_estimators
|
51
|
-
self.forest = None
|
52
|
-
self.bin_edges = None
|
53
|
-
self.base_prediction = None
|
54
|
-
self.unique_eras = None
|
55
|
-
self.device = device
|
56
|
-
self.root_gradient_histogram = None
|
57
|
-
self.root_hessian_histogram = None
|
58
|
-
self.gradients = None
|
59
|
-
self.root_node_indices = None
|
60
|
-
self.bin_indices = None
|
61
|
-
self.Y_gpu = None
|
62
|
-
self.num_features = None
|
63
|
-
self.num_samples = None
|
64
|
-
self.min_child_weight = min_child_weight
|
65
|
-
self.min_split_gain = min_split_gain
|
66
|
-
self.best_bin = torch.tensor([-1], dtype=torch.int32, device=self.device)
|
67
|
-
self.compute_histogram = histogram_kernels[histogram_computer]
|
68
|
-
self.threads_per_block = threads_per_block
|
69
|
-
self.rows_per_thread = rows_per_thread
|
70
|
-
self.L2_reg = L2_reg
|
71
|
-
self.L1_reg = L1_reg
|
72
|
-
|
73
|
-
def _validate_hyperparams(self, **kwargs):
|
74
|
-
# Type checks
|
75
|
-
int_params = [
|
76
|
-
"num_bins", "max_depth", "n_estimators", "min_child_weight",
|
77
|
-
"threads_per_block", "rows_per_thread"
|
78
|
-
]
|
79
|
-
float_params = [
|
80
|
-
"learning_rate", "min_split_gain", "L2_reg", "L1_reg"
|
81
|
-
]
|
82
|
-
|
83
|
-
for param in int_params:
|
84
|
-
if not isinstance(kwargs[param], int):
|
85
|
-
raise TypeError(f"{param} must be an integer, got {type(kwargs[param])}.")
|
86
|
-
|
87
|
-
for param in float_params:
|
88
|
-
if not isinstance(kwargs[param], (float, int)): # Accept ints as valid floats
|
89
|
-
raise TypeError(f"{param} must be a float, got {type(kwargs[param])}.")
|
90
|
-
|
91
|
-
if not ( 2 <= kwargs["num_bins"] <= 127 ):
|
92
|
-
raise ValueError("num_bins must be between 2 and 127 inclusive.")
|
93
|
-
if kwargs["max_depth"] < 1:
|
94
|
-
raise ValueError("max_depth must be at least 1.")
|
95
|
-
if not (0.0 < kwargs["learning_rate"] <= 1.0):
|
96
|
-
raise ValueError("learning_rate must be in (0.0, 1.0].")
|
97
|
-
if kwargs["n_estimators"] <= 0:
|
98
|
-
raise ValueError("n_estimators must be positive.")
|
99
|
-
if kwargs["min_child_weight"] < 1:
|
100
|
-
raise ValueError("min_child_weight must be a positive integer.")
|
101
|
-
if kwargs["min_split_gain"] < 0:
|
102
|
-
raise ValueError("min_split_gain must be non-negative.")
|
103
|
-
if kwargs["threads_per_block"] <= 0 or kwargs["threads_per_block"] % 32 != 0:
|
104
|
-
raise ValueError("threads_per_block should be a positive multiple of 32 (warp size).")
|
105
|
-
if not ( 1 <= kwargs["rows_per_thread"] <= 16 ):
|
106
|
-
raise ValueError("rows_per_thread must be positive between 1 and 16 inclusive.")
|
107
|
-
if kwargs["L2_reg"] < 0 or kwargs["L1_reg"] < 0:
|
108
|
-
raise ValueError("L2_reg and L1_reg must be non-negative.")
|
109
|
-
if kwargs["histogram_computer"] not in histogram_kernels:
|
110
|
-
raise ValueError(f"Invalid histogram_computer: {kwargs['histogram_computer']}. Choose from {list(histogram_kernels.keys())}.")
|
111
|
-
|
112
|
-
def fit(self, X, y, era_id=None):
|
113
|
-
if era_id is None:
|
114
|
-
era_id = np.ones(X.shape[0], dtype='int32')
|
115
|
-
self.bin_indices, era_indices, self.bin_edges, self.unique_eras, self.Y_gpu = self.preprocess_gpu_data(X, y, era_id)
|
116
|
-
self.num_samples, self.num_features = X.shape
|
117
|
-
self.gradients = torch.zeros_like(self.Y_gpu)
|
118
|
-
self.root_node_indices = torch.arange(self.num_samples, device=self.device)
|
119
|
-
self.base_prediction = self.Y_gpu.mean().item()
|
120
|
-
self.gradients += self.base_prediction
|
121
|
-
self.best_gains = torch.zeros(self.num_features, device=self.device)
|
122
|
-
self.best_bins = torch.zeros(self.num_features, device=self.device, dtype=torch.int32)
|
123
|
-
with torch.no_grad():
|
124
|
-
self.forest = self.grow_forest()
|
125
|
-
return self
|
126
|
-
|
127
|
-
def preprocess_gpu_data(self, X_np, Y_np, era_id_np):
|
128
|
-
with torch.no_grad():
|
129
|
-
self.num_samples, self.num_features = X_np.shape
|
130
|
-
Y_gpu = torch.from_numpy(Y_np).type(torch.float32).to(self.device)
|
131
|
-
era_id_gpu = torch.from_numpy(era_id_np).type(torch.int32).to(self.device)
|
132
|
-
is_integer_type = np.issubdtype(X_np.dtype, np.integer)
|
133
|
-
if is_integer_type:
|
134
|
-
max_vals = X_np.max(axis=0)
|
135
|
-
if np.all(max_vals < self.num_bins):
|
136
|
-
print("Detected pre-binned integer input — skipping quantile binning.")
|
137
|
-
bin_indices = torch.from_numpy(X_np).to(self.device).contiguous().to(torch.int8)
|
138
|
-
|
139
|
-
# We'll store None or an empty tensor in self.bin_edges
|
140
|
-
# to indicate that we skip binning at predict-time
|
141
|
-
bin_edges = torch.arange(1, self.num_bins, dtype=torch.float32).repeat(self.num_features, 1)
|
142
|
-
bin_edges = bin_edges.to(self.device)
|
143
|
-
unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
|
144
|
-
return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
|
145
|
-
else:
|
146
|
-
print("Integer input detected, but values exceed num_bins — falling back to quantile binning.")
|
147
|
-
|
148
|
-
bin_indices = torch.empty((self.num_samples, self.num_features), dtype=torch.int8, device='cuda')
|
149
|
-
bin_edges = torch.empty((self.num_features, self.num_bins - 1), dtype=torch.float32, device='cuda')
|
150
|
-
|
151
|
-
X_np = torch.from_numpy(X_np).to(torch.float32).pin_memory()
|
152
|
-
|
153
|
-
for f in range(self.num_features):
|
154
|
-
X_f = X_np[:, f].to('cuda', non_blocking=True)
|
155
|
-
quantiles = torch.linspace(0, 1, self.num_bins + 1, device='cuda', dtype=X_f.dtype)[1:-1]
|
156
|
-
bin_edges_f = torch.quantile(X_f, quantiles, dim=0).contiguous() # shape: [B-1] for 1D input
|
157
|
-
bin_indices_f = bin_indices[:, f].contiguous() # view into output
|
158
|
-
node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
|
159
|
-
bin_indices[:,f] = bin_indices_f
|
160
|
-
bin_edges[f,:] = bin_edges_f
|
161
|
-
|
162
|
-
unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
|
163
|
-
return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
|
164
|
-
|
165
|
-
def compute_histograms(self, bin_indices_sub, gradients):
|
166
|
-
grad_hist = torch.zeros((self.num_features, self.num_bins), device=self.device, dtype=torch.float32)
|
167
|
-
hess_hist = torch.zeros((self.num_features, self.num_bins), device=self.device, dtype=torch.float32)
|
168
|
-
|
169
|
-
self.compute_histogram(
|
170
|
-
bin_indices_sub,
|
171
|
-
gradients,
|
172
|
-
grad_hist,
|
173
|
-
hess_hist,
|
174
|
-
self.num_bins,
|
175
|
-
self.threads_per_block,
|
176
|
-
self.rows_per_thread
|
177
|
-
)
|
178
|
-
return grad_hist, hess_hist
|
179
|
-
|
180
|
-
def find_best_split(self, gradient_histogram, hessian_histogram):
|
181
|
-
node_kernel.compute_split(
|
182
|
-
gradient_histogram,
|
183
|
-
hessian_histogram,
|
184
|
-
self.min_split_gain,
|
185
|
-
self.min_child_weight,
|
186
|
-
self.L2_reg,
|
187
|
-
self.best_gains,
|
188
|
-
self.best_bins,
|
189
|
-
self.threads_per_block
|
190
|
-
)
|
191
|
-
|
192
|
-
if torch.all(self.best_bins == -1):
|
193
|
-
return -1, -1 # No valid split found
|
194
|
-
|
195
|
-
f = torch.argmax(self.best_gains).item()
|
196
|
-
b = self.best_bins[f].item()
|
197
|
-
|
198
|
-
return f, b
|
199
|
-
|
200
|
-
def grow_tree(self, gradient_histogram, hessian_histogram, node_indices, depth):
|
201
|
-
if depth == self.max_depth:
|
202
|
-
leaf_value = self.residual[node_indices].mean()
|
203
|
-
self.gradients[node_indices] += self.learning_rate * leaf_value
|
204
|
-
return {"leaf_value": leaf_value.item(), "samples": node_indices.numel()}
|
205
|
-
|
206
|
-
parent_size = node_indices.numel()
|
207
|
-
best_feature, best_bin = self.find_best_split(gradient_histogram, hessian_histogram)
|
208
|
-
|
209
|
-
if best_feature == -1:
|
210
|
-
leaf_value = self.residual[node_indices].mean()
|
211
|
-
self.gradients[node_indices] += self.learning_rate * leaf_value
|
212
|
-
return {"leaf_value": leaf_value.item(), "samples": parent_size}
|
213
|
-
|
214
|
-
split_mask = (self.bin_indices[node_indices, best_feature] <= best_bin)
|
215
|
-
left_indices = node_indices[split_mask]
|
216
|
-
right_indices = node_indices[~split_mask]
|
217
|
-
|
218
|
-
left_size = left_indices.numel()
|
219
|
-
right_size = right_indices.numel()
|
220
|
-
|
221
|
-
if left_size == 0 or right_size == 0:
|
222
|
-
leaf_value = self.residual[node_indices].mean()
|
223
|
-
self.gradients[node_indices] += self.learning_rate * leaf_value
|
224
|
-
return {"leaf_value": leaf_value.item(), "samples": parent_size}
|
225
|
-
|
226
|
-
if left_size <= right_size:
|
227
|
-
grad_hist_left, hess_hist_left = self.compute_histograms( self.bin_indices[left_indices], self.residual[left_indices] )
|
228
|
-
grad_hist_right = gradient_histogram - grad_hist_left
|
229
|
-
hess_hist_right = hessian_histogram - hess_hist_left
|
230
|
-
else:
|
231
|
-
grad_hist_right, hess_hist_right = self.compute_histograms( self.bin_indices[right_indices], self.residual[right_indices] )
|
232
|
-
grad_hist_left = gradient_histogram - grad_hist_right
|
233
|
-
hess_hist_left = hessian_histogram - hess_hist_right
|
234
|
-
|
235
|
-
new_depth = depth + 1
|
236
|
-
left_child = self.grow_tree(grad_hist_left, hess_hist_left, left_indices, new_depth)
|
237
|
-
right_child = self.grow_tree(grad_hist_right, hess_hist_right, right_indices, new_depth)
|
238
|
-
|
239
|
-
return { "feature": best_feature, "bin": best_bin, "left": left_child, "right": right_child }
|
240
|
-
|
241
|
-
def grow_forest(self):
|
242
|
-
forest = [{} for _ in range(self.n_estimators)]
|
243
|
-
self.training_loss = []
|
244
|
-
|
245
|
-
for i in tqdm( range(self.n_estimators) ):
|
246
|
-
self.residual = self.Y_gpu - self.gradients
|
247
|
-
|
248
|
-
self.root_gradient_histogram, self.root_hessian_histogram = \
|
249
|
-
self.compute_histograms(self.bin_indices, self.residual)
|
250
|
-
|
251
|
-
tree = self.grow_tree(
|
252
|
-
self.root_gradient_histogram,
|
253
|
-
self.root_hessian_histogram,
|
254
|
-
self.root_node_indices,
|
255
|
-
depth=0
|
256
|
-
)
|
257
|
-
forest[i] = tree
|
258
|
-
# loss = ((self.Y_gpu - self.gradients) ** 2).mean().item()
|
259
|
-
# self.training_loss.append(loss)
|
260
|
-
# print(f"🌲 Tree {i+1}/{self.n_estimators} - MSE: {loss:.6f}")
|
261
|
-
|
262
|
-
print("Finished training forest.")
|
263
|
-
return forest
|
264
|
-
|
265
|
-
def predict(self, X_np, chunk_size=50000):
|
266
|
-
"""
|
267
|
-
Vectorized predict using a padded layer-by-layer approach.
|
268
|
-
We assume `flatten_forest_to_tensors` has produced self.flat_forest with
|
269
|
-
"features", "thresholds", "leaf_values", all shaped [n_trees, max_nodes].
|
270
|
-
"""
|
271
|
-
with torch.no_grad():
|
272
|
-
# 1) Convert X_np -> bin_indices
|
273
|
-
is_integer_type = np.issubdtype(X_np.dtype, np.integer)
|
274
|
-
if is_integer_type:
|
275
|
-
max_vals = X_np.max(axis=0)
|
276
|
-
if np.all(max_vals < self.num_bins):
|
277
|
-
bin_indices = X_np.astype(np.int8)
|
278
|
-
else:
|
279
|
-
raise ValueError("Pre-binned integers must be < num_bins")
|
280
|
-
else:
|
281
|
-
X_cpu = torch.from_numpy(X_np).type(torch.float32)
|
282
|
-
bin_indices = torch.empty((X_np.shape[0], X_np.shape[1]), dtype=torch.int8)
|
283
|
-
bin_edges_cpu = self.bin_edges.to('cpu')
|
284
|
-
for f in range(self.num_features):
|
285
|
-
bin_indices[:, f] = torch.bucketize(X_cpu[:, f], bin_edges_cpu[f], right=False).type(torch.int8)
|
286
|
-
bin_indices = bin_indices.numpy()
|
287
|
-
|
288
|
-
# 2) Ensure we have a padded representation
|
289
|
-
self.flat_forest = self.flatten_forest_to_tensors(self.forest)
|
290
|
-
|
291
|
-
features_t = self.flat_forest["features"] # [n_trees, max_nodes], int16
|
292
|
-
thresholds_t = self.flat_forest["thresholds"] # [n_trees, max_nodes], int16
|
293
|
-
values_t = self.flat_forest["leaf_values"] # [n_trees, max_nodes], float32
|
294
|
-
max_nodes = self.flat_forest["max_nodes"]
|
295
|
-
|
296
|
-
n_trees = features_t.shape[0]
|
297
|
-
N = bin_indices.shape[0]
|
298
|
-
out = np.zeros(N, dtype=np.float32)
|
299
|
-
|
300
|
-
# 3) Process rows in chunks
|
301
|
-
for start in tqdm(range(0, N, chunk_size)):
|
302
|
-
end = min(start + chunk_size, N)
|
303
|
-
chunk_np = bin_indices[start:end] # shape [chunk_size, F]
|
304
|
-
chunk_gpu = torch.from_numpy(chunk_np).to(self.device) # [chunk_size, F], int8
|
305
|
-
|
306
|
-
# Accumulate raw (unscaled) leaf sums
|
307
|
-
chunk_preds = torch.zeros((end - start,), dtype=torch.float32, device=self.device)
|
308
|
-
|
309
|
-
# node_idx[i] tracks the current node index in the padded tree for row i
|
310
|
-
node_idx = torch.zeros((end - start,), dtype=torch.int32, device=self.device)
|
311
|
-
|
312
|
-
# 'active' is a boolean mask over [0..(end-start-1)], indicating which rows haven't reached a leaf
|
313
|
-
active = torch.ones((end - start,), dtype=torch.bool, device=self.device)
|
314
|
-
|
315
|
-
for t in range(n_trees):
|
316
|
-
# Reset for each tree (each tree is independent)
|
317
|
-
node_idx.fill_(0)
|
318
|
-
active.fill_(True)
|
319
|
-
|
320
|
-
tree_features = features_t[t] # shape [max_nodes], int16
|
321
|
-
tree_thresh = thresholds_t[t] # shape [max_nodes], int16
|
322
|
-
tree_values = values_t[t] # shape [max_nodes], float32
|
323
|
-
|
324
|
-
# Up to self.max_depth+1 layers
|
325
|
-
for _level in range(self.max_depth + 1):
|
326
|
-
active_idx = active.nonzero(as_tuple=True)[0]
|
327
|
-
if active_idx.numel() == 0:
|
328
|
-
break # all rows are done in this tree
|
329
|
-
|
330
|
-
current_node_idx = node_idx[active_idx]
|
331
|
-
f = tree_features[current_node_idx] # shape [#active], int16
|
332
|
-
thr = tree_thresh[current_node_idx] # shape [#active], int16
|
333
|
-
vals = tree_values[current_node_idx] # shape [#active], float32
|
334
|
-
|
335
|
-
mask_no_node = (f == -2)
|
336
|
-
mask_leaf = (f == -1)
|
337
|
-
|
338
|
-
# If leaf, add leaf value and mark inactive.
|
339
|
-
if mask_leaf.any():
|
340
|
-
leaf_rows = active_idx[mask_leaf]
|
341
|
-
chunk_preds[leaf_rows] += vals[mask_leaf]
|
342
|
-
active[leaf_rows] = False
|
343
|
-
|
344
|
-
# If no node, mark inactive.
|
345
|
-
if mask_no_node.any():
|
346
|
-
no_node_rows = active_idx[mask_no_node]
|
347
|
-
active[no_node_rows] = False
|
348
|
-
|
349
|
-
# For internal nodes, perform bin comparison.
|
350
|
-
mask_internal = (~mask_leaf & ~mask_no_node)
|
351
|
-
if mask_internal.any():
|
352
|
-
internal_rows = active_idx[mask_internal]
|
353
|
-
act_f = f[mask_internal].long()
|
354
|
-
act_thr = thr[mask_internal]
|
355
|
-
binvals = chunk_gpu[internal_rows, act_f]
|
356
|
-
go_left = (binvals <= act_thr)
|
357
|
-
new_left_idx = current_node_idx[mask_internal] * 2 + 1
|
358
|
-
new_right_idx = current_node_idx[mask_internal] * 2 + 2
|
359
|
-
node_idx[internal_rows[go_left]] = new_left_idx[go_left]
|
360
|
-
node_idx[internal_rows[~go_left]] = new_right_idx[~go_left]
|
361
|
-
# end per-tree layer loop
|
362
|
-
# end for each tree
|
363
|
-
|
364
|
-
out[start:end] = (
|
365
|
-
self.base_prediction + self.learning_rate * chunk_preds
|
366
|
-
).cpu().numpy()
|
367
|
-
|
368
|
-
return out
|
369
|
-
|
370
|
-
def flatten_forest_to_tensors(self, forest):
|
371
|
-
"""
|
372
|
-
Convert a list of dict-based trees into a fixed-size array representation
|
373
|
-
for each tree, up to max_depth. Each tree is stored in a 'perfect binary tree'
|
374
|
-
layout:
|
375
|
-
- node 0 is the root
|
376
|
-
- node i has children (2*i + 1) and (2*i + 2), if within range
|
377
|
-
- feature = -2 indicates no node / invalid
|
378
|
-
- feature = -1 indicates a leaf node
|
379
|
-
- otherwise, an internal node with that feature.
|
380
|
-
"""
|
381
|
-
n_trees = len(forest)
|
382
|
-
max_nodes = 2 ** (self.max_depth + 1) - 1 # total array slots per tree
|
383
|
-
|
384
|
-
# Allocate padded arrays (on CPU for ease of indexing).
|
385
|
-
feat_arr = np.full((n_trees, max_nodes), -2, dtype=np.int16)
|
386
|
-
thresh_arr = np.full((n_trees, max_nodes), -2, dtype=np.int16)
|
387
|
-
value_arr = np.zeros((n_trees, max_nodes), dtype=np.float32)
|
388
|
-
|
389
|
-
def fill_padded(tree, tree_idx, node_idx, depth):
|
390
|
-
"""
|
391
|
-
Recursively fill feat_arr, thresh_arr, value_arr for a single tree.
|
392
|
-
If depth == self.max_depth, no children are added.
|
393
|
-
If there's no node, feature remains -2.
|
394
|
-
"""
|
395
|
-
if "leaf_value" in tree:
|
396
|
-
feat_arr[tree_idx, node_idx] = -1
|
397
|
-
thresh_arr[tree_idx, node_idx] = -1
|
398
|
-
value_arr[tree_idx, node_idx] = tree["leaf_value"]
|
399
|
-
return
|
400
|
-
|
401
|
-
feat = tree["feature"]
|
402
|
-
bin_th = tree["bin"]
|
403
|
-
|
404
|
-
feat_arr[tree_idx, node_idx] = feat
|
405
|
-
thresh_arr[tree_idx, node_idx] = bin_th
|
406
|
-
# Internal nodes keep a 0 value.
|
407
|
-
|
408
|
-
if depth < self.max_depth:
|
409
|
-
left_idx = 2 * node_idx + 1
|
410
|
-
right_idx = 2 * node_idx + 2
|
411
|
-
fill_padded(tree["left"], tree_idx, left_idx, depth + 1)
|
412
|
-
fill_padded(tree["right"], tree_idx, right_idx, depth + 1)
|
413
|
-
# At max depth, children remain unfilled (-2).
|
414
|
-
|
415
|
-
for t, root in enumerate(forest):
|
416
|
-
fill_padded(root, t, 0, 0)
|
417
|
-
|
418
|
-
# Convert to torch Tensors on the proper device.
|
419
|
-
features_t = torch.from_numpy(feat_arr).to(self.device)
|
420
|
-
thresholds_t = torch.from_numpy(thresh_arr).to(self.device)
|
421
|
-
leaf_values_t = torch.from_numpy(value_arr).to(self.device)
|
422
|
-
|
423
|
-
return {
|
424
|
-
"features": features_t, # [n_trees, max_nodes]
|
425
|
-
"thresholds": thresholds_t, # [n_trees, max_nodes]
|
426
|
-
"leaf_values": leaf_values_t, # [n_trees, max_nodes]
|
427
|
-
"max_nodes": max_nodes
|
428
|
-
}
|
429
|
-
|
430
|
-
def predict_numpy(self, X_np, chunk_size=50000):
|
431
|
-
"""
|
432
|
-
Fully NumPy-based version of predict_fast.
|
433
|
-
Assumes flatten_forest_to_tensors has been called and `self.flat_forest` is ready.
|
434
|
-
"""
|
435
|
-
# 1) Convert X_np -> bin_indices
|
436
|
-
is_integer_type = np.issubdtype(X_np.dtype, np.integer)
|
437
|
-
if is_integer_type:
|
438
|
-
max_vals = X_np.max(axis=0)
|
439
|
-
if np.all(max_vals < self.num_bins):
|
440
|
-
bin_indices = X_np.astype(np.int8)
|
441
|
-
else:
|
442
|
-
raise ValueError("Pre-binned integers must be < num_bins")
|
443
|
-
else:
|
444
|
-
bin_indices = np.empty_like(X_np, dtype=np.int8)
|
445
|
-
# Ensure bin_edges are NumPy arrays
|
446
|
-
if isinstance(self.bin_edges[0], torch.Tensor):
|
447
|
-
bin_edges_np = [be.cpu().numpy() for be in self.bin_edges]
|
448
|
-
else:
|
449
|
-
bin_edges_np = self.bin_edges
|
450
|
-
|
451
|
-
for f in range(self.num_features):
|
452
|
-
bin_indices[:, f] = np.searchsorted(bin_edges_np[f], X_np[:, f], side='left')
|
453
|
-
|
454
|
-
# Ensure we have a padded representation
|
455
|
-
self.flat_forest = self.flatten_forest(self.forest)
|
456
|
-
|
457
|
-
# 2) Padded forest arrays (already NumPy now)
|
458
|
-
features_t = self.flat_forest["features"] # [n_trees, max_nodes], int16
|
459
|
-
thresholds_t = self.flat_forest["thresholds"] # [n_trees, max_nodes], int16
|
460
|
-
values_t = self.flat_forest["leaf_values"] # [n_trees, max_nodes], float32
|
461
|
-
max_nodes = self.flat_forest["max_nodes"]
|
462
|
-
n_trees = features_t.shape[0]
|
463
|
-
N = bin_indices.shape[0]
|
464
|
-
out = np.zeros(N, dtype=np.float32)
|
465
|
-
|
466
|
-
# 3) Process in chunks
|
467
|
-
for start in tqdm( range(0, N, chunk_size) ):
|
468
|
-
end = min(start + chunk_size, N)
|
469
|
-
chunk = bin_indices[start:end] # [chunk_size, F]
|
470
|
-
chunk_preds = np.zeros(end - start, dtype=np.float32)
|
471
|
-
|
472
|
-
for t in range(n_trees):
|
473
|
-
node_idx = np.zeros(end - start, dtype=np.int32)
|
474
|
-
active = np.ones(end - start, dtype=bool)
|
475
|
-
|
476
|
-
tree_features = features_t[t] # [max_nodes]
|
477
|
-
tree_thresh = thresholds_t[t] # [max_nodes]
|
478
|
-
tree_values = values_t[t] # [max_nodes]
|
479
|
-
|
480
|
-
for _level in range(self.max_depth + 1):
|
481
|
-
active_idx = np.nonzero(active)[0]
|
482
|
-
if active_idx.size == 0:
|
483
|
-
break
|
484
|
-
|
485
|
-
current_node_idx = node_idx[active_idx]
|
486
|
-
f = tree_features[current_node_idx]
|
487
|
-
thr = tree_thresh[current_node_idx]
|
488
|
-
vals = tree_values[current_node_idx]
|
489
|
-
|
490
|
-
mask_no_node = (f == -2)
|
491
|
-
mask_leaf = (f == -1)
|
492
|
-
mask_internal = ~(mask_leaf | mask_no_node)
|
493
|
-
|
494
|
-
if np.any(mask_leaf):
|
495
|
-
leaf_rows = active_idx[mask_leaf]
|
496
|
-
chunk_preds[leaf_rows] += vals[mask_leaf]
|
497
|
-
active[leaf_rows] = False
|
498
|
-
|
499
|
-
if np.any(mask_no_node):
|
500
|
-
no_node_rows = active_idx[mask_no_node]
|
501
|
-
active[no_node_rows] = False
|
502
|
-
|
503
|
-
if np.any(mask_internal):
|
504
|
-
internal_rows = active_idx[mask_internal]
|
505
|
-
act_f = f[mask_internal].astype(np.int32)
|
506
|
-
act_thr = thr[mask_internal]
|
507
|
-
binvals = chunk[internal_rows, act_f]
|
508
|
-
go_left = binvals <= act_thr
|
509
|
-
|
510
|
-
new_left_idx = current_node_idx[mask_internal] * 2 + 1
|
511
|
-
new_right_idx = current_node_idx[mask_internal] * 2 + 2
|
512
|
-
node_idx[internal_rows[go_left]] = new_left_idx[go_left]
|
513
|
-
node_idx[internal_rows[~go_left]] = new_right_idx[~go_left]
|
514
|
-
|
515
|
-
out[start:end] = self.base_prediction + self.learning_rate * chunk_preds
|
516
|
-
|
517
|
-
return out
|
518
|
-
|
519
|
-
def flatten_forest(self, forest):
|
520
|
-
n_trees = len(forest)
|
521
|
-
max_nodes = 2 ** (self.max_depth + 1) - 1
|
522
|
-
|
523
|
-
feat_arr = np.full((n_trees, max_nodes), -2, dtype=np.int16)
|
524
|
-
thresh_arr = np.full((n_trees, max_nodes), -2, dtype=np.int16)
|
525
|
-
value_arr = np.zeros((n_trees, max_nodes), dtype=np.float32)
|
526
|
-
|
527
|
-
def fill_padded(tree, tree_idx, node_idx, depth):
|
528
|
-
if "leaf_value" in tree:
|
529
|
-
feat_arr[tree_idx, node_idx] = -1
|
530
|
-
thresh_arr[tree_idx, node_idx] = -1
|
531
|
-
value_arr[tree_idx, node_idx] = tree["leaf_value"]
|
532
|
-
return
|
533
|
-
feat = tree["feature"]
|
534
|
-
bin_th = tree["bin"]
|
535
|
-
feat_arr[tree_idx, node_idx] = feat
|
536
|
-
thresh_arr[tree_idx, node_idx] = bin_th
|
537
|
-
|
538
|
-
if depth < self.max_depth:
|
539
|
-
left_idx = 2 * node_idx + 1
|
540
|
-
right_idx = 2 * node_idx + 2
|
541
|
-
fill_padded(tree["left"], tree_idx, left_idx, depth + 1)
|
542
|
-
fill_padded(tree["right"], tree_idx, right_idx, depth + 1)
|
543
|
-
|
544
|
-
for t, root in enumerate(forest):
|
545
|
-
fill_padded(root, t, 0, 0)
|
546
|
-
|
547
|
-
return {
|
548
|
-
"features": feat_arr,
|
549
|
-
"thresholds": thresh_arr,
|
550
|
-
"leaf_values": value_arr,
|
551
|
-
"max_nodes": max_nodes
|
552
|
-
}
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|