warpgbm 1.0.0__tar.gz → 2.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {warpgbm-1.0.0/warpgbm.egg-info → warpgbm-2.0.0}/PKG-INFO +319 -210
  2. warpgbm-2.0.0/README.md +424 -0
  3. {warpgbm-1.0.0 → warpgbm-2.0.0}/pyproject.toml +1 -1
  4. warpgbm-2.0.0/tests/test_multiclass.py +332 -0
  5. warpgbm-2.0.0/version.txt +1 -0
  6. {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/core.py +356 -21
  7. warpgbm-2.0.0/warpgbm/metrics.py +37 -0
  8. {warpgbm-1.0.0 → warpgbm-2.0.0/warpgbm.egg-info}/PKG-INFO +319 -210
  9. {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm.egg-info/SOURCES.txt +1 -0
  10. warpgbm-1.0.0/README.md +0 -315
  11. warpgbm-1.0.0/version.txt +0 -1
  12. warpgbm-1.0.0/warpgbm/metrics.py +0 -10
  13. {warpgbm-1.0.0 → warpgbm-2.0.0}/LICENSE +0 -0
  14. {warpgbm-1.0.0 → warpgbm-2.0.0}/MANIFEST.in +0 -0
  15. {warpgbm-1.0.0 → warpgbm-2.0.0}/setup.cfg +0 -0
  16. {warpgbm-1.0.0 → warpgbm-2.0.0}/setup.py +0 -0
  17. {warpgbm-1.0.0 → warpgbm-2.0.0}/tests/__init__.py +0 -0
  18. {warpgbm-1.0.0 → warpgbm-2.0.0}/tests/full_numerai_test.py +0 -0
  19. {warpgbm-1.0.0 → warpgbm-2.0.0}/tests/numerai_test.py +0 -0
  20. {warpgbm-1.0.0 → warpgbm-2.0.0}/tests/test_fit_predict_corr.py +0 -0
  21. {warpgbm-1.0.0 → warpgbm-2.0.0}/tests/test_invariant.py +0 -0
  22. {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/__init__.py +0 -0
  23. {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/cuda/__init__.py +0 -0
  24. {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/cuda/best_split_kernel.cu +0 -0
  25. {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/cuda/binner.cu +0 -0
  26. {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/cuda/histogram_kernel.cu +0 -0
  27. {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/cuda/node_kernel.cpp +0 -0
  28. {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm/cuda/predict.cu +0 -0
  29. {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm.egg-info/dependency_links.txt +0 -0
  30. {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm.egg-info/requires.txt +0 -0
  31. {warpgbm-1.0.0 → warpgbm-2.0.0}/warpgbm.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: warpgbm
3
- Version: 1.0.0
3
+ Version: 2.0.0
4
4
  Summary: A fast GPU-accelerated Gradient Boosted Decision Tree library with PyTorch + CUDA
5
5
  License: GNU GENERAL PUBLIC LICENSE
6
6
  Version 3, 29 June 2007
@@ -686,318 +686,427 @@ Requires-Dist: tqdm
686
686
  Requires-Dist: scikit-learn
687
687
  Dynamic: license-file
688
688
 
689
- ![raw](https://github.com/user-attachments/assets/924844ef-2536-4bde-a330-5e30f6b0762c)
689
+ ![warpgbm](https://github.com/user-attachments/assets/dee9de16-091b-49c1-a8fa-2b4ab6891184)
690
690
 
691
- # WarpGBM
691
+ # WarpGBM
692
692
 
693
- WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (GBDT) library built with PyTorch and CUDA. It offers blazing-fast histogram-based training and efficient prediction, with compatibility for research and production workflows.
693
+ > **Neural-speed gradient boosting. GPU-native. Distribution-aware. Production-ready.**
694
694
 
695
- **New in v1.0.0:** WarpGBM introduces *Invariant Gradient Boosting* — a powerful approach to learning signals that remain stable across shifting environments (e.g., time, regimes, or datasets). Powered by a novel algorithm called **[Directional Era-Splitting (DES)](https://arxiv.org/abs/2309.14496)**, WarpGBM doesn't just train faster than other leading GBDT libraries — it trains smarter.
695
+ WarpGBM is a high-performance, GPU-accelerated Gradient Boosted Decision Tree (GBDT) library engineered from silicon up with PyTorch and custom CUDA kernels. Built for speed demons and researchers who refuse to compromise.
696
696
 
697
- If your data evolves over time, WarpGBM is the only GBDT library designed to *adapt and generalize*.
698
- ---
697
+ ## 🎯 What Sets WarpGBM Apart
699
698
 
700
- ## Contents
699
+ **Regression + Classification Unified**
700
+ Train on continuous targets or multiclass labels with the same blazing-fast infrastructure.
701
701
 
702
- - [Features](#features)
703
- - [Benchmarks](#benchmarks)
704
- - [Installation](#installation)
705
- - [Learning Invariant Signals Across Environments](#learning-invariant-signals-across-environments)
706
- - [Why This Matters](#why-this-matters)
707
- - [Visual Intuition](#visual-intuition)
708
- - [Key References](#key-references)
709
- - [Examples](#examples)
710
- - [Quick Comparison with LightGBM CPU version](#quick-comparison-with-lightgbm-cpu-version)
711
- - [Pre-binned Data Example (Numerai)](#pre-binned-data-example-numerai)
712
- - [Documentation](#documentation)
713
- - [Acknowledgements](#acknowledgements)
714
- - [Version Notes](#version-notes)
702
+ **Invariant Learning (DES Algorithm)**
703
+ The only open-source GBDT that natively learns signals stable across shifting distributions. Powered by **[Directional Era-Splitting](https://arxiv.org/abs/2309.14496)** — because your data doesn't live in a vacuum.
715
704
 
705
+ **GPU-Accelerated Everything**
706
+ Custom CUDA kernels for binning, histograms, splits, and inference. No compromises, no CPU bottlenecks.
716
707
 
717
- ## Features
708
+ **Scikit-Learn Compatible**
709
+ Drop-in replacement. Same API you know, 10x the speed you need.
718
710
 
719
- - **Blazing-fast GPU training** with custom CUDA kernels for binning, histogram building, split finding, and prediction
720
- - **Invariant signal learning** via [Directional Era-Splitting (DES)](https://arxiv.org/abs/2309.14496) — designed for datasets with shifting environments (e.g., time, regimes, experimental settings)
721
- - Drop-in **scikit-learn style interface** for easy adoption
722
- - Supports **pre-binned data** or **automatic quantile binning**
723
- - Works with `float32` or `int8` inputs
724
- - Built-in **validation and early stopping** support with MSE, RMSLE, or correlation metrics
725
- - Simple install with `pip`, no custom drivers required
711
+ ---
726
712
 
727
- > 💡 **Note:** WarpGBM v1.0.0 is a *generalization* of the traditional GBDT algorithm.
728
- > To run standard GBM training at maximum speed, simply omit the `era_id` argument — WarpGBM will behave like a traditional booster but with industry-leading performance.
713
+ ## 🚀 Quick Start
729
714
 
730
- ---
715
+ ### Installation
731
716
 
732
- ## Benchmarks
717
+ ```bash
718
+ # Latest from GitHub (recommended)
719
+ pip install git+https://github.com/jefferythewind/warpgbm.git
733
720
 
734
- ### Scikit-Learn Synthetic Data: 1 Million Rows and 1,000 Features
721
+ # Stable from PyPI
722
+ pip install warpgbm
723
+ ```
735
724
 
736
- In this benchmark we compare the speed and in-sample correlation of **WarpGBM v0.1.21** against LightGBM, XGBoost and CatBoost, all with their GPU-enabled versions. This benchmark runs on Google Colab with the L4 GPU environment.
725
+ **Prerequisites:** PyTorch with CUDA support ([install guide](https://pytorch.org/get-started/locally/))
737
726
 
727
+ ### Regression in 5 Lines
728
+
729
+ ```python
730
+ from warpgbm import WarpGBM
731
+ import numpy as np
732
+
733
+ model = WarpGBM(objective='regression', max_depth=5, n_estimators=100)
734
+ model.fit(X_train, y_train)
735
+ predictions = model.predict(X_test)
738
736
  ```
739
- WarpGBM: corr = 0.8882, train = 18.7s, infer = 4.9s
740
- XGBoost: corr = 0.8877, train = 33.1s, infer = 8.1s
741
- LightGBM: corr = 0.8604, train = 30.3s, infer = 1.4s
742
- CatBoost: corr = 0.8935, train = 400.0s, infer = 382.6s
737
+
738
+ ### Classification in 5 Lines
739
+
740
+ ```python
741
+ from warpgbm import WarpGBM
742
+
743
+ model = WarpGBM(objective='multiclass', max_depth=5, n_estimators=50)
744
+ model.fit(X_train, y_train) # y can be integers, strings, whatever
745
+ probabilities = model.predict_proba(X_test)
746
+ labels = model.predict(X_test)
743
747
  ```
744
748
 
745
- Colab Notebook: https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK?usp=sharing
749
+ ---
750
+
751
+ ## 🎮 Features
752
+
753
+ ### Core Engine
754
+ - ⚡ **GPU-native CUDA kernels** for histogram building, split finding, binning, and prediction
755
+ - 🎯 **Multi-objective support**: regression, binary, multiclass classification
756
+ - 📊 **Pre-binned data optimization** — skip binning if your data's already quantized
757
+ - 🔥 **Mixed precision support** — `float32` or `int8` inputs
758
+ - 🎲 **Stochastic features** — `colsample_bytree` for regularization
759
+
760
+ ### Intelligence
761
+ - 🧠 **Invariant learning via DES** — identifies signals that generalize across time/regimes/environments
762
+ - 📈 **Smart initialization** — class priors for classification, mean for regression
763
+ - 🎯 **Automatic label encoding** — handles strings, integers, whatever you throw at it
764
+
765
+ ### Training Utilities
766
+ - ✅ **Early stopping** with validation sets
767
+ - 📊 **Rich metrics**: MSE, RMSLE, correlation, log loss, accuracy
768
+ - 🔍 **Progress tracking** with loss curves
769
+ - 🎚️ **Regularization** — L2 leaf penalties, min split gain, min child weight
746
770
 
747
771
  ---
748
772
 
749
- ## Installation
773
+ ## ⚔️ Benchmarks
750
774
 
751
- ### Recommended (GitHub, always latest):
775
+ ### Synthetic Data: 1M Rows × 1K Features (Google Colab L4 GPU)
752
776
 
753
- ```bash
754
- pip install git+https://github.com/jefferythewind/warpgbm.git
777
+ ```
778
+ WarpGBM: corr = 0.8882, train = 17.4s, infer = 3.2s ⚡
779
+ XGBoost: corr = 0.8877, train = 33.2s, infer = 8.0s
780
+ LightGBM: corr = 0.8604, train = 29.8s, infer = 1.6s
781
+ CatBoost: corr = 0.8935, train = 392.1s, infer = 379.2s
755
782
  ```
756
783
 
757
- This installs the latest version directly from GitHub and compiles CUDA extensions on your machine using your **local PyTorch and CUDA setup**. It's the most reliable method for ensuring compatibility and staying up to date with the latest features.
784
+ ** faster than XGBoost. 23× faster than CatBoost.**
758
785
 
759
- ### Alternatively (PyPI, stable releases):
786
+ [→ Run the benchmark yourself](https://colab.research.google.com/drive/16U1kbYlD5HibGbnF5NGsjChZ1p1IA2pK?usp=sharing)
760
787
 
761
- ```bash
762
- pip install warpgbm
788
+ ### Multiclass Classification: 3.5K Samples, 3 Classes, 50 Rounds
789
+
790
+ ```
791
+ Training: 2.13s
792
+ Inference: 0.37s
793
+ Accuracy: 75.3%
763
794
  ```
764
795
 
765
- This installs from PyPI and also compiles CUDA code locally during installation. This method works well **if your environment already has PyTorch with GPU support** installed and configured.
796
+ **Production-ready multiclass at neural network speeds.**
766
797
 
767
- > **Tip:**\
768
- > If you encounter an error related to mismatched or missing CUDA versions, try installing with the following flag. This is currently required in the Colab environments.
769
- >
770
- > ```bash
771
- > pip install warpgbm --no-build-isolation
772
- > ```
798
+ ---
773
799
 
774
- ### Windows
800
+ ## 📖 Examples
775
801
 
776
- Thank you, ShatteredX, for providing working instructions for a Windows installation.
802
+ ### Regression: Beat LightGBM on Your Laptop
777
803
 
778
- ```
779
- git clone https://github.com/jefferythewind/warpgbm.git
780
- cd warpgbm
781
- python setup.py bdist_wheel
782
- pip install .\dist\warpgbm-0.1.15-cp310-cp310-win_amd64.whl
804
+ ```python
805
+ import numpy as np
806
+ from sklearn.datasets import make_regression
807
+ from warpgbm import WarpGBM
808
+
809
+ # Generate data
810
+ X, y = make_regression(n_samples=100_000, n_features=500, random_state=42)
811
+ X, y = X.astype(np.float32), y.astype(np.float32)
812
+
813
+ # Train
814
+ model = WarpGBM(
815
+ objective='regression',
816
+ max_depth=5,
817
+ n_estimators=100,
818
+ learning_rate=0.01,
819
+ num_bins=32
820
+ )
821
+ model.fit(X, y)
822
+
823
+ # Predict
824
+ preds = model.predict(X)
825
+ print(f"Correlation: {np.corrcoef(preds, y)[0,1]:.4f}")
783
826
  ```
784
827
 
785
- Before either method, make sure you’ve installed PyTorch with GPU support:\
786
- [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
828
+ ### Classification: Multiclass with Early Stopping
787
829
 
788
- ---
830
+ ```python
831
+ from sklearn.datasets import make_classification
832
+ from sklearn.model_selection import train_test_split
833
+ from warpgbm import WarpGBM
789
834
 
790
- ## Learning Invariant Signals Across Environments
835
+ # 5-class problem
836
+ X, y = make_classification(
837
+ n_samples=10_000,
838
+ n_features=50,
839
+ n_classes=5,
840
+ n_informative=30
841
+ )
791
842
 
792
- Most supervised learning models rely on an assumption known as the **Empirical Risk Minimization (ERM)** principle. Under ERM, the data distribution connecting inputs \( X \) and targets \( Y \) is assumed to be **fixed** and **stationary** across training, validation, and test splits. That is:
843
+ X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
793
844
 
794
- > The patterns you learn from the training set are expected to generalize out-of-sample — *as long as the test data follows the same distribution as the training data.*
845
+ model = WarpGBM(
846
+ objective='multiclass',
847
+ max_depth=6,
848
+ n_estimators=200,
849
+ learning_rate=0.1,
850
+ num_bins=32
851
+ )
795
852
 
796
- However, this assumption is often violated in real-world settings. Data frequently shifts across time, geography, experimental conditions, or other hidden factors. This phenomenon is known as **distribution shift**, and it leads to models that perform well in-sample but fail catastrophically out-of-sample.
853
+ model.fit(
854
+ X_train, y_train,
855
+ X_eval=X_val, y_eval=y_val,
856
+ eval_every_n_trees=10,
857
+ early_stopping_rounds=5,
858
+ eval_metric='logloss'
859
+ )
797
860
 
798
- This challenge motivates the field of **Out-of-Distribution (OOD) Generalization**, which assumes your data is drawn from **distinct environments or eras** — e.g., time periods, customer segments, experimental trials. Some signals may appear predictive within specific environments but vanish or reverse in others. These are called **spurious signals**. On the other hand, signals that remain consistently predictive across all environments are called **invariant signals**.
861
+ # Get probabilities or class predictions
862
+ probs = model.predict_proba(X_val) # shape: (n_samples, n_classes)
863
+ labels = model.predict(X_val) # class labels
864
+ ```
799
865
 
800
- WarpGBM v1.0.0 introduces **Directional Era-Splitting (DES)**, a new algorithm designed to identify and learn from invariant signals — ignoring signals that fail to generalize across environments.
866
+ ### Invariant Learning: Distribution-Robust Signals
801
867
 
802
- ---
868
+ ```python
869
+ # Your data spans multiple time periods/regimes/environments
870
+ # Pass era_id to learn only signals that work across ALL eras
803
871
 
804
- ### Why This Matters
872
+ model = WarpGBM(
873
+ objective='regression',
874
+ max_depth=8,
875
+ n_estimators=100
876
+ )
805
877
 
806
- - Standard models trained via ERM can learn to exploit **spurious correlations** that only hold in some parts of the data.
807
- - DES explicitly tests whether a feature's split is **directionally consistent** across all eras — only such *invariant splits* are kept.
808
- - This approach has been shown to reduce overfitting and improve out-of-sample generalization, particularly in financial and scientific datasets.
878
+ model.fit(
879
+ X, y,
880
+ era_id=era_labels # Array marking which era each sample belongs to
881
+ )
809
882
 
810
- ---
883
+ # Now your model ignores spurious correlations that don't generalize!
884
+ ```
811
885
 
812
- ### Visual Intuition
886
+ ### Pre-binned Data: Maximum Speed (Numerai Example)
813
887
 
814
- We contrast two views of the data:
888
+ ```python
889
+ import pandas as pd
890
+ from numerapi import NumerAPI
891
+ from warpgbm import WarpGBM
815
892
 
816
- - **ERM Setting**: All data is assumed to come from the same source (single distribution).\
817
- No awareness of environments — spurious signals can dominate.
893
+ # Download Numerai data (already quantized to integers)
894
+ napi = NumerAPI()
895
+ napi.download_dataset('v5.0/train.parquet', 'train.parquet')
896
+ train = pd.read_parquet('train.parquet')
818
897
 
819
- - **OOD Setting (Era-Splitting)**: Data is explicitly grouped by environment (era).\
820
- The model checks whether a signal holds across all groups — enforcing **robustness**.
898
+ features = [f for f in train.columns if 'feature' in f]
899
+ X = train[features].astype('int8').values
900
+ y = train['target'].values
821
901
 
822
- *📷 [Placeholder for future visual illustration]*
902
+ # WarpGBM detects pre-binned data and skips binning
903
+ model = WarpGBM(max_depth=5, n_estimators=100, num_bins=20)
904
+ model.fit(X, y) # Blazing fast!
905
+ ```
823
906
 
907
+ **Result: 13× faster than LightGBM on Numerai data (49s vs 643s)**
824
908
 
825
909
  ---
826
910
 
827
- ### Key References
911
+ ## 🧠 Invariant Learning: Why It Matters
828
912
 
829
- - **Invariant Risk Minimization (IRM)**: [Arjovsky et al., 2019](https://arxiv.org/abs/1907.02893)
830
- - **Learning Explanations That Are Hard to Vary**: [Parascandolo et al., 2020](https://arxiv.org/abs/2009.00329)
831
- - **Era Splitting: Invariant Learning for Decision Trees**: [DeLise, 2023](https://arxiv.org/abs/2309.14496)
913
+ Most ML models assume your training and test data come from the same distribution. **Reality check: they don't.**
832
914
 
833
- ---
915
+ - Stock prices shift with market regimes
916
+ - User behavior changes over time
917
+ - Experimental data varies by batch/site/condition
834
918
 
835
- WarpGBM is the **first open-source GBDT framework to integrate this OOD-aware approach natively**, using efficient CUDA kernels to evaluate per-era consistency during tree growth. It’s not just faster — it’s smarter.
919
+ **Traditional GBDT:** Learns any signal that correlates with the target, including fragile patterns that break OOD.
836
920
 
837
- ---
921
+ **WarpGBM with DES:** Explicitly tests if each split generalizes across ALL environments (eras). Only keeps robust signals.
838
922
 
839
- ## Examples
923
+ ### The Algorithm
840
924
 
841
- WarpGBM is easy to drop into any supervised learning workflow and comes with curated examples in the `examples/` folder.
925
+ For each potential split, compute gain separately in each era. Only accept splits where:
926
+ 1. Gain is positive in ALL eras
927
+ 2. Split direction is consistent across eras
842
928
 
843
- - `Spiral Data.ipynb`: synthetic OOD benchmark from Learning Explanations That Are Hard to Vary
929
+ This prevents overfitting to spurious correlations that only work in some time periods or environments.
844
930
 
845
- ### Quick Comparison with LightGBM CPU version
931
+ ### Visual Intuition
846
932
 
847
- ```python
848
- import numpy as np
849
- from sklearn.datasets import make_regression
850
- from time import time
851
- import lightgbm as lgb
852
- from warpgbm import WarpGBM
933
+ <img src="https://github.com/user-attachments/assets/2be11ef3-6f2e-4636-ab91-307a73add247" alt="Era Splitting Visualization" width="400"/>
853
934
 
854
- # Create synthetic regression dataset
855
- X, y = make_regression(n_samples=100_000, n_features=500, noise=0.1, random_state=42)
856
- X = X.astype(np.float32)
857
- y = y.astype(np.float32)
858
-
859
- # Train LightGBM
860
- start = time()
861
- lgb_model = lgb.LGBMRegressor(max_depth=5, n_estimators=100, learning_rate=0.01, max_bin=7)
862
- lgb_model.fit(X, y)
863
- lgb_time = time() - start
864
- lgb_preds = lgb_model.predict(X)
865
-
866
- # Train WarpGBM
867
- start = time()
868
- wgbm_model = WarpGBM(max_depth=5, n_estimators=100, learning_rate=0.01, num_bins=7)
869
- wgbm_model.fit(X, y)
870
- wgbm_time = time() - start
871
- wgbm_preds = wgbm_model.predict(X)
872
-
873
- # Results
874
- print(f"LightGBM: corr = {np.corrcoef(lgb_preds, y)[0,1]:.4f}, time = {lgb_time:.2f}s")
875
- print(f"WarpGBM: corr = {np.corrcoef(wgbm_preds, y)[0,1]:.4f}, time = {wgbm_time:.2f}s")
876
- ```
935
+ **Left:** Standard training pools all data together — learns any signal that correlates.
936
+ **Right:** Era-aware training demands signals work across all periods — learns robust features only.
877
937
 
878
- **Results (Ryzen 9 CPU, NVIDIA 3090 GPU):**
938
+ ### Research Foundation
879
939
 
880
- ```
881
- LightGBM: corr = 0.8742, time = 37.33s
882
- WarpGBM: corr = 0.8621, time = 5.40s
883
- ```
940
+ - **Invariant Risk Minimization**: [Arjovsky et al., 2019](https://arxiv.org/abs/1907.02893)
941
+ - **Hard-to-Vary Explanations**: [Parascandolo et al., 2020](https://arxiv.org/abs/2009.00329)
942
+ - **Era Splitting for Trees**: [DeLise, 2023](https://arxiv.org/abs/2309.14496)
884
943
 
885
944
  ---
886
945
 
887
- ### Pre-binned Data Example (Numerai)
946
+ ## 📚 API Reference
888
947
 
889
- WarpGBM can save additional training time if your dataset is already pre-binned. The Numerai tournament data is a great example:
948
+ ### Constructor Parameters
890
949
 
891
950
  ```python
892
- import pandas as pd
893
- from numerapi import NumerAPI
894
- from time import time
895
- import lightgbm as lgb
896
- from warpgbm import WarpGBM
897
- import numpy as np
951
+ WarpGBM(
952
+ objective='regression', # 'regression', 'binary', or 'multiclass'
953
+ num_bins=10, # Histogram bins for feature quantization
954
+ max_depth=3, # Maximum tree depth
955
+ learning_rate=0.1, # Shrinkage rate (aka eta)
956
+ n_estimators=100, # Number of boosting rounds
957
+ min_child_weight=20, # Min sum of instance weights in child node
958
+ min_split_gain=0.0, # Min loss reduction to split
959
+ L2_reg=1e-6, # L2 leaf regularization
960
+ colsample_bytree=1.0, # Feature subsample ratio per tree
961
+ threads_per_block=64, # CUDA block size (tune for your GPU)
962
+ rows_per_thread=4, # Rows processed per thread
963
+ device='cuda' # 'cuda' or 'cpu' (GPU strongly recommended)
964
+ )
965
+ ```
898
966
 
899
- napi = NumerAPI()
900
- napi.download_dataset('v5.0/train.parquet', 'train.parquet')
901
- train = pd.read_parquet('train.parquet')
967
+ ### Training Methods
902
968
 
903
- feature_set = [f for f in train.columns if 'feature' in f]
904
- target = 'target_cyrus'
905
-
906
- X_np = train[feature_set].astype('int8').values
907
- Y_np = train[target].values
908
-
909
- # LightGBM
910
- start = time()
911
- lgb_model = lgb.LGBMRegressor(max_depth=5, n_estimators=100, learning_rate=0.01, max_bin=7)
912
- lgb_model.fit(X_np, Y_np)
913
- lgb_time = time() - start
914
- lgb_preds = lgb_model.predict(X_np)
915
-
916
- # WarpGBM
917
- start = time()
918
- wgbm_model = WarpGBM(max_depth=5, n_estimators=100, learning_rate=0.01, num_bins=7)
919
- wgbm_model.fit(X_np, Y_np)
920
- wgbm_time = time() - start
921
- wgbm_preds = wgbm_model.predict(X_np)
922
-
923
- # Results
924
- print(f"LightGBM: corr = {np.corrcoef(lgb_preds, Y_np)[0,1]:.4f}, time = {lgb_time:.2f}s")
925
- print(f"WarpGBM: corr = {np.corrcoef(wgbm_preds, Y_np)[0,1]:.4f}, time = {wgbm_time:.2f}s")
969
+ ```python
970
+ model.fit(
971
+ X, # Features: np.array shape (n_samples, n_features)
972
+ y, # Target: np.array shape (n_samples,)
973
+ era_id=None, # Optional: era labels for invariant learning
974
+ X_eval=None, # Optional: validation features
975
+ y_eval=None, # Optional: validation targets
976
+ eval_every_n_trees=None, # Eval frequency (in rounds)
977
+ early_stopping_rounds=None, # Stop if no improvement for N evals
978
+ eval_metric='mse' # 'mse', 'rmsle', 'corr', 'logloss', 'accuracy'
979
+ )
926
980
  ```
927
981
 
928
- **Results (Google Colab Pro, A100 GPU):**
982
+ ### Prediction Methods
983
+
984
+ ```python
985
+ # Regression: returns predicted values
986
+ predictions = model.predict(X)
987
+
988
+ # Classification: returns class labels (decoded)
989
+ labels = model.predict(X)
929
990
 
991
+ # Classification: returns class probabilities
992
+ probabilities = model.predict_proba(X) # shape: (n_samples, n_classes)
930
993
  ```
931
- LightGBM: corr = 0.0703, time = 643.88s
932
- WarpGBM: corr = 0.0660, time = 49.16s
994
+
995
+ ### Attributes
996
+
997
+ ```python
998
+ model.classes_ # Unique class labels (classification only)
999
+ model.num_classes # Number of classes (classification only)
1000
+ model.forest # Trained tree structures
1001
+ model.training_loss # Training loss history
1002
+ model.eval_loss # Validation loss history (if eval set provided)
933
1003
  ```
934
1004
 
935
1005
  ---
936
1006
 
937
- ## Documentation
938
-
939
- ### `WarpGBM` Parameters:
940
- - `num_bins`: Number of histogram bins to use (default: 10)
941
- - `max_depth`: Maximum depth of trees (default: 3)
942
- - `learning_rate`: Shrinkage rate applied to leaf outputs (default: 0.1)
943
- - `n_estimators`: Number of boosting iterations (default: 100)
944
- - `min_child_weight`: Minimum sum of instance weight needed in a child (default: 20)
945
- - `min_split_gain`: Minimum loss reduction required to make a further partition (default: 0.0)
946
- - `histogram_computer`: Choice of histogram kernel (`'hist1'`, `'hist2'`, `'hist3'`) (default: `'hist3'`)
947
- - `threads_per_block`: CUDA threads per block (default: 32)
948
- - `rows_per_thread`: Number of training rows processed per thread (default: 4)
949
- - `L2_reg`: L2 regularizer (default: 1e-6)
950
- - `colsample_bytree`: Proportion of features to subsample to grow each tree (default: 1)
951
-
952
- ### Methods:
953
- ```
954
- .fit(
955
- X, # numpy array (float or int) 2 dimensions (num_samples, num_features)
956
- y, # numpy array (float or int) 1 dimension (num_samples)
957
- era_id=None, # numpy array (int) 1 dimension (num_samples)
958
- X_eval=None, # numpy array (float or int) 2 dimensions (eval_num_samples, num_features)
959
- y_eval=None, # numpy array (float or int) 1 dimension (eval_num_samples)
960
- eval_every_n_trees=None, # const (int) >= 1
961
- early_stopping_rounds=None, # const (int) >= 1
962
- eval_metric='mse' # string, one of 'mse', 'rmsle' or 'corr'. For corr, loss is 1 - correlation(y_true, preds)
963
- )
1007
+ ## 🔧 Installation Details
1008
+
1009
+ ### Linux / macOS (Recommended)
1010
+
1011
+ ```bash
1012
+ pip install git+https://github.com/jefferythewind/warpgbm.git
964
1013
  ```
965
- Train with optional validation set and early stopping.
966
1014
 
1015
+ Compiles CUDA extensions using your local PyTorch + CUDA setup.
967
1016
 
1017
+ ### Colab / Mismatched CUDA Versions
1018
+
1019
+ ```bash
1020
+ pip install warpgbm --no-build-isolation
968
1021
  ```
969
- .predict(
970
- X # numpy array (float or int) 2 dimensions (predict_num_samples, num_features)
971
- )
1022
+
1023
+ ### Windows
1024
+
1025
+ ```bash
1026
+ git clone https://github.com/jefferythewind/warpgbm.git
1027
+ cd warpgbm
1028
+ python setup.py bdist_wheel
1029
+ pip install dist/warpgbm-*.whl
972
1030
  ```
973
- Predict on new data, using parallelized CUDA kernel.
974
1031
 
975
1032
  ---
976
1033
 
977
- ## Acknowledgements
1034
+ ## 🎯 Use Cases
978
1035
 
979
- WarpGBM builds on the shoulders of PyTorch, scikit-learn, LightGBM, and the CUDA ecosystem. Thanks to all contributors in the GBDT research and engineering space.
1036
+ **Financial ML:** Learn signals that work across market regimes
1037
+ **Time Series:** Robust forecasting across distribution shifts
1038
+ **Scientific Research:** Models that generalize across experimental batches
1039
+ **High-Speed Inference:** Production systems with millisecond SLAs
1040
+ **Kaggle/Competitions:** GPU-accelerated hyperparameter tuning
1041
+ **Multiclass Problems:** Image classification fallback, text categorization, fraud detection
980
1042
 
981
1043
  ---
982
1044
 
983
- ## Version Notes
1045
+ ## 🚧 Roadmap
984
1046
 
985
- ### v0.1.21
1047
+ - [ ] Multi-GPU training support
1048
+ - [ ] SHAP value computation on GPU
1049
+ - [ ] Feature interaction constraints
1050
+ - [ ] Monotonic constraints
1051
+ - [ ] Custom loss functions
1052
+ - [ ] Distributed training
1053
+ - [ ] ONNX export for deployment
986
1054
 
987
- - Vectorized predict function replaced with CUDA kernel (`warpgbm/cuda/predict.cu`), parallelizing per sample, per tree.
1055
+ ---
988
1056
 
989
- ### v0.1.23
1057
+ ## 🙏 Acknowledgements
990
1058
 
991
- - Adjust gain in split kernel and added support for an eval set with early stopping based on MSE.
1059
+ Built on the shoulders of PyTorch, scikit-learn, LightGBM, XGBoost, and the CUDA ecosystem. Special thanks to the GBDT research community and all contributors.
992
1060
 
993
- ### v0.1.25
1061
+ ---
1062
+
1063
+ ## 📝 Version History
994
1064
 
995
- - Added `colsample_bytree` parameter and new test using Numerai data.
1065
+ ### v2.0.0 (Current)
1066
+ - ✨ **Multiclass classification support** via softmax objective
1067
+ - 🎯 Binary classification mode
1068
+ - 📊 New metrics: log loss, accuracy
1069
+ - 🏷️ Automatic label encoding (supports strings)
1070
+ - 🔮 `predict_proba()` for probability outputs
1071
+ - ✅ Comprehensive test suite for classification
1072
+ - 🔒 Full backward compatibility with regression
1073
+ - 🐛 Fixed unused variable issue (#8)
1074
+ - 🧹 Removed unimplemented L1_reg parameter
1075
+ - 📚 Major documentation overhaul with AGENT_GUIDE.md
1076
+
1077
+ ### v1.0.0
1078
+ - 🧠 Invariant learning via Directional Era-Splitting (DES)
1079
+ - 🚀 VRAM optimizations
1080
+ - 📈 Era-aware histogram computation
996
1081
 
997
1082
  ### v0.1.26
1083
+ - 🐛 Memory bug fixes in prediction
1084
+ - 📊 Added correlation eval metric
998
1085
 
999
- - Fix Memory bugs in prediction and colsample bytree logic. Added "corr" eval metric.
1086
+ ### v0.1.25
1087
+ - 🎲 Feature subsampling (`colsample_bytree`)
1000
1088
 
1001
- ### v1.0.0
1089
+ ### v0.1.23
1090
+ - ⏹️ Early stopping support
1091
+ - ✅ Validation set evaluation
1092
+
1093
+ ### v0.1.21
1094
+ - ⚡ CUDA prediction kernel (replaced vectorized Python)
1095
+
1096
+ ---
1097
+
1098
+ ## 📄 License
1099
+
1100
+ MIT License - see [LICENSE](LICENSE) file
1101
+
1102
+ ---
1103
+
1104
+ ## 🤝 Contributing
1105
+
1106
+ Pull requests welcome! See [AGENT_GUIDE.md](AGENT_GUIDE.md) for architecture details and development guidelines.
1107
+
1108
+ ---
1109
+
1110
+ **Built with 🔥 by @jefferythewind**
1002
1111
 
1003
- - Introduce invariant learning via directional era splitting (DES). Also streamline VRAM improvements over previous sub versions.
1112
+ *"Train smarter. Predict faster. Generalize better."*