warpgbm 0.1.25__tar.gz → 0.1.27__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {warpgbm-0.1.25/warpgbm.egg-info → warpgbm-0.1.27}/PKG-INFO +6 -1
  2. {warpgbm-0.1.25 → warpgbm-0.1.27}/README.md +6 -1
  3. {warpgbm-0.1.25 → warpgbm-0.1.27}/pyproject.toml +2 -2
  4. warpgbm-0.1.27/tests/full_numerai_test.py +67 -0
  5. warpgbm-0.1.27/tests/test_fit_predict_corr.py +52 -0
  6. warpgbm-0.1.27/version.txt +1 -0
  7. {warpgbm-0.1.25 → warpgbm-0.1.27}/warpgbm/core.py +106 -91
  8. warpgbm-0.1.27/warpgbm/cuda/histogram_kernel.cu +95 -0
  9. {warpgbm-0.1.25 → warpgbm-0.1.27}/warpgbm/cuda/node_kernel.cpp +3 -20
  10. {warpgbm-0.1.25 → warpgbm-0.1.27}/warpgbm/cuda/predict.cu +20 -21
  11. warpgbm-0.1.27/warpgbm/metrics.py +10 -0
  12. {warpgbm-0.1.25 → warpgbm-0.1.27/warpgbm.egg-info}/PKG-INFO +6 -1
  13. {warpgbm-0.1.25 → warpgbm-0.1.27}/warpgbm.egg-info/SOURCES.txt +2 -0
  14. warpgbm-0.1.25/tests/test_fit_predict_corr.py +0 -57
  15. warpgbm-0.1.25/version.txt +0 -1
  16. warpgbm-0.1.25/warpgbm/cuda/histogram_kernel.cu +0 -250
  17. {warpgbm-0.1.25 → warpgbm-0.1.27}/LICENSE +0 -0
  18. {warpgbm-0.1.25 → warpgbm-0.1.27}/MANIFEST.in +0 -0
  19. {warpgbm-0.1.25 → warpgbm-0.1.27}/setup.cfg +0 -0
  20. {warpgbm-0.1.25 → warpgbm-0.1.27}/setup.py +0 -0
  21. {warpgbm-0.1.25 → warpgbm-0.1.27}/tests/__init__.py +0 -0
  22. {warpgbm-0.1.25 → warpgbm-0.1.27}/tests/numerai_test.py +0 -0
  23. {warpgbm-0.1.25 → warpgbm-0.1.27}/warpgbm/__init__.py +0 -0
  24. {warpgbm-0.1.25 → warpgbm-0.1.27}/warpgbm/cuda/__init__.py +0 -0
  25. {warpgbm-0.1.25 → warpgbm-0.1.27}/warpgbm/cuda/best_split_kernel.cu +0 -0
  26. {warpgbm-0.1.25 → warpgbm-0.1.27}/warpgbm/cuda/binner.cu +0 -0
  27. {warpgbm-0.1.25 → warpgbm-0.1.27}/warpgbm.egg-info/dependency_links.txt +0 -0
  28. {warpgbm-0.1.25 → warpgbm-0.1.27}/warpgbm.egg-info/requires.txt +0 -0
  29. {warpgbm-0.1.25 → warpgbm-0.1.27}/warpgbm.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: warpgbm
3
- Version: 0.1.25
3
+ Version: 0.1.27
4
4
  Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
5
5
  License: GNU GENERAL PUBLIC LICENSE
6
6
  Version 3, 29 June 2007
@@ -889,6 +889,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
889
889
  y_eval=None, # numpy array (float or int) 1 dimension (eval_num_samples)
890
890
  eval_every_n_trees=None, # const (int) >= 1
891
891
  early_stopping_rounds=None, # const (int) >= 1
892
+ eval_metric='mse' # string, one of 'mse' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
892
893
  )
893
894
  ```
894
895
  Train with optional validation set and early stopping.
@@ -922,3 +923,7 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
922
923
  ### v0.1.25
923
924
 
924
925
  - Added `colsample_bytree` parameter and new test using Numerai data.
926
+
927
+ ### v0.1.26
928
+
929
+ - Fix Memory bugs in prediction and colsample bytree logic. Added "corr" eval metric.
@@ -201,6 +201,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
201
201
  y_eval=None, # numpy array (float or int) 1 dimension (eval_num_samples)
202
202
  eval_every_n_trees=None, # const (int) >= 1
203
203
  early_stopping_rounds=None, # const (int) >= 1
204
+ eval_metric='mse' # string, one of 'mse' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
204
205
  )
205
206
  ```
206
207
  Train with optional validation set and early stopping.
@@ -233,4 +234,8 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
233
234
 
234
235
  ### v0.1.25
235
236
 
236
- - Added `colsample_bytree` parameter and new test using Numerai data.
237
+ - Added `colsample_bytree` parameter and new test using Numerai data.
238
+
239
+ ### v0.1.26
240
+
241
+ - Fix Memory bugs in prediction and colsample bytree logic. Added "corr" eval metric.
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "warpgbm"
7
- version = "0.1.25"
7
+ version = "0.1.27"
8
8
  description = "A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -13,5 +13,5 @@ dependencies = [
13
13
  "torch",
14
14
  "numpy",
15
15
  "tqdm",
16
- "scikit-learn"
16
+ "scikit-learn"
17
17
  ]
@@ -0,0 +1,67 @@
1
+ from numerapi import NumerAPI
2
+ import pandas as pd
3
+ import numpy as np
4
+ from warpgbm import WarpGBM
5
+ import time
6
+ from sklearn.metrics import mean_squared_error
7
+
8
+
9
+ def predict_in_chunks(model, X, chunk_size=100_000):
10
+ preds = []
11
+ for i in range(0, X.shape[0], chunk_size):
12
+ X_chunk = X[i : i + chunk_size]
13
+ preds.append(model.predict(X_chunk))
14
+ return np.concatenate(preds)
15
+
16
+
17
+ def test_numerai_data():
18
+ napi = NumerAPI()
19
+ napi.download_dataset("v5.0/train.parquet", "numerai_train.parquet")
20
+ napi.download_dataset("v5.0/validation.parquet", "numerai_validation.parquet")
21
+
22
+ data = pd.concat([
23
+ pd.read_parquet("numerai_train.parquet"),
24
+ pd.read_parquet("numerai_validation.parquet")
25
+ ])
26
+ features = [f for f in list(data) if "feature" in f]
27
+ target = "target"
28
+ data = data.loc[data[ target].isna() == False ]
29
+
30
+ X = data[features].astype("int8").values[:]
31
+ y = data[target].values
32
+
33
+ model = WarpGBM(
34
+ max_depth=3,
35
+ num_bins=5,
36
+ n_estimators=10,
37
+ learning_rate=1,
38
+ threads_per_block=64,
39
+ rows_per_thread=4,
40
+ colsample_bytree=0.8,
41
+ )
42
+
43
+ start_fit = time.time()
44
+ model.fit(
45
+ X,
46
+ y,
47
+ # era_id=era,
48
+ # X_eval=X,
49
+ # y_eval=y,
50
+ # eval_every_n_trees=10,
51
+ # early_stopping_rounds=1,
52
+ )
53
+ fit_time = time.time() - start_fit
54
+ print(f" Fit time: {fit_time:.3f} seconds")
55
+
56
+ start_pred = time.time()
57
+ preds = predict_in_chunks(model, X, chunk_size=500_000)
58
+ pred_time = time.time() - start_pred
59
+ print(f" Predict time: {pred_time:.3f} seconds")
60
+
61
+ corr = np.corrcoef(preds, y)[0, 1]
62
+ mse = mean_squared_error(preds, y)
63
+ print(f" Correlation: {corr:.4f}")
64
+ print(f" MSE: {mse:.4f}")
65
+
66
+ # assert corr > 0.68, f"In-sample correlation too low: {corr}"
67
+ # assert mse < 0.03, f"In-sample mse too high: {mse}"
@@ -0,0 +1,52 @@
1
+ import numpy as np
2
+ from warpgbm import WarpGBM
3
+ from sklearn.datasets import make_regression
4
+ import time
5
+ from sklearn.metrics import mean_squared_error
6
+
7
+
8
+ def test_fit_predictpytee_correlation():
9
+ np.random.seed(42)
10
+ N = 100_000
11
+ F = 1000
12
+ X, y = make_regression(n_samples=N, n_features=F, noise=0.1, random_state=42)
13
+ era = np.zeros(N, dtype=np.int32)
14
+ corrs = []
15
+ mses = []
16
+
17
+ model = WarpGBM(
18
+ max_depth=10,
19
+ num_bins=10,
20
+ n_estimators=100,
21
+ learning_rate=1,
22
+ threads_per_block=64,
23
+ rows_per_thread=4,
24
+ colsample_bytree=1.0,
25
+ )
26
+
27
+ start_fit = time.time()
28
+ model.fit(
29
+ X,
30
+ y,
31
+ era_id=era,
32
+ X_eval=X,
33
+ y_eval=y,
34
+ eval_every_n_trees=10,
35
+ early_stopping_rounds=1,
36
+ eval_metric="corr",
37
+ )
38
+ fit_time = time.time() - start_fit
39
+ print(f" Fit time: {fit_time:.3f} seconds")
40
+
41
+ start_pred = time.time()
42
+ preds = model.predict(X)
43
+ pred_time = time.time() - start_pred
44
+ print(f" Predict time: {pred_time:.3f} seconds")
45
+
46
+ corr = np.corrcoef(preds, y)[0, 1]
47
+ mse = mean_squared_error(preds, y)
48
+ print(f" Correlation: {corr:.4f}")
49
+ print(f" MSE: {mse:.4f}")
50
+
51
+ assert (corr > 0.9), f"In-sample correlation too low: {corrs}"
52
+ assert (mse < 2), f"In-sample mse too high: {mses}"
@@ -0,0 +1 @@
1
+ 0.1.27
@@ -1,19 +1,14 @@
1
1
  import torch
2
2
  import numpy as np
3
3
  from sklearn.base import BaseEstimator, RegressorMixin
4
+ from sklearn.metrics import mean_squared_log_error
4
5
  from warpgbm.cuda import node_kernel
6
+ from warpgbm.metrics import rmsle_torch
5
7
  from tqdm import tqdm
6
8
  from typing import Tuple
7
9
  from torch import Tensor
8
10
  import gc
9
11
 
10
- histogram_kernels = {
11
- "hist1": node_kernel.compute_histogram,
12
- "hist2": node_kernel.compute_histogram2,
13
- "hist3": node_kernel.compute_histogram3,
14
- }
15
-
16
-
17
12
  class WarpGBM(BaseEstimator, RegressorMixin):
18
13
  def __init__(
19
14
  self,
@@ -23,8 +18,6 @@ class WarpGBM(BaseEstimator, RegressorMixin):
23
18
  n_estimators=100,
24
19
  min_child_weight=20,
25
20
  min_split_gain=0.0,
26
- verbosity=True,
27
- histogram_computer="hist3",
28
21
  threads_per_block=64,
29
22
  rows_per_thread=4,
30
23
  L2_reg=1e-6,
@@ -40,7 +33,6 @@ class WarpGBM(BaseEstimator, RegressorMixin):
40
33
  n_estimators=n_estimators,
41
34
  min_child_weight=min_child_weight,
42
35
  min_split_gain=min_split_gain,
43
- histogram_computer=histogram_computer,
44
36
  threads_per_block=threads_per_block,
45
37
  rows_per_thread=rows_per_thread,
46
38
  L2_reg=L2_reg,
@@ -68,7 +60,6 @@ class WarpGBM(BaseEstimator, RegressorMixin):
68
60
  self.min_child_weight = min_child_weight
69
61
  self.min_split_gain = min_split_gain
70
62
  self.best_bin = torch.tensor([-1], dtype=torch.int32, device=self.device)
71
- self.compute_histogram = histogram_kernels[histogram_computer]
72
63
  self.threads_per_block = threads_per_block
73
64
  self.rows_per_thread = rows_per_thread
74
65
  self.L2_reg = L2_reg
@@ -128,17 +119,13 @@ class WarpGBM(BaseEstimator, RegressorMixin):
128
119
  )
129
120
  if kwargs["L2_reg"] < 0 or kwargs["L1_reg"] < 0:
130
121
  raise ValueError("L2_reg and L1_reg must be non-negative.")
131
- if kwargs["histogram_computer"] not in histogram_kernels:
132
- raise ValueError(
133
- f"Invalid histogram_computer: {kwargs['histogram_computer']}. Choose from {list(histogram_kernels.keys())}."
134
- )
135
122
  if kwargs["colsample_bytree"] <= 0 or kwargs["colsample_bytree"] > 1:
136
123
  raise ValueError(
137
124
  f"Invalid colsample_bytree: {kwargs['colsample_bytree']}. Must be a float value > 0 and <= 1."
138
125
  )
139
126
 
140
127
  def validate_fit_params(
141
- self, X, y, era_id, X_eval, y_eval, eval_every_n_trees, early_stopping_rounds
128
+ self, X, y, era_id, X_eval, y_eval, eval_every_n_trees, early_stopping_rounds, eval_metric
142
129
  ):
143
130
  # ─── Required: X and y ───
144
131
  if not isinstance(X, np.ndarray) or not isinstance(y, np.ndarray):
@@ -206,6 +193,11 @@ class WarpGBM(BaseEstimator, RegressorMixin):
206
193
  # No early stopping = set to "never trigger"
207
194
  early_stopping_rounds = self.n_estimators + 1
208
195
 
196
+ if eval_metric not in ["mse", "corr", "rmsle"]:
197
+ raise ValueError(
198
+ f"Invalid eval_metric: {eval_metric}. Choose 'mse' or 'corr', 'rmsle'."
199
+ )
200
+
209
201
  return early_stopping_rounds # May have been defaulted here
210
202
 
211
203
  def fit(
@@ -217,9 +209,10 @@ class WarpGBM(BaseEstimator, RegressorMixin):
217
209
  y_eval=None,
218
210
  eval_every_n_trees=None,
219
211
  early_stopping_rounds=None,
212
+ eval_metric = "mse",
220
213
  ):
221
214
  early_stopping_rounds = self.validate_fit_params(
222
- X, y, era_id, X_eval, y_eval, eval_every_n_trees, early_stopping_rounds
215
+ X, y, era_id, X_eval, y_eval, eval_every_n_trees, early_stopping_rounds, eval_metric
223
216
  )
224
217
 
225
218
  if era_id is None:
@@ -231,21 +224,24 @@ class WarpGBM(BaseEstimator, RegressorMixin):
231
224
  )
232
225
  self.num_samples, self.num_features = X.shape
233
226
  self.gradients = torch.zeros_like(self.Y_gpu)
234
- self.root_node_indices = torch.arange(self.num_samples, device=self.device)
227
+ self.root_node_indices = torch.arange(self.num_samples, device=self.device, dtype=torch.int32)
235
228
  self.base_prediction = self.Y_gpu.mean().item()
236
229
  self.gradients += self.base_prediction
237
- self.best_gains = torch.zeros(self.num_features, device=self.device)
238
- self.best_bins = torch.zeros(
239
- self.num_features, device=self.device, dtype=torch.int32
240
- )
241
- self.feature_indices = torch.arange(self.num_features, device=self.device)
230
+ if self.colsample_bytree < 1.0:
231
+ k = max(1, int(self.colsample_bytree * self.num_features))
232
+ else:
233
+ k = self.num_features
234
+ self.best_gains = torch.zeros(k, device=self.device)
235
+ self.best_bins = torch.zeros(k, device=self.device, dtype=torch.int32)
236
+ self.feature_indices = torch.arange(self.num_features, device=self.device, dtype=torch.int32)
242
237
 
243
238
  # ─── Optional Eval Set ───
244
239
  if X_eval is not None and y_eval is not None:
245
- self.bin_indices_eval = self.bin_data_with_existing_edges(X_eval)
240
+ self.bin_indices_eval = self.bin_inference_data(X_eval)
246
241
  self.Y_gpu_eval = torch.from_numpy(y_eval).to(torch.float32).to(self.device)
247
242
  self.eval_every_n_trees = eval_every_n_trees
248
243
  self.early_stopping_rounds = early_stopping_rounds
244
+ self.eval_metric = eval_metric
249
245
  else:
250
246
  self.bin_indices_eval = None
251
247
  self.Y_gpu_eval = None
@@ -266,50 +262,47 @@ class WarpGBM(BaseEstimator, RegressorMixin):
266
262
  def preprocess_gpu_data(self, X_np, Y_np, era_id_np):
267
263
  with torch.no_grad():
268
264
  self.num_samples, self.num_features = X_np.shape
265
+
269
266
  Y_gpu = torch.from_numpy(Y_np).type(torch.float32).to(self.device)
270
- era_id_gpu = torch.from_numpy(era_id_np).type(torch.int32).to(self.device)
271
- is_integer_type = np.issubdtype(X_np.dtype, np.integer)
272
- if is_integer_type:
273
- max_vals = X_np.max(axis=0)
274
- if np.all(max_vals < self.num_bins):
275
- print(
276
- "Detected pre-binned integer input — skipping quantile binning."
277
- )
278
- bin_indices = (
279
- torch.from_numpy(X_np)
280
- .to(self.device)
281
- .contiguous()
282
- .to(torch.int8)
283
- )
284
267
 
285
- # We'll store None or an empty tensor in self.bin_edges
286
- # to indicate that we skip binning at predict-time
287
- bin_edges = torch.arange(
288
- 1, self.num_bins, dtype=torch.float32
289
- ).repeat(self.num_features, 1)
290
- bin_edges = bin_edges.to(self.device)
291
- unique_eras, era_indices = torch.unique(
292
- era_id_gpu, return_inverse=True
293
- )
294
- return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
295
- else:
296
- print(
297
- "Integer input detected, but values exceed num_bins — falling back to quantile binning."
298
- )
268
+ era_id_gpu = torch.from_numpy(era_id_np).type(torch.int32).to(self.device)
299
269
 
300
270
  bin_indices = torch.empty(
301
271
  (self.num_samples, self.num_features), dtype=torch.int8, device="cuda"
302
272
  )
273
+
274
+ is_integer_type = np.issubdtype(X_np.dtype, np.integer)
275
+ max_vals = X_np.max(axis=0)
276
+
277
+ if is_integer_type and np.all(max_vals < self.num_bins):
278
+ print(
279
+ "Detected pre-binned integer input — skipping quantile binning."
280
+ )
281
+ for f in range(self.num_features):
282
+ bin_indices[:,f] = torch.as_tensor( X_np[:, f], device=self.device).contiguous()
283
+ # bin_indices = X_np.to("cuda", non_blocking=True).contiguous()
284
+
285
+ # We'll store None or an empty tensor in self.bin_edges
286
+ # to indicate that we skip binning at predict-time
287
+ bin_edges = torch.arange(
288
+ 1, self.num_bins, dtype=torch.float32
289
+ ).repeat(self.num_features, 1)
290
+ bin_edges = bin_edges.to(self.device)
291
+ unique_eras, era_indices = torch.unique(
292
+ era_id_gpu, return_inverse=True
293
+ )
294
+ return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
295
+
296
+ print("quantile binning.")
297
+
303
298
  bin_edges = torch.empty(
304
299
  (self.num_features, self.num_bins - 1),
305
300
  dtype=torch.float32,
306
301
  device="cuda",
307
302
  )
308
303
 
309
- X_np = torch.from_numpy(X_np).to(torch.float32).pin_memory()
310
-
311
304
  for f in range(self.num_features):
312
- X_f = X_np[:, f].to("cuda", non_blocking=True)
305
+ X_f = torch.as_tensor( X_np[:, f], device=self.device, dtype=torch.float32 ).contiguous()
313
306
  quantiles = torch.linspace(
314
307
  0, 1, self.num_bins + 1, device="cuda", dtype=X_f.dtype
315
308
  )[1:-1]
@@ -324,17 +317,19 @@ class WarpGBM(BaseEstimator, RegressorMixin):
324
317
  unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
325
318
  return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu
326
319
 
327
- def compute_histograms(self, bin_indices_sub, gradients):
320
+ def compute_histograms(self, sample_indices, feature_indices):
328
321
  grad_hist = torch.zeros(
329
- (self.num_features, self.num_bins), device=self.device, dtype=torch.float32
322
+ (len(feature_indices), self.num_bins), device=self.device, dtype=torch.float32
330
323
  )
331
324
  hess_hist = torch.zeros(
332
- (self.num_features, self.num_bins), device=self.device, dtype=torch.float32
325
+ (len(feature_indices), self.num_bins), device=self.device, dtype=torch.float32
333
326
  )
334
327
 
335
- self.compute_histogram(
336
- bin_indices_sub,
337
- gradients,
328
+ node_kernel.compute_histogram3(
329
+ self.bin_indices,
330
+ self.residual,
331
+ sample_indices,
332
+ feature_indices,
338
333
  grad_hist,
339
334
  hess_hist,
340
335
  self.num_bins,
@@ -357,6 +352,9 @@ class WarpGBM(BaseEstimator, RegressorMixin):
357
352
 
358
353
  if torch.all(self.best_bins == -1):
359
354
  return -1, -1 # No valid split found
355
+
356
+ # print(self.best_bins)
357
+ # print(self.best_gains)
360
358
 
361
359
  f = torch.argmax(self.best_gains).item()
362
360
  b = self.best_bins[f].item()
@@ -374,28 +372,38 @@ class WarpGBM(BaseEstimator, RegressorMixin):
374
372
  gradient_histogram, hessian_histogram
375
373
  )
376
374
 
375
+ # print(local_feature, best_bin)
376
+
377
377
  if local_feature == -1:
378
378
  leaf_value = self.residual[node_indices].mean()
379
379
  self.gradients[node_indices] += self.learning_rate * leaf_value
380
380
  return {"leaf_value": leaf_value.item(), "samples": parent_size}
381
-
382
- split_mask = self.bin_indices_tree[node_indices, local_feature] <= best_bin
381
+
382
+ # print("DEBUG SHAPES -> bin_indices:", self.bin_indices.shape,
383
+ # "| node_indices max:", node_indices.max().item(),
384
+ # "| local_feature:", local_feature,
385
+ # "| feat_indices_tree len:", len(self.feat_indices_tree),
386
+ # "| feat index:", self.feat_indices_tree[local_feature])
387
+
388
+ split_mask = self.bin_indices[node_indices, self.feat_indices_tree[local_feature]] <= best_bin
383
389
  left_indices = node_indices[split_mask]
384
390
  right_indices = node_indices[~split_mask]
385
391
 
392
+ # print("DEBUG SHAPES -> left_indices:", left_indices.shape,
393
+ # "| right_indices:", right_indices.shape,
394
+ # "| parent_size:", parent_size,
395
+ # "| local_feature:", local_feature,
396
+ # "| best_bin:", best_bin)
397
+
386
398
  left_size = left_indices.numel()
387
399
  right_size = right_indices.numel()
388
400
 
389
401
  if left_size <= right_size:
390
- grad_hist_left, hess_hist_left = self.compute_histograms(
391
- self.bin_indices_tree[left_indices], self.residual[left_indices]
392
- )
402
+ grad_hist_left, hess_hist_left = self.compute_histograms( left_indices, self.feat_indices_tree )
393
403
  grad_hist_right = gradient_histogram - grad_hist_left
394
404
  hess_hist_right = hessian_histogram - hess_hist_left
395
405
  else:
396
- grad_hist_right, hess_hist_right = self.compute_histograms(
397
- self.bin_indices_tree[right_indices], self.residual[right_indices]
398
- )
406
+ grad_hist_right, hess_hist_right = self.compute_histograms( right_indices, self.feat_indices_tree )
399
407
  grad_hist_left = gradient_histogram - grad_hist_right
400
408
  hess_hist_left = hessian_histogram - hess_hist_right
401
409
 
@@ -413,25 +421,35 @@ class WarpGBM(BaseEstimator, RegressorMixin):
413
421
  "left": left_child,
414
422
  "right": right_child,
415
423
  }
424
+
425
+ def get_eval_metric(self, y_true, y_pred):
426
+ if self.eval_metric == "mse":
427
+ return ((y_true - y_pred) ** 2).mean().item()
428
+ elif self.eval_metric == "corr":
429
+ return 1 - torch.corrcoef(torch.vstack([y_true, y_pred]))[0, 1].item()
430
+ elif self.eval_metric == "rmsle":
431
+ return rmsle_torch(y_true, y_pred).item()
432
+ else:
433
+ raise ValueError(f"Invalid eval_metric: {self.eval_metric}.")
416
434
 
417
435
  def compute_eval(self, i):
418
436
  if self.eval_every_n_trees == None:
419
437
  return
438
+
439
+ train_loss = ((self.Y_gpu - self.gradients) ** 2).mean().item()
440
+ self.training_loss.append(train_loss)
420
441
 
421
442
  if i % self.eval_every_n_trees == 0:
422
443
  eval_preds = self.predict_binned(self.bin_indices_eval)
423
- eval_loss = ((self.Y_gpu_eval - eval_preds) ** 2).mean().item()
444
+ eval_loss = self.get_eval_metric( self.Y_gpu_eval, eval_preds )
424
445
  self.eval_loss.append(eval_loss)
425
446
 
426
- train_loss = ((self.Y_gpu - self.gradients) ** 2).mean().item()
427
- self.training_loss.append(train_loss)
428
-
429
447
  if len(self.eval_loss) > self.early_stopping_rounds:
430
- if self.eval_loss[-self.early_stopping_rounds] < self.eval_loss[-1]:
448
+ if self.eval_loss[-(self.early_stopping_rounds+1)] < self.eval_loss[-1]:
431
449
  self.stop = True
432
450
 
433
451
  print(
434
- f"🌲 Tree {i+1}/{self.n_estimators} | Train MSE: {train_loss:.6f} | Eval MSE: {eval_loss:.6f}"
452
+ f"🌲 Tree {i+1}/{self.n_estimators} | Train MSE: {train_loss:.6f} | Eval {self.eval_metric}: {eval_loss:.6f}"
435
453
  )
436
454
 
437
455
  del eval_preds, eval_loss, train_loss
@@ -445,20 +463,14 @@ class WarpGBM(BaseEstimator, RegressorMixin):
445
463
  k = max(1, int(self.colsample_bytree * self.num_features))
446
464
  else:
447
465
  self.feat_indices_tree = self.feature_indices
448
- self.bin_indices_tree = self.bin_indices
449
466
 
450
467
  for i in range(self.n_estimators):
451
468
  self.residual = self.Y_gpu - self.gradients
452
469
 
453
470
  if self.colsample_bytree < 1.0:
454
- self.feat_indices_tree = torch.randperm(
455
- self.num_features, device=self.device
456
- )[:k]
457
- self.bin_indices_tree = self.bin_indices[:, self.feat_indices_tree]
471
+ self.feat_indices_tree = torch.randperm(self.num_features, device=self.device, dtype=torch.int32)[:k]
458
472
 
459
- self.root_gradient_histogram, self.root_hessian_histogram = (
460
- self.compute_histograms(self.bin_indices_tree, self.residual)
461
- )
473
+ self.root_gradient_histogram, self.root_hessian_histogram = self.compute_histograms( self.root_node_indices, self.feat_indices_tree )
462
474
 
463
475
  tree = self.grow_tree(
464
476
  self.root_gradient_histogram,
@@ -476,14 +488,13 @@ class WarpGBM(BaseEstimator, RegressorMixin):
476
488
  print("Finished training forest.")
477
489
 
478
490
  def bin_data_with_existing_edges(self, X_np):
479
- X_tensor = torch.from_numpy(X_np).type(torch.float32).pin_memory()
480
- num_samples = X_tensor.size(0)
491
+ num_samples = X_np.shape[0]
481
492
  bin_indices = torch.zeros(
482
493
  (num_samples, self.num_features), dtype=torch.int8, device=self.device
483
494
  )
484
495
  with torch.no_grad():
485
496
  for f in range(self.num_features):
486
- X_f = X_tensor[:, f].to(self.device, non_blocking=True)
497
+ X_f = torch.as_tensor( X_np[:, f], device=self.device, dtype=torch.float32 ).contiguous()
487
498
  bin_edges_f = self.bin_edges[f]
488
499
  bin_indices_f = bin_indices[:, f].contiguous()
489
500
  node_kernel.custom_cuda_binner(X_f, bin_edges_f, bin_indices_f)
@@ -493,7 +504,6 @@ class WarpGBM(BaseEstimator, RegressorMixin):
493
504
 
494
505
  def predict_binned(self, bin_indices):
495
506
  num_samples = bin_indices.size(0)
496
-
497
507
  tree_tensor = torch.stack(
498
508
  [
499
509
  self.flatten_tree(tree, max_nodes=2 ** (self.max_depth + 1))
@@ -508,8 +518,8 @@ class WarpGBM(BaseEstimator, RegressorMixin):
508
518
  )
509
519
 
510
520
  return out
511
-
512
- def predict(self, X_np):
521
+
522
+ def bin_inference_data(self, X_np):
513
523
  is_integer_type = np.issubdtype(X_np.dtype, np.integer)
514
524
 
515
525
  if is_integer_type and X_np.shape[1] == self.num_features:
@@ -523,12 +533,17 @@ class WarpGBM(BaseEstimator, RegressorMixin):
523
533
  is_prebinned = False
524
534
 
525
535
  if is_prebinned:
526
- bin_indices = (
527
- torch.from_numpy(X_np).to(self.device).contiguous().to(torch.int8)
536
+ bin_indices = torch.empty(
537
+ X_np.shape, dtype=torch.int8, device="cuda"
528
538
  )
539
+ for f in range(self.num_features):
540
+ bin_indices[:,f] = torch.as_tensor( X_np[:, f], device=self.device).contiguous()
529
541
  else:
530
542
  bin_indices = self.bin_data_with_existing_edges(X_np)
543
+ return bin_indices
531
544
 
545
+ def predict(self, X_np):
546
+ bin_indices = self.bin_inference_data(X_np)
532
547
  preds = self.predict_binned(bin_indices).cpu().numpy()
533
548
  del bin_indices
534
549
  return preds
@@ -0,0 +1,95 @@
1
+ #include <cuda.h>
2
+ #include <cuda_runtime.h>
3
+ #include <torch/extension.h>
4
+
5
+ __global__ void histogram_tiled_configurable_kernel(
6
+ const int8_t *__restrict__ bin_indices, // [N, F]
7
+ const float *__restrict__ residuals, // [N]
8
+ const int32_t *__restrict__ sample_indices, // [N]
9
+ const int32_t *__restrict__ feature_indices, // [F]
10
+ float *__restrict__ grad_hist, // [F * B]
11
+ float *__restrict__ hess_hist, // [F * B]
12
+ int64_t N, int64_t F, int64_t B,
13
+ int rows_per_thread)
14
+ {
15
+ int hist_feat_idx = blockIdx.x;
16
+ int feat = feature_indices[ hist_feat_idx ]; // 1 block per feature
17
+ int row_start = (blockIdx.y * blockDim.x + threadIdx.x) * rows_per_thread;
18
+
19
+ extern __shared__ float shmem[];
20
+ float *sh_grad = shmem; // [B]
21
+ float *sh_hess = &sh_grad[B]; // [B]
22
+
23
+ // Initialize shared memory histograms
24
+ for (int b = threadIdx.x; b < B; b += blockDim.x)
25
+ {
26
+ sh_grad[b] = 0.0f;
27
+ sh_hess[b] = 0.0f;
28
+ }
29
+ __syncthreads();
30
+
31
+ // Each thread processes multiple rows
32
+ for (int r = 0; r < rows_per_thread; ++r)
33
+ {
34
+ int row = row_start + r;
35
+ if (row < N)
36
+ {
37
+ int sample = sample_indices[row];
38
+ int8_t bin = bin_indices[sample * F + feat];
39
+ if (bin >= 0 && bin < B)
40
+ {
41
+ atomicAdd(&sh_grad[bin], residuals[sample]);
42
+ atomicAdd(&sh_hess[bin], 1.0f);
43
+ }
44
+ }
45
+ }
46
+ __syncthreads();
47
+
48
+ // One thread per bin writes results back to global memory
49
+ for (int b = threadIdx.x; b < B; b += blockDim.x)
50
+ {
51
+ int64_t idx = hist_feat_idx * B + b;
52
+ atomicAdd(&grad_hist[idx], sh_grad[b]);
53
+ atomicAdd(&hess_hist[idx], sh_hess[b]);
54
+ }
55
+ }
56
+
57
+ void launch_histogram_kernel_cuda_configurable(
58
+ const at::Tensor &bin_indices,
59
+ const at::Tensor &residuals,
60
+ const at::Tensor &sample_indices,
61
+ const at::Tensor &feature_indices,
62
+ at::Tensor &grad_hist,
63
+ at::Tensor &hess_hist,
64
+ int num_bins,
65
+ int threads_per_block = 256,
66
+ int rows_per_thread = 1)
67
+ {
68
+
69
+ int64_t N = sample_indices.size(0);
70
+ int64_t F = feature_indices.size(0);
71
+ int num_features_master = bin_indices.size(1);
72
+
73
+ int64_t rows_per_block = threads_per_block * rows_per_thread;
74
+ int64_t row_tiles = (N + rows_per_block - 1) / rows_per_block;
75
+
76
+ dim3 blocks(F, row_tiles); // grid.x = F, grid.y = row_tiles
77
+ dim3 threads(threads_per_block);
78
+ int shared_mem_bytes = 2 * num_bins * sizeof(float);
79
+
80
+ histogram_tiled_configurable_kernel<<<blocks, threads, shared_mem_bytes>>>(
81
+ bin_indices.data_ptr<int8_t>(),
82
+ residuals.data_ptr<float>(),
83
+ sample_indices.data_ptr<int32_t>(),
84
+ feature_indices.data_ptr<int32_t>(),
85
+ grad_hist.data_ptr<float>(),
86
+ hess_hist.data_ptr<float>(),
87
+ N, num_features_master, num_bins,
88
+ rows_per_thread);
89
+
90
+ cudaError_t err = cudaGetLastError();
91
+ if (err != cudaSuccess)
92
+ {
93
+ printf("CUDA kernel launch failed: %s\n", cudaGetErrorString(err));
94
+ }
95
+ }
@@ -2,23 +2,6 @@
2
2
  #include <vector>
3
3
 
4
4
  // Declare the function from histogram_kernel.cu
5
- void launch_histogram_kernel_cuda(
6
- const at::Tensor &bin_indices,
7
- const at::Tensor &gradients,
8
- at::Tensor &grad_hist,
9
- at::Tensor &hess_hist,
10
- int num_bins,
11
- int threads_per_block = 256,
12
- int rows_per_thread = 1);
13
-
14
- void launch_histogram_kernel_cuda_2(
15
- const at::Tensor &bin_indices, // int8 [N, F]
16
- const at::Tensor &gradients, // float32 [N]
17
- at::Tensor &grad_hist, // float32 [F * B]
18
- at::Tensor &hess_hist, // float32 [F * B]
19
- int num_bins,
20
- int threads_per_block = 256,
21
- int rows_per_thread = 1);
22
5
 
23
6
  void launch_best_split_kernel_cuda(
24
7
  const at::Tensor &G, // [F x B]
@@ -32,7 +15,9 @@ void launch_best_split_kernel_cuda(
32
15
 
33
16
  void launch_histogram_kernel_cuda_configurable(
34
17
  const at::Tensor &bin_indices,
35
- const at::Tensor &gradients,
18
+ const at::Tensor &residual,
19
+ const at::Tensor &sample_indices,
20
+ const at::Tensor &feature_indices,
36
21
  at::Tensor &grad_hist,
37
22
  at::Tensor &hess_hist,
38
23
  int num_bins,
@@ -54,8 +39,6 @@ void predict_with_forest(
54
39
  // Bindings
55
40
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
56
41
  {
57
- m.def("compute_histogram", &launch_histogram_kernel_cuda, "Histogram (CUDA)");
58
- m.def("compute_histogram2", &launch_histogram_kernel_cuda_2, "Histogram (CUDA) 2");
59
42
  m.def("compute_histogram3", &launch_histogram_kernel_cuda_configurable, "Histogram Feature Shared Mem");
60
43
  m.def("compute_split", &launch_best_split_kernel_cuda, "Best Split (CUDA)");
61
44
  m.def("custom_cuda_binner", &launch_bin_column_kernel, "Custom CUDA binning kernel");
@@ -5,23 +5,18 @@
5
5
  __global__ void predict_forest_kernel(
6
6
  const int8_t *__restrict__ bin_indices, // [N x F]
7
7
  const float *__restrict__ tree_tensor, // [T x max_nodes x 6]
8
- int N, int F, int T, int max_nodes,
8
+ int64_t N, int64_t F, int64_t T, int64_t max_nodes,
9
9
  float learning_rate,
10
10
  float *__restrict__ out // [N]
11
11
  )
12
12
  {
13
- int idx = blockIdx.x * blockDim.x + threadIdx.x;
14
- int total_jobs = N * T;
13
+ int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
14
+ int64_t total_jobs = N * T;
15
15
  if (idx >= total_jobs)
16
16
  return;
17
17
 
18
- int i = idx % N; // sample index
19
- int t = idx / N; // tree index
20
-
21
- // if (i == 0 && t == 0)
22
- // {
23
- // printf("[DEBUG] Thread (i=%d, t=%d): starting prediction\n", i, t);
24
- // }
18
+ int64_t i = idx % N; // sample index
19
+ int64_t t = idx / N; // tree index
25
20
 
26
21
  const float *tree = tree_tensor + t * max_nodes * 6;
27
22
 
@@ -35,32 +30,36 @@ __global__ void predict_forest_kernel(
35
30
  atomicAdd(&out[i], learning_rate * val);
36
31
  return;
37
32
  }
33
+
38
34
  int feat = static_cast<int>(tree[node_id * 6 + 0]);
39
35
  int split_bin = static_cast<int>(tree[node_id * 6 + 1]);
40
36
  int left_id = static_cast<int>(tree[node_id * 6 + 2]);
41
37
  int right_id = static_cast<int>(tree[node_id * 6 + 3]);
42
38
 
43
- int8_t bin = bin_indices[i * F + feat];
39
+ // prevent overflow
40
+ int64_t bin_idx = i * F + feat;
41
+ int8_t bin = bin_indices[bin_idx];
42
+
44
43
  node_id = (bin <= split_bin) ? left_id : right_id;
45
- // printf("sample %d, tree %d, feat %d, bin %d, split %d → %s\n", i, t, feat, bin, split_bin, (bin <= split_bin ? "L" : "R"));
46
44
  }
47
45
  }
48
46
 
47
+
49
48
  void predict_with_forest(
50
- const at::Tensor &bin_indices, // [N x F], int8
51
- const at::Tensor &tree_tensor, // [T x max_nodes x 6], float32
49
+ const at::Tensor &bin_indices,
50
+ const at::Tensor &tree_tensor,
52
51
  float learning_rate,
53
- at::Tensor &out // [N], float32
52
+ at::Tensor &out
54
53
  )
55
54
  {
56
- int N = bin_indices.size(0);
57
- int F = bin_indices.size(1);
58
- int T = tree_tensor.size(0);
59
- int max_nodes = tree_tensor.size(1);
55
+ int64_t N = bin_indices.size(0);
56
+ int64_t F = bin_indices.size(1);
57
+ int64_t T = tree_tensor.size(0);
58
+ int64_t max_nodes = tree_tensor.size(1);
60
59
 
61
- int total_jobs = N * T;
60
+ int64_t total_jobs = N * T;
62
61
  int threads_per_block = 256;
63
- int blocks = (total_jobs + threads_per_block - 1) / threads_per_block;
62
+ int64_t blocks = (total_jobs + threads_per_block - 1) / threads_per_block;
64
63
 
65
64
  predict_forest_kernel<<<blocks, threads_per_block>>>(
66
65
  bin_indices.data_ptr<int8_t>(),
@@ -0,0 +1,10 @@
1
+ # warpgbm/metrics.py
2
+
3
+ import torch
4
+
5
+ def rmsle_torch(y_true, y_pred, eps=1e-7):
6
+ y_true = torch.clamp(y_true, min=0)
7
+ y_pred = torch.clamp(y_pred, min=0)
8
+ log_true = torch.log1p(y_true + eps)
9
+ log_pred = torch.log1p(y_pred + eps)
10
+ return torch.sqrt(torch.mean((log_true - log_pred) ** 2))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: warpgbm
3
- Version: 0.1.25
3
+ Version: 0.1.27
4
4
  Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
5
5
  License: GNU GENERAL PUBLIC LICENSE
6
6
  Version 3, 29 June 2007
@@ -889,6 +889,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
889
889
  y_eval=None, # numpy array (float or int) 1 dimension (eval_num_samples)
890
890
  eval_every_n_trees=None, # const (int) >= 1
891
891
  early_stopping_rounds=None, # const (int) >= 1
892
+ eval_metric='mse' # string, one of 'mse' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
892
893
  )
893
894
  ```
894
895
  Train with optional validation set and early stopping.
@@ -922,3 +923,7 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
922
923
  ### v0.1.25
923
924
 
924
925
  - Added `colsample_bytree` parameter and new test using Numerai data.
926
+
927
+ ### v0.1.26
928
+
929
+ - Fix Memory bugs in prediction and colsample bytree logic. Added "corr" eval metric.
@@ -5,10 +5,12 @@ pyproject.toml
5
5
  setup.py
6
6
  version.txt
7
7
  tests/__init__.py
8
+ tests/full_numerai_test.py
8
9
  tests/numerai_test.py
9
10
  tests/test_fit_predict_corr.py
10
11
  warpgbm/__init__.py
11
12
  warpgbm/core.py
13
+ warpgbm/metrics.py
12
14
  warpgbm.egg-info/PKG-INFO
13
15
  warpgbm.egg-info/SOURCES.txt
14
16
  warpgbm.egg-info/dependency_links.txt
@@ -1,57 +0,0 @@
1
- import numpy as np
2
- from warpgbm import WarpGBM
3
- from sklearn.datasets import make_regression
4
- import time
5
- from sklearn.metrics import mean_squared_error
6
-
7
-
8
- def test_fit_predictpytee_correlation():
9
- np.random.seed(42)
10
- N = 100_000
11
- F = 1000
12
- X, y = make_regression(n_samples=N, n_features=F, noise=0.1, random_state=42)
13
- era = np.zeros(N, dtype=np.int32)
14
- corrs = []
15
- mses = []
16
-
17
- for hist_type in ["hist1", "hist2", "hist3"]:
18
- print(f"\nTesting histogram method: {hist_type}")
19
-
20
- model = WarpGBM(
21
- max_depth=10,
22
- num_bins=10,
23
- n_estimators=100,
24
- learning_rate=1,
25
- verbosity=False,
26
- histogram_computer=hist_type,
27
- threads_per_block=64,
28
- rows_per_thread=4,
29
- )
30
-
31
- start_fit = time.time()
32
- model.fit(
33
- X,
34
- y,
35
- era_id=era,
36
- X_eval=X,
37
- y_eval=y,
38
- eval_every_n_trees=10,
39
- early_stopping_rounds=1,
40
- )
41
- fit_time = time.time() - start_fit
42
- print(f" Fit time: {fit_time:.3f} seconds")
43
-
44
- start_pred = time.time()
45
- preds = model.predict(X)
46
- pred_time = time.time() - start_pred
47
- print(f" Predict time: {pred_time:.3f} seconds")
48
-
49
- corr = np.corrcoef(preds, y)[0, 1]
50
- mse = mean_squared_error(preds, y)
51
- print(f" Correlation: {corr:.4f}")
52
- print(f" MSE: {mse:.4f}")
53
- corrs.append(corr)
54
- mses.append(mse)
55
-
56
- assert (np.array(corrs) > 0.9).all(), f"In-sample correlation too low: {corrs}"
57
- assert (np.array(mses) < 2).all(), f"In-sample mse too high: {mses}"
@@ -1 +0,0 @@
1
- 0.1.25
@@ -1,250 +0,0 @@
1
- #include <cuda.h>
2
- #include <cuda_runtime.h>
3
- #include <torch/extension.h>
4
-
5
- #define F_TILE 128 // Number of features processed per block (tile)
6
-
7
- // Each block processes a tile of features (of size up to F_TILE) and a chunk of samples.
8
- __global__ void histogram_kernel_shared_sample(
9
- const int8_t *__restrict__ bin_indices, // [N, F] bin indices
10
- const float *__restrict__ gradients, // [N] gradient values
11
- float *__restrict__ grad_hist, // [F * B] global gradient histogram (flattened)
12
- float *__restrict__ hess_hist, // [F * B] global hessian histogram (flattened)
13
- int64_t N, int64_t F, int64_t B)
14
- {
15
- // Use dynamic shared memory to hold the histogram for a tile.
16
- // Allocate 2 arrays: one for gradients and one for hessians.
17
- extern __shared__ float shmem[];
18
- float *shared_grad = shmem; // size: tile_features * B floats
19
- float *shared_hess = shmem + (F_TILE * B); // same size
20
-
21
- int tid = threadIdx.x; // Use a 1D block (for sample processing)
22
- int block_size = blockDim.x;
23
-
24
- // Each block is assigned a tile of features:
25
- int feature_offset = blockIdx.x * F_TILE;
26
- // Adjust tile width if we're near the end of the feature dimension.
27
- int tile_features = (feature_offset + F_TILE > F) ? (F - feature_offset) : F_TILE;
28
- int tile_size = tile_features * B; // total number of bins in this feature tile
29
-
30
- // Initialize the tile’s shared memory histograms.
31
- for (int i = tid; i < tile_size; i += block_size)
32
- {
33
- shared_grad[i] = 0.0f;
34
- shared_hess[i] = 0.0f;
35
- }
36
- __syncthreads();
37
-
38
- // Each block also covers a chunk of samples. Determine the sample index
39
- int sample = blockIdx.y * block_size + tid;
40
- if (sample < N)
41
- {
42
- // For each feature in this tile, compute the bin and update shared histograms.
43
- for (int j = 0; j < tile_features; j++)
44
- {
45
- // Global feature index.
46
- int f_idx = feature_offset + j;
47
- int64_t idx = sample * F + f_idx; // index into the [N, F] bin_indices tensor
48
- int8_t b = bin_indices[idx]; // get bin index
49
- if (b >= 0 && b < B)
50
- {
51
- int shared_idx = j * B + b; // index into the tile histogram in shared memory
52
- // Using atomics because several threads may update the same bin.
53
- atomicAdd(&shared_grad[shared_idx], gradients[sample]);
54
- atomicAdd(&shared_hess[shared_idx], 1.0f);
55
- }
56
- }
57
- }
58
- __syncthreads();
59
-
60
- // Flush the per-tile histograms from shared memory to global memory.
61
- // Each bin in the tile is added to the global histogram (which is sized [F, B]).
62
- for (int i = tid; i < tile_size; i += block_size)
63
- {
64
- int local_feature = i / B; // feature index relative to the tile
65
- int bin = i % B; // bin index
66
- int f_idx = feature_offset + local_feature;
67
- if (f_idx < F)
68
- {
69
- int global_idx = f_idx * B + bin;
70
- atomicAdd(&grad_hist[global_idx], shared_grad[i]);
71
- atomicAdd(&hess_hist[global_idx], shared_hess[i]);
72
- }
73
- }
74
- }
75
-
76
- void launch_histogram_kernel_cuda(
77
- const at::Tensor &bin_indices, // [N, F] int8 tensor
78
- const at::Tensor &gradients, // [N] float tensor
79
- at::Tensor &grad_hist, // [F * B] float tensor (preallocated)
80
- at::Tensor &hess_hist, // [F * B] float tensor (preallocated)
81
- int num_bins,
82
- int threads_per_block = 256,
83
- int rows_per_thread = 1)
84
- {
85
- int64_t N = bin_indices.size(0);
86
- int64_t F = bin_indices.size(1);
87
- int64_t B = num_bins;
88
-
89
- // Define grid and block dimensions.
90
- // blockDim.x: number of threads per block (for processing samples).
91
- // gridDim.x: number of feature tiles.
92
- int grid_x = (F + F_TILE - 1) / F_TILE;
93
- // gridDim.y: number of sample chunks.
94
- int grid_y = (N + threads_per_block - 1) / threads_per_block;
95
- dim3 blocks(grid_x, grid_y);
96
- dim3 threads(threads_per_block);
97
-
98
- // Calculate shared memory size:
99
- // We allocate 2 arrays of size (F_TILE * B) floats (one for grad and one for hess).
100
- size_t shared_mem_size = 2 * F_TILE * B * sizeof(float);
101
-
102
- histogram_kernel_shared_sample<<<blocks, threads, shared_mem_size>>>(
103
- bin_indices.data_ptr<int8_t>(),
104
- gradients.data_ptr<float>(),
105
- grad_hist.data_ptr<float>(),
106
- hess_hist.data_ptr<float>(),
107
- N, F, B);
108
- }
109
-
110
- // CUDA kernel: tiled, 64-bit safe
111
- __global__ void histogram_tiled_kernel(
112
- const int8_t *__restrict__ bin_indices, // [N, F]
113
- const float *__restrict__ gradients, // [N]
114
- float *__restrict__ grad_hist, // [F * B]
115
- float *__restrict__ hess_hist, // [F * B]
116
- int64_t F, int64_t B, int64_t tile_size)
117
- {
118
- int64_t feature_tiles = (F + tile_size - 1) / tile_size;
119
- int64_t row = static_cast<int64_t>(blockIdx.x) / feature_tiles;
120
- int64_t tile = static_cast<int64_t>(blockIdx.x) % feature_tiles;
121
- int64_t feat = tile * tile_size + threadIdx.x;
122
-
123
- if (feat >= F)
124
- return;
125
-
126
- int8_t bin = bin_indices[row * F + feat];
127
- if (bin >= 0 && bin < B)
128
- {
129
- int64_t idx = feat * B + bin;
130
- atomicAdd(&grad_hist[idx], gradients[row]);
131
- atomicAdd(&hess_hist[idx], 1.0f);
132
- }
133
- }
134
-
135
- // Host function exposed to PyTorch
136
- void launch_histogram_kernel_cuda_2(
137
- const at::Tensor &bin_indices, // int8 [N, F]
138
- const at::Tensor &gradients, // float32 [N]
139
- at::Tensor &grad_hist, // float32 [F * B]
140
- at::Tensor &hess_hist, // float32 [F * B]
141
- int num_bins,
142
- int threads_per_block = 256,
143
- int rows_per_thread = 1)
144
- {
145
-
146
- int64_t N = bin_indices.size(0);
147
- int64_t F = bin_indices.size(1);
148
- int64_t tile_size = threads_per_block;
149
- int64_t feature_tiles = (F + tile_size - 1) / tile_size;
150
- int64_t total_blocks = N * feature_tiles;
151
-
152
- histogram_tiled_kernel<<<
153
- static_cast<int>(total_blocks),
154
- static_cast<int>(tile_size)>>>(
155
- bin_indices.data_ptr<int8_t>(),
156
- gradients.data_ptr<float>(),
157
- grad_hist.data_ptr<float>(),
158
- hess_hist.data_ptr<float>(),
159
- F, num_bins, tile_size);
160
-
161
- // Optional: check for kernel launch failure
162
- cudaError_t err = cudaGetLastError();
163
- if (err != cudaSuccess)
164
- {
165
- printf("CUDA kernel launch failed: %s\n", cudaGetErrorString(err));
166
- }
167
- }
168
-
169
- __global__ void histogram_tiled_configurable_kernel(
170
- const int8_t *__restrict__ bin_indices, // [N, F]
171
- const float *__restrict__ gradients, // [N]
172
- float *__restrict__ grad_hist, // [F * B]
173
- float *__restrict__ hess_hist, // [F * B]
174
- int64_t N, int64_t F, int64_t B,
175
- int rows_per_thread)
176
- {
177
- int feat = blockIdx.x; // 1 block per feature
178
- int row_start = (blockIdx.y * blockDim.x + threadIdx.x) * rows_per_thread;
179
-
180
- extern __shared__ float shmem[];
181
- float *sh_grad = shmem; // [B]
182
- float *sh_hess = &sh_grad[B]; // [B]
183
-
184
- // Initialize shared memory histograms
185
- for (int b = threadIdx.x; b < B; b += blockDim.x)
186
- {
187
- sh_grad[b] = 0.0f;
188
- sh_hess[b] = 0.0f;
189
- }
190
- __syncthreads();
191
-
192
- // Each thread processes multiple rows
193
- for (int r = 0; r < rows_per_thread; ++r)
194
- {
195
- int row = row_start + r;
196
- if (row < N)
197
- {
198
- int8_t bin = bin_indices[row * F + feat];
199
- if (bin >= 0 && bin < B)
200
- {
201
- atomicAdd(&sh_grad[bin], gradients[row]);
202
- atomicAdd(&sh_hess[bin], 1.0f);
203
- }
204
- }
205
- }
206
- __syncthreads();
207
-
208
- // One thread per bin writes results back to global memory
209
- for (int b = threadIdx.x; b < B; b += blockDim.x)
210
- {
211
- int64_t idx = feat * B + b;
212
- atomicAdd(&grad_hist[idx], sh_grad[b]);
213
- atomicAdd(&hess_hist[idx], sh_hess[b]);
214
- }
215
- }
216
-
217
- void launch_histogram_kernel_cuda_configurable(
218
- const at::Tensor &bin_indices,
219
- const at::Tensor &gradients,
220
- at::Tensor &grad_hist,
221
- at::Tensor &hess_hist,
222
- int num_bins,
223
- int threads_per_block = 256,
224
- int rows_per_thread = 1)
225
- {
226
-
227
- int64_t N = bin_indices.size(0);
228
- int64_t F = bin_indices.size(1);
229
-
230
- int rows_per_block = threads_per_block * rows_per_thread;
231
- int row_tiles = (N + rows_per_block - 1) / rows_per_block;
232
-
233
- dim3 blocks(F, row_tiles); // grid.x = F, grid.y = row_tiles
234
- dim3 threads(threads_per_block);
235
- int shared_mem_bytes = 2 * num_bins * sizeof(float);
236
-
237
- histogram_tiled_configurable_kernel<<<blocks, threads, shared_mem_bytes>>>(
238
- bin_indices.data_ptr<int8_t>(),
239
- gradients.data_ptr<float>(),
240
- grad_hist.data_ptr<float>(),
241
- hess_hist.data_ptr<float>(),
242
- N, F, num_bins,
243
- rows_per_thread);
244
-
245
- cudaError_t err = cudaGetLastError();
246
- if (err != cudaSuccess)
247
- {
248
- printf("CUDA kernel launch failed: %s\n", cudaGetErrorString(err));
249
- }
250
- }
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes