warpgbm 0.1.21__tar.gz → 0.1.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {warpgbm-0.1.21/warpgbm.egg-info → warpgbm-0.1.23}/PKG-INFO +9 -16
  2. {warpgbm-0.1.21 → warpgbm-0.1.23}/README.md +8 -15
  3. {warpgbm-0.1.21 → warpgbm-0.1.23}/pyproject.toml +1 -1
  4. {warpgbm-0.1.21 → warpgbm-0.1.23}/tests/test_fit_predict_corr.py +20 -9
  5. warpgbm-0.1.23/version.txt +1 -0
  6. warpgbm-0.1.23/warpgbm/core.py +537 -0
  7. {warpgbm-0.1.21 → warpgbm-0.1.23}/warpgbm/cuda/best_split_kernel.cu +1 -1
  8. {warpgbm-0.1.21 → warpgbm-0.1.23}/warpgbm/cuda/histogram_kernel.cu +0 -14
  9. {warpgbm-0.1.21 → warpgbm-0.1.23/warpgbm.egg-info}/PKG-INFO +9 -16
  10. warpgbm-0.1.21/version.txt +0 -1
  11. warpgbm-0.1.21/warpgbm/core.py +0 -341
  12. {warpgbm-0.1.21 → warpgbm-0.1.23}/LICENSE +0 -0
  13. {warpgbm-0.1.21 → warpgbm-0.1.23}/MANIFEST.in +0 -0
  14. {warpgbm-0.1.21 → warpgbm-0.1.23}/setup.cfg +0 -0
  15. {warpgbm-0.1.21 → warpgbm-0.1.23}/setup.py +0 -0
  16. {warpgbm-0.1.21 → warpgbm-0.1.23}/tests/__init__.py +0 -0
  17. {warpgbm-0.1.21 → warpgbm-0.1.23}/warpgbm/__init__.py +0 -0
  18. {warpgbm-0.1.21 → warpgbm-0.1.23}/warpgbm/cuda/__init__.py +0 -0
  19. {warpgbm-0.1.21 → warpgbm-0.1.23}/warpgbm/cuda/binner.cu +0 -0
  20. {warpgbm-0.1.21 → warpgbm-0.1.23}/warpgbm/cuda/node_kernel.cpp +0 -0
  21. {warpgbm-0.1.21 → warpgbm-0.1.23}/warpgbm/cuda/predict.cu +0 -0
  22. {warpgbm-0.1.21 → warpgbm-0.1.23}/warpgbm.egg-info/SOURCES.txt +0 -0
  23. {warpgbm-0.1.21 → warpgbm-0.1.23}/warpgbm.egg-info/dependency_links.txt +0 -0
  24. {warpgbm-0.1.21 → warpgbm-0.1.23}/warpgbm.egg-info/requires.txt +0 -0
  25. {warpgbm-0.1.21 → warpgbm-0.1.23}/warpgbm.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: warpgbm
3
- Version: 0.1.21
3
+ Version: 0.1.23
4
4
  Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
5
5
  License: GNU GENERAL PUBLIC LICENSE
6
6
  Version 3, 29 June 2007
@@ -704,26 +704,20 @@ WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (G
704
704
 
705
705
  ---
706
706
 
707
- ## Performance Note
708
-
709
- In our initial tests on an NVIDIA 3090 (local) and A100 (Google Colab Pro), WarpGBM achieves **14x to 20x faster training times** compared to LightGBM's CPU version and **2x faster** on the GPU version using default configurations. Speed also outperforms XGBoost and CatBoost on regression problems. It also consumes **significantly less RAM and CPU**. These early results hint at more thorough benchmarking to come.
710
-
711
- ---
712
-
713
707
  ## Benchmarks
714
708
 
715
709
  ### Scikit-Learn Synthetic Data: 1 Million Rows and 1,000 Features
716
710
 
717
- In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.19** against LightGBM, XGBoost and CatBoost, all with their GPU-enabled versions. This benchmark runs on Google Colab with the L4 GPU environment. The CPU versions don't even come close to the speed here so we didn't test them.
711
+ In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.21** against LightGBM, XGBoost and CatBoost, all with their GPU-enabled versions. This benchmark runs on Google Colab with the L4 GPU environment.
718
712
 
719
713
  ```
720
- WarpGBM: corr = 0.8882, train = 21.8s, infer = 11.6s
721
- XGBoost: corr = 0.8877, train = 33.4s, infer = 8.1s
722
- LightGBM: corr = 0.8604, train = 30.2s, infer = 1.4s
723
- CatBoost: corr = 0.8935, train = 377.9s, infer = 375.8s
714
+ WarpGBM: corr = 0.8882, train = 18.7s, infer = 4.9s
715
+ XGBoost: corr = 0.8877, train = 33.1s, infer = 8.1s
716
+ LightGBM: corr = 0.8604, train = 30.3s, infer = 1.4s
717
+ CatBoost: corr = 0.8935, train = 400.0s, infer = 382.6s
724
718
  ```
725
719
 
726
- Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK
720
+ Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK?usp=sharing
727
721
 
728
722
  ---
729
723
 
@@ -746,7 +740,7 @@ pip install warpgbm
746
740
  This installs from PyPI and also compiles CUDA code locally during installation. This method works well **if your environment already has PyTorch with GPU support** installed and configured.
747
741
 
748
742
  > **Tip:**\
749
- > If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag:
743
+ > If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag. This is currently required in the Colab environments.
750
744
  >
751
745
  > ```bash
752
746
  > pip install warpgbm --no-build-isolation
@@ -886,8 +880,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
886
880
 
887
881
  ### Methods:
888
882
  - `.fit(X, y, era_id=None)`: Train the model. `X` can be raw floats or pre-binned `int8` data. `era_id` is optional and used internally.
889
- - `.predict(X, chunksize=50_000)`: Predict on new raw float or pre-binned data.
890
- - `.predict_numpy(X, chunksize=50_000)`: Same as `.predict(X)` but without using the GPU.
883
+ - `.predict(X)`: Predict on new data, using parallelized CUDA kernel.
891
884
 
892
885
  ---
893
886
 
@@ -16,26 +16,20 @@ WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (G
16
16
 
17
17
  ---
18
18
 
19
- ## Performance Note
20
-
21
- In our initial tests on an NVIDIA 3090 (local) and A100 (Google Colab Pro), WarpGBM achieves **14x to 20x faster training times** compared to LightGBM's CPU version and **2x faster** on the GPU version using default configurations. Speed also outperforms XGBoost and CatBoost on regression problems. It also consumes **significantly less RAM and CPU**. These early results hint at more thorough benchmarking to come.
22
-
23
- ---
24
-
25
19
  ## Benchmarks
26
20
 
27
21
  ### Scikit-Learn Synthetic Data: 1 Million Rows and 1,000 Features
28
22
 
29
- In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.19** against LightGBM, XGBoost and CatBoost, all with their GPU-enabled versions. This benchmark runs on Google Colab with the L4 GPU environment. The CPU versions don't even come close to the speed here so we didn't test them.
23
+ In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.21** against LightGBM, XGBoost and CatBoost, all with their GPU-enabled versions. This benchmark runs on Google Colab with the L4 GPU environment.
30
24
 
31
25
  ```
32
- WarpGBM: corr = 0.8882, train = 21.8s, infer = 11.6s
33
- XGBoost: corr = 0.8877, train = 33.4s, infer = 8.1s
34
- LightGBM: corr = 0.8604, train = 30.2s, infer = 1.4s
35
- CatBoost: corr = 0.8935, train = 377.9s, infer = 375.8s
26
+ WarpGBM: corr = 0.8882, train = 18.7s, infer = 4.9s
27
+ XGBoost: corr = 0.8877, train = 33.1s, infer = 8.1s
28
+ LightGBM: corr = 0.8604, train = 30.3s, infer = 1.4s
29
+ CatBoost: corr = 0.8935, train = 400.0s, infer = 382.6s
36
30
  ```
37
31
 
38
- Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK
32
+ Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK?usp=sharing
39
33
 
40
34
  ---
41
35
 
@@ -58,7 +52,7 @@ pip install warpgbm
58
52
  This installs from PyPI and also compiles CUDA code locally during installation. This method works well **if your environment already has PyTorch with GPU support** installed and configured.
59
53
 
60
54
  > **Tip:**\
61
- > If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag:
55
+ > If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag. This is currently required in the Colab environments.
62
56
  >
63
57
  > ```bash
64
58
  > pip install warpgbm --no-build-isolation
@@ -198,8 +192,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
198
192
 
199
193
  ### Methods:
200
194
  - `.fit(X, y, era_id=None)`: Train the model. `X` can be raw floats or pre-binned `int8` data. `era_id` is optional and used internally.
201
- - `.predict(X, chunksize=50_000)`: Predict on new raw float or pre-binned data.
202
- - `.predict_numpy(X, chunksize=50_000)`: Same as `.predict(X)` but without using the GPU.
195
+ - `.predict(X)`: Predict on new data, using parallelized CUDA kernel.
203
196
 
204
197
  ---
205
198
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "warpgbm"
7
- version = "0.1.21"
7
+ version = "0.1.23"
8
8
  description = "A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -1,11 +1,9 @@
1
1
  import numpy as np
2
2
  from warpgbm import WarpGBM
3
3
  from sklearn.datasets import make_regression
4
-
5
- import numpy as np
6
4
  import time
7
- from warpgbm import WarpGBM
8
- from sklearn.datasets import make_regression
5
+ from sklearn.metrics import mean_squared_error
6
+
9
7
 
10
8
  def test_fit_predictpytee_correlation():
11
9
  np.random.seed(42)
@@ -14,23 +12,32 @@ def test_fit_predictpytee_correlation():
14
12
  X, y = make_regression(n_samples=N, n_features=F, noise=0.1, random_state=42)
15
13
  era = np.zeros(N, dtype=np.int32)
16
14
  corrs = []
15
+ mses = []
17
16
 
18
- for hist_type in ['hist1', 'hist2', 'hist3']:
17
+ for hist_type in ["hist1", "hist2", "hist3"]:
19
18
  print(f"\nTesting histogram method: {hist_type}")
20
19
 
21
20
  model = WarpGBM(
22
21
  max_depth=10,
23
22
  num_bins=10,
24
- n_estimators=10,
23
+ n_estimators=100,
25
24
  learning_rate=1,
26
25
  verbosity=False,
27
26
  histogram_computer=hist_type,
28
27
  threads_per_block=64,
29
- rows_per_thread=4
28
+ rows_per_thread=4,
30
29
  )
31
30
 
32
31
  start_fit = time.time()
33
- model.fit(X, y, era_id=era)
32
+ model.fit(
33
+ X,
34
+ y,
35
+ era_id=era,
36
+ X_eval=X,
37
+ y_eval=y,
38
+ eval_every_n_trees=10,
39
+ early_stopping_rounds=1,
40
+ )
34
41
  fit_time = time.time() - start_fit
35
42
  print(f" Fit time: {fit_time:.3f} seconds")
36
43
 
@@ -40,7 +47,11 @@ def test_fit_predictpytee_correlation():
40
47
  print(f" Predict time: {pred_time:.3f} seconds")
41
48
 
42
49
  corr = np.corrcoef(preds, y)[0, 1]
50
+ mse = mean_squared_error(preds, y)
43
51
  print(f" Correlation: {corr:.4f}")
52
+ print(f" MSE: {mse:.4f}")
44
53
  corrs.append(corr)
54
+ mses.append(mse)
45
55
 
46
- assert (np.array(corrs) > 0.95).all(), f"In-sample correlation too low: {corrs}"
56
+ assert (np.array(corrs) > 0.9).all(), f"In-sample correlation too low: {corrs}"
57
+ assert (np.array(mses) < 2).all(), f"In-sample mse too high: {mses}"
@@ -0,0 +1 @@
1
+ 0.1.23
@@ -0,0 +1,537 @@
1
+ import torch
2
+ import numpy as np
3
+ from sklearn.base import BaseEstimator, RegressorMixin
4
+ from warpgbm.cuda import node_kernel
5
+ from tqdm import tqdm
6
+ from typing import Tuple
7
+ from torch import Tensor
8
+
9
+ histogram_kernels = {
10
+ "hist1": node_kernel.compute_histogram,
11
+ "hist2": node_kernel.compute_histogram2,
12
+ "hist3": node_kernel.compute_histogram3,
13
+ }
14
+
15
+
16
+ class WarpGBM(BaseEstimator, RegressorMixin):
17
+ def __init__(
18
+ self,
19
+ num_bins=10,
20
+ max_depth=3,
21
+ learning_rate=0.1,
22
+ n_estimators=100,
23
+ min_child_weight=20,
24
+ min_split_gain=0.0,
25
+ verbosity=True,
26
+ histogram_computer="hist3",
27
+ threads_per_block=64,
28
+ rows_per_thread=4,
29
+ L2_reg=1e-6,
30
+ L1_reg=0.0,
31
+ device="cuda",
32
+ ):
33
+ # Validate arguments
34
+ self._validate_hyperparams(
35
+ num_bins=num_bins,
36
+ max_depth=max_depth,
37
+ learning_rate=learning_rate,
38
+ n_estimators=n_estimators,
39
+ min_child_weight=min_child_weight,
40
+ min_split_gain=min_split_gain,
41
+ histogram_computer=histogram_computer,
42
+ threads_per_block=threads_per_block,
43
+ rows_per_thread=rows_per_thread,
44
+ L2_reg=L2_reg,
45
+ L1_reg=L1_reg,
46
+ )
47
+
48
+ self.num_bins = num_bins
49
+ self.max_depth = max_depth
50
+ self.learning_rate = learning_rate
51
+ self.n_estimators = n_estimators
52
+ self.forest = None
53
+ self.bin_edges = None
54
+ self.base_prediction = None
55
+ self.unique_eras = None
56
+ self.device = device
57
+ self.root_gradient_histogram = None
58
+ self.root_hessian_histogram = None
59
+ self.gradients = None
60
+ self.root_node_indices = None
61
+ self.bin_indices = None
62
+ self.Y_gpu = None
63
+ self.num_features = None
64
+ self.num_samples = None
65
+ self.min_child_weight = min_child_weight
66
+ self.min_split_gain = min_split_gain
67
+ self.best_bin = torch.tensor([-1], dtype=torch.int32, device=self.device)
68
+ self.compute_histogram = histogram_kernels[histogram_computer]
69
+ self.threads_per_block = threads_per_block
70
+ self.rows_per_thread = rows_per_thread
71
+ self.L2_reg = L2_reg
72
+ self.L1_reg = L1_reg
73
+ self.forest = [{} for _ in range(self.n_estimators)]
74
+
75
+ def _validate_hyperparams(self, **kwargs):
76
+ # Type checks
77
+ int_params = [
78
+ "num_bins",
79
+ "max_depth",
80
+ "n_estimators",
81
+ "min_child_weight",
82
+ "threads_per_block",
83
+ "rows_per_thread",
84
+ ]
85
+ float_params = ["learning_rate", "min_split_gain", "L2_reg", "L1_reg"]
86
+
87
+ for param in int_params:
88
+ if not isinstance(kwargs[param], int):
89
+ raise TypeError(
90
+ f"{param} must be an integer, got {type(kwargs[param])}."
91
+ )
92
+
93
+ for param in float_params:
94
+ if not isinstance(
95
+ kwargs[param], (float, int)
96
+ ): # Accept ints as valid floats
97
+ raise TypeError(f"{param} must be a float, got {type(kwargs[param])}.")
98
+
99
+ if not (2 <= kwargs["num_bins"] <= 127):
100
+ raise ValueError("num_bins must be between 2 and 127 inclusive.")
101
+ if kwargs["max_depth"] < 1:
102
+ raise ValueError("max_depth must be at least 1.")
103
+ if not (0.0 < kwargs["learning_rate"] <= 1.0):
104
+ raise ValueError("learning_rate must be in (0.0, 1.0].")
105
+ if kwargs["n_estimators"] <= 0:
106
+ raise ValueError("n_estimators must be positive.")
107
+ if kwargs["min_child_weight"] < 1:
108
+ raise ValueError("min_child_weight must be a positive integer.")
109
+ if kwargs["min_split_gain"] < 0:
110
+ raise ValueError("min_split_gain must be non-negative.")
111
+ if kwargs["threads_per_block"] <= 0 or kwargs["threads_per_block"] % 32 != 0:
112
+ raise ValueError(
113
+ "threads_per_block should be a positive multiple of 32 (warp size)."
114
+ )
115
+ if not (1 <= kwargs["rows_per_thread"] <= 16):
116
+ raise ValueError(
117
+ "rows_per_thread must be positive between 1 and 16 inclusive."
118
+ )
119
+ if kwargs["L2_reg"] < 0 or kwargs["L1_reg"] < 0:
120
+ raise ValueError("L2_reg and L1_reg must be non-negative.")
121
+ if kwargs["histogram_computer"] not in histogram_kernels:
122
+ raise ValueError(
123
+ f"Invalid histogram_computer: {kwargs['histogram_computer']}. Choose from {list(histogram_kernels.keys())}."
124
+ )
125
+
126
+ def validate_fit_params(
127
+ self, X, y, era_id, X_eval, y_eval, eval_every_n_trees, early_stopping_rounds
128
+ ):
129
+ # ─── Required: X and y ───
130
+ if not isinstance(X, np.ndarray) or not isinstance(y, np.ndarray):
131
+ raise TypeError("X and y must be numpy arrays.")
132
+ if X.ndim != 2:
133
+ raise ValueError(f"X must be 2-dimensional, got shape {X.shape}")
134
+ if y.ndim != 1:
135
+ raise ValueError(f"y must be 1-dimensional, got shape {y.shape}")
136
+ if X.shape[0] != y.shape[0]:
137
+ raise ValueError(
138
+ f"X and y must have the same number of rows. Got {X.shape[0]} and {y.shape[0]}."
139
+ )
140
+
141
+ # ─── Optional: era_id ───
142
+ if era_id is not None:
143
+ if not isinstance(era_id, np.ndarray):
144
+ raise TypeError("era_id must be a numpy array.")
145
+ if era_id.ndim != 1:
146
+ raise ValueError(
147
+ f"era_id must be 1-dimensional, got shape {era_id.shape}"
148
+ )
149
+ if len(era_id) != len(y):
150
+ raise ValueError(
151
+ f"era_id must have same length as y. Got {len(era_id)} and {len(y)}."
152
+ )
153
+
154
+ # ─── Optional: Eval Set ───
155
+ eval_args = [X_eval, y_eval, eval_every_n_trees]
156
+ if any(arg is not None for arg in eval_args):
157
+ # Require all of them
158
+ if X_eval is None or y_eval is None or eval_every_n_trees is None:
159
+ raise ValueError(
160
+ "If using eval set, X_eval, y_eval, and eval_every_n_trees must all be defined."
161
+ )
162
+
163
+ if not isinstance(X_eval, np.ndarray) or not isinstance(y_eval, np.ndarray):
164
+ raise TypeError("X_eval and y_eval must be numpy arrays.")
165
+ if X_eval.ndim != 2:
166
+ raise ValueError(
167
+ f"X_eval must be 2-dimensional, got shape {X_eval.shape}"
168
+ )
169
+ if y_eval.ndim != 1:
170
+ raise ValueError(
171
+ f"y_eval must be 1-dimensional, got shape {y_eval.shape}"
172
+ )
173
+ if X_eval.shape[0] != y_eval.shape[0]:
174
+ raise ValueError(
175
+ f"X_eval and y_eval must have same number of rows. Got {X_eval.shape[0]} and {y_eval.shape[0]}."
176
+ )
177
+
178
+ if not isinstance(eval_every_n_trees, int) or eval_every_n_trees <= 0:
179
+ raise ValueError(
180
+ f"eval_every_n_trees must be a positive integer, got {eval_every_n_trees}."
181
+ )
182
+
183
+ if early_stopping_rounds is not None:
184
+ if (
185
+ not isinstance(early_stopping_rounds, int)
186
+ or early_stopping_rounds <= 0
187
+ ):
188
+ raise ValueError(
189
+ f"early_stopping_rounds must be a positive integer, got {early_stopping_rounds}."
190
+ )
191
+ else:
192
+ # No early stopping = set to "never trigger"
193
+ early_stopping_rounds = self.n_estimators + 1
194
+
195
+ return early_stopping_rounds # May have been defaulted here
196
+
197
+ def fit(
198
+ self,
199
+ X,
200
+ y,
201
+ era_id=None,
202
+ X_eval=None,
203
+ y_eval=None,
204
+ eval_every_n_trees=None,
205
+ early_stopping_rounds=None,
206
+ ):
207
+ early_stopping_rounds = self.validate_fit_params(
208
+ X, y, era_id, X_eval, y_eval, eval_every_n_trees, early_stopping_rounds
209
+ )
210
+
211
+ if era_id is None:
212
+ era_id = np.ones(X.shape[0], dtype="int32")
213
+
214
+ # Train data preprocessing
215
+ self.bin_indices, era_indices, self.bin_edges, self.unique_eras, self.Y_gpu = (
216
+ self.preprocess_gpu_data(X, y, era_id)
217
+ )
218
+ self.num_samples, self.num_features = X.shape
219
+ self.gradients = torch.zeros_like(self.Y_gpu)
220
+ self.root_node_indices = torch.arange(self.num_samples, device=self.device)
221
+ self.base_prediction = self.Y_gpu.mean().item()
222
+ self.gradients += self.base_prediction
223
+ self.best_gains = torch.zeros(self.num_features, device=self.device)
224
+ self.best_bins = torch.zeros(
225
+ self.num_features, device=self.device, dtype=torch.int32
226
+ )
227
+
228
+ # ─── Optional Eval Set ───
229
+ if X_eval is not None and y_eval is not None:
230
+ self.bin_indices_eval = self.bin_data_with_existing_edges(X_eval)
231
+ self.Y_gpu_eval = torch.from_numpy(y_eval).to(torch.float32).to(self.device)
232
+ self.eval_every_n_trees = eval_every_n_trees
233
+ self.early_stopping_rounds = early_stopping_rounds
234
+ else:
235
+ self.bin_indices_eval = None
236
+ self.Y_gpu_eval = None
237
+ self.eval_every_n_trees = None
238
+ self.early_stopping_rounds = None
239
+
240
+ # ─── Grow the forest ───
241
+ with torch.no_grad():
242
+ self.grow_forest()
243
+
244
+ return self
245
+
246
+ def preprocess_gpu_data(self, X_np, Y_np, era_id_np):
247
+ with torch.no_grad():
248
+ self.num_samples, self.num_features = X_np.shape
249
+ Y_gpu = torch.from_numpy(Y_np).type(torch.float32).to(self.device)
250
+ era_id_gpu = torch.from_numpy(era_id_np).type(torch.int32).to(self.device)
251
+ is_integer_type = np.issubdtype(X_np.dtype, np.integer)
252
+ if is_integer_type:
253
+ max_vals = X_np.max(axis=0)
254
+ if np.all(max_vals < self.num_bins):
255
+ print(
256
+ "Detected pre-binned integer input — skipping quantile binning."
257
+ )
258
+ bin_indices = (
259
+ torch.from_numpy(X_np)
260
+ .to(self.device)
261
+ .contiguous()
262
+ .to(torch.int8)
263
+ )
264
+
265
+ # We'll store None or an empty tensor in self.bin_edges
266
+ # to indicate that we skip binning at predict-time
267
+ bin_edges = torch.arange(
268
+ 1, self.num_bins, dtype=torch.float32
269
+ ).repeat(self.num_features, 1)
270
+ bin_edges = bin_edges.to(self.device)
271
+ unique_eras, era_indices = torch.unique(
272
+ era_id_gpu, return_inverse=True
273
+ )
274
+ return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
275
+ else:
276
+ print(
277
+ "Integer input detected, but values exceed num_bins — falling back to quantile binning."
278
+ )
279
+
280
+ bin_indices = torch.empty(
281
+ (self.num_samples, self.num_features), dtype=torch.int8, device="cuda"
282
+ )
283
+ bin_edges = torch.empty(
284
+ (self.num_features, self.num_bins - 1),
285
+ dtype=torch.float32,
286
+ device="cuda",
287
+ )
288
+
289
+ X_np = torch.from_numpy(X_np).to(torch.float32).pin_memory()
290
+
291
+ for f in range(self.num_features):
292
+ X_f = X_np[:, f].to("cuda", non_blocking=True)
293
+ quantiles = torch.linspace(
294
+ 0, 1, self.num_bins + 1, device="cuda", dtype=X_f.dtype
295
+ )[1:-1]
296
+ bin_edges_f = torch.quantile(
297
+ X_f, quantiles, dim=0
298
+ ).contiguous() # shape: [B-1] for 1D input
299
+ bin_indices_f = bin_indices[:, f].contiguous() # view into output
300
+ node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
301
+ bin_indices[:, f] = bin_indices_f
302
+ bin_edges[f, :] = bin_edges_f
303
+
304
+ unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
305
+ return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
306
+
307
+ def compute_histograms(self, bin_indices_sub, gradients):
308
+ grad_hist = torch.zeros(
309
+ (self.num_features, self.num_bins), device=self.device, dtype=torch.float32
310
+ )
311
+ hess_hist = torch.zeros(
312
+ (self.num_features, self.num_bins), device=self.device, dtype=torch.float32
313
+ )
314
+
315
+ self.compute_histogram(
316
+ bin_indices_sub,
317
+ gradients,
318
+ grad_hist,
319
+ hess_hist,
320
+ self.num_bins,
321
+ self.threads_per_block,
322
+ self.rows_per_thread,
323
+ )
324
+ return grad_hist, hess_hist
325
+
326
+ def find_best_split(self, gradient_histogram, hessian_histogram):
327
+ node_kernel.compute_split(
328
+ gradient_histogram,
329
+ hessian_histogram,
330
+ self.min_split_gain,
331
+ self.min_child_weight,
332
+ self.L2_reg,
333
+ self.best_gains,
334
+ self.best_bins,
335
+ self.threads_per_block,
336
+ )
337
+
338
+ if torch.all(self.best_bins == -1):
339
+ return -1, -1 # No valid split found
340
+
341
+ f = torch.argmax(self.best_gains).item()
342
+ b = self.best_bins[f].item()
343
+
344
+ return f, b
345
+
346
+ def grow_tree(self, gradient_histogram, hessian_histogram, node_indices, depth):
347
+ if depth == self.max_depth:
348
+ leaf_value = self.residual[node_indices].mean()
349
+ self.gradients[node_indices] += self.learning_rate * leaf_value
350
+ return {"leaf_value": leaf_value.item(), "samples": node_indices.numel()}
351
+
352
+ parent_size = node_indices.numel()
353
+ best_feature, best_bin = self.find_best_split(
354
+ gradient_histogram, hessian_histogram
355
+ )
356
+
357
+ if best_feature == -1:
358
+ leaf_value = self.residual[node_indices].mean()
359
+ self.gradients[node_indices] += self.learning_rate * leaf_value
360
+ return {"leaf_value": leaf_value.item(), "samples": parent_size}
361
+
362
+ split_mask = self.bin_indices[node_indices, best_feature] <= best_bin
363
+ left_indices = node_indices[split_mask]
364
+ right_indices = node_indices[~split_mask]
365
+
366
+ left_size = left_indices.numel()
367
+ right_size = right_indices.numel()
368
+
369
+ if left_size <= right_size:
370
+ grad_hist_left, hess_hist_left = self.compute_histograms(
371
+ self.bin_indices[left_indices], self.residual[left_indices]
372
+ )
373
+ grad_hist_right = gradient_histogram - grad_hist_left
374
+ hess_hist_right = hessian_histogram - hess_hist_left
375
+ else:
376
+ grad_hist_right, hess_hist_right = self.compute_histograms(
377
+ self.bin_indices[right_indices], self.residual[right_indices]
378
+ )
379
+ grad_hist_left = gradient_histogram - grad_hist_right
380
+ hess_hist_left = hessian_histogram - hess_hist_right
381
+
382
+ new_depth = depth + 1
383
+ left_child = self.grow_tree(
384
+ grad_hist_left, hess_hist_left, left_indices, new_depth
385
+ )
386
+ right_child = self.grow_tree(
387
+ grad_hist_right, hess_hist_right, right_indices, new_depth
388
+ )
389
+
390
+ return {
391
+ "feature": best_feature,
392
+ "bin": best_bin,
393
+ "left": left_child,
394
+ "right": right_child,
395
+ }
396
+
397
+ def compute_eval(self, i):
398
+ if self.eval_every_n_trees == None:
399
+ return
400
+
401
+ if i % self.eval_every_n_trees == 0:
402
+ eval_preds = self.predict_binned(self.bin_indices_eval)
403
+ eval_loss = ((self.Y_gpu_eval - eval_preds) ** 2).mean().item()
404
+ self.eval_loss.append(eval_loss)
405
+
406
+ train_loss = ((self.Y_gpu - self.gradients) ** 2).mean().item()
407
+ self.training_loss.append(train_loss)
408
+
409
+ if len(self.eval_loss) > self.early_stopping_rounds:
410
+ if self.eval_loss[-self.early_stopping_rounds] < self.eval_loss[-1]:
411
+ self.stop = True
412
+
413
+ print(
414
+ f"🌲 Tree {i+1}/{self.n_estimators} | Train MSE: {train_loss:.6f} | Eval MSE: {eval_loss:.6f}"
415
+ )
416
+
417
+ del eval_preds, eval_loss, train_loss
418
+
419
+ def grow_forest(self):
420
+ self.training_loss = []
421
+ self.eval_loss = [] # <-- if eval set is given
422
+ self.stop = False
423
+
424
+ for i in range(self.n_estimators):
425
+ self.residual = self.Y_gpu - self.gradients
426
+
427
+ self.root_gradient_histogram, self.root_hessian_histogram = (
428
+ self.compute_histograms(self.bin_indices, self.residual)
429
+ )
430
+
431
+ tree = self.grow_tree(
432
+ self.root_gradient_histogram,
433
+ self.root_hessian_histogram,
434
+ self.root_node_indices,
435
+ depth=0,
436
+ )
437
+ self.forest[i] = tree
438
+
439
+ self.compute_eval(i)
440
+
441
+ if self.stop:
442
+ break
443
+
444
+ print("Finished training forest.")
445
+
446
+ def bin_data_with_existing_edges(self, X_np):
447
+ X_tensor = torch.from_numpy(X_np).to(torch.float32).pin_memory()
448
+ num_samples = X_tensor.size(0)
449
+ bin_indices = torch.zeros(
450
+ (num_samples, self.num_features), dtype=torch.int8, device=self.device
451
+ )
452
+ with torch.no_grad():
453
+ for f in range(self.num_features):
454
+ X_f = X_tensor[:, f].to(self.device, non_blocking=True)
455
+ bin_edges_f = self.bin_edges[f]
456
+ bin_indices_f = bin_indices[:, f].contiguous()
457
+ node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
458
+ bin_indices[:, f] = bin_indices_f
459
+
460
+ return bin_indices
461
+
462
+ def predict_binned(self, bin_indices):
463
+ num_samples = bin_indices.size(0)
464
+
465
+ tree_tensor = torch.stack(
466
+ [
467
+ self.flatten_tree(tree, max_nodes=2 ** (self.max_depth + 1))
468
+ for tree in self.forest
469
+ if tree
470
+ ]
471
+ ).to(self.device)
472
+
473
+ out = torch.zeros(num_samples, device=self.device) + self.base_prediction
474
+ node_kernel.predict_forest(
475
+ bin_indices.contiguous(), tree_tensor.contiguous(), self.learning_rate, out
476
+ )
477
+
478
+ return out
479
+
480
+ def predict(self, X_np):
481
+ bin_indices = self.bin_data_with_existing_edges(X_np)
482
+ out = self.predict_binned(bin_indices)
483
+ return out.cpu().numpy()
484
+
485
+ def flatten_tree(self, tree, max_nodes):
486
+ """
487
+ Convert a recursive tree structure into a flat matrix format.
488
+
489
+ Each row in the output represents a node:
490
+ - Columns: [feature, bin, left_id, right_id, is_leaf, value]
491
+ - Internal nodes fill columns 0–3 and set is_leaf = 0
492
+ - Leaf nodes fill only value and set is_leaf = 1
493
+
494
+ Args:
495
+ tree (list): A list containing a single root node (recursive dict form).
496
+ max_nodes (int): Max number of nodes to allocate in the flat matrix.
497
+
498
+ Returns:
499
+ torch.Tensor: [max_nodes x 6] matrix representing the flattened tree.
500
+ """
501
+ flat = torch.full((max_nodes, 6), float("nan"), dtype=torch.float32)
502
+ node_counter = [0]
503
+ node_list = []
504
+
505
+ def walk(node):
506
+ curr_id = node_counter[0]
507
+ node_counter[0] += 1
508
+
509
+ new_node = {"node_id": curr_id}
510
+ if "leaf_value" in node:
511
+ new_node["leaf_value"] = float(node["leaf_value"])
512
+ else:
513
+ new_node["best_feature"] = float(node["feature"])
514
+ new_node["split_bin"] = float(node["bin"])
515
+ new_node["left_id"] = node_counter[0]
516
+ walk(node["left"])
517
+ new_node["right_id"] = node_counter[0]
518
+ walk(node["right"])
519
+
520
+ node_list.append(new_node)
521
+ return new_node
522
+
523
+ walk(tree)
524
+
525
+ for node in node_list:
526
+ i = node["node_id"]
527
+ if "leaf_value" in node:
528
+ flat[i, 4] = 1.0
529
+ flat[i, 5] = node["leaf_value"]
530
+ else:
531
+ flat[i, 0] = node["best_feature"]
532
+ flat[i, 1] = node["split_bin"]
533
+ flat[i, 2] = node["left_id"]
534
+ flat[i, 3] = node["right_id"]
535
+ flat[i, 4] = 0.0
536
+
537
+ return flat
@@ -38,7 +38,7 @@ __global__ void best_split_kernel_global_only(
38
38
 
39
39
  if (H_L >= min_child_samples && H_R >= min_child_samples)
40
40
  {
41
- float gain = (G_L * G_L) / (H_L + eps) + (G_R * G_R) / (H_R + eps);
41
+ float gain = (G_L * G_L) / (H_L + eps) + (G_R * G_R) / (H_R + eps) - (G_total * G_total) / (H_total + eps);
42
42
  if (gain > best_gain)
43
43
  {
44
44
  best_gain = gain;
@@ -107,12 +107,6 @@ void launch_histogram_kernel_cuda(
107
107
  N, F, B);
108
108
  }
109
109
 
110
- #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
111
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
112
- #define CHECK_INPUT(x) \
113
- CHECK_CUDA(x); \
114
- CHECK_CONTIGUOUS(x)
115
-
116
110
  // CUDA kernel: tiled, 64-bit safe
117
111
  __global__ void histogram_tiled_kernel(
118
112
  const int8_t *__restrict__ bin_indices, // [N, F]
@@ -148,10 +142,6 @@ void launch_histogram_kernel_cuda_2(
148
142
  int threads_per_block = 256,
149
143
  int rows_per_thread = 1)
150
144
  {
151
- CHECK_INPUT(bin_indices);
152
- CHECK_INPUT(gradients);
153
- CHECK_INPUT(grad_hist);
154
- CHECK_INPUT(hess_hist);
155
145
 
156
146
  int64_t N = bin_indices.size(0);
157
147
  int64_t F = bin_indices.size(1);
@@ -233,10 +223,6 @@ void launch_histogram_kernel_cuda_configurable(
233
223
  int threads_per_block = 256,
234
224
  int rows_per_thread = 1)
235
225
  {
236
- CHECK_INPUT(bin_indices);
237
- CHECK_INPUT(gradients);
238
- CHECK_INPUT(grad_hist);
239
- CHECK_INPUT(hess_hist);
240
226
 
241
227
  int64_t N = bin_indices.size(0);
242
228
  int64_t F = bin_indices.size(1);
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: warpgbm
3
- Version: 0.1.21
3
+ Version: 0.1.23
4
4
  Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
5
5
  License: GNU GENERAL PUBLIC LICENSE
6
6
  Version 3, 29 June 2007
@@ -704,26 +704,20 @@ WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (G
704
704
 
705
705
  ---
706
706
 
707
- ## Performance Note
708
-
709
- In our initial tests on an NVIDIA 3090 (local) and A100 (Google Colab Pro), WarpGBM achieves **14x to 20x faster training times** compared to LightGBM's CPU version and **2x faster** on the GPU version using default configurations. Speed also outperforms XGBoost and CatBoost on regression problems. It also consumes **significantly less RAM and CPU**. These early results hint at more thorough benchmarking to come.
710
-
711
- ---
712
-
713
707
  ## Benchmarks
714
708
 
715
709
  ### Scikit-Learn Synthetic Data: 1 Million Rows and 1,000 Features
716
710
 
717
- In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.19** against LightGBM, XGBoost and CatBoost, all with their GPU-enabled versions. This benchmark runs on Google Colab with the L4 GPU environment. The CPU versions don't even come close to the speed here so we didn't test them.
711
+ In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.21** against LightGBM, XGBoost and CatBoost, all with their GPU-enabled versions. This benchmark runs on Google Colab with the L4 GPU environment.
718
712
 
719
713
  ```
720
- WarpGBM: corr = 0.8882, train = 21.8s, infer = 11.6s
721
- XGBoost: corr = 0.8877, train = 33.4s, infer = 8.1s
722
- LightGBM: corr = 0.8604, train = 30.2s, infer = 1.4s
723
- CatBoost: corr = 0.8935, train = 377.9s, infer = 375.8s
714
+ WarpGBM: corr = 0.8882, train = 18.7s, infer = 4.9s
715
+ XGBoost: corr = 0.8877, train = 33.1s, infer = 8.1s
716
+ LightGBM: corr = 0.8604, train = 30.3s, infer = 1.4s
717
+ CatBoost: corr = 0.8935, train = 400.0s, infer = 382.6s
724
718
  ```
725
719
 
726
- Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK
720
+ Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK?usp=sharing
727
721
 
728
722
  ---
729
723
 
@@ -746,7 +740,7 @@ pip install warpgbm
746
740
  This installs from PyPI and also compiles CUDA code locally during installation. This method works well **if your environment already has PyTorch with GPU support** installed and configured.
747
741
 
748
742
  > **Tip:**\
749
- > If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag:
743
+ > If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag. This is currently required in the Colab environments.
750
744
  >
751
745
  > ```bash
752
746
  > pip install warpgbm --no-build-isolation
@@ -886,8 +880,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
886
880
 
887
881
  ### Methods:
888
882
  - `.fit(X, y, era_id=None)`: Train the model. `X` can be raw floats or pre-binned `int8` data. `era_id` is optional and used internally.
889
- - `.predict(X, chunksize=50_000)`: Predict on new raw float or pre-binned data.
890
- - `.predict_numpy(X, chunksize=50_000)`: Same as `.predict(X)` but without using the GPU.
883
+ - `.predict(X)`: Predict on new data, using parallelized CUDA kernel.
891
884
 
892
885
  ---
893
886
 
@@ -1 +0,0 @@
1
- 0.1.21
@@ -1,341 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from sklearn.base import BaseEstimator, RegressorMixin
4
- from warpgbm.cuda import node_kernel
5
- from tqdm import tqdm
6
- from typing import Tuple
7
- from torch import Tensor
8
-
9
- histogram_kernels = {
10
- 'hist1': node_kernel.compute_histogram,
11
- 'hist2': node_kernel.compute_histogram2,
12
- 'hist3': node_kernel.compute_histogram3
13
- }
14
-
15
- class WarpGBM(BaseEstimator, RegressorMixin):
16
- def __init__(
17
- self,
18
- num_bins=10,
19
- max_depth=3,
20
- learning_rate=0.1,
21
- n_estimators=100,
22
- min_child_weight=20,
23
- min_split_gain=0.0,
24
- verbosity=True,
25
- histogram_computer='hist3',
26
- threads_per_block=64,
27
- rows_per_thread=4,
28
- L2_reg=1e-6,
29
- L1_reg=0.0,
30
- device='cuda'
31
- ):
32
- # Validate arguments
33
- self._validate_hyperparams(
34
- num_bins=num_bins,
35
- max_depth=max_depth,
36
- learning_rate=learning_rate,
37
- n_estimators=n_estimators,
38
- min_child_weight=min_child_weight,
39
- min_split_gain=min_split_gain,
40
- histogram_computer=histogram_computer,
41
- threads_per_block=threads_per_block,
42
- rows_per_thread=rows_per_thread,
43
- L2_reg=L2_reg,
44
- L1_reg=L1_reg
45
- )
46
-
47
- self.num_bins = num_bins
48
- self.max_depth = max_depth
49
- self.learning_rate = learning_rate
50
- self.n_estimators = n_estimators
51
- self.forest = None
52
- self.bin_edges = None
53
- self.base_prediction = None
54
- self.unique_eras = None
55
- self.device = device
56
- self.root_gradient_histogram = None
57
- self.root_hessian_histogram = None
58
- self.gradients = None
59
- self.root_node_indices = None
60
- self.bin_indices = None
61
- self.Y_gpu = None
62
- self.num_features = None
63
- self.num_samples = None
64
- self.min_child_weight = min_child_weight
65
- self.min_split_gain = min_split_gain
66
- self.best_bin = torch.tensor([-1], dtype=torch.int32, device=self.device)
67
- self.compute_histogram = histogram_kernels[histogram_computer]
68
- self.threads_per_block = threads_per_block
69
- self.rows_per_thread = rows_per_thread
70
- self.L2_reg = L2_reg
71
- self.L1_reg = L1_reg
72
-
73
- def _validate_hyperparams(self, **kwargs):
74
- # Type checks
75
- int_params = [
76
- "num_bins", "max_depth", "n_estimators", "min_child_weight",
77
- "threads_per_block", "rows_per_thread"
78
- ]
79
- float_params = [
80
- "learning_rate", "min_split_gain", "L2_reg", "L1_reg"
81
- ]
82
-
83
- for param in int_params:
84
- if not isinstance(kwargs[param], int):
85
- raise TypeError(f"{param} must be an integer, got {type(kwargs[param])}.")
86
-
87
- for param in float_params:
88
- if not isinstance(kwargs[param], (float, int)): # Accept ints as valid floats
89
- raise TypeError(f"{param} must be a float, got {type(kwargs[param])}.")
90
-
91
- if not ( 2 <= kwargs["num_bins"] <= 127 ):
92
- raise ValueError("num_bins must be between 2 and 127 inclusive.")
93
- if kwargs["max_depth"] < 1:
94
- raise ValueError("max_depth must be at least 1.")
95
- if not (0.0 < kwargs["learning_rate"] <= 1.0):
96
- raise ValueError("learning_rate must be in (0.0, 1.0].")
97
- if kwargs["n_estimators"] <= 0:
98
- raise ValueError("n_estimators must be positive.")
99
- if kwargs["min_child_weight"] < 1:
100
- raise ValueError("min_child_weight must be a positive integer.")
101
- if kwargs["min_split_gain"] < 0:
102
- raise ValueError("min_split_gain must be non-negative.")
103
- if kwargs["threads_per_block"] <= 0 or kwargs["threads_per_block"] % 32 != 0:
104
- raise ValueError("threads_per_block should be a positive multiple of 32 (warp size).")
105
- if not ( 1 <= kwargs["rows_per_thread"] <= 16 ):
106
- raise ValueError("rows_per_thread must be positive between 1 and 16 inclusive.")
107
- if kwargs["L2_reg"] < 0 or kwargs["L1_reg"] < 0:
108
- raise ValueError("L2_reg and L1_reg must be non-negative.")
109
- if kwargs["histogram_computer"] not in histogram_kernels:
110
- raise ValueError(f"Invalid histogram_computer: {kwargs['histogram_computer']}. Choose from {list(histogram_kernels.keys())}.")
111
-
112
- def fit(self, X, y, era_id=None):
113
- if era_id is None:
114
- era_id = np.ones(X.shape[0], dtype='int32')
115
- self.bin_indices, era_indices, self.bin_edges, self.unique_eras, self.Y_gpu = self.preprocess_gpu_data(X, y, era_id)
116
- self.num_samples, self.num_features = X.shape
117
- self.gradients = torch.zeros_like(self.Y_gpu)
118
- self.root_node_indices = torch.arange(self.num_samples, device=self.device)
119
- self.base_prediction = self.Y_gpu.mean().item()
120
- self.gradients += self.base_prediction
121
- self.best_gains = torch.zeros(self.num_features, device=self.device)
122
- self.best_bins = torch.zeros(self.num_features, device=self.device, dtype=torch.int32)
123
- with torch.no_grad():
124
- self.forest = self.grow_forest()
125
- return self
126
-
127
- def preprocess_gpu_data(self, X_np, Y_np, era_id_np):
128
- with torch.no_grad():
129
- self.num_samples, self.num_features = X_np.shape
130
- Y_gpu = torch.from_numpy(Y_np).type(torch.float32).to(self.device)
131
- era_id_gpu = torch.from_numpy(era_id_np).type(torch.int32).to(self.device)
132
- is_integer_type = np.issubdtype(X_np.dtype, np.integer)
133
- if is_integer_type:
134
- max_vals = X_np.max(axis=0)
135
- if np.all(max_vals < self.num_bins):
136
- print("Detected pre-binned integer input — skipping quantile binning.")
137
- bin_indices = torch.from_numpy(X_np).to(self.device).contiguous().to(torch.int8)
138
-
139
- # We'll store None or an empty tensor in self.bin_edges
140
- # to indicate that we skip binning at predict-time
141
- bin_edges = torch.arange(1, self.num_bins, dtype=torch.float32).repeat(self.num_features, 1)
142
- bin_edges = bin_edges.to(self.device)
143
- unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
144
- return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
145
- else:
146
- print("Integer input detected, but values exceed num_bins — falling back to quantile binning.")
147
-
148
- bin_indices = torch.empty((self.num_samples, self.num_features), dtype=torch.int8, device='cuda')
149
- bin_edges = torch.empty((self.num_features, self.num_bins - 1), dtype=torch.float32, device='cuda')
150
-
151
- X_np = torch.from_numpy(X_np).to(torch.float32).pin_memory()
152
-
153
- for f in range(self.num_features):
154
- X_f = X_np[:, f].to('cuda', non_blocking=True)
155
- quantiles = torch.linspace(0, 1, self.num_bins + 1, device='cuda', dtype=X_f.dtype)[1:-1]
156
- bin_edges_f = torch.quantile(X_f, quantiles, dim=0).contiguous() # shape: [B-1] for 1D input
157
- bin_indices_f = bin_indices[:, f].contiguous() # view into output
158
- node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
159
- bin_indices[:,f] = bin_indices_f
160
- bin_edges[f,:] = bin_edges_f
161
-
162
- unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
163
- return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
164
-
165
- def compute_histograms(self, bin_indices_sub, gradients):
166
- grad_hist = torch.zeros((self.num_features, self.num_bins), device=self.device, dtype=torch.float32)
167
- hess_hist = torch.zeros((self.num_features, self.num_bins), device=self.device, dtype=torch.float32)
168
-
169
- self.compute_histogram(
170
- bin_indices_sub,
171
- gradients,
172
- grad_hist,
173
- hess_hist,
174
- self.num_bins,
175
- self.threads_per_block,
176
- self.rows_per_thread
177
- )
178
- return grad_hist, hess_hist
179
-
180
- def find_best_split(self, gradient_histogram, hessian_histogram):
181
- node_kernel.compute_split(
182
- gradient_histogram,
183
- hessian_histogram,
184
- self.min_split_gain,
185
- self.min_child_weight,
186
- self.L2_reg,
187
- self.best_gains,
188
- self.best_bins,
189
- self.threads_per_block
190
- )
191
-
192
- if torch.all(self.best_bins == -1):
193
- return -1, -1 # No valid split found
194
-
195
- f = torch.argmax(self.best_gains).item()
196
- b = self.best_bins[f].item()
197
-
198
- return f, b
199
-
200
- def grow_tree(self, gradient_histogram, hessian_histogram, node_indices, depth):
201
- if depth == self.max_depth:
202
- leaf_value = self.residual[node_indices].mean()
203
- self.gradients[node_indices] += self.learning_rate * leaf_value
204
- return {"leaf_value": leaf_value.item(), "samples": node_indices.numel()}
205
-
206
- parent_size = node_indices.numel()
207
- best_feature, best_bin = self.find_best_split(gradient_histogram, hessian_histogram)
208
-
209
- if best_feature == -1:
210
- leaf_value = self.residual[node_indices].mean()
211
- self.gradients[node_indices] += self.learning_rate * leaf_value
212
- return {"leaf_value": leaf_value.item(), "samples": parent_size}
213
-
214
- split_mask = (self.bin_indices[node_indices, best_feature] <= best_bin)
215
- left_indices = node_indices[split_mask]
216
- right_indices = node_indices[~split_mask]
217
-
218
- left_size = left_indices.numel()
219
- right_size = right_indices.numel()
220
-
221
-
222
- if left_size <= right_size:
223
- grad_hist_left, hess_hist_left = self.compute_histograms( self.bin_indices[left_indices], self.residual[left_indices] )
224
- grad_hist_right = gradient_histogram - grad_hist_left
225
- hess_hist_right = hessian_histogram - hess_hist_left
226
- else:
227
- grad_hist_right, hess_hist_right = self.compute_histograms( self.bin_indices[right_indices], self.residual[right_indices] )
228
- grad_hist_left = gradient_histogram - grad_hist_right
229
- hess_hist_left = hessian_histogram - hess_hist_right
230
-
231
- new_depth = depth + 1
232
- left_child = self.grow_tree(grad_hist_left, hess_hist_left, left_indices, new_depth)
233
- right_child = self.grow_tree(grad_hist_right, hess_hist_right, right_indices, new_depth)
234
-
235
- return { "feature": best_feature, "bin": best_bin, "left": left_child, "right": right_child }
236
-
237
- def grow_forest(self):
238
- forest = [{} for _ in range(self.n_estimators)]
239
- self.training_loss = []
240
-
241
- for i in tqdm( range(self.n_estimators) ):
242
- self.residual = self.Y_gpu - self.gradients
243
-
244
- self.root_gradient_histogram, self.root_hessian_histogram = \
245
- self.compute_histograms(self.bin_indices, self.residual)
246
-
247
- tree = self.grow_tree(
248
- self.root_gradient_histogram,
249
- self.root_hessian_histogram,
250
- self.root_node_indices,
251
- depth=0
252
- )
253
- forest[i] = tree
254
- # loss = ((self.Y_gpu - self.gradients) ** 2).mean().item()
255
- # self.training_loss.append(loss)
256
- # print(f"🌲 Tree {i+1}/{self.n_estimators} - MSE: {loss:.6f}")
257
-
258
- print("Finished training forest.")
259
- return forest
260
-
261
- def predict(self, X_np):
262
- X_tensor = torch.from_numpy(X_np).to(torch.float32).pin_memory()
263
- num_samples = X_tensor.size(0)
264
- bin_indices = torch.zeros((num_samples, self.num_features), dtype=torch.int8, device=self.device)
265
-
266
- with torch.no_grad():
267
- for f in range(self.num_features):
268
- X_f = X_tensor[:, f].to(self.device, non_blocking=True)
269
- bin_edges_f = self.bin_edges[f]
270
- bin_indices_f = bin_indices[:, f].contiguous()
271
- node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
272
- bin_indices[:, f] = bin_indices_f
273
-
274
- tree_tensor = torch.stack([
275
- self.flatten_tree(tree, max_nodes=2**(self.max_depth + 1))
276
- for tree in self.forest
277
- ]).to(self.device)
278
-
279
- out = torch.zeros(num_samples, device=self.device)
280
- node_kernel.predict_forest(
281
- bin_indices.contiguous(),
282
- tree_tensor.contiguous(),
283
- self.learning_rate,
284
- out
285
- )
286
-
287
- return out.cpu().numpy()
288
-
289
- def flatten_tree(self, tree, max_nodes):
290
- """
291
- Convert a recursive tree structure into a flat matrix format.
292
-
293
- Each row in the output represents a node:
294
- - Columns: [feature, bin, left_id, right_id, is_leaf, value]
295
- - Internal nodes fill columns 0–3 and set is_leaf = 0
296
- - Leaf nodes fill only value and set is_leaf = 1
297
-
298
- Args:
299
- tree (list): A list containing a single root node (recursive dict form).
300
- max_nodes (int): Max number of nodes to allocate in the flat matrix.
301
-
302
- Returns:
303
- torch.Tensor: [max_nodes x 6] matrix representing the flattened tree.
304
- """
305
- flat = torch.full((max_nodes, 6), float('nan'), dtype=torch.float32)
306
- node_counter = [0]
307
- node_list = []
308
-
309
- def walk(node):
310
- curr_id = node_counter[0]
311
- node_counter[0] += 1
312
-
313
- new_node = {'node_id': curr_id}
314
- if 'leaf_value' in node:
315
- new_node['leaf_value'] = float(node['leaf_value'])
316
- else:
317
- new_node['best_feature'] = float(node['feature'])
318
- new_node['split_bin'] = float(node['bin'])
319
- new_node['left_id'] = node_counter[0]
320
- walk(node['left'])
321
- new_node['right_id'] = node_counter[0]
322
- walk(node['right'])
323
-
324
- node_list.append(new_node)
325
- return new_node
326
-
327
- walk(tree)
328
-
329
- for node in node_list:
330
- i = node['node_id']
331
- if 'leaf_value' in node:
332
- flat[i, 4] = 1.0
333
- flat[i, 5] = node['leaf_value']
334
- else:
335
- flat[i, 0] = node['best_feature']
336
- flat[i, 1] = node['split_bin']
337
- flat[i, 2] = node['left_id']
338
- flat[i, 3] = node['right_id']
339
- flat[i, 4] = 0.0
340
-
341
- return flat
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes