warpgbm 0.1.20__tar.gz → 0.1.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {warpgbm-0.1.20/warpgbm.egg-info → warpgbm-0.1.22}/PKG-INFO +15 -16
  2. {warpgbm-0.1.20 → warpgbm-0.1.22}/README.md +14 -15
  3. {warpgbm-0.1.20 → warpgbm-0.1.22}/pyproject.toml +1 -1
  4. {warpgbm-0.1.20 → warpgbm-0.1.22}/setup.py +1 -0
  5. {warpgbm-0.1.20 → warpgbm-0.1.22}/tests/test_fit_predict_corr.py +11 -8
  6. warpgbm-0.1.22/version.txt +1 -0
  7. warpgbm-0.1.22/warpgbm/core.py +401 -0
  8. {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm/cuda/histogram_kernel.cu +0 -14
  9. {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm/cuda/node_kernel.cpp +8 -0
  10. warpgbm-0.1.22/warpgbm/cuda/predict.cu +77 -0
  11. {warpgbm-0.1.20 → warpgbm-0.1.22/warpgbm.egg-info}/PKG-INFO +15 -16
  12. {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm.egg-info/SOURCES.txt +2 -1
  13. warpgbm-0.1.20/version.txt +0 -1
  14. warpgbm-0.1.20/warpgbm/core.py +0 -552
  15. {warpgbm-0.1.20 → warpgbm-0.1.22}/LICENSE +0 -0
  16. {warpgbm-0.1.20 → warpgbm-0.1.22}/MANIFEST.in +0 -0
  17. {warpgbm-0.1.20 → warpgbm-0.1.22}/setup.cfg +0 -0
  18. {warpgbm-0.1.20 → warpgbm-0.1.22}/tests/__init__.py +0 -0
  19. {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm/__init__.py +0 -0
  20. {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm/cuda/__init__.py +0 -0
  21. {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm/cuda/best_split_kernel.cu +0 -0
  22. {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm/cuda/binner.cu +0 -0
  23. {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm.egg-info/dependency_links.txt +0 -0
  24. {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm.egg-info/requires.txt +0 -0
  25. {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: warpgbm
3
- Version: 0.1.20
3
+ Version: 0.1.22
4
4
  Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
5
5
  License: GNU GENERAL PUBLIC LICENSE
6
6
  Version 3, 29 June 2007
@@ -704,26 +704,20 @@ WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (G
704
704
 
705
705
  ---
706
706
 
707
- ## Performance Note
708
-
709
- In our initial tests on an NVIDIA 3090 (local) and A100 (Google Colab Pro), WarpGBM achieves **14x to 20x faster training times** compared to LightGBM's CPU version and **2x faster** on the GPU version using default configurations. Speed also outperforms XGBoost and CatBoost on regression problems. It also consumes **significantly less RAM and CPU**. These early results hint at more thorough benchmarking to come.
710
-
711
- ---
712
-
713
707
  ## Benchmarks
714
708
 
715
709
  ### Scikit-Learn Synthetic Data: 1 Million Rows and 1,000 Features
716
710
 
717
- In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.19** against LightGBM, XGBoost and CatBoost, all with their GPU-enabled versions. This benchmark runs on Google Colab with the L4 GPU environment. The CPU versions don't even come close to the speed here so we didn't test them.
711
+ In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.21** against LightGBM, XGBoost and CatBoost, all with their GPU-enabled versions. This benchmark runs on Google Colab with the L4 GPU environment.
718
712
 
719
713
  ```
720
- WarpGBM: corr = 0.8882, train = 21.8s, infer = 11.6s
721
- XGBoost: corr = 0.8877, train = 33.4s, infer = 8.1s
722
- LightGBM: corr = 0.8604, train = 30.2s, infer = 1.4s
723
- CatBoost: corr = 0.8935, train = 377.9s, infer = 375.8s
714
+ WarpGBM: corr = 0.8882, train = 18.7s, infer = 4.9s
715
+ XGBoost: corr = 0.8877, train = 33.1s, infer = 8.1s
716
+ LightGBM: corr = 0.8604, train = 30.3s, infer = 1.4s
717
+ CatBoost: corr = 0.8935, train = 400.0s, infer = 382.6s
724
718
  ```
725
719
 
726
- Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK
720
+ Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK?usp=sharing
727
721
 
728
722
  ---
729
723
 
@@ -746,7 +740,7 @@ pip install warpgbm
746
740
  This installs from PyPI and also compiles CUDA code locally during installation. This method works well **if your environment already has PyTorch with GPU support** installed and configured.
747
741
 
748
742
  > **Tip:**\
749
- > If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag:
743
+ > If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag. This is currently required in the Colab environments.
750
744
  >
751
745
  > ```bash
752
746
  > pip install warpgbm --no-build-isolation
@@ -886,8 +880,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
886
880
 
887
881
  ### Methods:
888
882
  - `.fit(X, y, era_id=None)`: Train the model. `X` can be raw floats or pre-binned `int8` data. `era_id` is optional and used internally.
889
- - `.predict(X, chunksize=50_000)`: Predict on new raw float or pre-binned data.
890
- - `.predict_numpy(X, chunksize=50_000)`: Same as `.predict(X)` but without using the GPU.
883
+ - `.predict(X)`: Predict on new data, using parallelized CUDA kernel.
891
884
 
892
885
  ---
893
886
 
@@ -897,3 +890,9 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
897
890
 
898
891
  ---
899
892
 
893
+ ## Version Notes
894
+
895
+ ### v0.1.21
896
+
897
+ - Vectorized predict function replaced with CUDA kernel (`warpgbm/cuda/predict.cu`), parallelizing per sample, per tree.
898
+
@@ -16,26 +16,20 @@ WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (G
16
16
 
17
17
  ---
18
18
 
19
- ## Performance Note
20
-
21
- In our initial tests on an NVIDIA 3090 (local) and A100 (Google Colab Pro), WarpGBM achieves **14x to 20x faster training times** compared to LightGBM's CPU version and **2x faster** on the GPU version using default configurations. Speed also outperforms XGBoost and CatBoost on regression problems. It also consumes **significantly less RAM and CPU**. These early results hint at more thorough benchmarking to come.
22
-
23
- ---
24
-
25
19
  ## Benchmarks
26
20
 
27
21
  ### Scikit-Learn Synthetic Data: 1 Million Rows and 1,000 Features
28
22
 
29
- In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.19** against LightGBM, XGBoost and CatBoost, all with their GPU-enabled versions. This benchmark runs on Google Colab with the L4 GPU environment. The CPU versions don't even come close to the speed here so we didn't test them.
23
+ In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.21** against LightGBM, XGBoost and CatBoost, all with their GPU-enabled versions. This benchmark runs on Google Colab with the L4 GPU environment.
30
24
 
31
25
  ```
32
- WarpGBM: corr = 0.8882, train = 21.8s, infer = 11.6s
33
- XGBoost: corr = 0.8877, train = 33.4s, infer = 8.1s
34
- LightGBM: corr = 0.8604, train = 30.2s, infer = 1.4s
35
- CatBoost: corr = 0.8935, train = 377.9s, infer = 375.8s
26
+ WarpGBM: corr = 0.8882, train = 18.7s, infer = 4.9s
27
+ XGBoost: corr = 0.8877, train = 33.1s, infer = 8.1s
28
+ LightGBM: corr = 0.8604, train = 30.3s, infer = 1.4s
29
+ CatBoost: corr = 0.8935, train = 400.0s, infer = 382.6s
36
30
  ```
37
31
 
38
- Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK
32
+ Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK?usp=sharing
39
33
 
40
34
  ---
41
35
 
@@ -58,7 +52,7 @@ pip install warpgbm
58
52
  This installs from PyPI and also compiles CUDA code locally during installation. This method works well **if your environment already has PyTorch with GPU support** installed and configured.
59
53
 
60
54
  > **Tip:**\
61
- > If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag:
55
+ > If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag. This is currently required in the Colab environments.
62
56
  >
63
57
  > ```bash
64
58
  > pip install warpgbm --no-build-isolation
@@ -198,8 +192,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
198
192
 
199
193
  ### Methods:
200
194
  - `.fit(X, y, era_id=None)`: Train the model. `X` can be raw floats or pre-binned `int8` data. `era_id` is optional and used internally.
201
- - `.predict(X, chunksize=50_000)`: Predict on new raw float or pre-binned data.
202
- - `.predict_numpy(X, chunksize=50_000)`: Same as `.predict(X)` but without using the GPU.
195
+ - `.predict(X)`: Predict on new data, using parallelized CUDA kernel.
203
196
 
204
197
  ---
205
198
 
@@ -209,3 +202,9 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
209
202
 
210
203
  ---
211
204
 
205
+ ## Version Notes
206
+
207
+ ### v0.1.21
208
+
209
+ - Vectorized predict function replaced with CUDA kernel (`warpgbm/cuda/predict.cu`), parallelizing per sample, per tree.
210
+
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "warpgbm"
7
- version = "0.1.20"
7
+ version = "0.1.22"
8
8
  description = "A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -23,6 +23,7 @@ def get_extensions():
23
23
  "warpgbm/cuda/histogram_kernel.cu",
24
24
  "warpgbm/cuda/best_split_kernel.cu",
25
25
  "warpgbm/cuda/binner.cu",
26
+ "warpgbm/cuda/predict.cu",
26
27
  "warpgbm/cuda/node_kernel.cpp",
27
28
  ]
28
29
  )
@@ -1,11 +1,9 @@
1
1
  import numpy as np
2
2
  from warpgbm import WarpGBM
3
3
  from sklearn.datasets import make_regression
4
-
5
- import numpy as np
6
4
  import time
7
- from warpgbm import WarpGBM
8
- from sklearn.datasets import make_regression
5
+ from sklearn.metrics import mean_squared_error
6
+
9
7
 
10
8
  def test_fit_predictpytee_correlation():
11
9
  np.random.seed(42)
@@ -14,19 +12,20 @@ def test_fit_predictpytee_correlation():
14
12
  X, y = make_regression(n_samples=N, n_features=F, noise=0.1, random_state=42)
15
13
  era = np.zeros(N, dtype=np.int32)
16
14
  corrs = []
15
+ mses = []
17
16
 
18
- for hist_type in ['hist1', 'hist2', 'hist3']:
17
+ for hist_type in ["hist1", "hist2", "hist3"]:
19
18
  print(f"\nTesting histogram method: {hist_type}")
20
19
 
21
20
  model = WarpGBM(
22
21
  max_depth=10,
23
22
  num_bins=10,
24
- n_estimators=10,
23
+ n_estimators=100,
25
24
  learning_rate=1,
26
25
  verbosity=False,
27
26
  histogram_computer=hist_type,
28
27
  threads_per_block=64,
29
- rows_per_thread=4
28
+ rows_per_thread=4,
30
29
  )
31
30
 
32
31
  start_fit = time.time()
@@ -40,7 +39,11 @@ def test_fit_predictpytee_correlation():
40
39
  print(f" Predict time: {pred_time:.3f} seconds")
41
40
 
42
41
  corr = np.corrcoef(preds, y)[0, 1]
42
+ mse = mean_squared_error(preds, y)
43
43
  print(f" Correlation: {corr:.4f}")
44
+ print(f" MSE: {mse:.4f}")
44
45
  corrs.append(corr)
46
+ mses.append(mse)
45
47
 
46
- assert (np.array(corrs) > 0.95).all(), f"In-sample correlation too low: {corrs}"
48
+ assert (np.array(corrs) > 0.9).all(), f"In-sample correlation too low: {corrs}"
49
+ assert (np.array(mses) < 2).all(), f"In-sample mse too high: {mses}"
@@ -0,0 +1 @@
1
+ 0.1.22
@@ -0,0 +1,401 @@
1
+ import torch
2
+ import numpy as np
3
+ from sklearn.base import BaseEstimator, RegressorMixin
4
+ from warpgbm.cuda import node_kernel
5
+ from tqdm import tqdm
6
+ from typing import Tuple
7
+ from torch import Tensor
8
+
9
+ histogram_kernels = {
10
+ "hist1": node_kernel.compute_histogram,
11
+ "hist2": node_kernel.compute_histogram2,
12
+ "hist3": node_kernel.compute_histogram3,
13
+ }
14
+
15
+
16
+ class WarpGBM(BaseEstimator, RegressorMixin):
17
+ def __init__(
18
+ self,
19
+ num_bins=10,
20
+ max_depth=3,
21
+ learning_rate=0.1,
22
+ n_estimators=100,
23
+ min_child_weight=20,
24
+ min_split_gain=0.0,
25
+ verbosity=True,
26
+ histogram_computer="hist3",
27
+ threads_per_block=64,
28
+ rows_per_thread=4,
29
+ L2_reg=1e-6,
30
+ L1_reg=0.0,
31
+ device="cuda",
32
+ ):
33
+ # Validate arguments
34
+ self._validate_hyperparams(
35
+ num_bins=num_bins,
36
+ max_depth=max_depth,
37
+ learning_rate=learning_rate,
38
+ n_estimators=n_estimators,
39
+ min_child_weight=min_child_weight,
40
+ min_split_gain=min_split_gain,
41
+ histogram_computer=histogram_computer,
42
+ threads_per_block=threads_per_block,
43
+ rows_per_thread=rows_per_thread,
44
+ L2_reg=L2_reg,
45
+ L1_reg=L1_reg,
46
+ )
47
+
48
+ self.num_bins = num_bins
49
+ self.max_depth = max_depth
50
+ self.learning_rate = learning_rate
51
+ self.n_estimators = n_estimators
52
+ self.forest = None
53
+ self.bin_edges = None
54
+ self.base_prediction = None
55
+ self.unique_eras = None
56
+ self.device = device
57
+ self.root_gradient_histogram = None
58
+ self.root_hessian_histogram = None
59
+ self.gradients = None
60
+ self.root_node_indices = None
61
+ self.bin_indices = None
62
+ self.Y_gpu = None
63
+ self.num_features = None
64
+ self.num_samples = None
65
+ self.min_child_weight = min_child_weight
66
+ self.min_split_gain = min_split_gain
67
+ self.best_bin = torch.tensor([-1], dtype=torch.int32, device=self.device)
68
+ self.compute_histogram = histogram_kernels[histogram_computer]
69
+ self.threads_per_block = threads_per_block
70
+ self.rows_per_thread = rows_per_thread
71
+ self.L2_reg = L2_reg
72
+ self.L1_reg = L1_reg
73
+
74
+ def _validate_hyperparams(self, **kwargs):
75
+ # Type checks
76
+ int_params = [
77
+ "num_bins",
78
+ "max_depth",
79
+ "n_estimators",
80
+ "min_child_weight",
81
+ "threads_per_block",
82
+ "rows_per_thread",
83
+ ]
84
+ float_params = ["learning_rate", "min_split_gain", "L2_reg", "L1_reg"]
85
+
86
+ for param in int_params:
87
+ if not isinstance(kwargs[param], int):
88
+ raise TypeError(
89
+ f"{param} must be an integer, got {type(kwargs[param])}."
90
+ )
91
+
92
+ for param in float_params:
93
+ if not isinstance(
94
+ kwargs[param], (float, int)
95
+ ): # Accept ints as valid floats
96
+ raise TypeError(f"{param} must be a float, got {type(kwargs[param])}.")
97
+
98
+ if not (2 <= kwargs["num_bins"] <= 127):
99
+ raise ValueError("num_bins must be between 2 and 127 inclusive.")
100
+ if kwargs["max_depth"] < 1:
101
+ raise ValueError("max_depth must be at least 1.")
102
+ if not (0.0 < kwargs["learning_rate"] <= 1.0):
103
+ raise ValueError("learning_rate must be in (0.0, 1.0].")
104
+ if kwargs["n_estimators"] <= 0:
105
+ raise ValueError("n_estimators must be positive.")
106
+ if kwargs["min_child_weight"] < 1:
107
+ raise ValueError("min_child_weight must be a positive integer.")
108
+ if kwargs["min_split_gain"] < 0:
109
+ raise ValueError("min_split_gain must be non-negative.")
110
+ if kwargs["threads_per_block"] <= 0 or kwargs["threads_per_block"] % 32 != 0:
111
+ raise ValueError(
112
+ "threads_per_block should be a positive multiple of 32 (warp size)."
113
+ )
114
+ if not (1 <= kwargs["rows_per_thread"] <= 16):
115
+ raise ValueError(
116
+ "rows_per_thread must be positive between 1 and 16 inclusive."
117
+ )
118
+ if kwargs["L2_reg"] < 0 or kwargs["L1_reg"] < 0:
119
+ raise ValueError("L2_reg and L1_reg must be non-negative.")
120
+ if kwargs["histogram_computer"] not in histogram_kernels:
121
+ raise ValueError(
122
+ f"Invalid histogram_computer: {kwargs['histogram_computer']}. Choose from {list(histogram_kernels.keys())}."
123
+ )
124
+
125
+ def fit(self, X, y, era_id=None):
126
+ if era_id is None:
127
+ era_id = np.ones(X.shape[0], dtype="int32")
128
+ self.bin_indices, era_indices, self.bin_edges, self.unique_eras, self.Y_gpu = (
129
+ self.preprocess_gpu_data(X, y, era_id)
130
+ )
131
+ self.num_samples, self.num_features = X.shape
132
+ self.gradients = torch.zeros_like(self.Y_gpu)
133
+ self.root_node_indices = torch.arange(self.num_samples, device=self.device)
134
+ self.base_prediction = self.Y_gpu.mean().item()
135
+ self.gradients += self.base_prediction
136
+ self.best_gains = torch.zeros(self.num_features, device=self.device)
137
+ self.best_bins = torch.zeros(
138
+ self.num_features, device=self.device, dtype=torch.int32
139
+ )
140
+ with torch.no_grad():
141
+ self.forest = self.grow_forest()
142
+ return self
143
+
144
+ def preprocess_gpu_data(self, X_np, Y_np, era_id_np):
145
+ with torch.no_grad():
146
+ self.num_samples, self.num_features = X_np.shape
147
+ Y_gpu = torch.from_numpy(Y_np).type(torch.float32).to(self.device)
148
+ era_id_gpu = torch.from_numpy(era_id_np).type(torch.int32).to(self.device)
149
+ is_integer_type = np.issubdtype(X_np.dtype, np.integer)
150
+ if is_integer_type:
151
+ max_vals = X_np.max(axis=0)
152
+ if np.all(max_vals < self.num_bins):
153
+ print(
154
+ "Detected pre-binned integer input — skipping quantile binning."
155
+ )
156
+ bin_indices = (
157
+ torch.from_numpy(X_np)
158
+ .to(self.device)
159
+ .contiguous()
160
+ .to(torch.int8)
161
+ )
162
+
163
+ # We'll store None or an empty tensor in self.bin_edges
164
+ # to indicate that we skip binning at predict-time
165
+ bin_edges = torch.arange(
166
+ 1, self.num_bins, dtype=torch.float32
167
+ ).repeat(self.num_features, 1)
168
+ bin_edges = bin_edges.to(self.device)
169
+ unique_eras, era_indices = torch.unique(
170
+ era_id_gpu, return_inverse=True
171
+ )
172
+ return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
173
+ else:
174
+ print(
175
+ "Integer input detected, but values exceed num_bins — falling back to quantile binning."
176
+ )
177
+
178
+ bin_indices = torch.empty(
179
+ (self.num_samples, self.num_features), dtype=torch.int8, device="cuda"
180
+ )
181
+ bin_edges = torch.empty(
182
+ (self.num_features, self.num_bins - 1),
183
+ dtype=torch.float32,
184
+ device="cuda",
185
+ )
186
+
187
+ X_np = torch.from_numpy(X_np).to(torch.float32).pin_memory()
188
+
189
+ for f in range(self.num_features):
190
+ X_f = X_np[:, f].to("cuda", non_blocking=True)
191
+ quantiles = torch.linspace(
192
+ 0, 1, self.num_bins + 1, device="cuda", dtype=X_f.dtype
193
+ )[1:-1]
194
+ bin_edges_f = torch.quantile(
195
+ X_f, quantiles, dim=0
196
+ ).contiguous() # shape: [B-1] for 1D input
197
+ bin_indices_f = bin_indices[:, f].contiguous() # view into output
198
+ node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
199
+ bin_indices[:, f] = bin_indices_f
200
+ bin_edges[f, :] = bin_edges_f
201
+
202
+ unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
203
+ return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
204
+
205
+ def compute_histograms(self, bin_indices_sub, gradients):
206
+ grad_hist = torch.zeros(
207
+ (self.num_features, self.num_bins), device=self.device, dtype=torch.float32
208
+ )
209
+ hess_hist = torch.zeros(
210
+ (self.num_features, self.num_bins), device=self.device, dtype=torch.float32
211
+ )
212
+
213
+ self.compute_histogram(
214
+ bin_indices_sub,
215
+ gradients,
216
+ grad_hist,
217
+ hess_hist,
218
+ self.num_bins,
219
+ self.threads_per_block,
220
+ self.rows_per_thread,
221
+ )
222
+ return grad_hist, hess_hist
223
+
224
+ def find_best_split(self, gradient_histogram, hessian_histogram):
225
+ node_kernel.compute_split(
226
+ gradient_histogram,
227
+ hessian_histogram,
228
+ self.min_split_gain,
229
+ self.min_child_weight,
230
+ self.L2_reg,
231
+ self.best_gains,
232
+ self.best_bins,
233
+ self.threads_per_block,
234
+ )
235
+
236
+ if torch.all(self.best_bins == -1):
237
+ return -1, -1 # No valid split found
238
+
239
+ f = torch.argmax(self.best_gains).item()
240
+ b = self.best_bins[f].item()
241
+
242
+ return f, b
243
+
244
+ def grow_tree(self, gradient_histogram, hessian_histogram, node_indices, depth):
245
+ if depth == self.max_depth:
246
+ leaf_value = self.residual[node_indices].mean()
247
+ self.gradients[node_indices] += self.learning_rate * leaf_value
248
+ return {"leaf_value": leaf_value.item(), "samples": node_indices.numel()}
249
+
250
+ parent_size = node_indices.numel()
251
+ best_feature, best_bin = self.find_best_split(
252
+ gradient_histogram, hessian_histogram
253
+ )
254
+
255
+ if best_feature == -1:
256
+ leaf_value = self.residual[node_indices].mean()
257
+ self.gradients[node_indices] += self.learning_rate * leaf_value
258
+ return {"leaf_value": leaf_value.item(), "samples": parent_size}
259
+
260
+ split_mask = self.bin_indices[node_indices, best_feature] <= best_bin
261
+ left_indices = node_indices[split_mask]
262
+ right_indices = node_indices[~split_mask]
263
+
264
+ left_size = left_indices.numel()
265
+ right_size = right_indices.numel()
266
+
267
+ if left_size <= right_size:
268
+ grad_hist_left, hess_hist_left = self.compute_histograms(
269
+ self.bin_indices[left_indices], self.residual[left_indices]
270
+ )
271
+ grad_hist_right = gradient_histogram - grad_hist_left
272
+ hess_hist_right = hessian_histogram - hess_hist_left
273
+ else:
274
+ grad_hist_right, hess_hist_right = self.compute_histograms(
275
+ self.bin_indices[right_indices], self.residual[right_indices]
276
+ )
277
+ grad_hist_left = gradient_histogram - grad_hist_right
278
+ hess_hist_left = hessian_histogram - hess_hist_right
279
+
280
+ new_depth = depth + 1
281
+ left_child = self.grow_tree(
282
+ grad_hist_left, hess_hist_left, left_indices, new_depth
283
+ )
284
+ right_child = self.grow_tree(
285
+ grad_hist_right, hess_hist_right, right_indices, new_depth
286
+ )
287
+
288
+ return {
289
+ "feature": best_feature,
290
+ "bin": best_bin,
291
+ "left": left_child,
292
+ "right": right_child,
293
+ }
294
+
295
+ def grow_forest(self):
296
+ forest = [{} for _ in range(self.n_estimators)]
297
+ self.training_loss = []
298
+
299
+ for i in tqdm(range(self.n_estimators)):
300
+ self.residual = self.Y_gpu - self.gradients
301
+
302
+ self.root_gradient_histogram, self.root_hessian_histogram = (
303
+ self.compute_histograms(self.bin_indices, self.residual)
304
+ )
305
+
306
+ tree = self.grow_tree(
307
+ self.root_gradient_histogram,
308
+ self.root_hessian_histogram,
309
+ self.root_node_indices,
310
+ depth=0,
311
+ )
312
+ forest[i] = tree
313
+ # loss = ((self.Y_gpu - self.gradients) ** 2).mean().item()
314
+ # self.training_loss.append(loss)
315
+ # print(f"🌲 Tree {i+1}/{self.n_estimators} - MSE: {loss:.6f}")
316
+
317
+ print("Finished training forest.")
318
+ return forest
319
+
320
+ def predict(self, X_np):
321
+ X_tensor = torch.from_numpy(X_np).to(torch.float32).pin_memory()
322
+ num_samples = X_tensor.size(0)
323
+ bin_indices = torch.zeros(
324
+ (num_samples, self.num_features), dtype=torch.int8, device=self.device
325
+ )
326
+
327
+ with torch.no_grad():
328
+ for f in range(self.num_features):
329
+ X_f = X_tensor[:, f].to(self.device, non_blocking=True)
330
+ bin_edges_f = self.bin_edges[f]
331
+ bin_indices_f = bin_indices[:, f].contiguous()
332
+ node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
333
+ bin_indices[:, f] = bin_indices_f
334
+
335
+ tree_tensor = torch.stack(
336
+ [
337
+ self.flatten_tree(tree, max_nodes=2 ** (self.max_depth + 1))
338
+ for tree in self.forest
339
+ ]
340
+ ).to(self.device)
341
+
342
+ out = torch.zeros(num_samples, device=self.device) + self.base_prediction
343
+ node_kernel.predict_forest(
344
+ bin_indices.contiguous(), tree_tensor.contiguous(), self.learning_rate, out
345
+ )
346
+
347
+ return out.cpu().numpy()
348
+
349
+ def flatten_tree(self, tree, max_nodes):
350
+ """
351
+ Convert a recursive tree structure into a flat matrix format.
352
+
353
+ Each row in the output represents a node:
354
+ - Columns: [feature, bin, left_id, right_id, is_leaf, value]
355
+ - Internal nodes fill columns 0–3 and set is_leaf = 0
356
+ - Leaf nodes fill only value and set is_leaf = 1
357
+
358
+ Args:
359
+ tree (list): A list containing a single root node (recursive dict form).
360
+ max_nodes (int): Max number of nodes to allocate in the flat matrix.
361
+
362
+ Returns:
363
+ torch.Tensor: [max_nodes x 6] matrix representing the flattened tree.
364
+ """
365
+ flat = torch.full((max_nodes, 6), float("nan"), dtype=torch.float32)
366
+ node_counter = [0]
367
+ node_list = []
368
+
369
+ def walk(node):
370
+ curr_id = node_counter[0]
371
+ node_counter[0] += 1
372
+
373
+ new_node = {"node_id": curr_id}
374
+ if "leaf_value" in node:
375
+ new_node["leaf_value"] = float(node["leaf_value"])
376
+ else:
377
+ new_node["best_feature"] = float(node["feature"])
378
+ new_node["split_bin"] = float(node["bin"])
379
+ new_node["left_id"] = node_counter[0]
380
+ walk(node["left"])
381
+ new_node["right_id"] = node_counter[0]
382
+ walk(node["right"])
383
+
384
+ node_list.append(new_node)
385
+ return new_node
386
+
387
+ walk(tree)
388
+
389
+ for node in node_list:
390
+ i = node["node_id"]
391
+ if "leaf_value" in node:
392
+ flat[i, 4] = 1.0
393
+ flat[i, 5] = node["leaf_value"]
394
+ else:
395
+ flat[i, 0] = node["best_feature"]
396
+ flat[i, 1] = node["split_bin"]
397
+ flat[i, 2] = node["left_id"]
398
+ flat[i, 3] = node["right_id"]
399
+ flat[i, 4] = 0.0
400
+
401
+ return flat
@@ -107,12 +107,6 @@ void launch_histogram_kernel_cuda(
107
107
  N, F, B);
108
108
  }
109
109
 
110
- #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
111
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
112
- #define CHECK_INPUT(x) \
113
- CHECK_CUDA(x); \
114
- CHECK_CONTIGUOUS(x)
115
-
116
110
  // CUDA kernel: tiled, 64-bit safe
117
111
  __global__ void histogram_tiled_kernel(
118
112
  const int8_t *__restrict__ bin_indices, // [N, F]
@@ -148,10 +142,6 @@ void launch_histogram_kernel_cuda_2(
148
142
  int threads_per_block = 256,
149
143
  int rows_per_thread = 1)
150
144
  {
151
- CHECK_INPUT(bin_indices);
152
- CHECK_INPUT(gradients);
153
- CHECK_INPUT(grad_hist);
154
- CHECK_INPUT(hess_hist);
155
145
 
156
146
  int64_t N = bin_indices.size(0);
157
147
  int64_t F = bin_indices.size(1);
@@ -233,10 +223,6 @@ void launch_histogram_kernel_cuda_configurable(
233
223
  int threads_per_block = 256,
234
224
  int rows_per_thread = 1)
235
225
  {
236
- CHECK_INPUT(bin_indices);
237
- CHECK_INPUT(gradients);
238
- CHECK_INPUT(grad_hist);
239
- CHECK_INPUT(hess_hist);
240
226
 
241
227
  int64_t N = bin_indices.size(0);
242
228
  int64_t F = bin_indices.size(1);
@@ -44,6 +44,13 @@ void launch_bin_column_kernel(
44
44
  at::Tensor bin_edges,
45
45
  at::Tensor bin_indices);
46
46
 
47
+ void predict_with_forest(
48
+ const at::Tensor &bin_indices, // [N x F], int8
49
+ const at::Tensor &tree_tensor, // [T x max_nodes x 6], float32
50
+ float learning_rate,
51
+ at::Tensor &out // [N], float32
52
+ );
53
+
47
54
  // Bindings
48
55
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
49
56
  {
@@ -52,4 +59,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
52
59
  m.def("compute_histogram3", &launch_histogram_kernel_cuda_configurable, "Histogram Feature Shared Mem");
53
60
  m.def("compute_split", &launch_best_split_kernel_cuda, "Best Split (CUDA)");
54
61
  m.def("custom_cuda_binner", &launch_bin_column_kernel, "Custom CUDA binning kernel");
62
+ m.def("predict_forest", &predict_with_forest, "CUDA Predictions");
55
63
  }