warpgbm 0.1.27__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {warpgbm-0.1.27/warpgbm.egg-info → warpgbm-1.0.0}/PKG-INFO +94 -20
  2. {warpgbm-0.1.27 → warpgbm-1.0.0}/README.md +94 -20
  3. {warpgbm-0.1.27 → warpgbm-1.0.0}/pyproject.toml +1 -1
  4. warpgbm-1.0.0/tests/test_invariant.py +100 -0
  5. warpgbm-1.0.0/version.txt +1 -0
  6. {warpgbm-0.1.27 → warpgbm-1.0.0}/warpgbm/core.py +30 -32
  7. warpgbm-1.0.0/warpgbm/cuda/best_split_kernel.cu +89 -0
  8. {warpgbm-0.1.27 → warpgbm-1.0.0}/warpgbm/cuda/histogram_kernel.cu +24 -15
  9. {warpgbm-0.1.27 → warpgbm-1.0.0}/warpgbm/cuda/node_kernel.cpp +9 -8
  10. {warpgbm-0.1.27 → warpgbm-1.0.0/warpgbm.egg-info}/PKG-INFO +94 -20
  11. {warpgbm-0.1.27 → warpgbm-1.0.0}/warpgbm.egg-info/SOURCES.txt +1 -0
  12. warpgbm-0.1.27/version.txt +0 -1
  13. warpgbm-0.1.27/warpgbm/cuda/best_split_kernel.cu +0 -79
  14. {warpgbm-0.1.27 → warpgbm-1.0.0}/LICENSE +0 -0
  15. {warpgbm-0.1.27 → warpgbm-1.0.0}/MANIFEST.in +0 -0
  16. {warpgbm-0.1.27 → warpgbm-1.0.0}/setup.cfg +0 -0
  17. {warpgbm-0.1.27 → warpgbm-1.0.0}/setup.py +0 -0
  18. {warpgbm-0.1.27 → warpgbm-1.0.0}/tests/__init__.py +0 -0
  19. {warpgbm-0.1.27 → warpgbm-1.0.0}/tests/full_numerai_test.py +0 -0
  20. {warpgbm-0.1.27 → warpgbm-1.0.0}/tests/numerai_test.py +0 -0
  21. {warpgbm-0.1.27 → warpgbm-1.0.0}/tests/test_fit_predict_corr.py +0 -0
  22. {warpgbm-0.1.27 → warpgbm-1.0.0}/warpgbm/__init__.py +0 -0
  23. {warpgbm-0.1.27 → warpgbm-1.0.0}/warpgbm/cuda/__init__.py +0 -0
  24. {warpgbm-0.1.27 → warpgbm-1.0.0}/warpgbm/cuda/binner.cu +0 -0
  25. {warpgbm-0.1.27 → warpgbm-1.0.0}/warpgbm/cuda/predict.cu +0 -0
  26. {warpgbm-0.1.27 → warpgbm-1.0.0}/warpgbm/metrics.py +0 -0
  27. {warpgbm-0.1.27 → warpgbm-1.0.0}/warpgbm.egg-info/dependency_links.txt +0 -0
  28. {warpgbm-0.1.27 → warpgbm-1.0.0}/warpgbm.egg-info/requires.txt +0 -0
  29. {warpgbm-0.1.27 → warpgbm-1.0.0}/warpgbm.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: warpgbm
3
- Version: 0.1.27
3
+ Version: 1.0.0
4
4
  Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
5
5
  License: GNU GENERAL PUBLIC LICENSE
6
6
  Version 3, 29 June 2007
@@ -686,21 +686,46 @@ Requires-Dist: tqdm
686
686
  Requires-Dist: scikit-learn
687
687
  Dynamic: license-file
688
688
 
689
- ![warpgbm](https://github.com/user-attachments/assets/dee9de16-091b-49c1-a8fa-2b4ab6891184)
690
-
689
+ ![raw](https://github.com/user-attachments/assets/924844ef-2536-4bde-a330-5e30f6b0762c)
691
690
 
692
691
  # WarpGBM
693
692
 
694
693
  WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (GBDT) library built with PyTorch and CUDA. It offers blazing-fast histogram-based training and efficient prediction, with compatibility for research and production workflows.
695
694
 
695
+ **New in v1.0.0:** WarpGBM introduces *Invariant Gradient Boosting* — a powerful approach to learning signals that remain stable across shifting environments (e.g., time, regimes, or datasets). Powered by a novel algorithm called **[Directional Era-Splitting (DES)](https://arxiv.org/abs/2309.14496)**, WarpGBM doesn't just train faster than other leading GBDT libraries — it trains smarter.
696
+
697
+ If your data evolves over time, WarpGBM is the only GBDT library designed to *adapt and generalize*.
696
698
  ---
697
699
 
700
+ ## Contents
701
+
702
+ - [Features](#features)
703
+ - [Benchmarks](#benchmarks)
704
+ - [Installation](#installation)
705
+ - [Learning Invariant Signals Across Environments](#learning-invariant-signals-across-environments)
706
+ - [Why This Matters](#why-this-matters)
707
+ - [Visual Intuition](#visual-intuition)
708
+ - [Key References](#key-references)
709
+ - [Examples](#examples)
710
+ - [Quick Comparison with LightGBM CPU version](#quick-comparison-with-lightgbm-cpu-version)
711
+ - [Pre-binned Data Example (Numerai)](#pre-binned-data-example-numerai)
712
+ - [Documentation](#documentation)
713
+ - [Acknowledgements](#acknowledgements)
714
+ - [Version Notes](#version-notes)
715
+
716
+
698
717
  ## Features
699
718
 
700
- - GPU-accelerated training and histogram construction using custom CUDA kernels
701
- - Drop-in scikit-learn style interface
702
- - Supports pre-binned data or automatic quantile binning
703
- - Simple install with `pip`
719
+ - **Blazing-fast GPU training** with custom CUDA kernels for binning, histogram building, split finding, and prediction
720
+ - **Invariant signal learning** via [Directional Era-Splitting (DES)](https://arxiv.org/abs/2309.14496) — designed for datasets with shifting environments (e.g., time, regimes, experimental settings)
721
+ - Drop-in **scikit-learn style interface** for easy adoption
722
+ - Supports **pre-binned data** or **automatic quantile binning**
723
+ - Works with `float32` or `int8` inputs
724
+ - Built-in **validation and early stopping** support with MSE, RMSLE, or correlation metrics
725
+ - Simple install with `pip`, no custom drivers required
726
+
727
+ > 💡 **Note:** WarpGBM v1.0.0 is a *generalization* of the traditional GBDT algorithm.
728
+ > To run standard GBM training at maximum speed, simply omit the `era_id` argument — WarpGBM will behave like a traditional booster but with industry-leading performance.
704
729
 
705
730
  ---
706
731
 
@@ -762,7 +787,62 @@ Before either method, make sure you’ve installed PyTorch with GPU support:\
762
787
 
763
788
  ---
764
789
 
765
- ## Example
790
+ ## Learning Invariant Signals Across Environments
791
+
792
+ Most supervised learning models rely on an assumption known as the **Empirical Risk Minimization (ERM)** principle. Under ERM, the data distribution connecting inputs \( X \) and targets \( Y \) is assumed to be **fixed** and **stationary** across training, validation, and test splits. That is:
793
+
794
+ > The patterns you learn from the training set are expected to generalize out-of-sample — *as long as the test data follows the same distribution as the training data.*
795
+
796
+ However, this assumption is often violated in real-world settings. Data frequently shifts across time, geography, experimental conditions, or other hidden factors. This phenomenon is known as **distribution shift**, and it leads to models that perform well in-sample but fail catastrophically out-of-sample.
797
+
798
+ This challenge motivates the field of **Out-of-Distribution (OOD) Generalization**, which assumes your data is drawn from **distinct environments or eras** — e.g., time periods, customer segments, experimental trials. Some signals may appear predictive within specific environments but vanish or reverse in others. These are called **spurious signals**. On the other hand, signals that remain consistently predictive across all environments are called **invariant signals**.
799
+
800
+ WarpGBM v1.0.0 introduces **Directional Era-Splitting (DES)**, a new algorithm designed to identify and learn from invariant signals — ignoring signals that fail to generalize across environments.
801
+
802
+ ---
803
+
804
+ ### Why This Matters
805
+
806
+ - Standard models trained via ERM can learn to exploit **spurious correlations** that only hold in some parts of the data.
807
+ - DES explicitly tests whether a feature's split is **directionally consistent** across all eras — only such *invariant splits* are kept.
808
+ - This approach has been shown to reduce overfitting and improve out-of-sample generalization, particularly in financial and scientific datasets.
809
+
810
+ ---
811
+
812
+ ### Visual Intuition
813
+
814
+ We contrast two views of the data:
815
+
816
+ - **ERM Setting**: All data is assumed to come from the same source (single distribution).\
817
+ No awareness of environments — spurious signals can dominate.
818
+
819
+ - **OOD Setting (Era-Splitting)**: Data is explicitly grouped by environment (era).\
820
+ The model checks whether a signal holds across all groups — enforcing **robustness**.
821
+
822
+ *📷 [Placeholder for future visual illustration]*
823
+
824
+
825
+ ---
826
+
827
+ ### Key References
828
+
829
+ - **Invariant Risk Minimization (IRM)**: [Arjovsky et al., 2019](https://arxiv.org/abs/1907.02893)
830
+ - **Learning Explanations That Are Hard to Vary**: [Parascandolo et al., 2020](https://arxiv.org/abs/2009.00329)
831
+ - **Era Splitting: Invariant Learning for Decision Trees**: [DeLise, 2023](https://arxiv.org/abs/2309.14496)
832
+
833
+ ---
834
+
835
+ WarpGBM is the **first open-source GBDT framework to integrate this OOD-aware approach natively**, using efficient CUDA kernels to evaluate per-era consistency during tree growth. It’s not just faster — it’s smarter.
836
+
837
+ ---
838
+
839
+ ## Examples
840
+
841
+ WarpGBM is easy to drop into any supervised learning workflow and comes with curated examples in the `examples/` folder.
842
+
843
+ - `Spiral Data.ipynb`: synthetic OOD benchmark from Learning Explanations That Are Hard to Vary
844
+
845
+ ### Quick Comparison with LightGBM CPU version
766
846
 
767
847
  ```python
768
848
  import numpy as np
@@ -804,7 +884,7 @@ WarpGBM: corr = 0.8621, time = 5.40s
804
884
 
805
885
  ---
806
886
 
807
- ## Pre-binned Data Example (Numerai)
887
+ ### Pre-binned Data Example (Numerai)
808
888
 
809
889
  WarpGBM can save additional training time if your dataset is already pre-binned. The Numerai tournament data is a great example:
810
890
 
@@ -854,16 +934,6 @@ WarpGBM: corr = 0.0660, time = 49.16s
854
934
 
855
935
  ---
856
936
 
857
- ### Run it live in Colab
858
-
859
- You can try WarpGBM in a live Colab notebook using real pre-binned Numerai tournament data:
860
-
861
- [Open in Colab](https://colab.research.google.com/drive/10mKSjs9UvmMgM5_lOXAylq5LUQAnNSi7?usp=sharing)
862
-
863
- No installation required — just press **"Open in Playground"**, then **Run All**!
864
-
865
- ---
866
-
867
937
  ## Documentation
868
938
 
869
939
  ### `WarpGBM` Parameters:
@@ -889,7 +959,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
889
959
  y_eval=None, # numpy array (float or int) 1 dimension (eval_num_samples)
890
960
  eval_every_n_trees=None, # const (int) >= 1
891
961
  early_stopping_rounds=None, # const (int) >= 1
892
- eval_metric='mse' # string, one of 'mse' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
962
+ eval_metric='mse' # string, one of 'mse', 'rmsle' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
893
963
  )
894
964
  ```
895
965
  Train with optional validation set and early stopping.
@@ -927,3 +997,7 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
927
997
  ### v0.1.26
928
998
 
929
999
  - Fix Memory bugs in prediction and colsample bytree logic. Added "corr" eval metric.
1000
+
1001
+ ### v1.0.0
1002
+
1003
+ - Introduce invariant learning via directional era splitting (DES). Also streamline VRAM improvements over previous sub versions.
@@ -1,18 +1,43 @@
1
- ![warpgbm](https://github.com/user-attachments/assets/dee9de16-091b-49c1-a8fa-2b4ab6891184)
2
-
1
+ ![raw](https://github.com/user-attachments/assets/924844ef-2536-4bde-a330-5e30f6b0762c)
3
2
 
4
3
  # WarpGBM
5
4
 
6
5
  WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (GBDT) library built with PyTorch and CUDA. It offers blazing-fast histogram-based training and efficient prediction, with compatibility for research and production workflows.
7
6
 
7
+ **New in v1.0.0:** WarpGBM introduces *Invariant Gradient Boosting* — a powerful approach to learning signals that remain stable across shifting environments (e.g., time, regimes, or datasets). Powered by a novel algorithm called **[Directional Era-Splitting (DES)](https://arxiv.org/abs/2309.14496)**, WarpGBM doesn't just train faster than other leading GBDT libraries — it trains smarter.
8
+
9
+ If your data evolves over time, WarpGBM is the only GBDT library designed to *adapt and generalize*.
8
10
  ---
9
11
 
12
+ ## Contents
13
+
14
+ - [Features](#features)
15
+ - [Benchmarks](#benchmarks)
16
+ - [Installation](#installation)
17
+ - [Learning Invariant Signals Across Environments](#learning-invariant-signals-across-environments)
18
+ - [Why This Matters](#why-this-matters)
19
+ - [Visual Intuition](#visual-intuition)
20
+ - [Key References](#key-references)
21
+ - [Examples](#examples)
22
+ - [Quick Comparison with LightGBM CPU version](#quick-comparison-with-lightgbm-cpu-version)
23
+ - [Pre-binned Data Example (Numerai)](#pre-binned-data-example-numerai)
24
+ - [Documentation](#documentation)
25
+ - [Acknowledgements](#acknowledgements)
26
+ - [Version Notes](#version-notes)
27
+
28
+
10
29
  ## Features
11
30
 
12
- - GPU-accelerated training and histogram construction using custom CUDA kernels
13
- - Drop-in scikit-learn style interface
14
- - Supports pre-binned data or automatic quantile binning
15
- - Simple install with `pip`
31
+ - **Blazing-fast GPU training** with custom CUDA kernels for binning, histogram building, split finding, and prediction
32
+ - **Invariant signal learning** via [Directional Era-Splitting (DES)](https://arxiv.org/abs/2309.14496) — designed for datasets with shifting environments (e.g., time, regimes, experimental settings)
33
+ - Drop-in **scikit-learn style interface** for easy adoption
34
+ - Supports **pre-binned data** or **automatic quantile binning**
35
+ - Works with `float32` or `int8` inputs
36
+ - Built-in **validation and early stopping** support with MSE, RMSLE, or correlation metrics
37
+ - Simple install with `pip`, no custom drivers required
38
+
39
+ > 💡 **Note:** WarpGBM v1.0.0 is a *generalization* of the traditional GBDT algorithm.
40
+ > To run standard GBM training at maximum speed, simply omit the `era_id` argument — WarpGBM will behave like a traditional booster but with industry-leading performance.
16
41
 
17
42
  ---
18
43
 
@@ -74,7 +99,62 @@ Before either method, make sure you’ve installed PyTorch with GPU support:\
74
99
 
75
100
  ---
76
101
 
77
- ## Example
102
+ ## Learning Invariant Signals Across Environments
103
+
104
+ Most supervised learning models rely on an assumption known as the **Empirical Risk Minimization (ERM)** principle. Under ERM, the data distribution connecting inputs \( X \) and targets \( Y \) is assumed to be **fixed** and **stationary** across training, validation, and test splits. That is:
105
+
106
+ > The patterns you learn from the training set are expected to generalize out-of-sample — *as long as the test data follows the same distribution as the training data.*
107
+
108
+ However, this assumption is often violated in real-world settings. Data frequently shifts across time, geography, experimental conditions, or other hidden factors. This phenomenon is known as **distribution shift**, and it leads to models that perform well in-sample but fail catastrophically out-of-sample.
109
+
110
+ This challenge motivates the field of **Out-of-Distribution (OOD) Generalization**, which assumes your data is drawn from **distinct environments or eras** — e.g., time periods, customer segments, experimental trials. Some signals may appear predictive within specific environments but vanish or reverse in others. These are called **spurious signals**. On the other hand, signals that remain consistently predictive across all environments are called **invariant signals**.
111
+
112
+ WarpGBM v1.0.0 introduces **Directional Era-Splitting (DES)**, a new algorithm designed to identify and learn from invariant signals — ignoring signals that fail to generalize across environments.
113
+
114
+ ---
115
+
116
+ ### Why This Matters
117
+
118
+ - Standard models trained via ERM can learn to exploit **spurious correlations** that only hold in some parts of the data.
119
+ - DES explicitly tests whether a feature's split is **directionally consistent** across all eras — only such *invariant splits* are kept.
120
+ - This approach has been shown to reduce overfitting and improve out-of-sample generalization, particularly in financial and scientific datasets.
121
+
122
+ ---
123
+
124
+ ### Visual Intuition
125
+
126
+ We contrast two views of the data:
127
+
128
+ - **ERM Setting**: All data is assumed to come from the same source (single distribution).\
129
+ No awareness of environments — spurious signals can dominate.
130
+
131
+ - **OOD Setting (Era-Splitting)**: Data is explicitly grouped by environment (era).\
132
+ The model checks whether a signal holds across all groups — enforcing **robustness**.
133
+
134
+ *📷 [Placeholder for future visual illustration]*
135
+
136
+
137
+ ---
138
+
139
+ ### Key References
140
+
141
+ - **Invariant Risk Minimization (IRM)**: [Arjovsky et al., 2019](https://arxiv.org/abs/1907.02893)
142
+ - **Learning Explanations That Are Hard to Vary**: [Parascandolo et al., 2020](https://arxiv.org/abs/2009.00329)
143
+ - **Era Splitting: Invariant Learning for Decision Trees**: [DeLise, 2023](https://arxiv.org/abs/2309.14496)
144
+
145
+ ---
146
+
147
+ WarpGBM is the **first open-source GBDT framework to integrate this OOD-aware approach natively**, using efficient CUDA kernels to evaluate per-era consistency during tree growth. It’s not just faster — it’s smarter.
148
+
149
+ ---
150
+
151
+ ## Examples
152
+
153
+ WarpGBM is easy to drop into any supervised learning workflow and comes with curated examples in the `examples/` folder.
154
+
155
+ - `Spiral Data.ipynb`: synthetic OOD benchmark from Learning Explanations That Are Hard to Vary
156
+
157
+ ### Quick Comparison with LightGBM CPU version
78
158
 
79
159
  ```python
80
160
  import numpy as np
@@ -116,7 +196,7 @@ WarpGBM: corr = 0.8621, time = 5.40s
116
196
 
117
197
  ---
118
198
 
119
- ## Pre-binned Data Example (Numerai)
199
+ ### Pre-binned Data Example (Numerai)
120
200
 
121
201
  WarpGBM can save additional training time if your dataset is already pre-binned. The Numerai tournament data is a great example:
122
202
 
@@ -166,16 +246,6 @@ WarpGBM: corr = 0.0660, time = 49.16s
166
246
 
167
247
  ---
168
248
 
169
- ### Run it live in Colab
170
-
171
- You can try WarpGBM in a live Colab notebook using real pre-binned Numerai tournament data:
172
-
173
- [Open in Colab](https://colab.research.google.com/drive/10mKSjs9UvmMgM5_lOXAylq5LUQAnNSi7?usp=sharing)
174
-
175
- No installation required — just press **"Open in Playground"**, then **Run All**!
176
-
177
- ---
178
-
179
249
  ## Documentation
180
250
 
181
251
  ### `WarpGBM` Parameters:
@@ -201,7 +271,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
201
271
  y_eval=None, # numpy array (float or int) 1 dimension (eval_num_samples)
202
272
  eval_every_n_trees=None, # const (int) >= 1
203
273
  early_stopping_rounds=None, # const (int) >= 1
204
- eval_metric='mse' # string, one of 'mse' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
274
+ eval_metric='mse' # string, one of 'mse', 'rmsle' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
205
275
  )
206
276
  ```
207
277
  Train with optional validation set and early stopping.
@@ -238,4 +308,8 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
238
308
 
239
309
  ### v0.1.26
240
310
 
241
- - Fix Memory bugs in prediction and colsample bytree logic. Added "corr" eval metric.
311
+ - Fix Memory bugs in prediction and colsample bytree logic. Added "corr" eval metric.
312
+
313
+ ### v1.0.0
314
+
315
+ - Introduce invariant learning via directional era splitting (DES). Also streamline VRAM improvements over previous sub versions.
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "warpgbm"
7
- version = "0.1.27"
7
+ version = "1.0.0"
8
8
  description = "A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -0,0 +1,100 @@
1
+ import numpy as np
2
+ from warpgbm import WarpGBM
3
+ import time
4
+
5
+ import os
6
+ import requests
7
+
8
+ def download_file_if_missing(url, local_dir):
9
+ filename = os.path.basename(url)
10
+ local_path = os.path.join(local_dir, filename)
11
+
12
+ if os.path.exists(local_path):
13
+ print(f"✅ {filename} already exists, skipping download.")
14
+ return
15
+
16
+ # Convert GitHub blob URL to raw URL
17
+ raw_url = url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/")
18
+
19
+ print(f"⬇️ Downloading {filename}...")
20
+ response = requests.get(raw_url)
21
+ response.raise_for_status()
22
+
23
+ os.makedirs(local_dir, exist_ok=True)
24
+ with open(local_path, "wb") as f:
25
+ f.write(response.content)
26
+ print(f"✅ Saved to {local_path}")
27
+
28
+ # === Usage ===
29
+
30
+ urls = [
31
+ "https://github.com/jefferythewind/era-splitting-notebook-examples/blob/main/Synthetic%20Memorization%20Data%20Set/X_train.npy",
32
+ "https://github.com/jefferythewind/era-splitting-notebook-examples/blob/main/Synthetic%20Memorization%20Data%20Set/y_train.npy",
33
+ "https://github.com/jefferythewind/era-splitting-notebook-examples/blob/main/Synthetic%20Memorization%20Data%20Set/X_test.npy",
34
+ "https://github.com/jefferythewind/era-splitting-notebook-examples/blob/main/Synthetic%20Memorization%20Data%20Set/y_test.npy",
35
+ "https://github.com/jefferythewind/era-splitting-notebook-examples/blob/main/Synthetic%20Memorization%20Data%20Set/X_eras.npy",
36
+ ]
37
+
38
+
39
+ local_folder = "./synthetic_data"
40
+
41
+ for url in urls:
42
+ download_file_if_missing(url, local_folder)
43
+
44
+ def test_fit_predictpytee_correlation():
45
+ import numpy as np
46
+ import os
47
+ from warpgbm import WarpGBM
48
+ from sklearn.metrics import mean_squared_error
49
+ import time
50
+
51
+ # Load the real dataset from local .npy files
52
+ data_dir = "./synthetic_data"
53
+ X = np.load(os.path.join(data_dir, "X_train.npy"))
54
+ y = np.load(os.path.join(data_dir, "y_train.npy"))
55
+ # era = np.zeros(X.shape[0], dtype=np.int32) # one era for default GBDT equivalence
56
+ era = np.load(os.path.join(data_dir, "X_eras.npy"))
57
+
58
+ X_test = np.load(os.path.join(data_dir, "X_test.npy"))
59
+ y_test = np.load(os.path.join(data_dir, "y_test.npy"))
60
+
61
+ print(f"X shape: {X.shape}, y shape: {y.shape}")
62
+
63
+ model = WarpGBM(
64
+ max_depth=10,
65
+ num_bins=127,
66
+ n_estimators=50,
67
+ learning_rate=1,
68
+ threads_per_block=128,
69
+ rows_per_thread=4,
70
+ colsample_bytree=0.9,
71
+ min_child_weight=4
72
+ )
73
+
74
+ start_fit = time.time()
75
+ model.fit(
76
+ X,
77
+ y,
78
+ era_id=era,
79
+ X_eval=X_test,
80
+ y_eval=y_test,
81
+ eval_every_n_trees=10,
82
+ early_stopping_rounds=100,
83
+ eval_metric="corr",
84
+ )
85
+ fit_time = time.time() - start_fit
86
+ print(f" Fit time: {fit_time:.3f} seconds")
87
+
88
+ start_pred = time.time()
89
+ preds = model.predict(X_test)
90
+ pred_time = time.time() - start_pred
91
+ print(f" Predict time: {pred_time:.3f} seconds")
92
+
93
+ corr = np.corrcoef(preds, y_test)[0, 1]
94
+ mse = mean_squared_error(preds, y_test)
95
+ print(f" Correlation: {corr:.4f}")
96
+ print(f" MSE: {mse:.4f}")
97
+
98
+ assert corr > 0.95, f"Out-of-sample correlation too low: {corr:.4f}"
99
+ assert mse < 0.02, f"Out-of-sample MSE too high: {mse:.4f}"
100
+
@@ -0,0 +1 @@
1
+ 1.0.0
@@ -219,10 +219,12 @@ class WarpGBM(BaseEstimator, RegressorMixin):
219
219
  era_id = np.ones(X.shape[0], dtype="int32")
220
220
 
221
221
  # Train data preprocessing
222
- self.bin_indices, era_indices, self.bin_edges, self.unique_eras, self.Y_gpu = (
222
+ self.bin_indices, self.era_indices, self.bin_edges, self.unique_eras, self.Y_gpu = (
223
223
  self.preprocess_gpu_data(X, y, era_id)
224
224
  )
225
225
  self.num_samples, self.num_features = X.shape
226
+ self.num_eras = len(self.unique_eras)
227
+ self.era_indices = self.era_indices.to(dtype=torch.int32)
226
228
  self.gradients = torch.zeros_like(self.Y_gpu)
227
229
  self.root_node_indices = torch.arange(self.num_samples, device=self.device, dtype=torch.int32)
228
230
  self.base_prediction = self.Y_gpu.mean().item()
@@ -231,8 +233,6 @@ class WarpGBM(BaseEstimator, RegressorMixin):
231
233
  k = max(1, int(self.colsample_bytree * self.num_features))
232
234
  else:
233
235
  k = self.num_features
234
- self.best_gains = torch.zeros(k, device=self.device)
235
- self.best_bins = torch.zeros(k, device=self.device, dtype=torch.int32)
236
236
  self.feature_indices = torch.arange(self.num_features, device=self.device, dtype=torch.int32)
237
237
 
238
238
  # ─── Optional Eval Set ───
@@ -275,9 +275,7 @@ class WarpGBM(BaseEstimator, RegressorMixin):
275
275
  max_vals = X_np.max(axis=0)
276
276
 
277
277
  if is_integer_type and np.all(max_vals < self.num_bins):
278
- print(
279
- "Detected pre-binned integer input — skipping quantile binning."
280
- )
278
+ print("Detected pre-binned integer input — skipping quantile binning.")
281
279
  for f in range(self.num_features):
282
280
  bin_indices[:,f] = torch.as_tensor( X_np[:, f], device=self.device).contiguous()
283
281
  # bin_indices = X_np.to("cuda", non_blocking=True).contiguous()
@@ -319,10 +317,10 @@ class WarpGBM(BaseEstimator, RegressorMixin):
319
317
 
320
318
  def compute_histograms(self, sample_indices, feature_indices):
321
319
  grad_hist = torch.zeros(
322
- (len(feature_indices), self.num_bins), device=self.device, dtype=torch.float32
320
+ ( self.num_eras, len(feature_indices), self.num_bins), device=self.device, dtype=torch.float32
323
321
  )
324
322
  hess_hist = torch.zeros(
325
- (len(feature_indices), self.num_bins), device=self.device, dtype=torch.float32
323
+ ( self.num_eras, len(feature_indices), self.num_bins), device=self.device, dtype=torch.float32
326
324
  )
327
325
 
328
326
  node_kernel.compute_histogram3(
@@ -330,6 +328,7 @@ class WarpGBM(BaseEstimator, RegressorMixin):
330
328
  self.residual,
331
329
  sample_indices,
332
330
  feature_indices,
331
+ self.era_indices,
333
332
  grad_hist,
334
333
  hess_hist,
335
334
  self.num_bins,
@@ -345,21 +344,30 @@ class WarpGBM(BaseEstimator, RegressorMixin):
345
344
  self.min_split_gain,
346
345
  self.min_child_weight,
347
346
  self.L2_reg,
348
- self.best_gains,
349
- self.best_bins,
350
- self.threads_per_block,
347
+ self.per_era_gain,
348
+ self.per_era_direction,
349
+ self.threads_per_block
351
350
  )
352
351
 
353
- if torch.all(self.best_bins == -1):
354
- return -1, -1 # No valid split found
352
+ if self.num_eras == 1:
353
+ era_splitting_criterion = self.per_era_gain[0,:,:] # [F, B-1]
354
+ dir_score_mask = era_splitting_criterion > self.min_split_gain
355
+ else:
356
+ directional_agreement = self.per_era_direction.mean(dim=0).abs() # [F, B-1]
357
+ era_splitting_criterion = self.per_era_gain.mean(dim=0) # [F, B-1]
358
+ dir_score_mask = ( directional_agreement == directional_agreement.max() ) & (era_splitting_criterion > self.min_split_gain)
359
+
360
+ if not dir_score_mask.any():
361
+ return -1, -1
355
362
 
356
- # print(self.best_bins)
357
- # print(self.best_gains)
363
+ era_splitting_criterion[dir_score_mask == 0] = float("-inf")
364
+ best_idx = torch.argmax(era_splitting_criterion) #index of flattened tensor
365
+ split_bins = self.num_bins - 1
366
+ best_feature = best_idx // split_bins
367
+ best_bin = best_idx % split_bins
358
368
 
359
- f = torch.argmax(self.best_gains).item()
360
- b = self.best_bins[f].item()
369
+ return best_feature.item(), best_bin.item()
361
370
 
362
- return f, b
363
371
 
364
372
  def grow_tree(self, gradient_histogram, hessian_histogram, node_indices, depth):
365
373
  if depth == self.max_depth:
@@ -372,29 +380,15 @@ class WarpGBM(BaseEstimator, RegressorMixin):
372
380
  gradient_histogram, hessian_histogram
373
381
  )
374
382
 
375
- # print(local_feature, best_bin)
376
-
377
383
  if local_feature == -1:
378
384
  leaf_value = self.residual[node_indices].mean()
379
385
  self.gradients[node_indices] += self.learning_rate * leaf_value
380
386
  return {"leaf_value": leaf_value.item(), "samples": parent_size}
381
387
 
382
- # print("DEBUG SHAPES -> bin_indices:", self.bin_indices.shape,
383
- # "| node_indices max:", node_indices.max().item(),
384
- # "| local_feature:", local_feature,
385
- # "| feat_indices_tree len:", len(self.feat_indices_tree),
386
- # "| feat index:", self.feat_indices_tree[local_feature])
387
-
388
388
  split_mask = self.bin_indices[node_indices, self.feat_indices_tree[local_feature]] <= best_bin
389
389
  left_indices = node_indices[split_mask]
390
390
  right_indices = node_indices[~split_mask]
391
391
 
392
- # print("DEBUG SHAPES -> left_indices:", left_indices.shape,
393
- # "| right_indices:", right_indices.shape,
394
- # "| parent_size:", parent_size,
395
- # "| local_feature:", local_feature,
396
- # "| best_bin:", best_bin)
397
-
398
392
  left_size = left_indices.numel()
399
393
  right_size = right_indices.numel()
400
394
 
@@ -463,6 +457,10 @@ class WarpGBM(BaseEstimator, RegressorMixin):
463
457
  k = max(1, int(self.colsample_bytree * self.num_features))
464
458
  else:
465
459
  self.feat_indices_tree = self.feature_indices
460
+ k = self.num_features
461
+
462
+ self.per_era_gain = torch.zeros(self.num_eras, k, self.num_bins-1, device=self.device, dtype=torch.float32)
463
+ self.per_era_direction = torch.zeros(self.num_eras, k, self.num_bins-1, device=self.device, dtype=torch.float32)
466
464
 
467
465
  for i in range(self.n_estimators):
468
466
  self.residual = self.Y_gpu - self.gradients
@@ -0,0 +1,89 @@
1
+ #include <torch/extension.h>
2
+ #include <cuda.h>
3
+ #include <cuda_runtime.h>
4
+
5
+ __global__ void directional_split_kernel(
6
+ const float *__restrict__ G, // [E * F * B]
7
+ const float *__restrict__ H, // [E * F * B]
8
+ int E, int F, int B,
9
+ float min_split_gain,
10
+ float min_child_samples,
11
+ float eps,
12
+ float *__restrict__ per_era_gain, // [E * F * (B-1)]
13
+ float *__restrict__ per_era_direction // [E * F * (B-1)]
14
+ )
15
+ {
16
+ int f = blockIdx.x * blockDim.x + threadIdx.x; // feature index
17
+ int e = blockIdx.y; // era index
18
+
19
+ if (f >= F || e >= E) return;
20
+
21
+ // Access base offset for this (era, feature)
22
+ int base = e * F * B + f * B;
23
+ int base_gain = e * F * (B - 1) + f * (B - 1);
24
+
25
+ float G_total = 0.0f, H_total = 0.0f;
26
+ for (int b = 0; b < B; ++b) {
27
+ G_total += G[base + b];
28
+ H_total += H[base + b];
29
+ }
30
+
31
+ float G_L = 0.0f, H_L = 0.0f;
32
+ for (int b = 0; b < B - 1; ++b) {
33
+ G_L += G[base + b];
34
+ H_L += H[base + b];
35
+
36
+ float G_R = G_total - G_L;
37
+ float H_R = H_total - H_L;
38
+
39
+ float gain = 0.0f;
40
+ float dir = 0.0f;
41
+
42
+ if (H_L >= min_child_samples && H_R >= min_child_samples) {
43
+ gain = (G_L * G_L) / (H_L + eps)
44
+ + (G_R * G_R) / (H_R + eps)
45
+ - (G_total * G_total) / (H_total + eps);
46
+
47
+ float left_val = G_L / (H_L + eps);
48
+ float right_val = G_R / (H_R + eps);
49
+ dir = (left_val > right_val) ? 1.0f : -1.0f;
50
+ }
51
+
52
+ per_era_gain[base_gain + b] = gain;
53
+ per_era_direction[base_gain + b] = dir;
54
+ }
55
+ }
56
+
57
+ void launch_directional_split_kernel(
58
+ const at::Tensor &G, // [E, F, B]
59
+ const at::Tensor &H, // [E, F, B]
60
+ float min_split_gain,
61
+ float min_child_samples,
62
+ float eps,
63
+ at::Tensor &per_era_gain, // [E, F, B]
64
+ at::Tensor &per_era_direction, // [E, F, B]
65
+ int threads = 128)
66
+ {
67
+ int E = G.size(0);
68
+ int F = G.size(1);
69
+ int B = G.size(2);
70
+
71
+ dim3 blocks((F + threads - 1) / threads, E); // (feature blocks, era grid)
72
+ dim3 thread_dims(threads);
73
+
74
+ directional_split_kernel<<<blocks, thread_dims>>>(
75
+ G.data_ptr<float>(),
76
+ H.data_ptr<float>(),
77
+ E, F, B,
78
+ min_split_gain,
79
+ min_child_samples,
80
+ eps,
81
+ per_era_gain.data_ptr<float>(),
82
+ per_era_direction.data_ptr<float>());
83
+
84
+ cudaError_t err = cudaGetLastError();
85
+ if (err != cudaSuccess) {
86
+ printf("Directional split kernel launch failed: %s\n", cudaGetErrorString(err));
87
+ }
88
+ }
89
+
@@ -3,13 +3,14 @@
3
3
  #include <torch/extension.h>
4
4
 
5
5
  __global__ void histogram_tiled_configurable_kernel(
6
- const int8_t *__restrict__ bin_indices, // [N, F]
6
+ const int8_t *__restrict__ bin_indices, // [N, F_master]
7
7
  const float *__restrict__ residuals, // [N]
8
8
  const int32_t *__restrict__ sample_indices, // [N]
9
9
  const int32_t *__restrict__ feature_indices, // [F]
10
+ const int32_t *__restrict__ era_indices, // [N]
10
11
  float *__restrict__ grad_hist, // [F * B]
11
12
  float *__restrict__ hess_hist, // [F * B]
12
- int64_t N, int64_t F, int64_t B,
13
+ int64_t N, int64_t F_master, int64_t F, int64_t B, int64_t num_eras,
13
14
  int rows_per_thread)
14
15
  {
15
16
  int hist_feat_idx = blockIdx.x;
@@ -17,15 +18,15 @@ __global__ void histogram_tiled_configurable_kernel(
17
18
  int row_start = (blockIdx.y * blockDim.x + threadIdx.x) * rows_per_thread;
18
19
 
19
20
  extern __shared__ float shmem[];
20
- float *sh_grad = shmem; // [B]
21
- float *sh_hess = &sh_grad[B]; // [B]
21
+ float *sh_grad = shmem; // [num_eras * B]
22
+ float *sh_hess = &sh_grad[num_eras * B]; // [num_eras * B]
22
23
 
23
24
  // Initialize shared memory histograms
24
- for (int b = threadIdx.x; b < B; b += blockDim.x)
25
- {
26
- sh_grad[b] = 0.0f;
27
- sh_hess[b] = 0.0f;
25
+ for (int i = threadIdx.x; i < num_eras * B; i += blockDim.x) {
26
+ sh_grad[i] = 0.0f;
27
+ sh_hess[i] = 0.0f;
28
28
  }
29
+
29
30
  __syncthreads();
30
31
 
31
32
  // Each thread processes multiple rows
@@ -35,23 +36,28 @@ __global__ void histogram_tiled_configurable_kernel(
35
36
  if (row < N)
36
37
  {
37
38
  int sample = sample_indices[row];
38
- int8_t bin = bin_indices[sample * F + feat];
39
+ int8_t bin = bin_indices[sample * F_master + feat];
40
+ int32_t era = era_indices[sample];
39
41
  if (bin >= 0 && bin < B)
40
42
  {
41
- atomicAdd(&sh_grad[bin], residuals[sample]);
42
- atomicAdd(&sh_hess[bin], 1.0f);
43
+ atomicAdd(&sh_grad[era * B + bin], residuals[sample]);
44
+ atomicAdd(&sh_hess[era * B + bin], 1.0f);
43
45
  }
44
46
  }
45
47
  }
46
48
  __syncthreads();
47
49
 
48
50
  // One thread per bin writes results back to global memory
49
- for (int b = threadIdx.x; b < B; b += blockDim.x)
51
+ for (int b = threadIdx.x; b < num_eras * B; b += blockDim.x)
50
52
  {
51
- int64_t idx = hist_feat_idx * B + b;
53
+ int e = b / B;
54
+ int bin = b % B;
55
+ int64_t idx = e * F * B + hist_feat_idx * B + bin;
56
+
52
57
  atomicAdd(&grad_hist[idx], sh_grad[b]);
53
58
  atomicAdd(&hess_hist[idx], sh_hess[b]);
54
59
  }
60
+
55
61
  }
56
62
 
57
63
  void launch_histogram_kernel_cuda_configurable(
@@ -59,6 +65,7 @@ void launch_histogram_kernel_cuda_configurable(
59
65
  const at::Tensor &residuals,
60
66
  const at::Tensor &sample_indices,
61
67
  const at::Tensor &feature_indices,
68
+ const at::Tensor &era_indices,
62
69
  at::Tensor &grad_hist,
63
70
  at::Tensor &hess_hist,
64
71
  int num_bins,
@@ -75,16 +82,18 @@ void launch_histogram_kernel_cuda_configurable(
75
82
 
76
83
  dim3 blocks(F, row_tiles); // grid.x = F, grid.y = row_tiles
77
84
  dim3 threads(threads_per_block);
78
- int shared_mem_bytes = 2 * num_bins * sizeof(float);
85
+ int num_eras = grad_hist.size(0); // inferred from output tensor
86
+ int shared_mem_bytes = 2 * num_eras * num_bins * sizeof(float);
79
87
 
80
88
  histogram_tiled_configurable_kernel<<<blocks, threads, shared_mem_bytes>>>(
81
89
  bin_indices.data_ptr<int8_t>(),
82
90
  residuals.data_ptr<float>(),
83
91
  sample_indices.data_ptr<int32_t>(),
84
92
  feature_indices.data_ptr<int32_t>(),
93
+ era_indices.data_ptr<int32_t>(),
85
94
  grad_hist.data_ptr<float>(),
86
95
  hess_hist.data_ptr<float>(),
87
- N, num_features_master, num_bins,
96
+ N, num_features_master, F, num_bins, num_eras,
88
97
  rows_per_thread);
89
98
 
90
99
  cudaError_t err = cudaGetLastError();
@@ -3,21 +3,22 @@
3
3
 
4
4
  // Declare the function from histogram_kernel.cu
5
5
 
6
- void launch_best_split_kernel_cuda(
7
- const at::Tensor &G, // [F x B]
8
- const at::Tensor &H, // [F x B]
6
+ void launch_directional_split_kernel(
7
+ const at::Tensor &G, // [E, F, B]
8
+ const at::Tensor &H, // [E, F, B]
9
9
  float min_split_gain,
10
10
  float min_child_samples,
11
11
  float eps,
12
- at::Tensor &best_gains, // [F], float32
13
- at::Tensor &best_bins,
14
- int threads);
12
+ at::Tensor &per_era_gain, // [E, F, B]
13
+ at::Tensor &per_era_direction, // [E, F, B]
14
+ int threads = 128);
15
15
 
16
16
  void launch_histogram_kernel_cuda_configurable(
17
17
  const at::Tensor &bin_indices,
18
- const at::Tensor &residual,
18
+ const at::Tensor &residuals,
19
19
  const at::Tensor &sample_indices,
20
20
  const at::Tensor &feature_indices,
21
+ const at::Tensor &era_indices,
21
22
  at::Tensor &grad_hist,
22
23
  at::Tensor &hess_hist,
23
24
  int num_bins,
@@ -40,7 +41,7 @@ void predict_with_forest(
40
41
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
41
42
  {
42
43
  m.def("compute_histogram3", &launch_histogram_kernel_cuda_configurable, "Histogram Feature Shared Mem");
43
- m.def("compute_split", &launch_best_split_kernel_cuda, "Best Split (CUDA)");
44
+ m.def("compute_split", &launch_directional_split_kernel, "Best Split (CUDA)");
44
45
  m.def("custom_cuda_binner", &launch_bin_column_kernel, "Custom CUDA binning kernel");
45
46
  m.def("predict_forest", &predict_with_forest, "CUDA Predictions");
46
47
  }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: warpgbm
3
- Version: 0.1.27
3
+ Version: 1.0.0
4
4
  Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
5
5
  License: GNU GENERAL PUBLIC LICENSE
6
6
  Version 3, 29 June 2007
@@ -686,21 +686,46 @@ Requires-Dist: tqdm
686
686
  Requires-Dist: scikit-learn
687
687
  Dynamic: license-file
688
688
 
689
- ![warpgbm](https://github.com/user-attachments/assets/dee9de16-091b-49c1-a8fa-2b4ab6891184)
690
-
689
+ ![raw](https://github.com/user-attachments/assets/924844ef-2536-4bde-a330-5e30f6b0762c)
691
690
 
692
691
  # WarpGBM
693
692
 
694
693
  WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (GBDT) library built with PyTorch and CUDA. It offers blazing-fast histogram-based training and efficient prediction, with compatibility for research and production workflows.
695
694
 
695
+ **New in v1.0.0:** WarpGBM introduces *Invariant Gradient Boosting* — a powerful approach to learning signals that remain stable across shifting environments (e.g., time, regimes, or datasets). Powered by a novel algorithm called **[Directional Era-Splitting (DES)](https://arxiv.org/abs/2309.14496)**, WarpGBM doesn't just train faster than other leading GBDT libraries — it trains smarter.
696
+
697
+ If your data evolves over time, WarpGBM is the only GBDT library designed to *adapt and generalize*.
696
698
  ---
697
699
 
700
+ ## Contents
701
+
702
+ - [Features](#features)
703
+ - [Benchmarks](#benchmarks)
704
+ - [Installation](#installation)
705
+ - [Learning Invariant Signals Across Environments](#learning-invariant-signals-across-environments)
706
+ - [Why This Matters](#why-this-matters)
707
+ - [Visual Intuition](#visual-intuition)
708
+ - [Key References](#key-references)
709
+ - [Examples](#examples)
710
+ - [Quick Comparison with LightGBM CPU version](#quick-comparison-with-lightgbm-cpu-version)
711
+ - [Pre-binned Data Example (Numerai)](#pre-binned-data-example-numerai)
712
+ - [Documentation](#documentation)
713
+ - [Acknowledgements](#acknowledgements)
714
+ - [Version Notes](#version-notes)
715
+
716
+
698
717
  ## Features
699
718
 
700
- - GPU-accelerated training and histogram construction using custom CUDA kernels
701
- - Drop-in scikit-learn style interface
702
- - Supports pre-binned data or automatic quantile binning
703
- - Simple install with `pip`
719
+ - **Blazing-fast GPU training** with custom CUDA kernels for binning, histogram building, split finding, and prediction
720
+ - **Invariant signal learning** via [Directional Era-Splitting (DES)](https://arxiv.org/abs/2309.14496) — designed for datasets with shifting environments (e.g., time, regimes, experimental settings)
721
+ - Drop-in **scikit-learn style interface** for easy adoption
722
+ - Supports **pre-binned data** or **automatic quantile binning**
723
+ - Works with `float32` or `int8` inputs
724
+ - Built-in **validation and early stopping** support with MSE, RMSLE, or correlation metrics
725
+ - Simple install with `pip`, no custom drivers required
726
+
727
+ > 💡 **Note:** WarpGBM v1.0.0 is a *generalization* of the traditional GBDT algorithm.
728
+ > To run standard GBM training at maximum speed, simply omit the `era_id` argument — WarpGBM will behave like a traditional booster but with industry-leading performance.
704
729
 
705
730
  ---
706
731
 
@@ -762,7 +787,62 @@ Before either method, make sure you’ve installed PyTorch with GPU support:\
762
787
 
763
788
  ---
764
789
 
765
- ## Example
790
+ ## Learning Invariant Signals Across Environments
791
+
792
+ Most supervised learning models rely on an assumption known as the **Empirical Risk Minimization (ERM)** principle. Under ERM, the data distribution connecting inputs \( X \) and targets \( Y \) is assumed to be **fixed** and **stationary** across training, validation, and test splits. That is:
793
+
794
+ > The patterns you learn from the training set are expected to generalize out-of-sample — *as long as the test data follows the same distribution as the training data.*
795
+
796
+ However, this assumption is often violated in real-world settings. Data frequently shifts across time, geography, experimental conditions, or other hidden factors. This phenomenon is known as **distribution shift**, and it leads to models that perform well in-sample but fail catastrophically out-of-sample.
797
+
798
+ This challenge motivates the field of **Out-of-Distribution (OOD) Generalization**, which assumes your data is drawn from **distinct environments or eras** — e.g., time periods, customer segments, experimental trials. Some signals may appear predictive within specific environments but vanish or reverse in others. These are called **spurious signals**. On the other hand, signals that remain consistently predictive across all environments are called **invariant signals**.
799
+
800
+ WarpGBM v1.0.0 introduces **Directional Era-Splitting (DES)**, a new algorithm designed to identify and learn from invariant signals — ignoring signals that fail to generalize across environments.
801
+
802
+ ---
803
+
804
+ ### Why This Matters
805
+
806
+ - Standard models trained via ERM can learn to exploit **spurious correlations** that only hold in some parts of the data.
807
+ - DES explicitly tests whether a feature's split is **directionally consistent** across all eras — only such *invariant splits* are kept.
808
+ - This approach has been shown to reduce overfitting and improve out-of-sample generalization, particularly in financial and scientific datasets.
809
+
810
+ ---
811
+
812
+ ### Visual Intuition
813
+
814
+ We contrast two views of the data:
815
+
816
+ - **ERM Setting**: All data is assumed to come from the same source (single distribution).\
817
+ No awareness of environments — spurious signals can dominate.
818
+
819
+ - **OOD Setting (Era-Splitting)**: Data is explicitly grouped by environment (era).\
820
+ The model checks whether a signal holds across all groups — enforcing **robustness**.
821
+
822
+ *📷 [Placeholder for future visual illustration]*
823
+
824
+
825
+ ---
826
+
827
+ ### Key References
828
+
829
+ - **Invariant Risk Minimization (IRM)**: [Arjovsky et al., 2019](https://arxiv.org/abs/1907.02893)
830
+ - **Learning Explanations That Are Hard to Vary**: [Parascandolo et al., 2020](https://arxiv.org/abs/2009.00329)
831
+ - **Era Splitting: Invariant Learning for Decision Trees**: [DeLise, 2023](https://arxiv.org/abs/2309.14496)
832
+
833
+ ---
834
+
835
+ WarpGBM is the **first open-source GBDT framework to integrate this OOD-aware approach natively**, using efficient CUDA kernels to evaluate per-era consistency during tree growth. It’s not just faster — it’s smarter.
836
+
837
+ ---
838
+
839
+ ## Examples
840
+
841
+ WarpGBM is easy to drop into any supervised learning workflow and comes with curated examples in the `examples/` folder.
842
+
843
+ - `Spiral Data.ipynb`: synthetic OOD benchmark from Learning Explanations That Are Hard to Vary
844
+
845
+ ### Quick Comparison with LightGBM CPU version
766
846
 
767
847
  ```python
768
848
  import numpy as np
@@ -804,7 +884,7 @@ WarpGBM: corr = 0.8621, time = 5.40s
804
884
 
805
885
  ---
806
886
 
807
- ## Pre-binned Data Example (Numerai)
887
+ ### Pre-binned Data Example (Numerai)
808
888
 
809
889
  WarpGBM can save additional training time if your dataset is already pre-binned. The Numerai tournament data is a great example:
810
890
 
@@ -854,16 +934,6 @@ WarpGBM: corr = 0.0660, time = 49.16s
854
934
 
855
935
  ---
856
936
 
857
- ### Run it live in Colab
858
-
859
- You can try WarpGBM in a live Colab notebook using real pre-binned Numerai tournament data:
860
-
861
- [Open in Colab](https://colab.research.google.com/drive/10mKSjs9UvmMgM5_lOXAylq5LUQAnNSi7?usp=sharing)
862
-
863
- No installation required — just press **"Open in Playground"**, then **Run All**!
864
-
865
- ---
866
-
867
937
  ## Documentation
868
938
 
869
939
  ### `WarpGBM` Parameters:
@@ -889,7 +959,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
889
959
  y_eval=None, # numpy array (float or int) 1 dimension (eval_num_samples)
890
960
  eval_every_n_trees=None, # const (int) >= 1
891
961
  early_stopping_rounds=None, # const (int) >= 1
892
- eval_metric='mse' # string, one of 'mse' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
962
+ eval_metric='mse' # string, one of 'mse', 'rmsle' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
893
963
  )
894
964
  ```
895
965
  Train with optional validation set and early stopping.
@@ -927,3 +997,7 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
927
997
  ### v0.1.26
928
998
 
929
999
  - Fix Memory bugs in prediction and colsample bytree logic. Added "corr" eval metric.
1000
+
1001
+ ### v1.0.0
1002
+
1003
+ - Introduce invariant learning via directional era splitting (DES). Also streamline VRAM improvements over previous sub versions.
@@ -8,6 +8,7 @@ tests/__init__.py
8
8
  tests/full_numerai_test.py
9
9
  tests/numerai_test.py
10
10
  tests/test_fit_predict_corr.py
11
+ tests/test_invariant.py
11
12
  warpgbm/__init__.py
12
13
  warpgbm/core.py
13
14
  warpgbm/metrics.py
@@ -1 +0,0 @@
1
- 0.1.27
@@ -1,79 +0,0 @@
1
- #include <torch/extension.h>
2
- #include <cuda.h>
3
- #include <cuda_runtime.h>
4
-
5
- __global__ void best_split_kernel_global_only(
6
- const float *__restrict__ G, // [F x B]
7
- const float *__restrict__ H, // [F x B]
8
- int F,
9
- int B,
10
- float min_split_gain,
11
- float min_child_samples,
12
- float eps,
13
- float *__restrict__ best_gains, // [F]
14
- int *__restrict__ best_bins // [F]
15
- )
16
- {
17
- int f = blockIdx.x * blockDim.x + threadIdx.x;
18
- if (f >= F)
19
- return;
20
-
21
- float G_total = 0.0f, H_total = 0.0f;
22
- for (int b = 0; b < B; ++b)
23
- {
24
- G_total += G[f * B + b];
25
- H_total += H[f * B + b];
26
- }
27
-
28
- float G_L = 0.0f, H_L = 0.0f;
29
- float best_gain = min_split_gain;
30
- int best_bin = -1;
31
-
32
- for (int b = 0; b < B - 1; ++b)
33
- {
34
- G_L += G[f * B + b];
35
- H_L += H[f * B + b];
36
- float G_R = G_total - G_L;
37
- float H_R = H_total - H_L;
38
-
39
- if (H_L >= min_child_samples && H_R >= min_child_samples)
40
- {
41
- float gain = (G_L * G_L) / (H_L + eps) + (G_R * G_R) / (H_R + eps) - (G_total * G_total) / (H_total + eps);
42
- if (gain > best_gain)
43
- {
44
- best_gain = gain;
45
- best_bin = b;
46
- }
47
- }
48
- }
49
-
50
- best_gains[f] = best_gain;
51
- best_bins[f] = best_bin;
52
- }
53
-
54
- void launch_best_split_kernel_cuda(
55
- const at::Tensor &G, // [F x B]
56
- const at::Tensor &H, // [F x B]
57
- float min_split_gain,
58
- float min_child_samples,
59
- float eps,
60
- at::Tensor &best_gains, // [F], float32
61
- at::Tensor &best_bins, // [F], int32
62
- int threads)
63
- {
64
- int F = G.size(0);
65
- int B = G.size(1);
66
-
67
- int blocks = (F + threads - 1) / threads;
68
-
69
- best_split_kernel_global_only<<<blocks, threads>>>(
70
- G.data_ptr<float>(),
71
- H.data_ptr<float>(),
72
- F,
73
- B,
74
- min_split_gain,
75
- min_child_samples,
76
- eps,
77
- best_gains.data_ptr<float>(),
78
- best_bins.data_ptr<int>());
79
- }
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes