warpgbm 0.1.19__tar.gz → 0.1.21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {warpgbm-0.1.19/warpgbm.egg-info → warpgbm-0.1.21}/PKG-INFO +7 -1
  2. {warpgbm-0.1.19 → warpgbm-0.1.21}/README.md +6 -0
  3. {warpgbm-0.1.19 → warpgbm-0.1.21}/pyproject.toml +1 -1
  4. {warpgbm-0.1.19 → warpgbm-0.1.21}/setup.py +1 -0
  5. warpgbm-0.1.21/version.txt +1 -0
  6. warpgbm-0.1.21/warpgbm/core.py +341 -0
  7. {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm/cuda/node_kernel.cpp +8 -0
  8. warpgbm-0.1.21/warpgbm/cuda/predict.cu +77 -0
  9. {warpgbm-0.1.19 → warpgbm-0.1.21/warpgbm.egg-info}/PKG-INFO +7 -1
  10. {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm.egg-info/SOURCES.txt +2 -1
  11. warpgbm-0.1.19/version.txt +0 -1
  12. warpgbm-0.1.19/warpgbm/core.py +0 -552
  13. {warpgbm-0.1.19 → warpgbm-0.1.21}/LICENSE +0 -0
  14. {warpgbm-0.1.19 → warpgbm-0.1.21}/MANIFEST.in +0 -0
  15. {warpgbm-0.1.19 → warpgbm-0.1.21}/setup.cfg +0 -0
  16. {warpgbm-0.1.19 → warpgbm-0.1.21}/tests/__init__.py +0 -0
  17. {warpgbm-0.1.19 → warpgbm-0.1.21}/tests/test_fit_predict_corr.py +0 -0
  18. {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm/__init__.py +0 -0
  19. {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm/cuda/__init__.py +0 -0
  20. {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm/cuda/best_split_kernel.cu +0 -0
  21. {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm/cuda/binner.cu +0 -0
  22. {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm/cuda/histogram_kernel.cu +0 -0
  23. {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm.egg-info/dependency_links.txt +0 -0
  24. {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm.egg-info/requires.txt +0 -0
  25. {warpgbm-0.1.19 → warpgbm-0.1.21}/warpgbm.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: warpgbm
3
- Version: 0.1.19
3
+ Version: 0.1.21
4
4
  Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
5
5
  License: GNU GENERAL PUBLIC LICENSE
6
6
  Version 3, 29 June 2007
@@ -897,3 +897,9 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
897
897
 
898
898
  ---
899
899
 
900
+ ## Version Notes
901
+
902
+ ### v0.1.21
903
+
904
+ - Vectorized predict function replaced with CUDA kernel (`warpgbm/cuda/predict.cu`), parallelizing per sample, per tree.
905
+
@@ -209,3 +209,9 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
209
209
 
210
210
  ---
211
211
 
212
+ ## Version Notes
213
+
214
+ ### v0.1.21
215
+
216
+ - Vectorized predict function replaced with CUDA kernel (`warpgbm/cuda/predict.cu`), parallelizing per sample, per tree.
217
+
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "warpgbm"
7
- version = "0.1.19"
7
+ version = "0.1.21"
8
8
  description = "A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -23,6 +23,7 @@ def get_extensions():
23
23
  "warpgbm/cuda/histogram_kernel.cu",
24
24
  "warpgbm/cuda/best_split_kernel.cu",
25
25
  "warpgbm/cuda/binner.cu",
26
+ "warpgbm/cuda/predict.cu",
26
27
  "warpgbm/cuda/node_kernel.cpp",
27
28
  ]
28
29
  )
@@ -0,0 +1 @@
1
+ 0.1.21
@@ -0,0 +1,341 @@
1
+ import torch
2
+ import numpy as np
3
+ from sklearn.base import BaseEstimator, RegressorMixin
4
+ from warpgbm.cuda import node_kernel
5
+ from tqdm import tqdm
6
+ from typing import Tuple
7
+ from torch import Tensor
8
+
9
+ histogram_kernels = {
10
+ 'hist1': node_kernel.compute_histogram,
11
+ 'hist2': node_kernel.compute_histogram2,
12
+ 'hist3': node_kernel.compute_histogram3
13
+ }
14
+
15
+ class WarpGBM(BaseEstimator, RegressorMixin):
16
+ def __init__(
17
+ self,
18
+ num_bins=10,
19
+ max_depth=3,
20
+ learning_rate=0.1,
21
+ n_estimators=100,
22
+ min_child_weight=20,
23
+ min_split_gain=0.0,
24
+ verbosity=True,
25
+ histogram_computer='hist3',
26
+ threads_per_block=64,
27
+ rows_per_thread=4,
28
+ L2_reg=1e-6,
29
+ L1_reg=0.0,
30
+ device='cuda'
31
+ ):
32
+ # Validate arguments
33
+ self._validate_hyperparams(
34
+ num_bins=num_bins,
35
+ max_depth=max_depth,
36
+ learning_rate=learning_rate,
37
+ n_estimators=n_estimators,
38
+ min_child_weight=min_child_weight,
39
+ min_split_gain=min_split_gain,
40
+ histogram_computer=histogram_computer,
41
+ threads_per_block=threads_per_block,
42
+ rows_per_thread=rows_per_thread,
43
+ L2_reg=L2_reg,
44
+ L1_reg=L1_reg
45
+ )
46
+
47
+ self.num_bins = num_bins
48
+ self.max_depth = max_depth
49
+ self.learning_rate = learning_rate
50
+ self.n_estimators = n_estimators
51
+ self.forest = None
52
+ self.bin_edges = None
53
+ self.base_prediction = None
54
+ self.unique_eras = None
55
+ self.device = device
56
+ self.root_gradient_histogram = None
57
+ self.root_hessian_histogram = None
58
+ self.gradients = None
59
+ self.root_node_indices = None
60
+ self.bin_indices = None
61
+ self.Y_gpu = None
62
+ self.num_features = None
63
+ self.num_samples = None
64
+ self.min_child_weight = min_child_weight
65
+ self.min_split_gain = min_split_gain
66
+ self.best_bin = torch.tensor([-1], dtype=torch.int32, device=self.device)
67
+ self.compute_histogram = histogram_kernels[histogram_computer]
68
+ self.threads_per_block = threads_per_block
69
+ self.rows_per_thread = rows_per_thread
70
+ self.L2_reg = L2_reg
71
+ self.L1_reg = L1_reg
72
+
73
+ def _validate_hyperparams(self, **kwargs):
74
+ # Type checks
75
+ int_params = [
76
+ "num_bins", "max_depth", "n_estimators", "min_child_weight",
77
+ "threads_per_block", "rows_per_thread"
78
+ ]
79
+ float_params = [
80
+ "learning_rate", "min_split_gain", "L2_reg", "L1_reg"
81
+ ]
82
+
83
+ for param in int_params:
84
+ if not isinstance(kwargs[param], int):
85
+ raise TypeError(f"{param} must be an integer, got {type(kwargs[param])}.")
86
+
87
+ for param in float_params:
88
+ if not isinstance(kwargs[param], (float, int)): # Accept ints as valid floats
89
+ raise TypeError(f"{param} must be a float, got {type(kwargs[param])}.")
90
+
91
+ if not ( 2 <= kwargs["num_bins"] <= 127 ):
92
+ raise ValueError("num_bins must be between 2 and 127 inclusive.")
93
+ if kwargs["max_depth"] < 1:
94
+ raise ValueError("max_depth must be at least 1.")
95
+ if not (0.0 < kwargs["learning_rate"] <= 1.0):
96
+ raise ValueError("learning_rate must be in (0.0, 1.0].")
97
+ if kwargs["n_estimators"] <= 0:
98
+ raise ValueError("n_estimators must be positive.")
99
+ if kwargs["min_child_weight"] < 1:
100
+ raise ValueError("min_child_weight must be a positive integer.")
101
+ if kwargs["min_split_gain"] < 0:
102
+ raise ValueError("min_split_gain must be non-negative.")
103
+ if kwargs["threads_per_block"] <= 0 or kwargs["threads_per_block"] % 32 != 0:
104
+ raise ValueError("threads_per_block should be a positive multiple of 32 (warp size).")
105
+ if not ( 1 <= kwargs["rows_per_thread"] <= 16 ):
106
+ raise ValueError("rows_per_thread must be positive between 1 and 16 inclusive.")
107
+ if kwargs["L2_reg"] < 0 or kwargs["L1_reg"] < 0:
108
+ raise ValueError("L2_reg and L1_reg must be non-negative.")
109
+ if kwargs["histogram_computer"] not in histogram_kernels:
110
+ raise ValueError(f"Invalid histogram_computer: {kwargs['histogram_computer']}. Choose from {list(histogram_kernels.keys())}.")
111
+
112
+ def fit(self, X, y, era_id=None):
113
+ if era_id is None:
114
+ era_id = np.ones(X.shape[0], dtype='int32')
115
+ self.bin_indices, era_indices, self.bin_edges, self.unique_eras, self.Y_gpu = self.preprocess_gpu_data(X, y, era_id)
116
+ self.num_samples, self.num_features = X.shape
117
+ self.gradients = torch.zeros_like(self.Y_gpu)
118
+ self.root_node_indices = torch.arange(self.num_samples, device=self.device)
119
+ self.base_prediction = self.Y_gpu.mean().item()
120
+ self.gradients += self.base_prediction
121
+ self.best_gains = torch.zeros(self.num_features, device=self.device)
122
+ self.best_bins = torch.zeros(self.num_features, device=self.device, dtype=torch.int32)
123
+ with torch.no_grad():
124
+ self.forest = self.grow_forest()
125
+ return self
126
+
127
+ def preprocess_gpu_data(self, X_np, Y_np, era_id_np):
128
+ with torch.no_grad():
129
+ self.num_samples, self.num_features = X_np.shape
130
+ Y_gpu = torch.from_numpy(Y_np).type(torch.float32).to(self.device)
131
+ era_id_gpu = torch.from_numpy(era_id_np).type(torch.int32).to(self.device)
132
+ is_integer_type = np.issubdtype(X_np.dtype, np.integer)
133
+ if is_integer_type:
134
+ max_vals = X_np.max(axis=0)
135
+ if np.all(max_vals < self.num_bins):
136
+ print("Detected pre-binned integer input — skipping quantile binning.")
137
+ bin_indices = torch.from_numpy(X_np).to(self.device).contiguous().to(torch.int8)
138
+
139
+ # We'll store None or an empty tensor in self.bin_edges
140
+ # to indicate that we skip binning at predict-time
141
+ bin_edges = torch.arange(1, self.num_bins, dtype=torch.float32).repeat(self.num_features, 1)
142
+ bin_edges = bin_edges.to(self.device)
143
+ unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
144
+ return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
145
+ else:
146
+ print("Integer input detected, but values exceed num_bins — falling back to quantile binning.")
147
+
148
+ bin_indices = torch.empty((self.num_samples, self.num_features), dtype=torch.int8, device='cuda')
149
+ bin_edges = torch.empty((self.num_features, self.num_bins - 1), dtype=torch.float32, device='cuda')
150
+
151
+ X_np = torch.from_numpy(X_np).to(torch.float32).pin_memory()
152
+
153
+ for f in range(self.num_features):
154
+ X_f = X_np[:, f].to('cuda', non_blocking=True)
155
+ quantiles = torch.linspace(0, 1, self.num_bins + 1, device='cuda', dtype=X_f.dtype)[1:-1]
156
+ bin_edges_f = torch.quantile(X_f, quantiles, dim=0).contiguous() # shape: [B-1] for 1D input
157
+ bin_indices_f = bin_indices[:, f].contiguous() # view into output
158
+ node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
159
+ bin_indices[:,f] = bin_indices_f
160
+ bin_edges[f,:] = bin_edges_f
161
+
162
+ unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
163
+ return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
164
+
165
+ def compute_histograms(self, bin_indices_sub, gradients):
166
+ grad_hist = torch.zeros((self.num_features, self.num_bins), device=self.device, dtype=torch.float32)
167
+ hess_hist = torch.zeros((self.num_features, self.num_bins), device=self.device, dtype=torch.float32)
168
+
169
+ self.compute_histogram(
170
+ bin_indices_sub,
171
+ gradients,
172
+ grad_hist,
173
+ hess_hist,
174
+ self.num_bins,
175
+ self.threads_per_block,
176
+ self.rows_per_thread
177
+ )
178
+ return grad_hist, hess_hist
179
+
180
+ def find_best_split(self, gradient_histogram, hessian_histogram):
181
+ node_kernel.compute_split(
182
+ gradient_histogram,
183
+ hessian_histogram,
184
+ self.min_split_gain,
185
+ self.min_child_weight,
186
+ self.L2_reg,
187
+ self.best_gains,
188
+ self.best_bins,
189
+ self.threads_per_block
190
+ )
191
+
192
+ if torch.all(self.best_bins == -1):
193
+ return -1, -1 # No valid split found
194
+
195
+ f = torch.argmax(self.best_gains).item()
196
+ b = self.best_bins[f].item()
197
+
198
+ return f, b
199
+
200
+ def grow_tree(self, gradient_histogram, hessian_histogram, node_indices, depth):
201
+ if depth == self.max_depth:
202
+ leaf_value = self.residual[node_indices].mean()
203
+ self.gradients[node_indices] += self.learning_rate * leaf_value
204
+ return {"leaf_value": leaf_value.item(), "samples": node_indices.numel()}
205
+
206
+ parent_size = node_indices.numel()
207
+ best_feature, best_bin = self.find_best_split(gradient_histogram, hessian_histogram)
208
+
209
+ if best_feature == -1:
210
+ leaf_value = self.residual[node_indices].mean()
211
+ self.gradients[node_indices] += self.learning_rate * leaf_value
212
+ return {"leaf_value": leaf_value.item(), "samples": parent_size}
213
+
214
+ split_mask = (self.bin_indices[node_indices, best_feature] <= best_bin)
215
+ left_indices = node_indices[split_mask]
216
+ right_indices = node_indices[~split_mask]
217
+
218
+ left_size = left_indices.numel()
219
+ right_size = right_indices.numel()
220
+
221
+
222
+ if left_size <= right_size:
223
+ grad_hist_left, hess_hist_left = self.compute_histograms( self.bin_indices[left_indices], self.residual[left_indices] )
224
+ grad_hist_right = gradient_histogram - grad_hist_left
225
+ hess_hist_right = hessian_histogram - hess_hist_left
226
+ else:
227
+ grad_hist_right, hess_hist_right = self.compute_histograms( self.bin_indices[right_indices], self.residual[right_indices] )
228
+ grad_hist_left = gradient_histogram - grad_hist_right
229
+ hess_hist_left = hessian_histogram - hess_hist_right
230
+
231
+ new_depth = depth + 1
232
+ left_child = self.grow_tree(grad_hist_left, hess_hist_left, left_indices, new_depth)
233
+ right_child = self.grow_tree(grad_hist_right, hess_hist_right, right_indices, new_depth)
234
+
235
+ return { "feature": best_feature, "bin": best_bin, "left": left_child, "right": right_child }
236
+
237
+ def grow_forest(self):
238
+ forest = [{} for _ in range(self.n_estimators)]
239
+ self.training_loss = []
240
+
241
+ for i in tqdm( range(self.n_estimators) ):
242
+ self.residual = self.Y_gpu - self.gradients
243
+
244
+ self.root_gradient_histogram, self.root_hessian_histogram = \
245
+ self.compute_histograms(self.bin_indices, self.residual)
246
+
247
+ tree = self.grow_tree(
248
+ self.root_gradient_histogram,
249
+ self.root_hessian_histogram,
250
+ self.root_node_indices,
251
+ depth=0
252
+ )
253
+ forest[i] = tree
254
+ # loss = ((self.Y_gpu - self.gradients) ** 2).mean().item()
255
+ # self.training_loss.append(loss)
256
+ # print(f"🌲 Tree {i+1}/{self.n_estimators} - MSE: {loss:.6f}")
257
+
258
+ print("Finished training forest.")
259
+ return forest
260
+
261
+ def predict(self, X_np):
262
+ X_tensor = torch.from_numpy(X_np).to(torch.float32).pin_memory()
263
+ num_samples = X_tensor.size(0)
264
+ bin_indices = torch.zeros((num_samples, self.num_features), dtype=torch.int8, device=self.device)
265
+
266
+ with torch.no_grad():
267
+ for f in range(self.num_features):
268
+ X_f = X_tensor[:, f].to(self.device, non_blocking=True)
269
+ bin_edges_f = self.bin_edges[f]
270
+ bin_indices_f = bin_indices[:, f].contiguous()
271
+ node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
272
+ bin_indices[:, f] = bin_indices_f
273
+
274
+ tree_tensor = torch.stack([
275
+ self.flatten_tree(tree, max_nodes=2**(self.max_depth + 1))
276
+ for tree in self.forest
277
+ ]).to(self.device)
278
+
279
+ out = torch.zeros(num_samples, device=self.device)
280
+ node_kernel.predict_forest(
281
+ bin_indices.contiguous(),
282
+ tree_tensor.contiguous(),
283
+ self.learning_rate,
284
+ out
285
+ )
286
+
287
+ return out.cpu().numpy()
288
+
289
+ def flatten_tree(self, tree, max_nodes):
290
+ """
291
+ Convert a recursive tree structure into a flat matrix format.
292
+
293
+ Each row in the output represents a node:
294
+ - Columns: [feature, bin, left_id, right_id, is_leaf, value]
295
+ - Internal nodes fill columns 0–3 and set is_leaf = 0
296
+ - Leaf nodes fill only value and set is_leaf = 1
297
+
298
+ Args:
299
+ tree (list): A list containing a single root node (recursive dict form).
300
+ max_nodes (int): Max number of nodes to allocate in the flat matrix.
301
+
302
+ Returns:
303
+ torch.Tensor: [max_nodes x 6] matrix representing the flattened tree.
304
+ """
305
+ flat = torch.full((max_nodes, 6), float('nan'), dtype=torch.float32)
306
+ node_counter = [0]
307
+ node_list = []
308
+
309
+ def walk(node):
310
+ curr_id = node_counter[0]
311
+ node_counter[0] += 1
312
+
313
+ new_node = {'node_id': curr_id}
314
+ if 'leaf_value' in node:
315
+ new_node['leaf_value'] = float(node['leaf_value'])
316
+ else:
317
+ new_node['best_feature'] = float(node['feature'])
318
+ new_node['split_bin'] = float(node['bin'])
319
+ new_node['left_id'] = node_counter[0]
320
+ walk(node['left'])
321
+ new_node['right_id'] = node_counter[0]
322
+ walk(node['right'])
323
+
324
+ node_list.append(new_node)
325
+ return new_node
326
+
327
+ walk(tree)
328
+
329
+ for node in node_list:
330
+ i = node['node_id']
331
+ if 'leaf_value' in node:
332
+ flat[i, 4] = 1.0
333
+ flat[i, 5] = node['leaf_value']
334
+ else:
335
+ flat[i, 0] = node['best_feature']
336
+ flat[i, 1] = node['split_bin']
337
+ flat[i, 2] = node['left_id']
338
+ flat[i, 3] = node['right_id']
339
+ flat[i, 4] = 0.0
340
+
341
+ return flat
@@ -44,6 +44,13 @@ void launch_bin_column_kernel(
44
44
  at::Tensor bin_edges,
45
45
  at::Tensor bin_indices);
46
46
 
47
+ void predict_with_forest(
48
+ const at::Tensor &bin_indices, // [N x F], int8
49
+ const at::Tensor &tree_tensor, // [T x max_nodes x 6], float32
50
+ float learning_rate,
51
+ at::Tensor &out // [N], float32
52
+ );
53
+
47
54
  // Bindings
48
55
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
49
56
  {
@@ -52,4 +59,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
52
59
  m.def("compute_histogram3", &launch_histogram_kernel_cuda_configurable, "Histogram Feature Shared Mem");
53
60
  m.def("compute_split", &launch_best_split_kernel_cuda, "Best Split (CUDA)");
54
61
  m.def("custom_cuda_binner", &launch_bin_column_kernel, "Custom CUDA binning kernel");
62
+ m.def("predict_forest", &predict_with_forest, "CUDA Predictions");
55
63
  }
@@ -0,0 +1,77 @@
1
+ #include <cuda.h>
2
+ #include <cuda_runtime.h>
3
+ #include <torch/extension.h>
4
+
5
+ __global__ void predict_forest_kernel(
6
+ const int8_t *__restrict__ bin_indices, // [N x F]
7
+ const float *__restrict__ tree_tensor, // [T x max_nodes x 6]
8
+ int N, int F, int T, int max_nodes,
9
+ float learning_rate,
10
+ float *__restrict__ out // [N]
11
+ )
12
+ {
13
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
14
+ int total_jobs = N * T;
15
+ if (idx >= total_jobs)
16
+ return;
17
+
18
+ int i = idx % N; // sample index
19
+ int t = idx / N; // tree index
20
+
21
+ // if (i == 0 && t == 0)
22
+ // {
23
+ // printf("[DEBUG] Thread (i=%d, t=%d): starting prediction\n", i, t);
24
+ // }
25
+
26
+ const float *tree = tree_tensor + t * max_nodes * 6;
27
+
28
+ int node_id = 0;
29
+ while (true)
30
+ {
31
+ float is_leaf = tree[node_id * 6 + 4];
32
+ if (is_leaf > 0.5f)
33
+ {
34
+ float val = tree[node_id * 6 + 5];
35
+ atomicAdd(&out[i], learning_rate * val);
36
+ return;
37
+ }
38
+ int feat = static_cast<int>(tree[node_id * 6 + 0]);
39
+ int split_bin = static_cast<int>(tree[node_id * 6 + 1]);
40
+ int left_id = static_cast<int>(tree[node_id * 6 + 2]);
41
+ int right_id = static_cast<int>(tree[node_id * 6 + 3]);
42
+
43
+ int8_t bin = bin_indices[i * F + feat];
44
+ node_id = (bin <= split_bin) ? left_id : right_id;
45
+ // printf("sample %d, tree %d, feat %d, bin %d, split %d → %s\n", i, t, feat, bin, split_bin, (bin <= split_bin ? "L" : "R"));
46
+ }
47
+ }
48
+
49
+ void predict_with_forest(
50
+ const at::Tensor &bin_indices, // [N x F], int8
51
+ const at::Tensor &tree_tensor, // [T x max_nodes x 6], float32
52
+ float learning_rate,
53
+ at::Tensor &out // [N], float32
54
+ )
55
+ {
56
+ int N = bin_indices.size(0);
57
+ int F = bin_indices.size(1);
58
+ int T = tree_tensor.size(0);
59
+ int max_nodes = tree_tensor.size(1);
60
+
61
+ int total_jobs = N * T;
62
+ int threads_per_block = 256;
63
+ int blocks = (total_jobs + threads_per_block - 1) / threads_per_block;
64
+
65
+ predict_forest_kernel<<<blocks, threads_per_block>>>(
66
+ bin_indices.data_ptr<int8_t>(),
67
+ tree_tensor.data_ptr<float>(),
68
+ N, F, T, max_nodes,
69
+ learning_rate,
70
+ out.data_ptr<float>());
71
+
72
+ cudaError_t err = cudaGetLastError();
73
+ if (err != cudaSuccess)
74
+ {
75
+ printf("CUDA predict kernel failed: %s\n", cudaGetErrorString(err));
76
+ }
77
+ }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: warpgbm
3
- Version: 0.1.19
3
+ Version: 0.1.21
4
4
  Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
5
5
  License: GNU GENERAL PUBLIC LICENSE
6
6
  Version 3, 29 June 2007
@@ -897,3 +897,9 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
897
897
 
898
898
  ---
899
899
 
900
+ ## Version Notes
901
+
902
+ ### v0.1.21
903
+
904
+ - Vectorized predict function replaced with CUDA kernel (`warpgbm/cuda/predict.cu`), parallelizing per sample, per tree.
905
+
@@ -17,4 +17,5 @@ warpgbm/cuda/__init__.py
17
17
  warpgbm/cuda/best_split_kernel.cu
18
18
  warpgbm/cuda/binner.cu
19
19
  warpgbm/cuda/histogram_kernel.cu
20
- warpgbm/cuda/node_kernel.cpp
20
+ warpgbm/cuda/node_kernel.cpp
21
+ warpgbm/cuda/predict.cu
@@ -1 +0,0 @@
1
- 0.1.19
@@ -1,552 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from sklearn.base import BaseEstimator, RegressorMixin
4
- from warpgbm.cuda import node_kernel
5
- from tqdm import tqdm
6
- from typing import Tuple
7
- from torch import Tensor
8
-
9
- histogram_kernels = {
10
- 'hist1': node_kernel.compute_histogram,
11
- 'hist2': node_kernel.compute_histogram2,
12
- 'hist3': node_kernel.compute_histogram3
13
- }
14
-
15
- class WarpGBM(BaseEstimator, RegressorMixin):
16
- def __init__(
17
- self,
18
- num_bins=10,
19
- max_depth=3,
20
- learning_rate=0.1,
21
- n_estimators=100,
22
- min_child_weight=20,
23
- min_split_gain=0.0,
24
- verbosity=True,
25
- histogram_computer='hist3',
26
- threads_per_block=64,
27
- rows_per_thread=4,
28
- L2_reg=1e-6,
29
- L1_reg=0.0,
30
- device='cuda'
31
- ):
32
- # Validate arguments
33
- self._validate_hyperparams(
34
- num_bins=num_bins,
35
- max_depth=max_depth,
36
- learning_rate=learning_rate,
37
- n_estimators=n_estimators,
38
- min_child_weight=min_child_weight,
39
- min_split_gain=min_split_gain,
40
- histogram_computer=histogram_computer,
41
- threads_per_block=threads_per_block,
42
- rows_per_thread=rows_per_thread,
43
- L2_reg=L2_reg,
44
- L1_reg=L1_reg
45
- )
46
-
47
- self.num_bins = num_bins
48
- self.max_depth = max_depth
49
- self.learning_rate = learning_rate
50
- self.n_estimators = n_estimators
51
- self.forest = None
52
- self.bin_edges = None
53
- self.base_prediction = None
54
- self.unique_eras = None
55
- self.device = device
56
- self.root_gradient_histogram = None
57
- self.root_hessian_histogram = None
58
- self.gradients = None
59
- self.root_node_indices = None
60
- self.bin_indices = None
61
- self.Y_gpu = None
62
- self.num_features = None
63
- self.num_samples = None
64
- self.min_child_weight = min_child_weight
65
- self.min_split_gain = min_split_gain
66
- self.best_bin = torch.tensor([-1], dtype=torch.int32, device=self.device)
67
- self.compute_histogram = histogram_kernels[histogram_computer]
68
- self.threads_per_block = threads_per_block
69
- self.rows_per_thread = rows_per_thread
70
- self.L2_reg = L2_reg
71
- self.L1_reg = L1_reg
72
-
73
- def _validate_hyperparams(self, **kwargs):
74
- # Type checks
75
- int_params = [
76
- "num_bins", "max_depth", "n_estimators", "min_child_weight",
77
- "threads_per_block", "rows_per_thread"
78
- ]
79
- float_params = [
80
- "learning_rate", "min_split_gain", "L2_reg", "L1_reg"
81
- ]
82
-
83
- for param in int_params:
84
- if not isinstance(kwargs[param], int):
85
- raise TypeError(f"{param} must be an integer, got {type(kwargs[param])}.")
86
-
87
- for param in float_params:
88
- if not isinstance(kwargs[param], (float, int)): # Accept ints as valid floats
89
- raise TypeError(f"{param} must be a float, got {type(kwargs[param])}.")
90
-
91
- if not ( 2 <= kwargs["num_bins"] <= 127 ):
92
- raise ValueError("num_bins must be between 2 and 127 inclusive.")
93
- if kwargs["max_depth"] < 1:
94
- raise ValueError("max_depth must be at least 1.")
95
- if not (0.0 < kwargs["learning_rate"] <= 1.0):
96
- raise ValueError("learning_rate must be in (0.0, 1.0].")
97
- if kwargs["n_estimators"] <= 0:
98
- raise ValueError("n_estimators must be positive.")
99
- if kwargs["min_child_weight"] < 1:
100
- raise ValueError("min_child_weight must be a positive integer.")
101
- if kwargs["min_split_gain"] < 0:
102
- raise ValueError("min_split_gain must be non-negative.")
103
- if kwargs["threads_per_block"] <= 0 or kwargs["threads_per_block"] % 32 != 0:
104
- raise ValueError("threads_per_block should be a positive multiple of 32 (warp size).")
105
- if not ( 1 <= kwargs["rows_per_thread"] <= 16 ):
106
- raise ValueError("rows_per_thread must be positive between 1 and 16 inclusive.")
107
- if kwargs["L2_reg"] < 0 or kwargs["L1_reg"] < 0:
108
- raise ValueError("L2_reg and L1_reg must be non-negative.")
109
- if kwargs["histogram_computer"] not in histogram_kernels:
110
- raise ValueError(f"Invalid histogram_computer: {kwargs['histogram_computer']}. Choose from {list(histogram_kernels.keys())}.")
111
-
112
- def fit(self, X, y, era_id=None):
113
- if era_id is None:
114
- era_id = np.ones(X.shape[0], dtype='int32')
115
- self.bin_indices, era_indices, self.bin_edges, self.unique_eras, self.Y_gpu = self.preprocess_gpu_data(X, y, era_id)
116
- self.num_samples, self.num_features = X.shape
117
- self.gradients = torch.zeros_like(self.Y_gpu)
118
- self.root_node_indices = torch.arange(self.num_samples, device=self.device)
119
- self.base_prediction = self.Y_gpu.mean().item()
120
- self.gradients += self.base_prediction
121
- self.best_gains = torch.zeros(self.num_features, device=self.device)
122
- self.best_bins = torch.zeros(self.num_features, device=self.device, dtype=torch.int32)
123
- with torch.no_grad():
124
- self.forest = self.grow_forest()
125
- return self
126
-
127
- def preprocess_gpu_data(self, X_np, Y_np, era_id_np):
128
- with torch.no_grad():
129
- self.num_samples, self.num_features = X_np.shape
130
- Y_gpu = torch.from_numpy(Y_np).type(torch.float32).to(self.device)
131
- era_id_gpu = torch.from_numpy(era_id_np).type(torch.int32).to(self.device)
132
- is_integer_type = np.issubdtype(X_np.dtype, np.integer)
133
- if is_integer_type:
134
- max_vals = X_np.max(axis=0)
135
- if np.all(max_vals < self.num_bins):
136
- print("Detected pre-binned integer input — skipping quantile binning.")
137
- bin_indices = torch.from_numpy(X_np).to(self.device).contiguous().to(torch.int8)
138
-
139
- # We'll store None or an empty tensor in self.bin_edges
140
- # to indicate that we skip binning at predict-time
141
- bin_edges = torch.arange(1, self.num_bins, dtype=torch.float32).repeat(self.num_features, 1)
142
- bin_edges = bin_edges.to(self.device)
143
- unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
144
- return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
145
- else:
146
- print("Integer input detected, but values exceed num_bins — falling back to quantile binning.")
147
-
148
- bin_indices = torch.empty((self.num_samples, self.num_features), dtype=torch.int8, device='cuda')
149
- bin_edges = torch.empty((self.num_features, self.num_bins - 1), dtype=torch.float32, device='cuda')
150
-
151
- X_np = torch.from_numpy(X_np).to(torch.float32).pin_memory()
152
-
153
- for f in range(self.num_features):
154
- X_f = X_np[:, f].to('cuda', non_blocking=True)
155
- quantiles = torch.linspace(0, 1, self.num_bins + 1, device='cuda', dtype=X_f.dtype)[1:-1]
156
- bin_edges_f = torch.quantile(X_f, quantiles, dim=0).contiguous() # shape: [B-1] for 1D input
157
- bin_indices_f = bin_indices[:, f].contiguous() # view into output
158
- node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
159
- bin_indices[:,f] = bin_indices_f
160
- bin_edges[f,:] = bin_edges_f
161
-
162
- unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
163
- return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
164
-
165
- def compute_histograms(self, bin_indices_sub, gradients):
166
- grad_hist = torch.zeros((self.num_features, self.num_bins), device=self.device, dtype=torch.float32)
167
- hess_hist = torch.zeros((self.num_features, self.num_bins), device=self.device, dtype=torch.float32)
168
-
169
- self.compute_histogram(
170
- bin_indices_sub,
171
- gradients,
172
- grad_hist,
173
- hess_hist,
174
- self.num_bins,
175
- self.threads_per_block,
176
- self.rows_per_thread
177
- )
178
- return grad_hist, hess_hist
179
-
180
- def find_best_split(self, gradient_histogram, hessian_histogram):
181
- node_kernel.compute_split(
182
- gradient_histogram,
183
- hessian_histogram,
184
- self.min_split_gain,
185
- self.min_child_weight,
186
- self.L2_reg,
187
- self.best_gains,
188
- self.best_bins,
189
- self.threads_per_block
190
- )
191
-
192
- if torch.all(self.best_bins == -1):
193
- return -1, -1 # No valid split found
194
-
195
- f = torch.argmax(self.best_gains).item()
196
- b = self.best_bins[f].item()
197
-
198
- return f, b
199
-
200
- def grow_tree(self, gradient_histogram, hessian_histogram, node_indices, depth):
201
- if depth == self.max_depth:
202
- leaf_value = self.residual[node_indices].mean()
203
- self.gradients[node_indices] += self.learning_rate * leaf_value
204
- return {"leaf_value": leaf_value.item(), "samples": node_indices.numel()}
205
-
206
- parent_size = node_indices.numel()
207
- best_feature, best_bin = self.find_best_split(gradient_histogram, hessian_histogram)
208
-
209
- if best_feature == -1:
210
- leaf_value = self.residual[node_indices].mean()
211
- self.gradients[node_indices] += self.learning_rate * leaf_value
212
- return {"leaf_value": leaf_value.item(), "samples": parent_size}
213
-
214
- split_mask = (self.bin_indices[node_indices, best_feature] <= best_bin)
215
- left_indices = node_indices[split_mask]
216
- right_indices = node_indices[~split_mask]
217
-
218
- left_size = left_indices.numel()
219
- right_size = right_indices.numel()
220
-
221
- if left_size == 0 or right_size == 0:
222
- leaf_value = self.residual[node_indices].mean()
223
- self.gradients[node_indices] += self.learning_rate * leaf_value
224
- return {"leaf_value": leaf_value.item(), "samples": parent_size}
225
-
226
- if left_size <= right_size:
227
- grad_hist_left, hess_hist_left = self.compute_histograms( self.bin_indices[left_indices], self.residual[left_indices] )
228
- grad_hist_right = gradient_histogram - grad_hist_left
229
- hess_hist_right = hessian_histogram - hess_hist_left
230
- else:
231
- grad_hist_right, hess_hist_right = self.compute_histograms( self.bin_indices[right_indices], self.residual[right_indices] )
232
- grad_hist_left = gradient_histogram - grad_hist_right
233
- hess_hist_left = hessian_histogram - hess_hist_right
234
-
235
- new_depth = depth + 1
236
- left_child = self.grow_tree(grad_hist_left, hess_hist_left, left_indices, new_depth)
237
- right_child = self.grow_tree(grad_hist_right, hess_hist_right, right_indices, new_depth)
238
-
239
- return { "feature": best_feature, "bin": best_bin, "left": left_child, "right": right_child }
240
-
241
- def grow_forest(self):
242
- forest = [{} for _ in range(self.n_estimators)]
243
- self.training_loss = []
244
-
245
- for i in tqdm( range(self.n_estimators) ):
246
- self.residual = self.Y_gpu - self.gradients
247
-
248
- self.root_gradient_histogram, self.root_hessian_histogram = \
249
- self.compute_histograms(self.bin_indices, self.residual)
250
-
251
- tree = self.grow_tree(
252
- self.root_gradient_histogram,
253
- self.root_hessian_histogram,
254
- self.root_node_indices,
255
- depth=0
256
- )
257
- forest[i] = tree
258
- # loss = ((self.Y_gpu - self.gradients) ** 2).mean().item()
259
- # self.training_loss.append(loss)
260
- # print(f"🌲 Tree {i+1}/{self.n_estimators} - MSE: {loss:.6f}")
261
-
262
- print("Finished training forest.")
263
- return forest
264
-
265
- def predict(self, X_np, chunk_size=50000):
266
- """
267
- Vectorized predict using a padded layer-by-layer approach.
268
- We assume `flatten_forest_to_tensors` has produced self.flat_forest with
269
- "features", "thresholds", "leaf_values", all shaped [n_trees, max_nodes].
270
- """
271
- with torch.no_grad():
272
- # 1) Convert X_np -> bin_indices
273
- is_integer_type = np.issubdtype(X_np.dtype, np.integer)
274
- if is_integer_type:
275
- max_vals = X_np.max(axis=0)
276
- if np.all(max_vals < self.num_bins):
277
- bin_indices = X_np.astype(np.int8)
278
- else:
279
- raise ValueError("Pre-binned integers must be < num_bins")
280
- else:
281
- X_cpu = torch.from_numpy(X_np).type(torch.float32)
282
- bin_indices = torch.empty((X_np.shape[0], X_np.shape[1]), dtype=torch.int8)
283
- bin_edges_cpu = self.bin_edges.to('cpu')
284
- for f in range(self.num_features):
285
- bin_indices[:, f] = torch.bucketize(X_cpu[:, f], bin_edges_cpu[f], right=False).type(torch.int8)
286
- bin_indices = bin_indices.numpy()
287
-
288
- # 2) Ensure we have a padded representation
289
- self.flat_forest = self.flatten_forest_to_tensors(self.forest)
290
-
291
- features_t = self.flat_forest["features"] # [n_trees, max_nodes], int16
292
- thresholds_t = self.flat_forest["thresholds"] # [n_trees, max_nodes], int16
293
- values_t = self.flat_forest["leaf_values"] # [n_trees, max_nodes], float32
294
- max_nodes = self.flat_forest["max_nodes"]
295
-
296
- n_trees = features_t.shape[0]
297
- N = bin_indices.shape[0]
298
- out = np.zeros(N, dtype=np.float32)
299
-
300
- # 3) Process rows in chunks
301
- for start in tqdm(range(0, N, chunk_size)):
302
- end = min(start + chunk_size, N)
303
- chunk_np = bin_indices[start:end] # shape [chunk_size, F]
304
- chunk_gpu = torch.from_numpy(chunk_np).to(self.device) # [chunk_size, F], int8
305
-
306
- # Accumulate raw (unscaled) leaf sums
307
- chunk_preds = torch.zeros((end - start,), dtype=torch.float32, device=self.device)
308
-
309
- # node_idx[i] tracks the current node index in the padded tree for row i
310
- node_idx = torch.zeros((end - start,), dtype=torch.int32, device=self.device)
311
-
312
- # 'active' is a boolean mask over [0..(end-start-1)], indicating which rows haven't reached a leaf
313
- active = torch.ones((end - start,), dtype=torch.bool, device=self.device)
314
-
315
- for t in range(n_trees):
316
- # Reset for each tree (each tree is independent)
317
- node_idx.fill_(0)
318
- active.fill_(True)
319
-
320
- tree_features = features_t[t] # shape [max_nodes], int16
321
- tree_thresh = thresholds_t[t] # shape [max_nodes], int16
322
- tree_values = values_t[t] # shape [max_nodes], float32
323
-
324
- # Up to self.max_depth+1 layers
325
- for _level in range(self.max_depth + 1):
326
- active_idx = active.nonzero(as_tuple=True)[0]
327
- if active_idx.numel() == 0:
328
- break # all rows are done in this tree
329
-
330
- current_node_idx = node_idx[active_idx]
331
- f = tree_features[current_node_idx] # shape [#active], int16
332
- thr = tree_thresh[current_node_idx] # shape [#active], int16
333
- vals = tree_values[current_node_idx] # shape [#active], float32
334
-
335
- mask_no_node = (f == -2)
336
- mask_leaf = (f == -1)
337
-
338
- # If leaf, add leaf value and mark inactive.
339
- if mask_leaf.any():
340
- leaf_rows = active_idx[mask_leaf]
341
- chunk_preds[leaf_rows] += vals[mask_leaf]
342
- active[leaf_rows] = False
343
-
344
- # If no node, mark inactive.
345
- if mask_no_node.any():
346
- no_node_rows = active_idx[mask_no_node]
347
- active[no_node_rows] = False
348
-
349
- # For internal nodes, perform bin comparison.
350
- mask_internal = (~mask_leaf & ~mask_no_node)
351
- if mask_internal.any():
352
- internal_rows = active_idx[mask_internal]
353
- act_f = f[mask_internal].long()
354
- act_thr = thr[mask_internal]
355
- binvals = chunk_gpu[internal_rows, act_f]
356
- go_left = (binvals <= act_thr)
357
- new_left_idx = current_node_idx[mask_internal] * 2 + 1
358
- new_right_idx = current_node_idx[mask_internal] * 2 + 2
359
- node_idx[internal_rows[go_left]] = new_left_idx[go_left]
360
- node_idx[internal_rows[~go_left]] = new_right_idx[~go_left]
361
- # end per-tree layer loop
362
- # end for each tree
363
-
364
- out[start:end] = (
365
- self.base_prediction + self.learning_rate * chunk_preds
366
- ).cpu().numpy()
367
-
368
- return out
369
-
370
- def flatten_forest_to_tensors(self, forest):
371
- """
372
- Convert a list of dict-based trees into a fixed-size array representation
373
- for each tree, up to max_depth. Each tree is stored in a 'perfect binary tree'
374
- layout:
375
- - node 0 is the root
376
- - node i has children (2*i + 1) and (2*i + 2), if within range
377
- - feature = -2 indicates no node / invalid
378
- - feature = -1 indicates a leaf node
379
- - otherwise, an internal node with that feature.
380
- """
381
- n_trees = len(forest)
382
- max_nodes = 2 ** (self.max_depth + 1) - 1 # total array slots per tree
383
-
384
- # Allocate padded arrays (on CPU for ease of indexing).
385
- feat_arr = np.full((n_trees, max_nodes), -2, dtype=np.int16)
386
- thresh_arr = np.full((n_trees, max_nodes), -2, dtype=np.int16)
387
- value_arr = np.zeros((n_trees, max_nodes), dtype=np.float32)
388
-
389
- def fill_padded(tree, tree_idx, node_idx, depth):
390
- """
391
- Recursively fill feat_arr, thresh_arr, value_arr for a single tree.
392
- If depth == self.max_depth, no children are added.
393
- If there's no node, feature remains -2.
394
- """
395
- if "leaf_value" in tree:
396
- feat_arr[tree_idx, node_idx] = -1
397
- thresh_arr[tree_idx, node_idx] = -1
398
- value_arr[tree_idx, node_idx] = tree["leaf_value"]
399
- return
400
-
401
- feat = tree["feature"]
402
- bin_th = tree["bin"]
403
-
404
- feat_arr[tree_idx, node_idx] = feat
405
- thresh_arr[tree_idx, node_idx] = bin_th
406
- # Internal nodes keep a 0 value.
407
-
408
- if depth < self.max_depth:
409
- left_idx = 2 * node_idx + 1
410
- right_idx = 2 * node_idx + 2
411
- fill_padded(tree["left"], tree_idx, left_idx, depth + 1)
412
- fill_padded(tree["right"], tree_idx, right_idx, depth + 1)
413
- # At max depth, children remain unfilled (-2).
414
-
415
- for t, root in enumerate(forest):
416
- fill_padded(root, t, 0, 0)
417
-
418
- # Convert to torch Tensors on the proper device.
419
- features_t = torch.from_numpy(feat_arr).to(self.device)
420
- thresholds_t = torch.from_numpy(thresh_arr).to(self.device)
421
- leaf_values_t = torch.from_numpy(value_arr).to(self.device)
422
-
423
- return {
424
- "features": features_t, # [n_trees, max_nodes]
425
- "thresholds": thresholds_t, # [n_trees, max_nodes]
426
- "leaf_values": leaf_values_t, # [n_trees, max_nodes]
427
- "max_nodes": max_nodes
428
- }
429
-
430
- def predict_numpy(self, X_np, chunk_size=50000):
431
- """
432
- Fully NumPy-based version of predict_fast.
433
- Assumes flatten_forest_to_tensors has been called and `self.flat_forest` is ready.
434
- """
435
- # 1) Convert X_np -> bin_indices
436
- is_integer_type = np.issubdtype(X_np.dtype, np.integer)
437
- if is_integer_type:
438
- max_vals = X_np.max(axis=0)
439
- if np.all(max_vals < self.num_bins):
440
- bin_indices = X_np.astype(np.int8)
441
- else:
442
- raise ValueError("Pre-binned integers must be < num_bins")
443
- else:
444
- bin_indices = np.empty_like(X_np, dtype=np.int8)
445
- # Ensure bin_edges are NumPy arrays
446
- if isinstance(self.bin_edges[0], torch.Tensor):
447
- bin_edges_np = [be.cpu().numpy() for be in self.bin_edges]
448
- else:
449
- bin_edges_np = self.bin_edges
450
-
451
- for f in range(self.num_features):
452
- bin_indices[:, f] = np.searchsorted(bin_edges_np[f], X_np[:, f], side='left')
453
-
454
- # Ensure we have a padded representation
455
- self.flat_forest = self.flatten_forest(self.forest)
456
-
457
- # 2) Padded forest arrays (already NumPy now)
458
- features_t = self.flat_forest["features"] # [n_trees, max_nodes], int16
459
- thresholds_t = self.flat_forest["thresholds"] # [n_trees, max_nodes], int16
460
- values_t = self.flat_forest["leaf_values"] # [n_trees, max_nodes], float32
461
- max_nodes = self.flat_forest["max_nodes"]
462
- n_trees = features_t.shape[0]
463
- N = bin_indices.shape[0]
464
- out = np.zeros(N, dtype=np.float32)
465
-
466
- # 3) Process in chunks
467
- for start in tqdm( range(0, N, chunk_size) ):
468
- end = min(start + chunk_size, N)
469
- chunk = bin_indices[start:end] # [chunk_size, F]
470
- chunk_preds = np.zeros(end - start, dtype=np.float32)
471
-
472
- for t in range(n_trees):
473
- node_idx = np.zeros(end - start, dtype=np.int32)
474
- active = np.ones(end - start, dtype=bool)
475
-
476
- tree_features = features_t[t] # [max_nodes]
477
- tree_thresh = thresholds_t[t] # [max_nodes]
478
- tree_values = values_t[t] # [max_nodes]
479
-
480
- for _level in range(self.max_depth + 1):
481
- active_idx = np.nonzero(active)[0]
482
- if active_idx.size == 0:
483
- break
484
-
485
- current_node_idx = node_idx[active_idx]
486
- f = tree_features[current_node_idx]
487
- thr = tree_thresh[current_node_idx]
488
- vals = tree_values[current_node_idx]
489
-
490
- mask_no_node = (f == -2)
491
- mask_leaf = (f == -1)
492
- mask_internal = ~(mask_leaf | mask_no_node)
493
-
494
- if np.any(mask_leaf):
495
- leaf_rows = active_idx[mask_leaf]
496
- chunk_preds[leaf_rows] += vals[mask_leaf]
497
- active[leaf_rows] = False
498
-
499
- if np.any(mask_no_node):
500
- no_node_rows = active_idx[mask_no_node]
501
- active[no_node_rows] = False
502
-
503
- if np.any(mask_internal):
504
- internal_rows = active_idx[mask_internal]
505
- act_f = f[mask_internal].astype(np.int32)
506
- act_thr = thr[mask_internal]
507
- binvals = chunk[internal_rows, act_f]
508
- go_left = binvals <= act_thr
509
-
510
- new_left_idx = current_node_idx[mask_internal] * 2 + 1
511
- new_right_idx = current_node_idx[mask_internal] * 2 + 2
512
- node_idx[internal_rows[go_left]] = new_left_idx[go_left]
513
- node_idx[internal_rows[~go_left]] = new_right_idx[~go_left]
514
-
515
- out[start:end] = self.base_prediction + self.learning_rate * chunk_preds
516
-
517
- return out
518
-
519
- def flatten_forest(self, forest):
520
- n_trees = len(forest)
521
- max_nodes = 2 ** (self.max_depth + 1) - 1
522
-
523
- feat_arr = np.full((n_trees, max_nodes), -2, dtype=np.int16)
524
- thresh_arr = np.full((n_trees, max_nodes), -2, dtype=np.int16)
525
- value_arr = np.zeros((n_trees, max_nodes), dtype=np.float32)
526
-
527
- def fill_padded(tree, tree_idx, node_idx, depth):
528
- if "leaf_value" in tree:
529
- feat_arr[tree_idx, node_idx] = -1
530
- thresh_arr[tree_idx, node_idx] = -1
531
- value_arr[tree_idx, node_idx] = tree["leaf_value"]
532
- return
533
- feat = tree["feature"]
534
- bin_th = tree["bin"]
535
- feat_arr[tree_idx, node_idx] = feat
536
- thresh_arr[tree_idx, node_idx] = bin_th
537
-
538
- if depth < self.max_depth:
539
- left_idx = 2 * node_idx + 1
540
- right_idx = 2 * node_idx + 2
541
- fill_padded(tree["left"], tree_idx, left_idx, depth + 1)
542
- fill_padded(tree["right"], tree_idx, right_idx, depth + 1)
543
-
544
- for t, root in enumerate(forest):
545
- fill_padded(root, t, 0, 0)
546
-
547
- return {
548
- "features": feat_arr,
549
- "thresholds": thresh_arr,
550
- "leaf_values": value_arr,
551
- "max_nodes": max_nodes
552
- }
File without changes
File without changes
File without changes
File without changes
File without changes