warpgbm 0.1.26__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {warpgbm-0.1.26/warpgbm.egg-info → warpgbm-1.0.0}/PKG-INFO +94 -20
  2. {warpgbm-0.1.26 → warpgbm-1.0.0}/README.md +94 -20
  3. {warpgbm-0.1.26 → warpgbm-1.0.0}/pyproject.toml +2 -2
  4. warpgbm-1.0.0/tests/full_numerai_test.py +67 -0
  5. warpgbm-1.0.0/tests/test_fit_predict_corr.py +52 -0
  6. warpgbm-1.0.0/tests/test_invariant.py +100 -0
  7. warpgbm-1.0.0/version.txt +1 -0
  8. {warpgbm-0.1.26 → warpgbm-1.0.0}/warpgbm/core.py +85 -89
  9. warpgbm-1.0.0/warpgbm/cuda/best_split_kernel.cu +89 -0
  10. warpgbm-1.0.0/warpgbm/cuda/histogram_kernel.cu +104 -0
  11. warpgbm-1.0.0/warpgbm/cuda/node_kernel.cpp +47 -0
  12. warpgbm-1.0.0/warpgbm/metrics.py +10 -0
  13. {warpgbm-0.1.26 → warpgbm-1.0.0/warpgbm.egg-info}/PKG-INFO +94 -20
  14. {warpgbm-0.1.26 → warpgbm-1.0.0}/warpgbm.egg-info/SOURCES.txt +3 -0
  15. warpgbm-0.1.26/tests/test_fit_predict_corr.py +0 -58
  16. warpgbm-0.1.26/version.txt +0 -1
  17. warpgbm-0.1.26/warpgbm/cuda/best_split_kernel.cu +0 -79
  18. warpgbm-0.1.26/warpgbm/cuda/histogram_kernel.cu +0 -250
  19. warpgbm-0.1.26/warpgbm/cuda/node_kernel.cpp +0 -63
  20. {warpgbm-0.1.26 → warpgbm-1.0.0}/LICENSE +0 -0
  21. {warpgbm-0.1.26 → warpgbm-1.0.0}/MANIFEST.in +0 -0
  22. {warpgbm-0.1.26 → warpgbm-1.0.0}/setup.cfg +0 -0
  23. {warpgbm-0.1.26 → warpgbm-1.0.0}/setup.py +0 -0
  24. {warpgbm-0.1.26 → warpgbm-1.0.0}/tests/__init__.py +0 -0
  25. {warpgbm-0.1.26 → warpgbm-1.0.0}/tests/numerai_test.py +0 -0
  26. {warpgbm-0.1.26 → warpgbm-1.0.0}/warpgbm/__init__.py +0 -0
  27. {warpgbm-0.1.26 → warpgbm-1.0.0}/warpgbm/cuda/__init__.py +0 -0
  28. {warpgbm-0.1.26 → warpgbm-1.0.0}/warpgbm/cuda/binner.cu +0 -0
  29. {warpgbm-0.1.26 → warpgbm-1.0.0}/warpgbm/cuda/predict.cu +0 -0
  30. {warpgbm-0.1.26 → warpgbm-1.0.0}/warpgbm.egg-info/dependency_links.txt +0 -0
  31. {warpgbm-0.1.26 → warpgbm-1.0.0}/warpgbm.egg-info/requires.txt +0 -0
  32. {warpgbm-0.1.26 → warpgbm-1.0.0}/warpgbm.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: warpgbm
3
- Version: 0.1.26
3
+ Version: 1.0.0
4
4
  Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
5
5
  License: GNU GENERAL PUBLIC LICENSE
6
6
  Version 3, 29 June 2007
@@ -686,21 +686,46 @@ Requires-Dist: tqdm
686
686
  Requires-Dist: scikit-learn
687
687
  Dynamic: license-file
688
688
 
689
- ![warpgbm](https://github.com/user-attachments/assets/dee9de16-091b-49c1-a8fa-2b4ab6891184)
690
-
689
+ ![raw](https://github.com/user-attachments/assets/924844ef-2536-4bde-a330-5e30f6b0762c)
691
690
 
692
691
  # WarpGBM
693
692
 
694
693
  WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (GBDT) library built with PyTorch and CUDA. It offers blazing-fast histogram-based training and efficient prediction, with compatibility for research and production workflows.
695
694
 
695
+ **New in v1.0.0:** WarpGBM introduces *Invariant Gradient Boosting* — a powerful approach to learning signals that remain stable across shifting environments (e.g., time, regimes, or datasets). Powered by a novel algorithm called **[Directional Era-Splitting (DES)](https://arxiv.org/abs/2309.14496)**, WarpGBM doesn't just train faster than other leading GBDT libraries — it trains smarter.
696
+
697
+ If your data evolves over time, WarpGBM is the only GBDT library designed to *adapt and generalize*.
696
698
  ---
697
699
 
700
+ ## Contents
701
+
702
+ - [Features](#features)
703
+ - [Benchmarks](#benchmarks)
704
+ - [Installation](#installation)
705
+ - [Learning Invariant Signals Across Environments](#learning-invariant-signals-across-environments)
706
+ - [Why This Matters](#why-this-matters)
707
+ - [Visual Intuition](#visual-intuition)
708
+ - [Key References](#key-references)
709
+ - [Examples](#examples)
710
+ - [Quick Comparison with LightGBM CPU version](#quick-comparison-with-lightgbm-cpu-version)
711
+ - [Pre-binned Data Example (Numerai)](#pre-binned-data-example-numerai)
712
+ - [Documentation](#documentation)
713
+ - [Acknowledgements](#acknowledgements)
714
+ - [Version Notes](#version-notes)
715
+
716
+
698
717
  ## Features
699
718
 
700
- - GPU-accelerated training and histogram construction using custom CUDA kernels
701
- - Drop-in scikit-learn style interface
702
- - Supports pre-binned data or automatic quantile binning
703
- - Simple install with `pip`
719
+ - **Blazing-fast GPU training** with custom CUDA kernels for binning, histogram building, split finding, and prediction
720
+ - **Invariant signal learning** via [Directional Era-Splitting (DES)](https://arxiv.org/abs/2309.14496) — designed for datasets with shifting environments (e.g., time, regimes, experimental settings)
721
+ - Drop-in **scikit-learn style interface** for easy adoption
722
+ - Supports **pre-binned data** or **automatic quantile binning**
723
+ - Works with `float32` or `int8` inputs
724
+ - Built-in **validation and early stopping** support with MSE, RMSLE, or correlation metrics
725
+ - Simple install with `pip`, no custom drivers required
726
+
727
+ > 💡 **Note:** WarpGBM v1.0.0 is a *generalization* of the traditional GBDT algorithm.
728
+ > To run standard GBM training at maximum speed, simply omit the `era_id` argument — WarpGBM will behave like a traditional booster but with industry-leading performance.
704
729
 
705
730
  ---
706
731
 
@@ -762,7 +787,62 @@ Before either method, make sure you’ve installed PyTorch with GPU support:\
762
787
 
763
788
  ---
764
789
 
765
- ## Example
790
+ ## Learning Invariant Signals Across Environments
791
+
792
+ Most supervised learning models rely on an assumption known as the **Empirical Risk Minimization (ERM)** principle. Under ERM, the data distribution connecting inputs \( X \) and targets \( Y \) is assumed to be **fixed** and **stationary** across training, validation, and test splits. That is:
793
+
794
+ > The patterns you learn from the training set are expected to generalize out-of-sample — *as long as the test data follows the same distribution as the training data.*
795
+
796
+ However, this assumption is often violated in real-world settings. Data frequently shifts across time, geography, experimental conditions, or other hidden factors. This phenomenon is known as **distribution shift**, and it leads to models that perform well in-sample but fail catastrophically out-of-sample.
797
+
798
+ This challenge motivates the field of **Out-of-Distribution (OOD) Generalization**, which assumes your data is drawn from **distinct environments or eras** — e.g., time periods, customer segments, experimental trials. Some signals may appear predictive within specific environments but vanish or reverse in others. These are called **spurious signals**. On the other hand, signals that remain consistently predictive across all environments are called **invariant signals**.
799
+
800
+ WarpGBM v1.0.0 introduces **Directional Era-Splitting (DES)**, a new algorithm designed to identify and learn from invariant signals — ignoring signals that fail to generalize across environments.
801
+
802
+ ---
803
+
804
+ ### Why This Matters
805
+
806
+ - Standard models trained via ERM can learn to exploit **spurious correlations** that only hold in some parts of the data.
807
+ - DES explicitly tests whether a feature's split is **directionally consistent** across all eras — only such *invariant splits* are kept.
808
+ - This approach has been shown to reduce overfitting and improve out-of-sample generalization, particularly in financial and scientific datasets.
809
+
810
+ ---
811
+
812
+ ### Visual Intuition
813
+
814
+ We contrast two views of the data:
815
+
816
+ - **ERM Setting**: All data is assumed to come from the same source (single distribution).\
817
+ No awareness of environments — spurious signals can dominate.
818
+
819
+ - **OOD Setting (Era-Splitting)**: Data is explicitly grouped by environment (era).\
820
+ The model checks whether a signal holds across all groups — enforcing **robustness**.
821
+
822
+ *📷 [Placeholder for future visual illustration]*
823
+
824
+
825
+ ---
826
+
827
+ ### Key References
828
+
829
+ - **Invariant Risk Minimization (IRM)**: [Arjovsky et al., 2019](https://arxiv.org/abs/1907.02893)
830
+ - **Learning Explanations That Are Hard to Vary**: [Parascandolo et al., 2020](https://arxiv.org/abs/2009.00329)
831
+ - **Era Splitting: Invariant Learning for Decision Trees**: [DeLise, 2023](https://arxiv.org/abs/2309.14496)
832
+
833
+ ---
834
+
835
+ WarpGBM is the **first open-source GBDT framework to integrate this OOD-aware approach natively**, using efficient CUDA kernels to evaluate per-era consistency during tree growth. It’s not just faster — it’s smarter.
836
+
837
+ ---
838
+
839
+ ## Examples
840
+
841
+ WarpGBM is easy to drop into any supervised learning workflow and comes with curated examples in the `examples/` folder.
842
+
843
+ - `Spiral Data.ipynb`: synthetic OOD benchmark from Learning Explanations That Are Hard to Vary
844
+
845
+ ### Quick Comparison with LightGBM CPU version
766
846
 
767
847
  ```python
768
848
  import numpy as np
@@ -804,7 +884,7 @@ WarpGBM: corr = 0.8621, time = 5.40s
804
884
 
805
885
  ---
806
886
 
807
- ## Pre-binned Data Example (Numerai)
887
+ ### Pre-binned Data Example (Numerai)
808
888
 
809
889
  WarpGBM can save additional training time if your dataset is already pre-binned. The Numerai tournament data is a great example:
810
890
 
@@ -854,16 +934,6 @@ WarpGBM: corr = 0.0660, time = 49.16s
854
934
 
855
935
  ---
856
936
 
857
- ### Run it live in Colab
858
-
859
- You can try WarpGBM in a live Colab notebook using real pre-binned Numerai tournament data:
860
-
861
- [Open in Colab](https://colab.research.google.com/drive/10mKSjs9UvmMgM5_lOXAylq5LUQAnNSi7?usp=sharing)
862
-
863
- No installation required — just press **"Open in Playground"**, then **Run All**!
864
-
865
- ---
866
-
867
937
  ## Documentation
868
938
 
869
939
  ### `WarpGBM` Parameters:
@@ -889,7 +959,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
889
959
  y_eval=None, # numpy array (float or int) 1 dimension (eval_num_samples)
890
960
  eval_every_n_trees=None, # const (int) >= 1
891
961
  early_stopping_rounds=None, # const (int) >= 1
892
- eval_metric='mse' # string, one of 'mse' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
962
+ eval_metric='mse' # string, one of 'mse', 'rmsle' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
893
963
  )
894
964
  ```
895
965
  Train with optional validation set and early stopping.
@@ -927,3 +997,7 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
927
997
  ### v0.1.26
928
998
 
929
999
  - Fix Memory bugs in prediction and colsample bytree logic. Added "corr" eval metric.
1000
+
1001
+ ### v1.0.0
1002
+
1003
+ - Introduce invariant learning via directional era splitting (DES). Also streamline VRAM improvements over previous sub versions.
@@ -1,18 +1,43 @@
1
- ![warpgbm](https://github.com/user-attachments/assets/dee9de16-091b-49c1-a8fa-2b4ab6891184)
2
-
1
+ ![raw](https://github.com/user-attachments/assets/924844ef-2536-4bde-a330-5e30f6b0762c)
3
2
 
4
3
  # WarpGBM
5
4
 
6
5
  WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (GBDT) library built with PyTorch and CUDA. It offers blazing-fast histogram-based training and efficient prediction, with compatibility for research and production workflows.
7
6
 
7
+ **New in v1.0.0:** WarpGBM introduces *Invariant Gradient Boosting* — a powerful approach to learning signals that remain stable across shifting environments (e.g., time, regimes, or datasets). Powered by a novel algorithm called **[Directional Era-Splitting (DES)](https://arxiv.org/abs/2309.14496)**, WarpGBM doesn't just train faster than other leading GBDT libraries — it trains smarter.
8
+
9
+ If your data evolves over time, WarpGBM is the only GBDT library designed to *adapt and generalize*.
8
10
  ---
9
11
 
12
+ ## Contents
13
+
14
+ - [Features](#features)
15
+ - [Benchmarks](#benchmarks)
16
+ - [Installation](#installation)
17
+ - [Learning Invariant Signals Across Environments](#learning-invariant-signals-across-environments)
18
+ - [Why This Matters](#why-this-matters)
19
+ - [Visual Intuition](#visual-intuition)
20
+ - [Key References](#key-references)
21
+ - [Examples](#examples)
22
+ - [Quick Comparison with LightGBM CPU version](#quick-comparison-with-lightgbm-cpu-version)
23
+ - [Pre-binned Data Example (Numerai)](#pre-binned-data-example-numerai)
24
+ - [Documentation](#documentation)
25
+ - [Acknowledgements](#acknowledgements)
26
+ - [Version Notes](#version-notes)
27
+
28
+
10
29
  ## Features
11
30
 
12
- - GPU-accelerated training and histogram construction using custom CUDA kernels
13
- - Drop-in scikit-learn style interface
14
- - Supports pre-binned data or automatic quantile binning
15
- - Simple install with `pip`
31
+ - **Blazing-fast GPU training** with custom CUDA kernels for binning, histogram building, split finding, and prediction
32
+ - **Invariant signal learning** via [Directional Era-Splitting (DES)](https://arxiv.org/abs/2309.14496) — designed for datasets with shifting environments (e.g., time, regimes, experimental settings)
33
+ - Drop-in **scikit-learn style interface** for easy adoption
34
+ - Supports **pre-binned data** or **automatic quantile binning**
35
+ - Works with `float32` or `int8` inputs
36
+ - Built-in **validation and early stopping** support with MSE, RMSLE, or correlation metrics
37
+ - Simple install with `pip`, no custom drivers required
38
+
39
+ > 💡 **Note:** WarpGBM v1.0.0 is a *generalization* of the traditional GBDT algorithm.
40
+ > To run standard GBM training at maximum speed, simply omit the `era_id` argument — WarpGBM will behave like a traditional booster but with industry-leading performance.
16
41
 
17
42
  ---
18
43
 
@@ -74,7 +99,62 @@ Before either method, make sure you’ve installed PyTorch with GPU support:\
74
99
 
75
100
  ---
76
101
 
77
- ## Example
102
+ ## Learning Invariant Signals Across Environments
103
+
104
+ Most supervised learning models rely on an assumption known as the **Empirical Risk Minimization (ERM)** principle. Under ERM, the data distribution connecting inputs \( X \) and targets \( Y \) is assumed to be **fixed** and **stationary** across training, validation, and test splits. That is:
105
+
106
+ > The patterns you learn from the training set are expected to generalize out-of-sample — *as long as the test data follows the same distribution as the training data.*
107
+
108
+ However, this assumption is often violated in real-world settings. Data frequently shifts across time, geography, experimental conditions, or other hidden factors. This phenomenon is known as **distribution shift**, and it leads to models that perform well in-sample but fail catastrophically out-of-sample.
109
+
110
+ This challenge motivates the field of **Out-of-Distribution (OOD) Generalization**, which assumes your data is drawn from **distinct environments or eras** — e.g., time periods, customer segments, experimental trials. Some signals may appear predictive within specific environments but vanish or reverse in others. These are called **spurious signals**. On the other hand, signals that remain consistently predictive across all environments are called **invariant signals**.
111
+
112
+ WarpGBM v1.0.0 introduces **Directional Era-Splitting (DES)**, a new algorithm designed to identify and learn from invariant signals — ignoring signals that fail to generalize across environments.
113
+
114
+ ---
115
+
116
+ ### Why This Matters
117
+
118
+ - Standard models trained via ERM can learn to exploit **spurious correlations** that only hold in some parts of the data.
119
+ - DES explicitly tests whether a feature's split is **directionally consistent** across all eras — only such *invariant splits* are kept.
120
+ - This approach has been shown to reduce overfitting and improve out-of-sample generalization, particularly in financial and scientific datasets.
121
+
122
+ ---
123
+
124
+ ### Visual Intuition
125
+
126
+ We contrast two views of the data:
127
+
128
+ - **ERM Setting**: All data is assumed to come from the same source (single distribution).\
129
+ No awareness of environments — spurious signals can dominate.
130
+
131
+ - **OOD Setting (Era-Splitting)**: Data is explicitly grouped by environment (era).\
132
+ The model checks whether a signal holds across all groups — enforcing **robustness**.
133
+
134
+ *📷 [Placeholder for future visual illustration]*
135
+
136
+
137
+ ---
138
+
139
+ ### Key References
140
+
141
+ - **Invariant Risk Minimization (IRM)**: [Arjovsky et al., 2019](https://arxiv.org/abs/1907.02893)
142
+ - **Learning Explanations That Are Hard to Vary**: [Parascandolo et al., 2020](https://arxiv.org/abs/2009.00329)
143
+ - **Era Splitting: Invariant Learning for Decision Trees**: [DeLise, 2023](https://arxiv.org/abs/2309.14496)
144
+
145
+ ---
146
+
147
+ WarpGBM is the **first open-source GBDT framework to integrate this OOD-aware approach natively**, using efficient CUDA kernels to evaluate per-era consistency during tree growth. It’s not just faster — it’s smarter.
148
+
149
+ ---
150
+
151
+ ## Examples
152
+
153
+ WarpGBM is easy to drop into any supervised learning workflow and comes with curated examples in the `examples/` folder.
154
+
155
+ - `Spiral Data.ipynb`: synthetic OOD benchmark from Learning Explanations That Are Hard to Vary
156
+
157
+ ### Quick Comparison with LightGBM CPU version
78
158
 
79
159
  ```python
80
160
  import numpy as np
@@ -116,7 +196,7 @@ WarpGBM: corr = 0.8621, time = 5.40s
116
196
 
117
197
  ---
118
198
 
119
- ## Pre-binned Data Example (Numerai)
199
+ ### Pre-binned Data Example (Numerai)
120
200
 
121
201
  WarpGBM can save additional training time if your dataset is already pre-binned. The Numerai tournament data is a great example:
122
202
 
@@ -166,16 +246,6 @@ WarpGBM: corr = 0.0660, time = 49.16s
166
246
 
167
247
  ---
168
248
 
169
- ### Run it live in Colab
170
-
171
- You can try WarpGBM in a live Colab notebook using real pre-binned Numerai tournament data:
172
-
173
- [Open in Colab](https://colab.research.google.com/drive/10mKSjs9UvmMgM5_lOXAylq5LUQAnNSi7?usp=sharing)
174
-
175
- No installation required — just press **"Open in Playground"**, then **Run All**!
176
-
177
- ---
178
-
179
249
  ## Documentation
180
250
 
181
251
  ### `WarpGBM` Parameters:
@@ -201,7 +271,7 @@ No installation required — just press **"Open in Playground"**, then **Run All
201
271
  y_eval=None, # numpy array (float or int) 1 dimension (eval_num_samples)
202
272
  eval_every_n_trees=None, # const (int) >= 1
203
273
  early_stopping_rounds=None, # const (int) >= 1
204
- eval_metric='mse' # string, one of 'mse' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
274
+ eval_metric='mse' # string, one of 'mse', 'rmsle' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
205
275
  )
206
276
  ```
207
277
  Train with optional validation set and early stopping.
@@ -238,4 +308,8 @@ WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA
238
308
 
239
309
  ### v0.1.26
240
310
 
241
- - Fix Memory bugs in prediction and colsample bytree logic. Added "corr" eval metric.
311
+ - Fix Memory bugs in prediction and colsample bytree logic. Added "corr" eval metric.
312
+
313
+ ### v1.0.0
314
+
315
+ - Introduce invariant learning via directional era splitting (DES). Also streamline VRAM improvements over previous sub versions.
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "warpgbm"
7
- version = "0.1.26"
7
+ version = "1.0.0"
8
8
  description = "A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -13,5 +13,5 @@ dependencies = [
13
13
  "torch",
14
14
  "numpy",
15
15
  "tqdm",
16
- "scikit-learn"
16
+ "scikit-learn"
17
17
  ]
@@ -0,0 +1,67 @@
1
+ from numerapi import NumerAPI
2
+ import pandas as pd
3
+ import numpy as np
4
+ from warpgbm import WarpGBM
5
+ import time
6
+ from sklearn.metrics import mean_squared_error
7
+
8
+
9
+ def predict_in_chunks(model, X, chunk_size=100_000):
10
+ preds = []
11
+ for i in range(0, X.shape[0], chunk_size):
12
+ X_chunk = X[i : i + chunk_size]
13
+ preds.append(model.predict(X_chunk))
14
+ return np.concatenate(preds)
15
+
16
+
17
+ def test_numerai_data():
18
+ napi = NumerAPI()
19
+ napi.download_dataset("v5.0/train.parquet", "numerai_train.parquet")
20
+ napi.download_dataset("v5.0/validation.parquet", "numerai_validation.parquet")
21
+
22
+ data = pd.concat([
23
+ pd.read_parquet("numerai_train.parquet"),
24
+ pd.read_parquet("numerai_validation.parquet")
25
+ ])
26
+ features = [f for f in list(data) if "feature" in f]
27
+ target = "target"
28
+ data = data.loc[data[ target].isna() == False ]
29
+
30
+ X = data[features].astype("int8").values[:]
31
+ y = data[target].values
32
+
33
+ model = WarpGBM(
34
+ max_depth=3,
35
+ num_bins=5,
36
+ n_estimators=10,
37
+ learning_rate=1,
38
+ threads_per_block=64,
39
+ rows_per_thread=4,
40
+ colsample_bytree=0.8,
41
+ )
42
+
43
+ start_fit = time.time()
44
+ model.fit(
45
+ X,
46
+ y,
47
+ # era_id=era,
48
+ # X_eval=X,
49
+ # y_eval=y,
50
+ # eval_every_n_trees=10,
51
+ # early_stopping_rounds=1,
52
+ )
53
+ fit_time = time.time() - start_fit
54
+ print(f" Fit time: {fit_time:.3f} seconds")
55
+
56
+ start_pred = time.time()
57
+ preds = predict_in_chunks(model, X, chunk_size=500_000)
58
+ pred_time = time.time() - start_pred
59
+ print(f" Predict time: {pred_time:.3f} seconds")
60
+
61
+ corr = np.corrcoef(preds, y)[0, 1]
62
+ mse = mean_squared_error(preds, y)
63
+ print(f" Correlation: {corr:.4f}")
64
+ print(f" MSE: {mse:.4f}")
65
+
66
+ # assert corr > 0.68, f"In-sample correlation too low: {corr}"
67
+ # assert mse < 0.03, f"In-sample mse too high: {mse}"
@@ -0,0 +1,52 @@
1
+ import numpy as np
2
+ from warpgbm import WarpGBM
3
+ from sklearn.datasets import make_regression
4
+ import time
5
+ from sklearn.metrics import mean_squared_error
6
+
7
+
8
+ def test_fit_predictpytee_correlation():
9
+ np.random.seed(42)
10
+ N = 100_000
11
+ F = 1000
12
+ X, y = make_regression(n_samples=N, n_features=F, noise=0.1, random_state=42)
13
+ era = np.zeros(N, dtype=np.int32)
14
+ corrs = []
15
+ mses = []
16
+
17
+ model = WarpGBM(
18
+ max_depth=10,
19
+ num_bins=10,
20
+ n_estimators=100,
21
+ learning_rate=1,
22
+ threads_per_block=64,
23
+ rows_per_thread=4,
24
+ colsample_bytree=1.0,
25
+ )
26
+
27
+ start_fit = time.time()
28
+ model.fit(
29
+ X,
30
+ y,
31
+ era_id=era,
32
+ X_eval=X,
33
+ y_eval=y,
34
+ eval_every_n_trees=10,
35
+ early_stopping_rounds=1,
36
+ eval_metric="corr",
37
+ )
38
+ fit_time = time.time() - start_fit
39
+ print(f" Fit time: {fit_time:.3f} seconds")
40
+
41
+ start_pred = time.time()
42
+ preds = model.predict(X)
43
+ pred_time = time.time() - start_pred
44
+ print(f" Predict time: {pred_time:.3f} seconds")
45
+
46
+ corr = np.corrcoef(preds, y)[0, 1]
47
+ mse = mean_squared_error(preds, y)
48
+ print(f" Correlation: {corr:.4f}")
49
+ print(f" MSE: {mse:.4f}")
50
+
51
+ assert (corr > 0.9), f"In-sample correlation too low: {corrs}"
52
+ assert (mse < 2), f"In-sample mse too high: {mses}"
@@ -0,0 +1,100 @@
1
+ import numpy as np
2
+ from warpgbm import WarpGBM
3
+ import time
4
+
5
+ import os
6
+ import requests
7
+
8
+ def download_file_if_missing(url, local_dir):
9
+ filename = os.path.basename(url)
10
+ local_path = os.path.join(local_dir, filename)
11
+
12
+ if os.path.exists(local_path):
13
+ print(f"✅ {filename} already exists, skipping download.")
14
+ return
15
+
16
+ # Convert GitHub blob URL to raw URL
17
+ raw_url = url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/")
18
+
19
+ print(f"⬇️ Downloading {filename}...")
20
+ response = requests.get(raw_url)
21
+ response.raise_for_status()
22
+
23
+ os.makedirs(local_dir, exist_ok=True)
24
+ with open(local_path, "wb") as f:
25
+ f.write(response.content)
26
+ print(f"✅ Saved to {local_path}")
27
+
28
+ # === Usage ===
29
+
30
+ urls = [
31
+ "https://github.com/jefferythewind/era-splitting-notebook-examples/blob/main/Synthetic%20Memorization%20Data%20Set/X_train.npy",
32
+ "https://github.com/jefferythewind/era-splitting-notebook-examples/blob/main/Synthetic%20Memorization%20Data%20Set/y_train.npy",
33
+ "https://github.com/jefferythewind/era-splitting-notebook-examples/blob/main/Synthetic%20Memorization%20Data%20Set/X_test.npy",
34
+ "https://github.com/jefferythewind/era-splitting-notebook-examples/blob/main/Synthetic%20Memorization%20Data%20Set/y_test.npy",
35
+ "https://github.com/jefferythewind/era-splitting-notebook-examples/blob/main/Synthetic%20Memorization%20Data%20Set/X_eras.npy",
36
+ ]
37
+
38
+
39
+ local_folder = "./synthetic_data"
40
+
41
+ for url in urls:
42
+ download_file_if_missing(url, local_folder)
43
+
44
+ def test_fit_predictpytee_correlation():
45
+ import numpy as np
46
+ import os
47
+ from warpgbm import WarpGBM
48
+ from sklearn.metrics import mean_squared_error
49
+ import time
50
+
51
+ # Load the real dataset from local .npy files
52
+ data_dir = "./synthetic_data"
53
+ X = np.load(os.path.join(data_dir, "X_train.npy"))
54
+ y = np.load(os.path.join(data_dir, "y_train.npy"))
55
+ # era = np.zeros(X.shape[0], dtype=np.int32) # one era for default GBDT equivalence
56
+ era = np.load(os.path.join(data_dir, "X_eras.npy"))
57
+
58
+ X_test = np.load(os.path.join(data_dir, "X_test.npy"))
59
+ y_test = np.load(os.path.join(data_dir, "y_test.npy"))
60
+
61
+ print(f"X shape: {X.shape}, y shape: {y.shape}")
62
+
63
+ model = WarpGBM(
64
+ max_depth=10,
65
+ num_bins=127,
66
+ n_estimators=50,
67
+ learning_rate=1,
68
+ threads_per_block=128,
69
+ rows_per_thread=4,
70
+ colsample_bytree=0.9,
71
+ min_child_weight=4
72
+ )
73
+
74
+ start_fit = time.time()
75
+ model.fit(
76
+ X,
77
+ y,
78
+ era_id=era,
79
+ X_eval=X_test,
80
+ y_eval=y_test,
81
+ eval_every_n_trees=10,
82
+ early_stopping_rounds=100,
83
+ eval_metric="corr",
84
+ )
85
+ fit_time = time.time() - start_fit
86
+ print(f" Fit time: {fit_time:.3f} seconds")
87
+
88
+ start_pred = time.time()
89
+ preds = model.predict(X_test)
90
+ pred_time = time.time() - start_pred
91
+ print(f" Predict time: {pred_time:.3f} seconds")
92
+
93
+ corr = np.corrcoef(preds, y_test)[0, 1]
94
+ mse = mean_squared_error(preds, y_test)
95
+ print(f" Correlation: {corr:.4f}")
96
+ print(f" MSE: {mse:.4f}")
97
+
98
+ assert corr > 0.95, f"Out-of-sample correlation too low: {corr:.4f}"
99
+ assert mse < 0.02, f"Out-of-sample MSE too high: {mse:.4f}"
100
+
@@ -0,0 +1 @@
1
+ 1.0.0