warpgbm 0.1.20__tar.gz → 0.1.22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {warpgbm-0.1.20/warpgbm.egg-info → warpgbm-0.1.22}/PKG-INFO +15 -16
- {warpgbm-0.1.20 → warpgbm-0.1.22}/README.md +14 -15
- {warpgbm-0.1.20 → warpgbm-0.1.22}/pyproject.toml +1 -1
- {warpgbm-0.1.20 → warpgbm-0.1.22}/setup.py +1 -0
- {warpgbm-0.1.20 → warpgbm-0.1.22}/tests/test_fit_predict_corr.py +11 -8
- warpgbm-0.1.22/version.txt +1 -0
- warpgbm-0.1.22/warpgbm/core.py +401 -0
- {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm/cuda/histogram_kernel.cu +0 -14
- {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm/cuda/node_kernel.cpp +8 -0
- warpgbm-0.1.22/warpgbm/cuda/predict.cu +77 -0
- {warpgbm-0.1.20 → warpgbm-0.1.22/warpgbm.egg-info}/PKG-INFO +15 -16
- {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm.egg-info/SOURCES.txt +2 -1
- warpgbm-0.1.20/version.txt +0 -1
- warpgbm-0.1.20/warpgbm/core.py +0 -552
- {warpgbm-0.1.20 → warpgbm-0.1.22}/LICENSE +0 -0
- {warpgbm-0.1.20 → warpgbm-0.1.22}/MANIFEST.in +0 -0
- {warpgbm-0.1.20 → warpgbm-0.1.22}/setup.cfg +0 -0
- {warpgbm-0.1.20 → warpgbm-0.1.22}/tests/__init__.py +0 -0
- {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm/__init__.py +0 -0
- {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm/cuda/__init__.py +0 -0
- {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm/cuda/best_split_kernel.cu +0 -0
- {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm/cuda/binner.cu +0 -0
- {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm.egg-info/dependency_links.txt +0 -0
- {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm.egg-info/requires.txt +0 -0
- {warpgbm-0.1.20 → warpgbm-0.1.22}/warpgbm.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: warpgbm
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.22
|
4
4
|
Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
|
5
5
|
License: GNU GENERAL PUBLIC LICENSE
|
6
6
|
Version 3, 29 June 2007
|
@@ -704,26 +704,20 @@ WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (G
|
|
704
704
|
|
705
705
|
---
|
706
706
|
|
707
|
-
## Performance Note
|
708
|
-
|
709
|
-
In our initial tests on an NVIDIA 3090 (local) and A100 (Google Colab Pro), WarpGBM achieves **14x to 20x faster training times** compared to LightGBM's CPU version and **2x faster** on the GPU version using default configurations. Speed also outperforms XGBoost and CatBoost on regression problems. It also consumes **significantly less RAM and CPU**. These early results hint at more thorough benchmarking to come.
|
710
|
-
|
711
|
-
---
|
712
|
-
|
713
707
|
## Benchmarks
|
714
708
|
|
715
709
|
### Scikit-Learn Synthetic Data: 1 Million Rows and 1,000 Features
|
716
710
|
|
717
|
-
In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.
|
711
|
+
In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.21** against LightGBM, XGBoost and CatBoost, all with their GPU-enabled versions. This benchmark runs on Google Colab with the L4 GPU environment.
|
718
712
|
|
719
713
|
```
|
720
|
-
WarpGBM: corr = 0.8882, train =
|
721
|
-
XGBoost: corr = 0.8877, train = 33.
|
722
|
-
LightGBM: corr = 0.8604, train = 30.
|
723
|
-
CatBoost: corr = 0.8935, train =
|
714
|
+
WarpGBM: corr = 0.8882, train = 18.7s, infer = 4.9s
|
715
|
+
XGBoost: corr = 0.8877, train = 33.1s, infer = 8.1s
|
716
|
+
LightGBM: corr = 0.8604, train = 30.3s, infer = 1.4s
|
717
|
+
CatBoost: corr = 0.8935, train = 400.0s, infer = 382.6s
|
724
718
|
```
|
725
719
|
|
726
|
-
Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK
|
720
|
+
Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK?usp=sharing
|
727
721
|
|
728
722
|
---
|
729
723
|
|
@@ -746,7 +740,7 @@ pip install warpgbm
|
|
746
740
|
This installs from PyPI and also compiles CUDA code locally during installation. This method works well **if your environment already has PyTorch with GPU support** installed and configured.
|
747
741
|
|
748
742
|
> **Tip:**\
|
749
|
-
> If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag
|
743
|
+
> If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag. This is currently required in the Colab environments.
|
750
744
|
>
|
751
745
|
> ```bash
|
752
746
|
> pip install warpgbm --no-build-isolation
|
@@ -886,8 +880,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
|
|
886
880
|
|
887
881
|
### Methods:
|
888
882
|
- `.fit(X, y, era_id=None)`: Train the model. `X` can be raw floats or pre-binned `int8` data. `era_id` is optional and used internally.
|
889
|
-
- `.predict(X
|
890
|
-
- `.predict_numpy(X, chunksize=50_000)`: Same as `.predict(X)` but without using the GPU.
|
883
|
+
- `.predict(X)`: Predict on new data, using parallelized CUDA kernel.
|
891
884
|
|
892
885
|
---
|
893
886
|
|
@@ -897,3 +890,9 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
|
|
897
890
|
|
898
891
|
---
|
899
892
|
|
893
|
+
## Version Notes
|
894
|
+
|
895
|
+
### v0.1.21
|
896
|
+
|
897
|
+
- Vectorized predict function replaced with CUDA kernel (`warpgbm/cuda/predict.cu`), parallelizing per sample, per tree.
|
898
|
+
|
@@ -16,26 +16,20 @@ WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (G
|
|
16
16
|
|
17
17
|
---
|
18
18
|
|
19
|
-
## Performance Note
|
20
|
-
|
21
|
-
In our initial tests on an NVIDIA 3090 (local) and A100 (Google Colab Pro), WarpGBM achieves **14x to 20x faster training times** compared to LightGBM's CPU version and **2x faster** on the GPU version using default configurations. Speed also outperforms XGBoost and CatBoost on regression problems. It also consumes **significantly less RAM and CPU**. These early results hint at more thorough benchmarking to come.
|
22
|
-
|
23
|
-
---
|
24
|
-
|
25
19
|
## Benchmarks
|
26
20
|
|
27
21
|
### Scikit-Learn Synthetic Data: 1 Million Rows and 1,000 Features
|
28
22
|
|
29
|
-
In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.
|
23
|
+
In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.21** against LightGBM, XGBoost and CatBoost, all with their GPU-enabled versions. This benchmark runs on Google Colab with the L4 GPU environment.
|
30
24
|
|
31
25
|
```
|
32
|
-
WarpGBM: corr = 0.8882, train =
|
33
|
-
XGBoost: corr = 0.8877, train = 33.
|
34
|
-
LightGBM: corr = 0.8604, train = 30.
|
35
|
-
CatBoost: corr = 0.8935, train =
|
26
|
+
WarpGBM: corr = 0.8882, train = 18.7s, infer = 4.9s
|
27
|
+
XGBoost: corr = 0.8877, train = 33.1s, infer = 8.1s
|
28
|
+
LightGBM: corr = 0.8604, train = 30.3s, infer = 1.4s
|
29
|
+
CatBoost: corr = 0.8935, train = 400.0s, infer = 382.6s
|
36
30
|
```
|
37
31
|
|
38
|
-
Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK
|
32
|
+
Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK?usp=sharing
|
39
33
|
|
40
34
|
---
|
41
35
|
|
@@ -58,7 +52,7 @@ pip install warpgbm
|
|
58
52
|
This installs from PyPI and also compiles CUDA code locally during installation. This method works well **if your environment already has PyTorch with GPU support** installed and configured.
|
59
53
|
|
60
54
|
> **Tip:**\
|
61
|
-
> If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag
|
55
|
+
> If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag. This is currently required in the Colab environments.
|
62
56
|
>
|
63
57
|
> ```bash
|
64
58
|
> pip install warpgbm --no-build-isolation
|
@@ -198,8 +192,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
|
|
198
192
|
|
199
193
|
### Methods:
|
200
194
|
- `.fit(X, y, era_id=None)`: Train the model. `X` can be raw floats or pre-binned `int8` data. `era_id` is optional and used internally.
|
201
|
-
- `.predict(X
|
202
|
-
- `.predict_numpy(X, chunksize=50_000)`: Same as `.predict(X)` but without using the GPU.
|
195
|
+
- `.predict(X)`: Predict on new data, using parallelized CUDA kernel.
|
203
196
|
|
204
197
|
---
|
205
198
|
|
@@ -209,3 +202,9 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
|
|
209
202
|
|
210
203
|
---
|
211
204
|
|
205
|
+
## Version Notes
|
206
|
+
|
207
|
+
### v0.1.21
|
208
|
+
|
209
|
+
- Vectorized predict function replaced with CUDA kernel (`warpgbm/cuda/predict.cu`), parallelizing per sample, per tree.
|
210
|
+
|
@@ -1,11 +1,9 @@
|
|
1
1
|
import numpy as np
|
2
2
|
from warpgbm import WarpGBM
|
3
3
|
from sklearn.datasets import make_regression
|
4
|
-
|
5
|
-
import numpy as np
|
6
4
|
import time
|
7
|
-
from
|
8
|
-
|
5
|
+
from sklearn.metrics import mean_squared_error
|
6
|
+
|
9
7
|
|
10
8
|
def test_fit_predictpytee_correlation():
|
11
9
|
np.random.seed(42)
|
@@ -14,19 +12,20 @@ def test_fit_predictpytee_correlation():
|
|
14
12
|
X, y = make_regression(n_samples=N, n_features=F, noise=0.1, random_state=42)
|
15
13
|
era = np.zeros(N, dtype=np.int32)
|
16
14
|
corrs = []
|
15
|
+
mses = []
|
17
16
|
|
18
|
-
for hist_type in [
|
17
|
+
for hist_type in ["hist1", "hist2", "hist3"]:
|
19
18
|
print(f"\nTesting histogram method: {hist_type}")
|
20
19
|
|
21
20
|
model = WarpGBM(
|
22
21
|
max_depth=10,
|
23
22
|
num_bins=10,
|
24
|
-
n_estimators=
|
23
|
+
n_estimators=100,
|
25
24
|
learning_rate=1,
|
26
25
|
verbosity=False,
|
27
26
|
histogram_computer=hist_type,
|
28
27
|
threads_per_block=64,
|
29
|
-
rows_per_thread=4
|
28
|
+
rows_per_thread=4,
|
30
29
|
)
|
31
30
|
|
32
31
|
start_fit = time.time()
|
@@ -40,7 +39,11 @@ def test_fit_predictpytee_correlation():
|
|
40
39
|
print(f" Predict time: {pred_time:.3f} seconds")
|
41
40
|
|
42
41
|
corr = np.corrcoef(preds, y)[0, 1]
|
42
|
+
mse = mean_squared_error(preds, y)
|
43
43
|
print(f" Correlation: {corr:.4f}")
|
44
|
+
print(f" MSE: {mse:.4f}")
|
44
45
|
corrs.append(corr)
|
46
|
+
mses.append(mse)
|
45
47
|
|
46
|
-
assert (np.array(corrs) > 0.
|
48
|
+
assert (np.array(corrs) > 0.9).all(), f"In-sample correlation too low: {corrs}"
|
49
|
+
assert (np.array(mses) < 2).all(), f"In-sample mse too high: {mses}"
|
@@ -0,0 +1 @@
|
|
1
|
+
0.1.22
|
@@ -0,0 +1,401 @@
|
|
1
|
+
import torch
|
2
|
+
import numpy as np
|
3
|
+
from sklearn.base import BaseEstimator, RegressorMixin
|
4
|
+
from warpgbm.cuda import node_kernel
|
5
|
+
from tqdm import tqdm
|
6
|
+
from typing import Tuple
|
7
|
+
from torch import Tensor
|
8
|
+
|
9
|
+
histogram_kernels = {
|
10
|
+
"hist1": node_kernel.compute_histogram,
|
11
|
+
"hist2": node_kernel.compute_histogram2,
|
12
|
+
"hist3": node_kernel.compute_histogram3,
|
13
|
+
}
|
14
|
+
|
15
|
+
|
16
|
+
class WarpGBM(BaseEstimator, RegressorMixin):
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
num_bins=10,
|
20
|
+
max_depth=3,
|
21
|
+
learning_rate=0.1,
|
22
|
+
n_estimators=100,
|
23
|
+
min_child_weight=20,
|
24
|
+
min_split_gain=0.0,
|
25
|
+
verbosity=True,
|
26
|
+
histogram_computer="hist3",
|
27
|
+
threads_per_block=64,
|
28
|
+
rows_per_thread=4,
|
29
|
+
L2_reg=1e-6,
|
30
|
+
L1_reg=0.0,
|
31
|
+
device="cuda",
|
32
|
+
):
|
33
|
+
# Validate arguments
|
34
|
+
self._validate_hyperparams(
|
35
|
+
num_bins=num_bins,
|
36
|
+
max_depth=max_depth,
|
37
|
+
learning_rate=learning_rate,
|
38
|
+
n_estimators=n_estimators,
|
39
|
+
min_child_weight=min_child_weight,
|
40
|
+
min_split_gain=min_split_gain,
|
41
|
+
histogram_computer=histogram_computer,
|
42
|
+
threads_per_block=threads_per_block,
|
43
|
+
rows_per_thread=rows_per_thread,
|
44
|
+
L2_reg=L2_reg,
|
45
|
+
L1_reg=L1_reg,
|
46
|
+
)
|
47
|
+
|
48
|
+
self.num_bins = num_bins
|
49
|
+
self.max_depth = max_depth
|
50
|
+
self.learning_rate = learning_rate
|
51
|
+
self.n_estimators = n_estimators
|
52
|
+
self.forest = None
|
53
|
+
self.bin_edges = None
|
54
|
+
self.base_prediction = None
|
55
|
+
self.unique_eras = None
|
56
|
+
self.device = device
|
57
|
+
self.root_gradient_histogram = None
|
58
|
+
self.root_hessian_histogram = None
|
59
|
+
self.gradients = None
|
60
|
+
self.root_node_indices = None
|
61
|
+
self.bin_indices = None
|
62
|
+
self.Y_gpu = None
|
63
|
+
self.num_features = None
|
64
|
+
self.num_samples = None
|
65
|
+
self.min_child_weight = min_child_weight
|
66
|
+
self.min_split_gain = min_split_gain
|
67
|
+
self.best_bin = torch.tensor([-1], dtype=torch.int32, device=self.device)
|
68
|
+
self.compute_histogram = histogram_kernels[histogram_computer]
|
69
|
+
self.threads_per_block = threads_per_block
|
70
|
+
self.rows_per_thread = rows_per_thread
|
71
|
+
self.L2_reg = L2_reg
|
72
|
+
self.L1_reg = L1_reg
|
73
|
+
|
74
|
+
def _validate_hyperparams(self, **kwargs):
|
75
|
+
# Type checks
|
76
|
+
int_params = [
|
77
|
+
"num_bins",
|
78
|
+
"max_depth",
|
79
|
+
"n_estimators",
|
80
|
+
"min_child_weight",
|
81
|
+
"threads_per_block",
|
82
|
+
"rows_per_thread",
|
83
|
+
]
|
84
|
+
float_params = ["learning_rate", "min_split_gain", "L2_reg", "L1_reg"]
|
85
|
+
|
86
|
+
for param in int_params:
|
87
|
+
if not isinstance(kwargs[param], int):
|
88
|
+
raise TypeError(
|
89
|
+
f"{param} must be an integer, got {type(kwargs[param])}."
|
90
|
+
)
|
91
|
+
|
92
|
+
for param in float_params:
|
93
|
+
if not isinstance(
|
94
|
+
kwargs[param], (float, int)
|
95
|
+
): # Accept ints as valid floats
|
96
|
+
raise TypeError(f"{param} must be a float, got {type(kwargs[param])}.")
|
97
|
+
|
98
|
+
if not (2 <= kwargs["num_bins"] <= 127):
|
99
|
+
raise ValueError("num_bins must be between 2 and 127 inclusive.")
|
100
|
+
if kwargs["max_depth"] < 1:
|
101
|
+
raise ValueError("max_depth must be at least 1.")
|
102
|
+
if not (0.0 < kwargs["learning_rate"] <= 1.0):
|
103
|
+
raise ValueError("learning_rate must be in (0.0, 1.0].")
|
104
|
+
if kwargs["n_estimators"] <= 0:
|
105
|
+
raise ValueError("n_estimators must be positive.")
|
106
|
+
if kwargs["min_child_weight"] < 1:
|
107
|
+
raise ValueError("min_child_weight must be a positive integer.")
|
108
|
+
if kwargs["min_split_gain"] < 0:
|
109
|
+
raise ValueError("min_split_gain must be non-negative.")
|
110
|
+
if kwargs["threads_per_block"] <= 0 or kwargs["threads_per_block"] % 32 != 0:
|
111
|
+
raise ValueError(
|
112
|
+
"threads_per_block should be a positive multiple of 32 (warp size)."
|
113
|
+
)
|
114
|
+
if not (1 <= kwargs["rows_per_thread"] <= 16):
|
115
|
+
raise ValueError(
|
116
|
+
"rows_per_thread must be positive between 1 and 16 inclusive."
|
117
|
+
)
|
118
|
+
if kwargs["L2_reg"] < 0 or kwargs["L1_reg"] < 0:
|
119
|
+
raise ValueError("L2_reg and L1_reg must be non-negative.")
|
120
|
+
if kwargs["histogram_computer"] not in histogram_kernels:
|
121
|
+
raise ValueError(
|
122
|
+
f"Invalid histogram_computer: {kwargs['histogram_computer']}. Choose from {list(histogram_kernels.keys())}."
|
123
|
+
)
|
124
|
+
|
125
|
+
def fit(self, X, y, era_id=None):
|
126
|
+
if era_id is None:
|
127
|
+
era_id = np.ones(X.shape[0], dtype="int32")
|
128
|
+
self.bin_indices, era_indices, self.bin_edges, self.unique_eras, self.Y_gpu = (
|
129
|
+
self.preprocess_gpu_data(X, y, era_id)
|
130
|
+
)
|
131
|
+
self.num_samples, self.num_features = X.shape
|
132
|
+
self.gradients = torch.zeros_like(self.Y_gpu)
|
133
|
+
self.root_node_indices = torch.arange(self.num_samples, device=self.device)
|
134
|
+
self.base_prediction = self.Y_gpu.mean().item()
|
135
|
+
self.gradients += self.base_prediction
|
136
|
+
self.best_gains = torch.zeros(self.num_features, device=self.device)
|
137
|
+
self.best_bins = torch.zeros(
|
138
|
+
self.num_features, device=self.device, dtype=torch.int32
|
139
|
+
)
|
140
|
+
with torch.no_grad():
|
141
|
+
self.forest = self.grow_forest()
|
142
|
+
return self
|
143
|
+
|
144
|
+
def preprocess_gpu_data(self, X_np, Y_np, era_id_np):
|
145
|
+
with torch.no_grad():
|
146
|
+
self.num_samples, self.num_features = X_np.shape
|
147
|
+
Y_gpu = torch.from_numpy(Y_np).type(torch.float32).to(self.device)
|
148
|
+
era_id_gpu = torch.from_numpy(era_id_np).type(torch.int32).to(self.device)
|
149
|
+
is_integer_type = np.issubdtype(X_np.dtype, np.integer)
|
150
|
+
if is_integer_type:
|
151
|
+
max_vals = X_np.max(axis=0)
|
152
|
+
if np.all(max_vals < self.num_bins):
|
153
|
+
print(
|
154
|
+
"Detected pre-binned integer input — skipping quantile binning."
|
155
|
+
)
|
156
|
+
bin_indices = (
|
157
|
+
torch.from_numpy(X_np)
|
158
|
+
.to(self.device)
|
159
|
+
.contiguous()
|
160
|
+
.to(torch.int8)
|
161
|
+
)
|
162
|
+
|
163
|
+
# We'll store None or an empty tensor in self.bin_edges
|
164
|
+
# to indicate that we skip binning at predict-time
|
165
|
+
bin_edges = torch.arange(
|
166
|
+
1, self.num_bins, dtype=torch.float32
|
167
|
+
).repeat(self.num_features, 1)
|
168
|
+
bin_edges = bin_edges.to(self.device)
|
169
|
+
unique_eras, era_indices = torch.unique(
|
170
|
+
era_id_gpu, return_inverse=True
|
171
|
+
)
|
172
|
+
return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
|
173
|
+
else:
|
174
|
+
print(
|
175
|
+
"Integer input detected, but values exceed num_bins — falling back to quantile binning."
|
176
|
+
)
|
177
|
+
|
178
|
+
bin_indices = torch.empty(
|
179
|
+
(self.num_samples, self.num_features), dtype=torch.int8, device="cuda"
|
180
|
+
)
|
181
|
+
bin_edges = torch.empty(
|
182
|
+
(self.num_features, self.num_bins - 1),
|
183
|
+
dtype=torch.float32,
|
184
|
+
device="cuda",
|
185
|
+
)
|
186
|
+
|
187
|
+
X_np = torch.from_numpy(X_np).to(torch.float32).pin_memory()
|
188
|
+
|
189
|
+
for f in range(self.num_features):
|
190
|
+
X_f = X_np[:, f].to("cuda", non_blocking=True)
|
191
|
+
quantiles = torch.linspace(
|
192
|
+
0, 1, self.num_bins + 1, device="cuda", dtype=X_f.dtype
|
193
|
+
)[1:-1]
|
194
|
+
bin_edges_f = torch.quantile(
|
195
|
+
X_f, quantiles, dim=0
|
196
|
+
).contiguous() # shape: [B-1] for 1D input
|
197
|
+
bin_indices_f = bin_indices[:, f].contiguous() # view into output
|
198
|
+
node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
|
199
|
+
bin_indices[:, f] = bin_indices_f
|
200
|
+
bin_edges[f, :] = bin_edges_f
|
201
|
+
|
202
|
+
unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
|
203
|
+
return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
|
204
|
+
|
205
|
+
def compute_histograms(self, bin_indices_sub, gradients):
|
206
|
+
grad_hist = torch.zeros(
|
207
|
+
(self.num_features, self.num_bins), device=self.device, dtype=torch.float32
|
208
|
+
)
|
209
|
+
hess_hist = torch.zeros(
|
210
|
+
(self.num_features, self.num_bins), device=self.device, dtype=torch.float32
|
211
|
+
)
|
212
|
+
|
213
|
+
self.compute_histogram(
|
214
|
+
bin_indices_sub,
|
215
|
+
gradients,
|
216
|
+
grad_hist,
|
217
|
+
hess_hist,
|
218
|
+
self.num_bins,
|
219
|
+
self.threads_per_block,
|
220
|
+
self.rows_per_thread,
|
221
|
+
)
|
222
|
+
return grad_hist, hess_hist
|
223
|
+
|
224
|
+
def find_best_split(self, gradient_histogram, hessian_histogram):
|
225
|
+
node_kernel.compute_split(
|
226
|
+
gradient_histogram,
|
227
|
+
hessian_histogram,
|
228
|
+
self.min_split_gain,
|
229
|
+
self.min_child_weight,
|
230
|
+
self.L2_reg,
|
231
|
+
self.best_gains,
|
232
|
+
self.best_bins,
|
233
|
+
self.threads_per_block,
|
234
|
+
)
|
235
|
+
|
236
|
+
if torch.all(self.best_bins == -1):
|
237
|
+
return -1, -1 # No valid split found
|
238
|
+
|
239
|
+
f = torch.argmax(self.best_gains).item()
|
240
|
+
b = self.best_bins[f].item()
|
241
|
+
|
242
|
+
return f, b
|
243
|
+
|
244
|
+
def grow_tree(self, gradient_histogram, hessian_histogram, node_indices, depth):
|
245
|
+
if depth == self.max_depth:
|
246
|
+
leaf_value = self.residual[node_indices].mean()
|
247
|
+
self.gradients[node_indices] += self.learning_rate * leaf_value
|
248
|
+
return {"leaf_value": leaf_value.item(), "samples": node_indices.numel()}
|
249
|
+
|
250
|
+
parent_size = node_indices.numel()
|
251
|
+
best_feature, best_bin = self.find_best_split(
|
252
|
+
gradient_histogram, hessian_histogram
|
253
|
+
)
|
254
|
+
|
255
|
+
if best_feature == -1:
|
256
|
+
leaf_value = self.residual[node_indices].mean()
|
257
|
+
self.gradients[node_indices] += self.learning_rate * leaf_value
|
258
|
+
return {"leaf_value": leaf_value.item(), "samples": parent_size}
|
259
|
+
|
260
|
+
split_mask = self.bin_indices[node_indices, best_feature] <= best_bin
|
261
|
+
left_indices = node_indices[split_mask]
|
262
|
+
right_indices = node_indices[~split_mask]
|
263
|
+
|
264
|
+
left_size = left_indices.numel()
|
265
|
+
right_size = right_indices.numel()
|
266
|
+
|
267
|
+
if left_size <= right_size:
|
268
|
+
grad_hist_left, hess_hist_left = self.compute_histograms(
|
269
|
+
self.bin_indices[left_indices], self.residual[left_indices]
|
270
|
+
)
|
271
|
+
grad_hist_right = gradient_histogram - grad_hist_left
|
272
|
+
hess_hist_right = hessian_histogram - hess_hist_left
|
273
|
+
else:
|
274
|
+
grad_hist_right, hess_hist_right = self.compute_histograms(
|
275
|
+
self.bin_indices[right_indices], self.residual[right_indices]
|
276
|
+
)
|
277
|
+
grad_hist_left = gradient_histogram - grad_hist_right
|
278
|
+
hess_hist_left = hessian_histogram - hess_hist_right
|
279
|
+
|
280
|
+
new_depth = depth + 1
|
281
|
+
left_child = self.grow_tree(
|
282
|
+
grad_hist_left, hess_hist_left, left_indices, new_depth
|
283
|
+
)
|
284
|
+
right_child = self.grow_tree(
|
285
|
+
grad_hist_right, hess_hist_right, right_indices, new_depth
|
286
|
+
)
|
287
|
+
|
288
|
+
return {
|
289
|
+
"feature": best_feature,
|
290
|
+
"bin": best_bin,
|
291
|
+
"left": left_child,
|
292
|
+
"right": right_child,
|
293
|
+
}
|
294
|
+
|
295
|
+
def grow_forest(self):
|
296
|
+
forest = [{} for _ in range(self.n_estimators)]
|
297
|
+
self.training_loss = []
|
298
|
+
|
299
|
+
for i in tqdm(range(self.n_estimators)):
|
300
|
+
self.residual = self.Y_gpu - self.gradients
|
301
|
+
|
302
|
+
self.root_gradient_histogram, self.root_hessian_histogram = (
|
303
|
+
self.compute_histograms(self.bin_indices, self.residual)
|
304
|
+
)
|
305
|
+
|
306
|
+
tree = self.grow_tree(
|
307
|
+
self.root_gradient_histogram,
|
308
|
+
self.root_hessian_histogram,
|
309
|
+
self.root_node_indices,
|
310
|
+
depth=0,
|
311
|
+
)
|
312
|
+
forest[i] = tree
|
313
|
+
# loss = ((self.Y_gpu - self.gradients) ** 2).mean().item()
|
314
|
+
# self.training_loss.append(loss)
|
315
|
+
# print(f"🌲 Tree {i+1}/{self.n_estimators} - MSE: {loss:.6f}")
|
316
|
+
|
317
|
+
print("Finished training forest.")
|
318
|
+
return forest
|
319
|
+
|
320
|
+
def predict(self, X_np):
|
321
|
+
X_tensor = torch.from_numpy(X_np).to(torch.float32).pin_memory()
|
322
|
+
num_samples = X_tensor.size(0)
|
323
|
+
bin_indices = torch.zeros(
|
324
|
+
(num_samples, self.num_features), dtype=torch.int8, device=self.device
|
325
|
+
)
|
326
|
+
|
327
|
+
with torch.no_grad():
|
328
|
+
for f in range(self.num_features):
|
329
|
+
X_f = X_tensor[:, f].to(self.device, non_blocking=True)
|
330
|
+
bin_edges_f = self.bin_edges[f]
|
331
|
+
bin_indices_f = bin_indices[:, f].contiguous()
|
332
|
+
node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
|
333
|
+
bin_indices[:, f] = bin_indices_f
|
334
|
+
|
335
|
+
tree_tensor = torch.stack(
|
336
|
+
[
|
337
|
+
self.flatten_tree(tree, max_nodes=2 ** (self.max_depth + 1))
|
338
|
+
for tree in self.forest
|
339
|
+
]
|
340
|
+
).to(self.device)
|
341
|
+
|
342
|
+
out = torch.zeros(num_samples, device=self.device) + self.base_prediction
|
343
|
+
node_kernel.predict_forest(
|
344
|
+
bin_indices.contiguous(), tree_tensor.contiguous(), self.learning_rate, out
|
345
|
+
)
|
346
|
+
|
347
|
+
return out.cpu().numpy()
|
348
|
+
|
349
|
+
def flatten_tree(self, tree, max_nodes):
|
350
|
+
"""
|
351
|
+
Convert a recursive tree structure into a flat matrix format.
|
352
|
+
|
353
|
+
Each row in the output represents a node:
|
354
|
+
- Columns: [feature, bin, left_id, right_id, is_leaf, value]
|
355
|
+
- Internal nodes fill columns 0–3 and set is_leaf = 0
|
356
|
+
- Leaf nodes fill only value and set is_leaf = 1
|
357
|
+
|
358
|
+
Args:
|
359
|
+
tree (list): A list containing a single root node (recursive dict form).
|
360
|
+
max_nodes (int): Max number of nodes to allocate in the flat matrix.
|
361
|
+
|
362
|
+
Returns:
|
363
|
+
torch.Tensor: [max_nodes x 6] matrix representing the flattened tree.
|
364
|
+
"""
|
365
|
+
flat = torch.full((max_nodes, 6), float("nan"), dtype=torch.float32)
|
366
|
+
node_counter = [0]
|
367
|
+
node_list = []
|
368
|
+
|
369
|
+
def walk(node):
|
370
|
+
curr_id = node_counter[0]
|
371
|
+
node_counter[0] += 1
|
372
|
+
|
373
|
+
new_node = {"node_id": curr_id}
|
374
|
+
if "leaf_value" in node:
|
375
|
+
new_node["leaf_value"] = float(node["leaf_value"])
|
376
|
+
else:
|
377
|
+
new_node["best_feature"] = float(node["feature"])
|
378
|
+
new_node["split_bin"] = float(node["bin"])
|
379
|
+
new_node["left_id"] = node_counter[0]
|
380
|
+
walk(node["left"])
|
381
|
+
new_node["right_id"] = node_counter[0]
|
382
|
+
walk(node["right"])
|
383
|
+
|
384
|
+
node_list.append(new_node)
|
385
|
+
return new_node
|
386
|
+
|
387
|
+
walk(tree)
|
388
|
+
|
389
|
+
for node in node_list:
|
390
|
+
i = node["node_id"]
|
391
|
+
if "leaf_value" in node:
|
392
|
+
flat[i, 4] = 1.0
|
393
|
+
flat[i, 5] = node["leaf_value"]
|
394
|
+
else:
|
395
|
+
flat[i, 0] = node["best_feature"]
|
396
|
+
flat[i, 1] = node["split_bin"]
|
397
|
+
flat[i, 2] = node["left_id"]
|
398
|
+
flat[i, 3] = node["right_id"]
|
399
|
+
flat[i, 4] = 0.0
|
400
|
+
|
401
|
+
return flat
|
@@ -107,12 +107,6 @@ void launch_histogram_kernel_cuda(
|
|
107
107
|
N, F, B);
|
108
108
|
}
|
109
109
|
|
110
|
-
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
111
|
-
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
112
|
-
#define CHECK_INPUT(x) \
|
113
|
-
CHECK_CUDA(x); \
|
114
|
-
CHECK_CONTIGUOUS(x)
|
115
|
-
|
116
110
|
// CUDA kernel: tiled, 64-bit safe
|
117
111
|
__global__ void histogram_tiled_kernel(
|
118
112
|
const int8_t *__restrict__ bin_indices, // [N, F]
|
@@ -148,10 +142,6 @@ void launch_histogram_kernel_cuda_2(
|
|
148
142
|
int threads_per_block = 256,
|
149
143
|
int rows_per_thread = 1)
|
150
144
|
{
|
151
|
-
CHECK_INPUT(bin_indices);
|
152
|
-
CHECK_INPUT(gradients);
|
153
|
-
CHECK_INPUT(grad_hist);
|
154
|
-
CHECK_INPUT(hess_hist);
|
155
145
|
|
156
146
|
int64_t N = bin_indices.size(0);
|
157
147
|
int64_t F = bin_indices.size(1);
|
@@ -233,10 +223,6 @@ void launch_histogram_kernel_cuda_configurable(
|
|
233
223
|
int threads_per_block = 256,
|
234
224
|
int rows_per_thread = 1)
|
235
225
|
{
|
236
|
-
CHECK_INPUT(bin_indices);
|
237
|
-
CHECK_INPUT(gradients);
|
238
|
-
CHECK_INPUT(grad_hist);
|
239
|
-
CHECK_INPUT(hess_hist);
|
240
226
|
|
241
227
|
int64_t N = bin_indices.size(0);
|
242
228
|
int64_t F = bin_indices.size(1);
|
@@ -44,6 +44,13 @@ void launch_bin_column_kernel(
|
|
44
44
|
at::Tensor bin_edges,
|
45
45
|
at::Tensor bin_indices);
|
46
46
|
|
47
|
+
void predict_with_forest(
|
48
|
+
const at::Tensor &bin_indices, // [N x F], int8
|
49
|
+
const at::Tensor &tree_tensor, // [T x max_nodes x 6], float32
|
50
|
+
float learning_rate,
|
51
|
+
at::Tensor &out // [N], float32
|
52
|
+
);
|
53
|
+
|
47
54
|
// Bindings
|
48
55
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
49
56
|
{
|
@@ -52,4 +59,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|
52
59
|
m.def("compute_histogram3", &launch_histogram_kernel_cuda_configurable, "Histogram Feature Shared Mem");
|
53
60
|
m.def("compute_split", &launch_best_split_kernel_cuda, "Best Split (CUDA)");
|
54
61
|
m.def("custom_cuda_binner", &launch_bin_column_kernel, "Custom CUDA binning kernel");
|
62
|
+
m.def("predict_forest", &predict_with_forest, "CUDA Predictions");
|
55
63
|
}
|