warpgbm 0.1.22__tar.gz → 0.1.24__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {warpgbm-0.1.22/warpgbm.egg-info → warpgbm-0.1.24}/PKG-INFO +25 -3
  2. {warpgbm-0.1.22 → warpgbm-0.1.24}/README.md +24 -2
  3. {warpgbm-0.1.22 → warpgbm-0.1.24}/pyproject.toml +1 -1
  4. warpgbm-0.1.24/tests/numerai_test.py +62 -0
  5. {warpgbm-0.1.22 → warpgbm-0.1.24}/tests/test_fit_predict_corr.py +9 -1
  6. warpgbm-0.1.24/version.txt +1 -0
  7. {warpgbm-0.1.22 → warpgbm-0.1.24}/warpgbm/core.py +207 -35
  8. {warpgbm-0.1.22 → warpgbm-0.1.24}/warpgbm/cuda/best_split_kernel.cu +1 -1
  9. {warpgbm-0.1.22 → warpgbm-0.1.24/warpgbm.egg-info}/PKG-INFO +25 -3
  10. {warpgbm-0.1.22 → warpgbm-0.1.24}/warpgbm.egg-info/SOURCES.txt +1 -0
  11. warpgbm-0.1.22/version.txt +0 -1
  12. {warpgbm-0.1.22 → warpgbm-0.1.24}/LICENSE +0 -0
  13. {warpgbm-0.1.22 → warpgbm-0.1.24}/MANIFEST.in +0 -0
  14. {warpgbm-0.1.22 → warpgbm-0.1.24}/setup.cfg +0 -0
  15. {warpgbm-0.1.22 → warpgbm-0.1.24}/setup.py +0 -0
  16. {warpgbm-0.1.22 → warpgbm-0.1.24}/tests/__init__.py +0 -0
  17. {warpgbm-0.1.22 → warpgbm-0.1.24}/warpgbm/__init__.py +0 -0
  18. {warpgbm-0.1.22 → warpgbm-0.1.24}/warpgbm/cuda/__init__.py +0 -0
  19. {warpgbm-0.1.22 → warpgbm-0.1.24}/warpgbm/cuda/binner.cu +0 -0
  20. {warpgbm-0.1.22 → warpgbm-0.1.24}/warpgbm/cuda/histogram_kernel.cu +0 -0
  21. {warpgbm-0.1.22 → warpgbm-0.1.24}/warpgbm/cuda/node_kernel.cpp +0 -0
  22. {warpgbm-0.1.22 → warpgbm-0.1.24}/warpgbm/cuda/predict.cu +0 -0
  23. {warpgbm-0.1.22 → warpgbm-0.1.24}/warpgbm.egg-info/dependency_links.txt +0 -0
  24. {warpgbm-0.1.22 → warpgbm-0.1.24}/warpgbm.egg-info/requires.txt +0 -0
  25. {warpgbm-0.1.22 → warpgbm-0.1.24}/warpgbm.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: warpgbm
3
- Version: 0.1.22
3
+ Version: 0.1.24
4
4
  Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
5
5
  License: GNU GENERAL PUBLIC LICENSE
6
6
  Version 3, 29 June 2007
@@ -879,8 +879,26 @@ No installation required — just press **"Open in Playground"**, then **Run All
879
879
  - `L2_reg`: L2 regularizer (default: 1e-6)
880
880
 
881
881
  ### Methods:
882
- - `.fit(X, y, era_id=None)`: Train the model. `X` can be raw floats or pre-binned `int8` data. `era_id` is optional and used internally.
883
- - `.predict(X)`: Predict on new data, using parallelized CUDA kernel.
882
+ ```
883
+ .fit(
884
+ X, # numpy array (float or int) 2 dimensions (num_samples, num_features)
885
+ y, # numpy array (float or int) 1 dimension (num_samples)
886
+ era_id=None, # numpy array (int) 1 dimension (num_samples)
887
+ X_eval=None, # numpy array (float or int) 2 dimensions (eval_num_samples, num_features)
888
+ y_eval=None, # numpy array (float or int) 1 dimension (eval_num_samples)
889
+ eval_every_n_trees=None, # const (int) >= 1
890
+ early_stopping_rounds=None, # const (int) >= 1
891
+ )
892
+ ```
893
+ Train with optional validation set and early stopping.
894
+
895
+
896
+ ```
897
+ .predict(
898
+ X # numpy array (float or int) 2 dimensions (predict_num_samples, num_features)
899
+ )
900
+ ```
901
+ Predict on new data, using parallelized CUDA kernel.
884
902
 
885
903
  ---
886
904
 
@@ -896,3 +914,7 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
896
914
 
897
915
  - Vectorized predict function replaced with CUDA kernel (`warpgbm/cuda/predict.cu`), parallelizing per sample, per tree.
898
916
 
917
+ ### v0.1.23
918
+
919
+ - Adjust gain in split kernel and added support for an eval set with early stopping based on MSE.
920
+
@@ -191,8 +191,26 @@ No installation required — just press **"Open in Playground"**, then **Run All
191
191
  - `L2_reg`: L2 regularizer (default: 1e-6)
192
192
 
193
193
  ### Methods:
194
- - `.fit(X, y, era_id=None)`: Train the model. `X` can be raw floats or pre-binned `int8` data. `era_id` is optional and used internally.
195
- - `.predict(X)`: Predict on new data, using parallelized CUDA kernel.
194
+ ```
195
+ .fit(
196
+ X, # numpy array (float or int) 2 dimensions (num_samples, num_features)
197
+ y, # numpy array (float or int) 1 dimension (num_samples)
198
+ era_id=None, # numpy array (int) 1 dimension (num_samples)
199
+ X_eval=None, # numpy array (float or int) 2 dimensions (eval_num_samples, num_features)
200
+ y_eval=None, # numpy array (float or int) 1 dimension (eval_num_samples)
201
+ eval_every_n_trees=None, # const (int) >= 1
202
+ early_stopping_rounds=None, # const (int) >= 1
203
+ )
204
+ ```
205
+ Train with optional validation set and early stopping.
206
+
207
+
208
+ ```
209
+ .predict(
210
+ X # numpy array (float or int) 2 dimensions (predict_num_samples, num_features)
211
+ )
212
+ ```
213
+ Predict on new data, using parallelized CUDA kernel.
196
214
 
197
215
  ---
198
216
 
@@ -208,3 +226,7 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
208
226
 
209
227
  - Vectorized predict function replaced with CUDA kernel (`warpgbm/cuda/predict.cu`), parallelizing per sample, per tree.
210
228
 
229
+ ### v0.1.23
230
+
231
+ - Adjust gain in split kernel and added support for an eval set with early stopping based on MSE.
232
+
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "warpgbm"
7
- version = "0.1.22"
7
+ version = "0.1.24"
8
8
  description = "A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -0,0 +1,62 @@
1
+ from numerapi import NumerAPI
2
+ import pandas as pd
3
+ import numpy as np
4
+ from warpgbm import WarpGBM
5
+ import time
6
+ from sklearn.metrics import mean_squared_error
7
+
8
+
9
+ def predict_in_chunks(model, X, chunk_size=100_000):
10
+ preds = []
11
+ for i in range(0, X.shape[0], chunk_size):
12
+ X_chunk = X[i : i + chunk_size]
13
+ preds.append(model.predict(X_chunk))
14
+ return np.concatenate(preds)
15
+
16
+
17
+ def test_numerai_data():
18
+ napi = NumerAPI()
19
+ napi.download_dataset("v5.0/train.parquet", "numerai_train.parquet")
20
+
21
+ data = pd.read_parquet("numerai_train.parquet")
22
+ features = [f for f in list(data) if "feature" in f][:1000]
23
+ target = "target"
24
+
25
+ X = data[features].astype("int8").values[:]
26
+ y = data[target].values
27
+
28
+ model = WarpGBM(
29
+ max_depth=10,
30
+ num_bins=5,
31
+ n_estimators=100,
32
+ learning_rate=1,
33
+ threads_per_block=64,
34
+ rows_per_thread=4,
35
+ colsample_bytree=0.8,
36
+ )
37
+
38
+ start_fit = time.time()
39
+ model.fit(
40
+ X,
41
+ y,
42
+ # era_id=era,
43
+ # X_eval=X,
44
+ # y_eval=y,
45
+ # eval_every_n_trees=10,
46
+ # early_stopping_rounds=1,
47
+ )
48
+ fit_time = time.time() - start_fit
49
+ print(f" Fit time: {fit_time:.3f} seconds")
50
+
51
+ start_pred = time.time()
52
+ preds = predict_in_chunks(model, X, chunk_size=500_000)
53
+ pred_time = time.time() - start_pred
54
+ print(f" Predict time: {pred_time:.3f} seconds")
55
+
56
+ corr = np.corrcoef(preds, y)[0, 1]
57
+ mse = mean_squared_error(preds, y)
58
+ print(f" Correlation: {corr:.4f}")
59
+ print(f" MSE: {mse:.4f}")
60
+
61
+ assert corr > 0.68, f"In-sample correlation too low: {corr}"
62
+ assert mse < 0.03, f"In-sample mse too high: {mse}"
@@ -29,7 +29,15 @@ def test_fit_predictpytee_correlation():
29
29
  )
30
30
 
31
31
  start_fit = time.time()
32
- model.fit(X, y, era_id=era)
32
+ model.fit(
33
+ X,
34
+ y,
35
+ era_id=era,
36
+ X_eval=X,
37
+ y_eval=y,
38
+ eval_every_n_trees=10,
39
+ early_stopping_rounds=1,
40
+ )
33
41
  fit_time = time.time() - start_fit
34
42
  print(f" Fit time: {fit_time:.3f} seconds")
35
43
 
@@ -0,0 +1 @@
1
+ 0.1.24
@@ -5,6 +5,7 @@ from warpgbm.cuda import node_kernel
5
5
  from tqdm import tqdm
6
6
  from typing import Tuple
7
7
  from torch import Tensor
8
+ import gc
8
9
 
9
10
  histogram_kernels = {
10
11
  "hist1": node_kernel.compute_histogram,
@@ -29,6 +30,7 @@ class WarpGBM(BaseEstimator, RegressorMixin):
29
30
  L2_reg=1e-6,
30
31
  L1_reg=0.0,
31
32
  device="cuda",
33
+ colsample_bytree=1.0,
32
34
  ):
33
35
  # Validate arguments
34
36
  self._validate_hyperparams(
@@ -43,6 +45,7 @@ class WarpGBM(BaseEstimator, RegressorMixin):
43
45
  rows_per_thread=rows_per_thread,
44
46
  L2_reg=L2_reg,
45
47
  L1_reg=L1_reg,
48
+ colsample_bytree=colsample_bytree,
46
49
  )
47
50
 
48
51
  self.num_bins = num_bins
@@ -70,6 +73,8 @@ class WarpGBM(BaseEstimator, RegressorMixin):
70
73
  self.rows_per_thread = rows_per_thread
71
74
  self.L2_reg = L2_reg
72
75
  self.L1_reg = L1_reg
76
+ self.forest = [{} for _ in range(self.n_estimators)]
77
+ self.colsample_bytree = colsample_bytree
73
78
 
74
79
  def _validate_hyperparams(self, **kwargs):
75
80
  # Type checks
@@ -81,7 +86,13 @@ class WarpGBM(BaseEstimator, RegressorMixin):
81
86
  "threads_per_block",
82
87
  "rows_per_thread",
83
88
  ]
84
- float_params = ["learning_rate", "min_split_gain", "L2_reg", "L1_reg"]
89
+ float_params = [
90
+ "learning_rate",
91
+ "min_split_gain",
92
+ "L2_reg",
93
+ "L1_reg",
94
+ "colsample_bytree",
95
+ ]
85
96
 
86
97
  for param in int_params:
87
98
  if not isinstance(kwargs[param], int):
@@ -121,10 +132,100 @@ class WarpGBM(BaseEstimator, RegressorMixin):
121
132
  raise ValueError(
122
133
  f"Invalid histogram_computer: {kwargs['histogram_computer']}. Choose from {list(histogram_kernels.keys())}."
123
134
  )
135
+ if kwargs["colsample_bytree"] <= 0 or kwargs["colsample_bytree"] > 1:
136
+ raise ValueError(
137
+ f"Invalid colsample_bytree: {kwargs['colsample_bytree']}. Must be a float value > 0 and <= 1."
138
+ )
139
+
140
+ def validate_fit_params(
141
+ self, X, y, era_id, X_eval, y_eval, eval_every_n_trees, early_stopping_rounds
142
+ ):
143
+ # ─── Required: X and y ───
144
+ if not isinstance(X, np.ndarray) or not isinstance(y, np.ndarray):
145
+ raise TypeError("X and y must be numpy arrays.")
146
+ if X.ndim != 2:
147
+ raise ValueError(f"X must be 2-dimensional, got shape {X.shape}")
148
+ if y.ndim != 1:
149
+ raise ValueError(f"y must be 1-dimensional, got shape {y.shape}")
150
+ if X.shape[0] != y.shape[0]:
151
+ raise ValueError(
152
+ f"X and y must have the same number of rows. Got {X.shape[0]} and {y.shape[0]}."
153
+ )
154
+
155
+ # ─── Optional: era_id ───
156
+ if era_id is not None:
157
+ if not isinstance(era_id, np.ndarray):
158
+ raise TypeError("era_id must be a numpy array.")
159
+ if era_id.ndim != 1:
160
+ raise ValueError(
161
+ f"era_id must be 1-dimensional, got shape {era_id.shape}"
162
+ )
163
+ if len(era_id) != len(y):
164
+ raise ValueError(
165
+ f"era_id must have same length as y. Got {len(era_id)} and {len(y)}."
166
+ )
167
+
168
+ # ─── Optional: Eval Set ───
169
+ eval_args = [X_eval, y_eval, eval_every_n_trees]
170
+ if any(arg is not None for arg in eval_args):
171
+ # Require all of them
172
+ if X_eval is None or y_eval is None or eval_every_n_trees is None:
173
+ raise ValueError(
174
+ "If using eval set, X_eval, y_eval, and eval_every_n_trees must all be defined."
175
+ )
176
+
177
+ if not isinstance(X_eval, np.ndarray) or not isinstance(y_eval, np.ndarray):
178
+ raise TypeError("X_eval and y_eval must be numpy arrays.")
179
+ if X_eval.ndim != 2:
180
+ raise ValueError(
181
+ f"X_eval must be 2-dimensional, got shape {X_eval.shape}"
182
+ )
183
+ if y_eval.ndim != 1:
184
+ raise ValueError(
185
+ f"y_eval must be 1-dimensional, got shape {y_eval.shape}"
186
+ )
187
+ if X_eval.shape[0] != y_eval.shape[0]:
188
+ raise ValueError(
189
+ f"X_eval and y_eval must have same number of rows. Got {X_eval.shape[0]} and {y_eval.shape[0]}."
190
+ )
191
+
192
+ if not isinstance(eval_every_n_trees, int) or eval_every_n_trees <= 0:
193
+ raise ValueError(
194
+ f"eval_every_n_trees must be a positive integer, got {eval_every_n_trees}."
195
+ )
196
+
197
+ if early_stopping_rounds is not None:
198
+ if (
199
+ not isinstance(early_stopping_rounds, int)
200
+ or early_stopping_rounds <= 0
201
+ ):
202
+ raise ValueError(
203
+ f"early_stopping_rounds must be a positive integer, got {early_stopping_rounds}."
204
+ )
205
+ else:
206
+ # No early stopping = set to "never trigger"
207
+ early_stopping_rounds = self.n_estimators + 1
208
+
209
+ return early_stopping_rounds # May have been defaulted here
210
+
211
+ def fit(
212
+ self,
213
+ X,
214
+ y,
215
+ era_id=None,
216
+ X_eval=None,
217
+ y_eval=None,
218
+ eval_every_n_trees=None,
219
+ early_stopping_rounds=None,
220
+ ):
221
+ early_stopping_rounds = self.validate_fit_params(
222
+ X, y, era_id, X_eval, y_eval, eval_every_n_trees, early_stopping_rounds
223
+ )
124
224
 
125
- def fit(self, X, y, era_id=None):
126
225
  if era_id is None:
127
226
  era_id = np.ones(X.shape[0], dtype="int32")
227
+
228
+ # Train data preprocessing
128
229
  self.bin_indices, era_indices, self.bin_edges, self.unique_eras, self.Y_gpu = (
129
230
  self.preprocess_gpu_data(X, y, era_id)
130
231
  )
@@ -137,8 +238,29 @@ class WarpGBM(BaseEstimator, RegressorMixin):
137
238
  self.best_bins = torch.zeros(
138
239
  self.num_features, device=self.device, dtype=torch.int32
139
240
  )
241
+ self.feature_indices = torch.arange(self.num_features, device=self.device)
242
+
243
+ # ─── Optional Eval Set ───
244
+ if X_eval is not None and y_eval is not None:
245
+ self.bin_indices_eval = self.bin_data_with_existing_edges(X_eval)
246
+ self.Y_gpu_eval = torch.from_numpy(y_eval).to(torch.float32).to(self.device)
247
+ self.eval_every_n_trees = eval_every_n_trees
248
+ self.early_stopping_rounds = early_stopping_rounds
249
+ else:
250
+ self.bin_indices_eval = None
251
+ self.Y_gpu_eval = None
252
+ self.eval_every_n_trees = None
253
+ self.early_stopping_rounds = None
254
+
255
+ # ─── Grow the forest ───
140
256
  with torch.no_grad():
141
- self.forest = self.grow_forest()
257
+ self.grow_forest()
258
+
259
+ del self.bin_indices
260
+ del self.Y_gpu
261
+
262
+ gc.collect()
263
+
142
264
  return self
143
265
 
144
266
  def preprocess_gpu_data(self, X_np, Y_np, era_id_np):
@@ -248,16 +370,16 @@ class WarpGBM(BaseEstimator, RegressorMixin):
248
370
  return {"leaf_value": leaf_value.item(), "samples": node_indices.numel()}
249
371
 
250
372
  parent_size = node_indices.numel()
251
- best_feature, best_bin = self.find_best_split(
373
+ local_feature, best_bin = self.find_best_split(
252
374
  gradient_histogram, hessian_histogram
253
375
  )
254
376
 
255
- if best_feature == -1:
377
+ if local_feature == -1:
256
378
  leaf_value = self.residual[node_indices].mean()
257
379
  self.gradients[node_indices] += self.learning_rate * leaf_value
258
380
  return {"leaf_value": leaf_value.item(), "samples": parent_size}
259
381
 
260
- split_mask = self.bin_indices[node_indices, best_feature] <= best_bin
382
+ split_mask = self.bin_indices_tree[node_indices, local_feature] <= best_bin
261
383
  left_indices = node_indices[split_mask]
262
384
  right_indices = node_indices[~split_mask]
263
385
 
@@ -266,13 +388,13 @@ class WarpGBM(BaseEstimator, RegressorMixin):
266
388
 
267
389
  if left_size <= right_size:
268
390
  grad_hist_left, hess_hist_left = self.compute_histograms(
269
- self.bin_indices[left_indices], self.residual[left_indices]
391
+ self.bin_indices_tree[left_indices], self.residual[left_indices]
270
392
  )
271
393
  grad_hist_right = gradient_histogram - grad_hist_left
272
394
  hess_hist_right = hessian_histogram - hess_hist_left
273
395
  else:
274
396
  grad_hist_right, hess_hist_right = self.compute_histograms(
275
- self.bin_indices[right_indices], self.residual[right_indices]
397
+ self.bin_indices_tree[right_indices], self.residual[right_indices]
276
398
  )
277
399
  grad_hist_left = gradient_histogram - grad_hist_right
278
400
  hess_hist_left = hessian_histogram - hess_hist_right
@@ -286,44 +408,79 @@ class WarpGBM(BaseEstimator, RegressorMixin):
286
408
  )
287
409
 
288
410
  return {
289
- "feature": best_feature,
411
+ "feature": self.feat_indices_tree[local_feature],
290
412
  "bin": best_bin,
291
413
  "left": left_child,
292
414
  "right": right_child,
293
415
  }
294
416
 
417
+ def compute_eval(self, i):
418
+ if self.eval_every_n_trees == None:
419
+ return
420
+
421
+ if i % self.eval_every_n_trees == 0:
422
+ eval_preds = self.predict_binned(self.bin_indices_eval)
423
+ eval_loss = ((self.Y_gpu_eval - eval_preds) ** 2).mean().item()
424
+ self.eval_loss.append(eval_loss)
425
+
426
+ train_loss = ((self.Y_gpu - self.gradients) ** 2).mean().item()
427
+ self.training_loss.append(train_loss)
428
+
429
+ if len(self.eval_loss) > self.early_stopping_rounds:
430
+ if self.eval_loss[-self.early_stopping_rounds] < self.eval_loss[-1]:
431
+ self.stop = True
432
+
433
+ print(
434
+ f"🌲 Tree {i+1}/{self.n_estimators} | Train MSE: {train_loss:.6f} | Eval MSE: {eval_loss:.6f}"
435
+ )
436
+
437
+ del eval_preds, eval_loss, train_loss
438
+
295
439
  def grow_forest(self):
296
- forest = [{} for _ in range(self.n_estimators)]
297
440
  self.training_loss = []
441
+ self.eval_loss = [] # if eval set is given
442
+ self.stop = False
298
443
 
299
- for i in tqdm(range(self.n_estimators)):
444
+ if self.colsample_bytree < 1.0:
445
+ k = max(1, int(self.colsample_bytree * self.num_features))
446
+ else:
447
+ self.feat_indices_tree = self.feature_indices
448
+ self.bin_indices_tree = self.bin_indices
449
+
450
+ for i in range(self.n_estimators):
300
451
  self.residual = self.Y_gpu - self.gradients
301
452
 
453
+ if self.colsample_bytree < 1.0:
454
+ self.feat_indices_tree = torch.randperm(
455
+ self.num_features, device=self.device
456
+ )[:k]
457
+ self.bin_indices_tree = self.bin_indices[:, self.feat_indices_tree]
458
+
302
459
  self.root_gradient_histogram, self.root_hessian_histogram = (
303
- self.compute_histograms(self.bin_indices, self.residual)
460
+ self.compute_histograms(self.bin_indices_tree, self.residual)
304
461
  )
305
462
 
306
463
  tree = self.grow_tree(
307
464
  self.root_gradient_histogram,
308
465
  self.root_hessian_histogram,
309
466
  self.root_node_indices,
310
- depth=0,
467
+ 0,
311
468
  )
312
- forest[i] = tree
313
- # loss = ((self.Y_gpu - self.gradients) ** 2).mean().item()
314
- # self.training_loss.append(loss)
315
- # print(f"🌲 Tree {i+1}/{self.n_estimators} - MSE: {loss:.6f}")
469
+ self.forest[i] = tree
470
+
471
+ self.compute_eval(i)
472
+
473
+ if self.stop:
474
+ break
316
475
 
317
476
  print("Finished training forest.")
318
- return forest
319
477
 
320
- def predict(self, X_np):
321
- X_tensor = torch.from_numpy(X_np).to(torch.float32).pin_memory()
478
+ def bin_data_with_existing_edges(self, X_np):
479
+ X_tensor = torch.from_numpy(X_np).type(torch.float32).pin_memory()
322
480
  num_samples = X_tensor.size(0)
323
481
  bin_indices = torch.zeros(
324
482
  (num_samples, self.num_features), dtype=torch.int8, device=self.device
325
483
  )
326
-
327
484
  with torch.no_grad():
328
485
  for f in range(self.num_features):
329
486
  X_f = X_tensor[:, f].to(self.device, non_blocking=True)
@@ -332,10 +489,16 @@ class WarpGBM(BaseEstimator, RegressorMixin):
332
489
  node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
333
490
  bin_indices[:, f] = bin_indices_f
334
491
 
492
+ return bin_indices
493
+
494
+ def predict_binned(self, bin_indices):
495
+ num_samples = bin_indices.size(0)
496
+
335
497
  tree_tensor = torch.stack(
336
498
  [
337
499
  self.flatten_tree(tree, max_nodes=2 ** (self.max_depth + 1))
338
500
  for tree in self.forest
501
+ if tree
339
502
  ]
340
503
  ).to(self.device)
341
504
 
@@ -344,24 +507,33 @@ class WarpGBM(BaseEstimator, RegressorMixin):
344
507
  bin_indices.contiguous(), tree_tensor.contiguous(), self.learning_rate, out
345
508
  )
346
509
 
347
- return out.cpu().numpy()
510
+ return out
348
511
 
349
- def flatten_tree(self, tree, max_nodes):
350
- """
351
- Convert a recursive tree structure into a flat matrix format.
512
+ def predict(self, X_np):
513
+ is_integer_type = np.issubdtype(X_np.dtype, np.integer)
352
514
 
353
- Each row in the output represents a node:
354
- - Columns: [feature, bin, left_id, right_id, is_leaf, value]
355
- - Internal nodes fill columns 0–3 and set is_leaf = 0
356
- - Leaf nodes fill only value and set is_leaf = 1
515
+ if is_integer_type and X_np.shape[1] == self.num_features:
516
+ max_vals = X_np.max(axis=0)
517
+ if np.all(max_vals < self.num_bins):
518
+ print("Detected pre-binned input at predict-time skipping binning.")
519
+ is_prebinned = True
520
+ else:
521
+ is_prebinned = False
522
+ else:
523
+ is_prebinned = False
357
524
 
358
- Args:
359
- tree (list): A list containing a single root node (recursive dict form).
360
- max_nodes (int): Max number of nodes to allocate in the flat matrix.
525
+ if is_prebinned:
526
+ bin_indices = (
527
+ torch.from_numpy(X_np).to(self.device).contiguous().to(torch.int8)
528
+ )
529
+ else:
530
+ bin_indices = self.bin_data_with_existing_edges(X_np)
531
+
532
+ preds = self.predict_binned(bin_indices).cpu().numpy()
533
+ del bin_indices
534
+ return preds
361
535
 
362
- Returns:
363
- torch.Tensor: [max_nodes x 6] matrix representing the flattened tree.
364
- """
536
+ def flatten_tree(self, tree, max_nodes):
365
537
  flat = torch.full((max_nodes, 6), float("nan"), dtype=torch.float32)
366
538
  node_counter = [0]
367
539
  node_list = []
@@ -38,7 +38,7 @@ __global__ void best_split_kernel_global_only(
38
38
 
39
39
  if (H_L >= min_child_samples && H_R >= min_child_samples)
40
40
  {
41
- float gain = (G_L * G_L) / (H_L + eps) + (G_R * G_R) / (H_R + eps);
41
+ float gain = (G_L * G_L) / (H_L + eps) + (G_R * G_R) / (H_R + eps) - (G_total * G_total) / (H_total + eps);
42
42
  if (gain > best_gain)
43
43
  {
44
44
  best_gain = gain;
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: warpgbm
3
- Version: 0.1.22
3
+ Version: 0.1.24
4
4
  Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
5
5
  License: GNU GENERAL PUBLIC LICENSE
6
6
  Version 3, 29 June 2007
@@ -879,8 +879,26 @@ No installation required — just press **"Open in Playground"**, then **Run All
879
879
  - `L2_reg`: L2 regularizer (default: 1e-6)
880
880
 
881
881
  ### Methods:
882
- - `.fit(X, y, era_id=None)`: Train the model. `X` can be raw floats or pre-binned `int8` data. `era_id` is optional and used internally.
883
- - `.predict(X)`: Predict on new data, using parallelized CUDA kernel.
882
+ ```
883
+ .fit(
884
+ X, # numpy array (float or int) 2 dimensions (num_samples, num_features)
885
+ y, # numpy array (float or int) 1 dimension (num_samples)
886
+ era_id=None, # numpy array (int) 1 dimension (num_samples)
887
+ X_eval=None, # numpy array (float or int) 2 dimensions (eval_num_samples, num_features)
888
+ y_eval=None, # numpy array (float or int) 1 dimension (eval_num_samples)
889
+ eval_every_n_trees=None, # const (int) >= 1
890
+ early_stopping_rounds=None, # const (int) >= 1
891
+ )
892
+ ```
893
+ Train with optional validation set and early stopping.
894
+
895
+
896
+ ```
897
+ .predict(
898
+ X # numpy array (float or int) 2 dimensions (predict_num_samples, num_features)
899
+ )
900
+ ```
901
+ Predict on new data, using parallelized CUDA kernel.
884
902
 
885
903
  ---
886
904
 
@@ -896,3 +914,7 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
896
914
 
897
915
  - Vectorized predict function replaced with CUDA kernel (`warpgbm/cuda/predict.cu`), parallelizing per sample, per tree.
898
916
 
917
+ ### v0.1.23
918
+
919
+ - Adjust gain in split kernel and added support for an eval set with early stopping based on MSE.
920
+
@@ -5,6 +5,7 @@ pyproject.toml
5
5
  setup.py
6
6
  version.txt
7
7
  tests/__init__.py
8
+ tests/numerai_test.py
8
9
  tests/test_fit_predict_corr.py
9
10
  warpgbm/__init__.py
10
11
  warpgbm/core.py
@@ -1 +0,0 @@
1
- 0.1.22
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes