warpgbm 1.0.0__tar.gz → 2.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {warpgbm-1.0.0/warpgbm.egg-info → warpgbm-2.0.0}/PKG-INFO +319 -210
- warpgbm-2.0.0/README.md +424 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/pyproject.toml +1 -1
- warpgbm-2.0.0/tests/test_multiclass.py +332 -0
- warpgbm-2.0.0/version.txt +1 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/core.py +356 -21
- warpgbm-2.0.0/warpgbm/metrics.py +37 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0/warpgbm.egg-info}/PKG-INFO +319 -210
- {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm.egg-info/SOURCES.txt +1 -0
- warpgbm-1.0.0/README.md +0 -315
- warpgbm-1.0.0/version.txt +0 -1
- warpgbm-1.0.0/warpgbm/metrics.py +0 -10
- {warpgbm-1.0.0 → warpgbm-2.0.0}/LICENSE +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/MANIFEST.in +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/setup.cfg +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/setup.py +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/tests/__init__.py +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/tests/full_numerai_test.py +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/tests/numerai_test.py +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/tests/test_fit_predict_corr.py +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/tests/test_invariant.py +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/__init__.py +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/cuda/__init__.py +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/cuda/best_split_kernel.cu +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/cuda/binner.cu +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/cuda/histogram_kernel.cu +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/cuda/node_kernel.cpp +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/cuda/predict.cu +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm.egg-info/dependency_links.txt +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm.egg-info/requires.txt +0 -0
- {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: warpgbm
|
3
|
-
Version:
|
3
|
+
Version: 2.0.0
|
4
4
|
Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
|
5
5
|
License: GNU GENERAL PUBLIC LICENSE
|
6
6
|
Version 3, 29 June 2007
|
@@ -686,318 +686,427 @@ Requires-Dist: tqdm
|
|
686
686
|
Requires-Dist: scikit-learn
|
687
687
|
Dynamic: license-file
|
688
688
|
|
689
|
-

|
690
690
|
|
691
|
-
# WarpGBM
|
691
|
+
# WarpGBM ⚡
|
692
692
|
|
693
|
-
|
693
|
+
> **Neural-speed gradient boosting. GPU-native. Distribution-aware. Production-ready.**
|
694
694
|
|
695
|
-
|
695
|
+
WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (GBDT) library engineered from silicon up with PyTorch and custom CUDA kernels. Built for speed demons and researchers who refuse to compromise.
|
696
696
|
|
697
|
-
|
698
|
-
---
|
697
|
+
## 🎯 What Sets WarpGBM Apart
|
699
698
|
|
700
|
-
|
699
|
+
**Regression + Classification Unified**
|
700
|
+
Train on continuous targets or multiclass labels with the same blazing-fast infrastructure.
|
701
701
|
|
702
|
-
|
703
|
-
- [
|
704
|
-
- [Installation](#installation)
|
705
|
-
- [Learning Invariant Signals Across Environments](#learning-invariant-signals-across-environments)
|
706
|
-
- [Why This Matters](#why-this-matters)
|
707
|
-
- [Visual Intuition](#visual-intuition)
|
708
|
-
- [Key References](#key-references)
|
709
|
-
- [Examples](#examples)
|
710
|
-
- [Quick Comparison with LightGBM CPU version](#quick-comparison-with-lightgbm-cpu-version)
|
711
|
-
- [Pre-binned Data Example (Numerai)](#pre-binned-data-example-numerai)
|
712
|
-
- [Documentation](#documentation)
|
713
|
-
- [Acknowledgements](#acknowledgements)
|
714
|
-
- [Version Notes](#version-notes)
|
702
|
+
**Invariant Learning (DES Algorithm)**
|
703
|
+
The only open-source GBDT that natively learns signals stable across shifting distributions. Powered by **[Directional Era-Splitting](https://arxiv.org/abs/2309.14496)** — because your data doesn't live in a vacuum.
|
715
704
|
|
705
|
+
**GPU-Accelerated Everything**
|
706
|
+
Custom CUDA kernels for binning, histograms, splits, and inference. No compromises, no CPU bottlenecks.
|
716
707
|
|
717
|
-
|
708
|
+
**Scikit-Learn Compatible**
|
709
|
+
Drop-in replacement. Same API you know, 10x the speed you need.
|
718
710
|
|
719
|
-
|
720
|
-
- **Invariant signal learning** via [Directional Era-Splitting (DES)](https://arxiv.org/abs/2309.14496) — designed for datasets with shifting environments (e.g., time, regimes, experimental settings)
|
721
|
-
- Drop-in **scikit-learn style interface** for easy adoption
|
722
|
-
- Supports **pre-binned data** or **automatic quantile binning**
|
723
|
-
- Works with `float32` or `int8` inputs
|
724
|
-
- Built-in **validation and early stopping** support with MSE, RMSLE, or correlation metrics
|
725
|
-
- Simple install with `pip`, no custom drivers required
|
711
|
+
---
|
726
712
|
|
727
|
-
|
728
|
-
> To run standard GBM training at maximum speed, simply omit the `era_id` argument — WarpGBM will behave like a traditional booster but with industry-leading performance.
|
713
|
+
## 🚀 Quick Start
|
729
714
|
|
730
|
-
|
715
|
+
### Installation
|
731
716
|
|
732
|
-
|
717
|
+
```bash
|
718
|
+
# Latest from GitHub (recommended)
|
719
|
+
pip install git+https://github.com/jefferythewind/warpgbm.git
|
733
720
|
|
734
|
-
|
721
|
+
# Stable from PyPI
|
722
|
+
pip install warpgbm
|
723
|
+
```
|
735
724
|
|
736
|
-
|
725
|
+
**Prerequisites:** PyTorch with CUDA support ([install guide](https://pytorch.org/get-started/locally/))
|
737
726
|
|
727
|
+
### Regression in 5 Lines
|
728
|
+
|
729
|
+
```python
|
730
|
+
from warpgbm import WarpGBM
|
731
|
+
import numpy as np
|
732
|
+
|
733
|
+
model = WarpGBM(objective='regression', max_depth=5, n_estimators=100)
|
734
|
+
model.fit(X_train, y_train)
|
735
|
+
predictions = model.predict(X_test)
|
738
736
|
```
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
737
|
+
|
738
|
+
### Classification in 5 Lines
|
739
|
+
|
740
|
+
```python
|
741
|
+
from warpgbm import WarpGBM
|
742
|
+
|
743
|
+
model = WarpGBM(objective='multiclass', max_depth=5, n_estimators=50)
|
744
|
+
model.fit(X_train, y_train) # y can be integers, strings, whatever
|
745
|
+
probabilities = model.predict_proba(X_test)
|
746
|
+
labels = model.predict(X_test)
|
743
747
|
```
|
744
748
|
|
745
|
-
|
749
|
+
---
|
750
|
+
|
751
|
+
## 🎮 Features
|
752
|
+
|
753
|
+
### Core Engine
|
754
|
+
- ⚡ **GPU-native CUDA kernels** for histogram building, split finding, binning, and prediction
|
755
|
+
- 🎯 **Multi-objective support**: regression, binary, multiclass classification
|
756
|
+
- 📊 **Pre-binned data optimization** — skip binning if your data's already quantized
|
757
|
+
- 🔥 **Mixed precision support** — `float32` or `int8` inputs
|
758
|
+
- 🎲 **Stochastic features** — `colsample_bytree` for regularization
|
759
|
+
|
760
|
+
### Intelligence
|
761
|
+
- 🧠 **Invariant learning via DES** — identifies signals that generalize across time/regimes/environments
|
762
|
+
- 📈 **Smart initialization** — class priors for classification, mean for regression
|
763
|
+
- 🎯 **Automatic label encoding** — handles strings, integers, whatever you throw at it
|
764
|
+
|
765
|
+
### Training Utilities
|
766
|
+
- ✅ **Early stopping** with validation sets
|
767
|
+
- 📊 **Rich metrics**: MSE, RMSLE, correlation, log loss, accuracy
|
768
|
+
- 🔍 **Progress tracking** with loss curves
|
769
|
+
- 🎚️ **Regularization** — L2 leaf penalties, min split gain, min child weight
|
746
770
|
|
747
771
|
---
|
748
772
|
|
749
|
-
##
|
773
|
+
## ⚔️ Benchmarks
|
750
774
|
|
751
|
-
###
|
775
|
+
### Synthetic Data: 1M Rows × 1K Features (Google Colab L4 GPU)
|
752
776
|
|
753
|
-
```
|
754
|
-
|
777
|
+
```
|
778
|
+
WarpGBM: corr = 0.8882, train = 17.4s, infer = 3.2s ⚡
|
779
|
+
XGBoost: corr = 0.8877, train = 33.2s, infer = 8.0s
|
780
|
+
LightGBM: corr = 0.8604, train = 29.8s, infer = 1.6s
|
781
|
+
CatBoost: corr = 0.8935, train = 392.1s, infer = 379.2s
|
755
782
|
```
|
756
783
|
|
757
|
-
|
784
|
+
**2× faster than XGBoost. 23× faster than CatBoost.**
|
758
785
|
|
759
|
-
|
786
|
+
[→ Run the benchmark yourself](https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK?usp=sharing)
|
760
787
|
|
761
|
-
|
762
|
-
|
788
|
+
### Multiclass Classification: 3.5K Samples, 3 Classes, 50 Rounds
|
789
|
+
|
790
|
+
```
|
791
|
+
Training: 2.13s
|
792
|
+
Inference: 0.37s
|
793
|
+
Accuracy: 75.3%
|
763
794
|
```
|
764
795
|
|
765
|
-
|
796
|
+
**Production-ready multiclass at neural network speeds.**
|
766
797
|
|
767
|
-
|
768
|
-
> If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag. This is currently required in the Colab environments.
|
769
|
-
>
|
770
|
-
> ```bash
|
771
|
-
> pip install warpgbm --no-build-isolation
|
772
|
-
> ```
|
798
|
+
---
|
773
799
|
|
774
|
-
|
800
|
+
## 📖 Examples
|
775
801
|
|
776
|
-
|
802
|
+
### Regression: Beat LightGBM on Your Laptop
|
777
803
|
|
778
|
-
```
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
804
|
+
```python
|
805
|
+
import numpy as np
|
806
|
+
from sklearn.datasets import make_regression
|
807
|
+
from warpgbm import WarpGBM
|
808
|
+
|
809
|
+
# Generate data
|
810
|
+
X, y = make_regression(n_samples=100_000, n_features=500, random_state=42)
|
811
|
+
X, y = X.astype(np.float32), y.astype(np.float32)
|
812
|
+
|
813
|
+
# Train
|
814
|
+
model = WarpGBM(
|
815
|
+
objective='regression',
|
816
|
+
max_depth=5,
|
817
|
+
n_estimators=100,
|
818
|
+
learning_rate=0.01,
|
819
|
+
num_bins=32
|
820
|
+
)
|
821
|
+
model.fit(X, y)
|
822
|
+
|
823
|
+
# Predict
|
824
|
+
preds = model.predict(X)
|
825
|
+
print(f"Correlation: {np.corrcoef(preds, y)[0,1]:.4f}")
|
783
826
|
```
|
784
827
|
|
785
|
-
|
786
|
-
[https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
828
|
+
### Classification: Multiclass with Early Stopping
|
787
829
|
|
788
|
-
|
830
|
+
```python
|
831
|
+
from sklearn.datasets import make_classification
|
832
|
+
from sklearn.model_selection import train_test_split
|
833
|
+
from warpgbm import WarpGBM
|
789
834
|
|
790
|
-
|
835
|
+
# 5-class problem
|
836
|
+
X, y = make_classification(
|
837
|
+
n_samples=10_000,
|
838
|
+
n_features=50,
|
839
|
+
n_classes=5,
|
840
|
+
n_informative=30
|
841
|
+
)
|
791
842
|
|
792
|
-
|
843
|
+
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
|
793
844
|
|
794
|
-
|
845
|
+
model = WarpGBM(
|
846
|
+
objective='multiclass',
|
847
|
+
max_depth=6,
|
848
|
+
n_estimators=200,
|
849
|
+
learning_rate=0.1,
|
850
|
+
num_bins=32
|
851
|
+
)
|
795
852
|
|
796
|
-
|
853
|
+
model.fit(
|
854
|
+
X_train, y_train,
|
855
|
+
X_eval=X_val, y_eval=y_val,
|
856
|
+
eval_every_n_trees=10,
|
857
|
+
early_stopping_rounds=5,
|
858
|
+
eval_metric='logloss'
|
859
|
+
)
|
797
860
|
|
798
|
-
|
861
|
+
# Get probabilities or class predictions
|
862
|
+
probs = model.predict_proba(X_val) # shape: (n_samples, n_classes)
|
863
|
+
labels = model.predict(X_val) # class labels
|
864
|
+
```
|
799
865
|
|
800
|
-
|
866
|
+
### Invariant Learning: Distribution-Robust Signals
|
801
867
|
|
802
|
-
|
868
|
+
```python
|
869
|
+
# Your data spans multiple time periods/regimes/environments
|
870
|
+
# Pass era_id to learn only signals that work across ALL eras
|
803
871
|
|
804
|
-
|
872
|
+
model = WarpGBM(
|
873
|
+
objective='regression',
|
874
|
+
max_depth=8,
|
875
|
+
n_estimators=100
|
876
|
+
)
|
805
877
|
|
806
|
-
|
807
|
-
|
808
|
-
|
878
|
+
model.fit(
|
879
|
+
X, y,
|
880
|
+
era_id=era_labels # Array marking which era each sample belongs to
|
881
|
+
)
|
809
882
|
|
810
|
-
|
883
|
+
# Now your model ignores spurious correlations that don't generalize!
|
884
|
+
```
|
811
885
|
|
812
|
-
###
|
886
|
+
### Pre-binned Data: Maximum Speed (Numerai Example)
|
813
887
|
|
814
|
-
|
888
|
+
```python
|
889
|
+
import pandas as pd
|
890
|
+
from numerapi import NumerAPI
|
891
|
+
from warpgbm import WarpGBM
|
815
892
|
|
816
|
-
|
817
|
-
|
893
|
+
# Download Numerai data (already quantized to integers)
|
894
|
+
napi = NumerAPI()
|
895
|
+
napi.download_dataset('v5.0/train.parquet', 'train.parquet')
|
896
|
+
train = pd.read_parquet('train.parquet')
|
818
897
|
|
819
|
-
|
820
|
-
|
898
|
+
features = [f for f in train.columns if 'feature' in f]
|
899
|
+
X = train[features].astype('int8').values
|
900
|
+
y = train['target'].values
|
821
901
|
|
822
|
-
|
902
|
+
# WarpGBM detects pre-binned data and skips binning
|
903
|
+
model = WarpGBM(max_depth=5, n_estimators=100, num_bins=20)
|
904
|
+
model.fit(X, y) # Blazing fast!
|
905
|
+
```
|
823
906
|
|
907
|
+
**Result: 13× faster than LightGBM on Numerai data (49s vs 643s)**
|
824
908
|
|
825
909
|
---
|
826
910
|
|
827
|
-
|
911
|
+
## 🧠 Invariant Learning: Why It Matters
|
828
912
|
|
829
|
-
|
830
|
-
- **Learning Explanations That Are Hard to Vary**: [Parascandolo et al., 2020](https://arxiv.org/abs/2009.00329)
|
831
|
-
- **Era Splitting: Invariant Learning for Decision Trees**: [DeLise, 2023](https://arxiv.org/abs/2309.14496)
|
913
|
+
Most ML models assume your training and test data come from the same distribution. **Reality check: they don't.**
|
832
914
|
|
833
|
-
|
915
|
+
- Stock prices shift with market regimes
|
916
|
+
- User behavior changes over time
|
917
|
+
- Experimental data varies by batch/site/condition
|
834
918
|
|
835
|
-
|
919
|
+
**Traditional GBDT:** Learns any signal that correlates with the target, including fragile patterns that break OOD.
|
836
920
|
|
837
|
-
|
921
|
+
**WarpGBM with DES:** Explicitly tests if each split generalizes across ALL environments (eras). Only keeps robust signals.
|
838
922
|
|
839
|
-
|
923
|
+
### The Algorithm
|
840
924
|
|
841
|
-
|
925
|
+
For each potential split, compute gain separately in each era. Only accept splits where:
|
926
|
+
1. Gain is positive in ALL eras
|
927
|
+
2. Split direction is consistent across eras
|
842
928
|
|
843
|
-
|
929
|
+
This prevents overfitting to spurious correlations that only work in some time periods or environments.
|
844
930
|
|
845
|
-
###
|
931
|
+
### Visual Intuition
|
846
932
|
|
847
|
-
|
848
|
-
import numpy as np
|
849
|
-
from sklearn.datasets import make_regression
|
850
|
-
from time import time
|
851
|
-
import lightgbm as lgb
|
852
|
-
from warpgbm import WarpGBM
|
933
|
+
<img src="https://github.com/user-attachments/assets/2be11ef3-6f2e-4636-ab91-307a73add247" alt="Era Splitting Visualization" width="400"/>
|
853
934
|
|
854
|
-
|
855
|
-
|
856
|
-
X = X.astype(np.float32)
|
857
|
-
y = y.astype(np.float32)
|
858
|
-
|
859
|
-
# Train LightGBM
|
860
|
-
start = time()
|
861
|
-
lgb_model = lgb.LGBMRegressor(max_depth=5, n_estimators=100, learning_rate=0.01, max_bin=7)
|
862
|
-
lgb_model.fit(X, y)
|
863
|
-
lgb_time = time() - start
|
864
|
-
lgb_preds = lgb_model.predict(X)
|
865
|
-
|
866
|
-
# Train WarpGBM
|
867
|
-
start = time()
|
868
|
-
wgbm_model = WarpGBM(max_depth=5, n_estimators=100, learning_rate=0.01, num_bins=7)
|
869
|
-
wgbm_model.fit(X, y)
|
870
|
-
wgbm_time = time() - start
|
871
|
-
wgbm_preds = wgbm_model.predict(X)
|
872
|
-
|
873
|
-
# Results
|
874
|
-
print(f"LightGBM: corr = {np.corrcoef(lgb_preds, y)[0,1]:.4f}, time = {lgb_time:.2f}s")
|
875
|
-
print(f"WarpGBM: corr = {np.corrcoef(wgbm_preds, y)[0,1]:.4f}, time = {wgbm_time:.2f}s")
|
876
|
-
```
|
935
|
+
**Left:** Standard training pools all data together — learns any signal that correlates.
|
936
|
+
**Right:** Era-aware training demands signals work across all periods — learns robust features only.
|
877
937
|
|
878
|
-
|
938
|
+
### Research Foundation
|
879
939
|
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
```
|
940
|
+
- **Invariant Risk Minimization**: [Arjovsky et al., 2019](https://arxiv.org/abs/1907.02893)
|
941
|
+
- **Hard-to-Vary Explanations**: [Parascandolo et al., 2020](https://arxiv.org/abs/2009.00329)
|
942
|
+
- **Era Splitting for Trees**: [DeLise, 2023](https://arxiv.org/abs/2309.14496)
|
884
943
|
|
885
944
|
---
|
886
945
|
|
887
|
-
|
946
|
+
## 📚 API Reference
|
888
947
|
|
889
|
-
|
948
|
+
### Constructor Parameters
|
890
949
|
|
891
950
|
```python
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
951
|
+
WarpGBM(
|
952
|
+
objective='regression', # 'regression', 'binary', or 'multiclass'
|
953
|
+
num_bins=10, # Histogram bins for feature quantization
|
954
|
+
max_depth=3, # Maximum tree depth
|
955
|
+
learning_rate=0.1, # Shrinkage rate (aka eta)
|
956
|
+
n_estimators=100, # Number of boosting rounds
|
957
|
+
min_child_weight=20, # Min sum of instance weights in child node
|
958
|
+
min_split_gain=0.0, # Min loss reduction to split
|
959
|
+
L2_reg=1e-6, # L2 leaf regularization
|
960
|
+
colsample_bytree=1.0, # Feature subsample ratio per tree
|
961
|
+
threads_per_block=64, # CUDA block size (tune for your GPU)
|
962
|
+
rows_per_thread=4, # Rows processed per thread
|
963
|
+
device='cuda' # 'cuda' or 'cpu' (GPU strongly recommended)
|
964
|
+
)
|
965
|
+
```
|
898
966
|
|
899
|
-
|
900
|
-
napi.download_dataset('v5.0/train.parquet', 'train.parquet')
|
901
|
-
train = pd.read_parquet('train.parquet')
|
967
|
+
### Training Methods
|
902
968
|
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
#
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
lgb_preds = lgb_model.predict(X_np)
|
915
|
-
|
916
|
-
# WarpGBM
|
917
|
-
start = time()
|
918
|
-
wgbm_model = WarpGBM(max_depth=5, n_estimators=100, learning_rate=0.01, num_bins=7)
|
919
|
-
wgbm_model.fit(X_np, Y_np)
|
920
|
-
wgbm_time = time() - start
|
921
|
-
wgbm_preds = wgbm_model.predict(X_np)
|
922
|
-
|
923
|
-
# Results
|
924
|
-
print(f"LightGBM: corr = {np.corrcoef(lgb_preds, Y_np)[0,1]:.4f}, time = {lgb_time:.2f}s")
|
925
|
-
print(f"WarpGBM: corr = {np.corrcoef(wgbm_preds, Y_np)[0,1]:.4f}, time = {wgbm_time:.2f}s")
|
969
|
+
```python
|
970
|
+
model.fit(
|
971
|
+
X, # Features: np.array shape (n_samples, n_features)
|
972
|
+
y, # Target: np.array shape (n_samples,)
|
973
|
+
era_id=None, # Optional: era labels for invariant learning
|
974
|
+
X_eval=None, # Optional: validation features
|
975
|
+
y_eval=None, # Optional: validation targets
|
976
|
+
eval_every_n_trees=None, # Eval frequency (in rounds)
|
977
|
+
early_stopping_rounds=None, # Stop if no improvement for N evals
|
978
|
+
eval_metric='mse' # 'mse', 'rmsle', 'corr', 'logloss', 'accuracy'
|
979
|
+
)
|
926
980
|
```
|
927
981
|
|
928
|
-
|
982
|
+
### Prediction Methods
|
983
|
+
|
984
|
+
```python
|
985
|
+
# Regression: returns predicted values
|
986
|
+
predictions = model.predict(X)
|
987
|
+
|
988
|
+
# Classification: returns class labels (decoded)
|
989
|
+
labels = model.predict(X)
|
929
990
|
|
991
|
+
# Classification: returns class probabilities
|
992
|
+
probabilities = model.predict_proba(X) # shape: (n_samples, n_classes)
|
930
993
|
```
|
931
|
-
|
932
|
-
|
994
|
+
|
995
|
+
### Attributes
|
996
|
+
|
997
|
+
```python
|
998
|
+
model.classes_ # Unique class labels (classification only)
|
999
|
+
model.num_classes # Number of classes (classification only)
|
1000
|
+
model.forest # Trained tree structures
|
1001
|
+
model.training_loss # Training loss history
|
1002
|
+
model.eval_loss # Validation loss history (if eval set provided)
|
933
1003
|
```
|
934
1004
|
|
935
1005
|
---
|
936
1006
|
|
937
|
-
##
|
938
|
-
|
939
|
-
###
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
- `n_estimators`: Number of boosting iterations (default: 100)
|
944
|
-
- `min_child_weight`: Minimum sum of instance weight needed in a child (default: 20)
|
945
|
-
- `min_split_gain`: Minimum loss reduction required to make a further partition (default: 0.0)
|
946
|
-
- `histogram_computer`: Choice of histogram kernel (`'hist1'`, `'hist2'`, `'hist3'`) (default: `'hist3'`)
|
947
|
-
- `threads_per_block`: CUDA threads per block (default: 32)
|
948
|
-
- `rows_per_thread`: Number of training rows processed per thread (default: 4)
|
949
|
-
- `L2_reg`: L2 regularizer (default: 1e-6)
|
950
|
-
- `colsample_bytree`: Proportion of features to subsample to grow each tree (default: 1)
|
951
|
-
|
952
|
-
### Methods:
|
953
|
-
```
|
954
|
-
.fit(
|
955
|
-
X, # numpy array (float or int) 2 dimensions (num_samples, num_features)
|
956
|
-
y, # numpy array (float or int) 1 dimension (num_samples)
|
957
|
-
era_id=None, # numpy array (int) 1 dimension (num_samples)
|
958
|
-
X_eval=None, # numpy array (float or int) 2 dimensions (eval_num_samples, num_features)
|
959
|
-
y_eval=None, # numpy array (float or int) 1 dimension (eval_num_samples)
|
960
|
-
eval_every_n_trees=None, # const (int) >= 1
|
961
|
-
early_stopping_rounds=None, # const (int) >= 1
|
962
|
-
eval_metric='mse' # string, one of 'mse', 'rmsle' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
|
963
|
-
)
|
1007
|
+
## 🔧 Installation Details
|
1008
|
+
|
1009
|
+
### Linux / macOS (Recommended)
|
1010
|
+
|
1011
|
+
```bash
|
1012
|
+
pip install git+https://github.com/jefferythewind/warpgbm.git
|
964
1013
|
```
|
965
|
-
Train with optional validation set and early stopping.
|
966
1014
|
|
1015
|
+
Compiles CUDA extensions using your local PyTorch + CUDA setup.
|
967
1016
|
|
1017
|
+
### Colab / Mismatched CUDA Versions
|
1018
|
+
|
1019
|
+
```bash
|
1020
|
+
pip install warpgbm --no-build-isolation
|
968
1021
|
```
|
969
|
-
|
970
|
-
|
971
|
-
|
1022
|
+
|
1023
|
+
### Windows
|
1024
|
+
|
1025
|
+
```bash
|
1026
|
+
git clone https://github.com/jefferythewind/warpgbm.git
|
1027
|
+
cd warpgbm
|
1028
|
+
python setup.py bdist_wheel
|
1029
|
+
pip install dist/warpgbm-*.whl
|
972
1030
|
```
|
973
|
-
Predict on new data, using parallelized CUDA kernel.
|
974
1031
|
|
975
1032
|
---
|
976
1033
|
|
977
|
-
##
|
1034
|
+
## 🎯 Use Cases
|
978
1035
|
|
979
|
-
|
1036
|
+
**Financial ML:** Learn signals that work across market regimes
|
1037
|
+
**Time Series:** Robust forecasting across distribution shifts
|
1038
|
+
**Scientific Research:** Models that generalize across experimental batches
|
1039
|
+
**High-Speed Inference:** Production systems with millisecond SLAs
|
1040
|
+
**Kaggle/Competitions:** GPU-accelerated hyperparameter tuning
|
1041
|
+
**Multiclass Problems:** Image classification fallback, text categorization, fraud detection
|
980
1042
|
|
981
1043
|
---
|
982
1044
|
|
983
|
-
##
|
1045
|
+
## 🚧 Roadmap
|
984
1046
|
|
985
|
-
|
1047
|
+
- [ ] Multi-GPU training support
|
1048
|
+
- [ ] SHAP value computation on GPU
|
1049
|
+
- [ ] Feature interaction constraints
|
1050
|
+
- [ ] Monotonic constraints
|
1051
|
+
- [ ] Custom loss functions
|
1052
|
+
- [ ] Distributed training
|
1053
|
+
- [ ] ONNX export for deployment
|
986
1054
|
|
987
|
-
|
1055
|
+
---
|
988
1056
|
|
989
|
-
|
1057
|
+
## 🙏 Acknowledgements
|
990
1058
|
|
991
|
-
|
1059
|
+
Built on the shoulders of PyTorch, scikit-learn, LightGBM, XGBoost, and the CUDA ecosystem. Special thanks to the GBDT research community and all contributors.
|
992
1060
|
|
993
|
-
|
1061
|
+
---
|
1062
|
+
|
1063
|
+
## 📝 Version History
|
994
1064
|
|
995
|
-
|
1065
|
+
### v2.0.0 (Current)
|
1066
|
+
- ✨ **Multiclass classification support** via softmax objective
|
1067
|
+
- 🎯 Binary classification mode
|
1068
|
+
- 📊 New metrics: log loss, accuracy
|
1069
|
+
- 🏷️ Automatic label encoding (supports strings)
|
1070
|
+
- 🔮 `predict_proba()` for probability outputs
|
1071
|
+
- ✅ Comprehensive test suite for classification
|
1072
|
+
- 🔒 Full backward compatibility with regression
|
1073
|
+
- 🐛 Fixed unused variable issue (#8)
|
1074
|
+
- 🧹 Removed unimplemented L1_reg parameter
|
1075
|
+
- 📚 Major documentation overhaul with AGENT_GUIDE.md
|
1076
|
+
|
1077
|
+
### v1.0.0
|
1078
|
+
- 🧠 Invariant learning via Directional Era-Splitting (DES)
|
1079
|
+
- 🚀 VRAM optimizations
|
1080
|
+
- 📈 Era-aware histogram computation
|
996
1081
|
|
997
1082
|
### v0.1.26
|
1083
|
+
- 🐛 Memory bug fixes in prediction
|
1084
|
+
- 📊 Added correlation eval metric
|
998
1085
|
|
999
|
-
|
1086
|
+
### v0.1.25
|
1087
|
+
- 🎲 Feature subsampling (`colsample_bytree`)
|
1000
1088
|
|
1001
|
-
###
|
1089
|
+
### v0.1.23
|
1090
|
+
- ⏹️ Early stopping support
|
1091
|
+
- ✅ Validation set evaluation
|
1092
|
+
|
1093
|
+
### v0.1.21
|
1094
|
+
- ⚡ CUDA prediction kernel (replaced vectorized Python)
|
1095
|
+
|
1096
|
+
---
|
1097
|
+
|
1098
|
+
## 📄 License
|
1099
|
+
|
1100
|
+
MIT License - see [LICENSE](LICENSE) file
|
1101
|
+
|
1102
|
+
---
|
1103
|
+
|
1104
|
+
## 🤝 Contributing
|
1105
|
+
|
1106
|
+
Pull requests welcome! See [AGENT_GUIDE.md](AGENT_GUIDE.md) for architecture details and development guidelines.
|
1107
|
+
|
1108
|
+
---
|
1109
|
+
|
1110
|
+
**Built with 🔥 by @jefferythewind**
|
1002
1111
|
|
1003
|
-
|
1112
|
+
*"Train smarter. Predict faster. Generalize better."*
|