smallaxe 0.6.2__tar.gz → 0.6.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smallaxe-0.6.4/Goals.md +220 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/PKG-INFO +1 -1
- {smallaxe-0.6.2 → smallaxe-0.6.4}/TODO.md +1 -9
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/training/__init__.py +26 -0
- smallaxe-0.6.4/smallaxe/training/catboost.py +296 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/training/classifiers.py +151 -11
- smallaxe-0.6.4/smallaxe/training/lightgbm.py +441 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/training/regressors.py +147 -11
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe.egg-info/PKG-INFO +1 -1
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe.egg-info/SOURCES.txt +5 -0
- smallaxe-0.6.4/tests/test_catboost.py +200 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_factories.py +111 -1
- smallaxe-0.6.4/tests/test_lightgbm.py +837 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/.github/workflows/ci.yml +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/.github/workflows/publish.yml +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/.gitignore +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/LICENSE +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/README.md +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/pyproject.toml +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/requirements-dev.txt +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/setup.cfg +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/__init__.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/_config.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/auto/__init__.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/datasets/__init__.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/datasets/_data.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/exceptions/__init__.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/metrics/__init__.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/metrics/classification.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/metrics/regression.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/pipeline/__init__.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/pipeline/pipeline.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/preprocessing/__init__.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/preprocessing/encoder.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/preprocessing/imputer.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/preprocessing/scaler.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/search/__init__.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/training/base.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/training/mixins/__init__.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/training/mixins/metadata_mixin.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/training/mixins/param_mixin.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/training/mixins/persistence_mixin.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/training/mixins/spark_model_mixin.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/training/mixins/validation_mixin.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/training/random_forest.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/training/xgboost.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe/viz/__init__.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe.egg-info/dependency_links.txt +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe.egg-info/requires.txt +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/smallaxe.egg-info/top_level.txt +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/__init__.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/conftest.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_config.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_datasets.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_encoder.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_exceptions.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_imputer.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_metrics_classification.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_metrics_regression.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_mixins.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_pipeline.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_random_forest.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_scaler.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_smoke.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_training_base.py +0 -0
- {smallaxe-0.6.2 → smallaxe-0.6.4}/tests/test_xgboost.py +0 -0
smallaxe-0.6.4/Goals.md
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
> Note for AI agents working on this repository:
|
|
2
|
+
>
|
|
3
|
+
> Your job is to move smallaxe toward the goals in this document while keeping the library readable, simple to use, and extensible. Prefer clear APIs, small focused abstractions, and implementation patterns that match the existing codebase.
|
|
4
|
+
>
|
|
5
|
+
> Make goal-related changes on a git branch named `goals`. If the branch does not exist, create it before editing.
|
|
6
|
+
>
|
|
7
|
+
> Always validate changes with tests. The Python environment is managed with UV and should be activated with:
|
|
8
|
+
>
|
|
9
|
+
> ```bash
|
|
10
|
+
> source ~/Desktop/basic/bin/activate
|
|
11
|
+
> ```
|
|
12
|
+
>
|
|
13
|
+
> If a new Python library is needed in the environment, install it with:
|
|
14
|
+
>
|
|
15
|
+
> ```bash
|
|
16
|
+
> uv pip install <library-name>
|
|
17
|
+
> ```
|
|
18
|
+
>
|
|
19
|
+
> For PySpark tests on this machine, use OpenJDK 11:
|
|
20
|
+
>
|
|
21
|
+
> ```bash
|
|
22
|
+
> export JAVA_HOME=/opt/homebrew/opt/openjdk@11
|
|
23
|
+
> export PATH="$JAVA_HOME/bin:$PATH"
|
|
24
|
+
> ```
|
|
25
|
+
>
|
|
26
|
+
> Run relevant focused tests after each change, and run the full suite before considering work complete:
|
|
27
|
+
>
|
|
28
|
+
> ```bash
|
|
29
|
+
> pytest -q
|
|
30
|
+
> ```
|
|
31
|
+
>
|
|
32
|
+
> If you are unsure about the current behavior, API, or implementation details of a dependency or library, use the available DeepWiki MCP tools to inspect authoritative project documentation before making assumptions.
|
|
33
|
+
|
|
34
|
+
# smallaxe Goals
|
|
35
|
+
|
|
36
|
+
## Product Goal
|
|
37
|
+
|
|
38
|
+
smallaxe should make common supervised modeling on PySpark DataFrames feel as simple as scikit-learn on pandas, while keeping execution distributed through Spark-native and Spark-compatible ML libraries.
|
|
39
|
+
|
|
40
|
+
The first stable target is:
|
|
41
|
+
|
|
42
|
+
- Binary classification.
|
|
43
|
+
- Standard continuous regression across Random Forest, LightGBM, XGBoost, and CatBoost regressors.
|
|
44
|
+
- A simple, consistent user API for preprocessing, training, evaluation, prediction, persistence, and pipeline composition.
|
|
45
|
+
|
|
46
|
+
Longer-term expansion should add multiclass classification, multilabel classification, and specialized regression tasks such as quantile regression.
|
|
47
|
+
|
|
48
|
+
## Current Baseline
|
|
49
|
+
|
|
50
|
+
The current implementation already has useful foundations:
|
|
51
|
+
|
|
52
|
+
- Global configuration, custom exceptions, sample datasets, metrics, preprocessing, pipeline, and training modules.
|
|
53
|
+
- Random Forest regressors/classifiers backed by PySpark ML.
|
|
54
|
+
- Optional XGBoost and LightGBM wrappers.
|
|
55
|
+
- Imputer, Scaler, Encoder, and Pipeline classes.
|
|
56
|
+
- Model metadata, validation scores, feature importance, and save/load support for individual models.
|
|
57
|
+
- A substantial test suite. With `~/Desktop/basic`, PySpark 3.5.x, and OpenJDK 11, the current suite passes: 485 passed, 102 skipped. The skipped tests are optional XGBoost/LightGBM coverage when those libraries are not installed.
|
|
58
|
+
|
|
59
|
+
## Missing For v1
|
|
60
|
+
|
|
61
|
+
### 1. Align Public API With Actual Capabilities
|
|
62
|
+
|
|
63
|
+
- Update README to describe only implemented APIs, or implement the advertised APIs before release.
|
|
64
|
+
- Current README advertises `smallaxe.search.optimize`, `smallaxe.auto.AutomatedTraining`, visualization, and CatBoost, but those modules are empty or missing.
|
|
65
|
+
- Decide whether the first regression API is called "regression" or "linear regression." Random Forest, XGBoost, LightGBM, and CatBoost are not linear models. If true linear regression is a first-class goal, add a Spark `LinearRegression` baseline separately.
|
|
66
|
+
|
|
67
|
+
### 2. Finish The Four-Algorithm Training Surface
|
|
68
|
+
|
|
69
|
+
- Add CatBoost regressor and binary classifier support, or remove CatBoost from public docs until implemented.
|
|
70
|
+
- Add factory methods for LightGBM in `Regressors` and `Classifiers`; the classes exist, but the factories only expose Random Forest and XGBoost.
|
|
71
|
+
- Make optional dependency handling explicit:
|
|
72
|
+
- `available_models()` should report installed and unavailable models with install hints.
|
|
73
|
+
- Factories should raise clear `DependencyError` messages when a requested optional model is missing.
|
|
74
|
+
- Tests should verify missing optional dependency behavior without being globally skipped.
|
|
75
|
+
- Normalize model parameter names across algorithms where possible:
|
|
76
|
+
- User-facing: `n_estimators`, `max_depth`, `learning_rate`, `seed`.
|
|
77
|
+
- Internal adapters translate to Spark/XGBoost/LightGBM/CatBoost-specific names.
|
|
78
|
+
|
|
79
|
+
### 3. Make Preprocessing Production-Ready
|
|
80
|
+
|
|
81
|
+
- Split categorical and numeric preprocessing into predictable steps:
|
|
82
|
+
- Numeric imputation.
|
|
83
|
+
- Categorical imputation.
|
|
84
|
+
- Categorical encoding.
|
|
85
|
+
- Numeric scaling when useful.
|
|
86
|
+
- Feature vector assembly.
|
|
87
|
+
- Add a fitted preprocessing schema artifact:
|
|
88
|
+
- Input columns.
|
|
89
|
+
- Output feature columns.
|
|
90
|
+
- Encoded category mappings.
|
|
91
|
+
- Unknown-category behavior.
|
|
92
|
+
- Null handling behavior.
|
|
93
|
+
- Replace Python UDF extraction in Scaler/Encoder where practical with Spark SQL/vector functions for performance.
|
|
94
|
+
- Ensure transform-time behavior is stable for unseen categories, missing columns, and changed schemas.
|
|
95
|
+
- Avoid silently dropping rows during feature assembly. Current `VectorAssembler(handleInvalid="skip")` can change row counts during training or prediction.
|
|
96
|
+
|
|
97
|
+
### 4. Harden Pipeline Semantics
|
|
98
|
+
|
|
99
|
+
- Pipeline should own feature-column construction instead of passing all non-label columns to the model.
|
|
100
|
+
- Pipelines should support both:
|
|
101
|
+
- Preprocessing-only `fit/transform`.
|
|
102
|
+
- End-to-end `fit/predict/evaluate/save/load` with a model step.
|
|
103
|
+
- Add robust pipeline persistence for model pipelines, not only preprocessing pipelines.
|
|
104
|
+
- Save/load must preserve:
|
|
105
|
+
- Preprocessing state.
|
|
106
|
+
- Model artifacts.
|
|
107
|
+
- Feature schema.
|
|
108
|
+
- Label column.
|
|
109
|
+
- Task type.
|
|
110
|
+
- Model params.
|
|
111
|
+
- Validation/evaluation metadata.
|
|
112
|
+
- Add tests for saving and loading full pipelines with Random Forest first, then optional algorithm-specific tests.
|
|
113
|
+
|
|
114
|
+
### 5. Evaluation API
|
|
115
|
+
|
|
116
|
+
- Add a model-level `evaluate(df, label_col=None, metrics=None)` method.
|
|
117
|
+
- Add a pipeline-level `evaluate(...)` method that preprocesses, predicts, and scores in one call.
|
|
118
|
+
- For binary classification, support at least:
|
|
119
|
+
- Accuracy.
|
|
120
|
+
- Precision.
|
|
121
|
+
- Recall.
|
|
122
|
+
- F1.
|
|
123
|
+
- ROC AUC.
|
|
124
|
+
- PR AUC.
|
|
125
|
+
- Log loss.
|
|
126
|
+
- Confusion matrix.
|
|
127
|
+
- For regression, support at least:
|
|
128
|
+
- RMSE.
|
|
129
|
+
- MAE.
|
|
130
|
+
- MSE.
|
|
131
|
+
- R2.
|
|
132
|
+
- MAPE.
|
|
133
|
+
- Keep multiclass and multilabel metrics separate from binary metrics. The current binary precision/recall/F1 implementation should not be reused for multiclass without explicit averaging policy.
|
|
134
|
+
|
|
135
|
+
### 6. Training And Validation
|
|
136
|
+
|
|
137
|
+
- Move train/test split and k-fold logic into a dedicated validation module.
|
|
138
|
+
- Add public split utilities for reuse and testing.
|
|
139
|
+
- Make validation behavior explicit:
|
|
140
|
+
- `validation="none" | "train_test" | "kfold"`.
|
|
141
|
+
- `stratified=True` only for classification.
|
|
142
|
+
- Fixed seed behavior.
|
|
143
|
+
- Empty fold and tiny-class handling.
|
|
144
|
+
- Add train/validation metrics and final model metadata in a consistent structure.
|
|
145
|
+
- Add an option to cache training data during fitting, with documented tradeoffs.
|
|
146
|
+
|
|
147
|
+
### 7. Model Persistence And Registry-Ready Artifacts
|
|
148
|
+
|
|
149
|
+
- Define a stable artifact layout:
|
|
150
|
+
- `metadata.json`.
|
|
151
|
+
- `preprocessing/`.
|
|
152
|
+
- `model/`.
|
|
153
|
+
- `metrics.json`.
|
|
154
|
+
- `schema.json`.
|
|
155
|
+
- Include a `smallaxe_version`, Spark version, algorithm name, task type, params, feature schema, and timestamp.
|
|
156
|
+
- Provide `load_model(path)` and `load_pipeline(path)` convenience functions.
|
|
157
|
+
- Ensure loaded models produce the same predictions as saved models on deterministic test data.
|
|
158
|
+
- Design the artifact format so it can later plug into MLflow or a model registry.
|
|
159
|
+
|
|
160
|
+
### 8. Automated Training
|
|
161
|
+
|
|
162
|
+
- Implement `AutomatedTraining` after the four algorithm wrappers are stable.
|
|
163
|
+
- It should:
|
|
164
|
+
- Train all available compatible algorithms.
|
|
165
|
+
- Skip missing optional dependencies with warnings and install hints.
|
|
166
|
+
- Return a comparison table as a Spark or pandas DataFrame.
|
|
167
|
+
- Select `best_model` by a user-specified metric.
|
|
168
|
+
- Persist the winning model or full comparison run.
|
|
169
|
+
- Keep the first version constrained to binary classification and continuous regression.
|
|
170
|
+
|
|
171
|
+
### 9. Hyperparameter Search
|
|
172
|
+
|
|
173
|
+
- Implement `smallaxe.search.optimize`.
|
|
174
|
+
- Start with a simple, predictable API:
|
|
175
|
+
- model instance.
|
|
176
|
+
- DataFrame.
|
|
177
|
+
- label column.
|
|
178
|
+
- search space.
|
|
179
|
+
- metric.
|
|
180
|
+
- validation strategy.
|
|
181
|
+
- max evaluations.
|
|
182
|
+
- Preserve `best_params`, `best_score`, and trial history.
|
|
183
|
+
- Make search optional and clearly dependency-gated if using Hyperopt.
|
|
184
|
+
|
|
185
|
+
### 10. Documentation And Examples
|
|
186
|
+
|
|
187
|
+
- Rewrite README around the actual v1 user journey:
|
|
188
|
+
- Install.
|
|
189
|
+
- Build a preprocessing pipeline.
|
|
190
|
+
- Train binary classifier.
|
|
191
|
+
- Train regressor.
|
|
192
|
+
- Evaluate.
|
|
193
|
+
- Save/load.
|
|
194
|
+
- Use optional algorithms.
|
|
195
|
+
- Add examples for:
|
|
196
|
+
- Random Forest binary classification.
|
|
197
|
+
- XGBoost regression.
|
|
198
|
+
- LightGBM classification when dependency is installed.
|
|
199
|
+
- Full pipeline save/load.
|
|
200
|
+
- Add a compatibility matrix for Python, Spark, Java, and optional algorithm packages.
|
|
201
|
+
|
|
202
|
+
## v1 Acceptance Criteria
|
|
203
|
+
|
|
204
|
+
- A new user can train, evaluate, save, load, and predict with Random Forest on a PySpark DataFrame in under 20 lines of code.
|
|
205
|
+
- The same user-facing workflow works for XGBoost, LightGBM, and CatBoost when optional dependencies are installed.
|
|
206
|
+
- Binary classification and continuous regression have clear metrics and stable output schemas.
|
|
207
|
+
- A full preprocessing-plus-model pipeline can be saved and loaded with identical predictions on deterministic data.
|
|
208
|
+
- Missing optional dependencies fail with actionable install instructions.
|
|
209
|
+
- Documentation does not advertise unimplemented APIs.
|
|
210
|
+
- CI runs core tests on supported Python/Spark versions and optional algorithm tests in separate dependency-enabled jobs.
|
|
211
|
+
|
|
212
|
+
## Later Goals
|
|
213
|
+
|
|
214
|
+
- Multiclass classification with explicit averaging options for metrics.
|
|
215
|
+
- Multilabel classification.
|
|
216
|
+
- Quantile regression and other specialized regression objectives.
|
|
217
|
+
- Calibration and threshold tuning for binary classifiers.
|
|
218
|
+
- Feature importance and model comparison visualizations.
|
|
219
|
+
- MLflow integration for experiment tracking and model registry workflows.
|
|
220
|
+
- Distributed hyperparameter tuning with Spark-aware execution.
|
|
@@ -2,14 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
### Phase 8: Training Module - External Algorithms (v0.7.0)
|
|
4
4
|
|
|
5
|
-
#### Step 8.2: LightGBM
|
|
6
|
-
- [ ] Create `smallaxe/training/lightgbm.py`
|
|
7
|
-
- [ ] Implement `LightGBMRegressor` and `LightGBMClassifier`
|
|
8
|
-
- [ ] Handle optional dependency
|
|
9
|
-
- [ ] Create `tests/test_lightgbm.py`
|
|
10
|
-
- [ ] Commit: "Add LightGBM support"
|
|
11
|
-
- [ ] PR → main
|
|
12
|
-
|
|
13
5
|
#### Step 8.3: CatBoost
|
|
14
6
|
- [ ] Create `smallaxe/training/catboost.py`
|
|
15
7
|
- [ ] Implement `CatBoostRegressor` and `CatBoostClassifier`
|
|
@@ -236,4 +228,4 @@
|
|
|
236
228
|
| v0.10.0 | Optimization (hyperopt) |
|
|
237
229
|
| v0.11.0 | AutomatedTraining |
|
|
238
230
|
| v0.12.0 | Visualization |
|
|
239
|
-
| v1.0.0 | Integration, README, PyPI publish |
|
|
231
|
+
| v1.0.0 | Integration, README, PyPI publish |
|
|
@@ -27,3 +27,29 @@ try:
|
|
|
27
27
|
__all__.extend(["XGBoostRegressor", "XGBoostClassifier"])
|
|
28
28
|
except ImportError:
|
|
29
29
|
pass
|
|
30
|
+
|
|
31
|
+
# Import LightGBM classes if available (optional dependency)
|
|
32
|
+
try:
|
|
33
|
+
from smallaxe.training.lightgbm import (
|
|
34
|
+
LightGBMClassifier as LightGBMClassifier,
|
|
35
|
+
)
|
|
36
|
+
from smallaxe.training.lightgbm import (
|
|
37
|
+
LightGBMRegressor as LightGBMRegressor,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
__all__.extend(["LightGBMRegressor", "LightGBMClassifier"])
|
|
41
|
+
except ImportError:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
# Import CatBoost classes if available (optional dependency)
|
|
45
|
+
try:
|
|
46
|
+
from smallaxe.training.catboost import (
|
|
47
|
+
CatBoostClassifier as CatBoostClassifier,
|
|
48
|
+
)
|
|
49
|
+
from smallaxe.training.catboost import (
|
|
50
|
+
CatBoostRegressor as CatBoostRegressor,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
__all__.extend(["CatBoostRegressor", "CatBoostClassifier"])
|
|
54
|
+
except ImportError:
|
|
55
|
+
pass
|
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
"""CatBoost models for regression and classification."""
|
|
2
|
+
|
|
3
|
+
import shutil
|
|
4
|
+
import tempfile
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from pyspark.sql import DataFrame
|
|
8
|
+
|
|
9
|
+
from smallaxe.exceptions import DependencyError
|
|
10
|
+
from smallaxe.training.base import BaseClassifier, BaseRegressor
|
|
11
|
+
|
|
12
|
+
CATBOOST_AVAILABLE = False
|
|
13
|
+
SparkCatBoostRegressor = None
|
|
14
|
+
SparkCatBoostRegressionModel = None
|
|
15
|
+
SparkCatBoostClassifier = None
|
|
16
|
+
SparkCatBoostClassificationModel = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _load_catboost_spark() -> bool:
|
|
20
|
+
"""Load CatBoost Spark classes if Spark has made them importable."""
|
|
21
|
+
global CATBOOST_AVAILABLE
|
|
22
|
+
global SparkCatBoostRegressor
|
|
23
|
+
global SparkCatBoostRegressionModel
|
|
24
|
+
global SparkCatBoostClassifier
|
|
25
|
+
global SparkCatBoostClassificationModel
|
|
26
|
+
|
|
27
|
+
if CATBOOST_AVAILABLE:
|
|
28
|
+
return True
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
from catboost_spark import (
|
|
32
|
+
CatBoostClassificationModel,
|
|
33
|
+
CatBoostClassifier,
|
|
34
|
+
CatBoostRegressionModel,
|
|
35
|
+
CatBoostRegressor,
|
|
36
|
+
)
|
|
37
|
+
except ImportError:
|
|
38
|
+
return False
|
|
39
|
+
|
|
40
|
+
SparkCatBoostRegressor = CatBoostRegressor
|
|
41
|
+
SparkCatBoostRegressionModel = CatBoostRegressionModel
|
|
42
|
+
SparkCatBoostClassifier = CatBoostClassifier
|
|
43
|
+
SparkCatBoostClassificationModel = CatBoostClassificationModel
|
|
44
|
+
CATBOOST_AVAILABLE = True
|
|
45
|
+
return True
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
_load_catboost_spark()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _check_catboost_available() -> None:
|
|
52
|
+
"""Check if CatBoost Spark support is available."""
|
|
53
|
+
if not _load_catboost_spark():
|
|
54
|
+
raise DependencyError(
|
|
55
|
+
package="catboost_spark",
|
|
56
|
+
install_command=(
|
|
57
|
+
"pip install smallaxe[catboost] and configure Spark with "
|
|
58
|
+
"ai.catboost:catboost-spark_3.5_2.12:1.2.10"
|
|
59
|
+
),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def is_catboost_available() -> bool:
|
|
64
|
+
"""Return whether CatBoost Spark support is currently importable."""
|
|
65
|
+
return _load_catboost_spark()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def catboost_install_hint() -> str:
|
|
69
|
+
"""Return the install and Spark package hint for CatBoost support."""
|
|
70
|
+
return (
|
|
71
|
+
"pip install smallaxe[catboost] and configure Spark with "
|
|
72
|
+
"ai.catboost:catboost-spark_3.5_2.12:1.2.10"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class CatBoostRegressor(BaseRegressor):
|
|
77
|
+
"""CatBoost regressor for regression tasks.
|
|
78
|
+
|
|
79
|
+
This class wraps CatBoost for Spark's CatBoostRegressor to provide the
|
|
80
|
+
same smallaxe fit/predict/save/load interface as the other Spark-backed
|
|
81
|
+
regressors.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(self, task: str = "simple_regression") -> None:
|
|
85
|
+
"""Initialize the CatBoost regressor."""
|
|
86
|
+
_check_catboost_available()
|
|
87
|
+
super().__init__(task)
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def params(self) -> Dict[str, str]:
|
|
91
|
+
"""Get parameter descriptions."""
|
|
92
|
+
return {
|
|
93
|
+
"n_estimators": "Number of boosting iterations",
|
|
94
|
+
"max_depth": "Maximum tree depth",
|
|
95
|
+
"learning_rate": "Boosting learning rate",
|
|
96
|
+
"subsample": "Sample rate for bagging",
|
|
97
|
+
"l2_leaf_reg": "L2 regularization coefficient",
|
|
98
|
+
"random_strength": "Amount of randomness used when scoring splits",
|
|
99
|
+
"one_hot_max_size": "Maximum categorical cardinality for one-hot encoding",
|
|
100
|
+
"allow_writing_files": "Whether CatBoost may write training artifacts",
|
|
101
|
+
"train_dir": "Directory for CatBoost training artifacts",
|
|
102
|
+
"seed": "Random seed for reproducibility",
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def default_params(self) -> Dict[str, Any]:
|
|
107
|
+
"""Get default parameter values."""
|
|
108
|
+
return {
|
|
109
|
+
"n_estimators": 100,
|
|
110
|
+
"max_depth": 6,
|
|
111
|
+
"learning_rate": 0.03,
|
|
112
|
+
"subsample": None,
|
|
113
|
+
"l2_leaf_reg": 3.0,
|
|
114
|
+
"random_strength": 1.0,
|
|
115
|
+
"one_hot_max_size": None,
|
|
116
|
+
"allow_writing_files": False,
|
|
117
|
+
"train_dir": None,
|
|
118
|
+
"seed": None,
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
def _catboost_params(
|
|
122
|
+
self,
|
|
123
|
+
label_col: Optional[str] = None,
|
|
124
|
+
train_dir: Optional[str] = None,
|
|
125
|
+
) -> Dict[str, Any]:
|
|
126
|
+
"""Translate smallaxe parameter names to CatBoost Spark parameter names."""
|
|
127
|
+
params = {
|
|
128
|
+
"iterations": self.get_param("n_estimators"),
|
|
129
|
+
"depth": self.get_param("max_depth"),
|
|
130
|
+
"learningRate": self.get_param("learning_rate"),
|
|
131
|
+
"l2LeafReg": self.get_param("l2_leaf_reg"),
|
|
132
|
+
"randomStrength": self.get_param("random_strength"),
|
|
133
|
+
"lossFunction": "RMSE",
|
|
134
|
+
"allowWritingFiles": self.get_param("allow_writing_files"),
|
|
135
|
+
"featuresCol": self.FEATURES_COL,
|
|
136
|
+
"predictionCol": self.PREDICTION_COL,
|
|
137
|
+
}
|
|
138
|
+
if label_col is not None:
|
|
139
|
+
params["labelCol"] = label_col
|
|
140
|
+
|
|
141
|
+
configured_train_dir = train_dir or self.get_param("train_dir")
|
|
142
|
+
if configured_train_dir is not None:
|
|
143
|
+
params["trainDir"] = configured_train_dir
|
|
144
|
+
|
|
145
|
+
optional_params = {
|
|
146
|
+
"subsample": self.get_param("subsample"),
|
|
147
|
+
"oneHotMaxSize": self.get_param("one_hot_max_size"),
|
|
148
|
+
"randomSeed": self.get_param("seed"),
|
|
149
|
+
}
|
|
150
|
+
params.update({name: value for name, value in optional_params.items() if value is not None})
|
|
151
|
+
return params
|
|
152
|
+
|
|
153
|
+
def _create_spark_estimator(self) -> Any:
|
|
154
|
+
"""Create the underlying CatBoost Spark regressor."""
|
|
155
|
+
return SparkCatBoostRegressor(**self._catboost_params())
|
|
156
|
+
|
|
157
|
+
def _fit_spark_model(
|
|
158
|
+
self,
|
|
159
|
+
df: DataFrame,
|
|
160
|
+
label_col: str,
|
|
161
|
+
feature_cols: List[str],
|
|
162
|
+
) -> Any:
|
|
163
|
+
"""Fit the CatBoost Spark regressor."""
|
|
164
|
+
df_with_features = self._assemble_features(df, feature_cols)
|
|
165
|
+
temp_train_dir = None
|
|
166
|
+
if self.get_param("train_dir") is None:
|
|
167
|
+
temp_train_dir = tempfile.mkdtemp(prefix="smallaxe_catboost_")
|
|
168
|
+
estimator = SparkCatBoostRegressor(
|
|
169
|
+
**self._catboost_params(label_col, train_dir=temp_train_dir)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
self._feature_cols = feature_cols
|
|
173
|
+
self._label_col = label_col
|
|
174
|
+
try:
|
|
175
|
+
self._spark_model = estimator.fit(df_with_features)
|
|
176
|
+
finally:
|
|
177
|
+
if temp_train_dir is not None:
|
|
178
|
+
shutil.rmtree(temp_train_dir, ignore_errors=True)
|
|
179
|
+
|
|
180
|
+
return self._spark_model
|
|
181
|
+
|
|
182
|
+
def _load_artifacts(self, path: str) -> None:
|
|
183
|
+
"""Load the CatBoost Spark model from disk."""
|
|
184
|
+
self._load_spark_model(path, SparkCatBoostRegressionModel)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class CatBoostClassifier(BaseClassifier):
|
|
188
|
+
"""CatBoost classifier for binary and multiclass classification tasks."""
|
|
189
|
+
|
|
190
|
+
def __init__(self, task: str = "binary") -> None:
|
|
191
|
+
"""Initialize the CatBoost classifier."""
|
|
192
|
+
_check_catboost_available()
|
|
193
|
+
super().__init__(task)
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def params(self) -> Dict[str, str]:
|
|
197
|
+
"""Get parameter descriptions."""
|
|
198
|
+
return {
|
|
199
|
+
"n_estimators": "Number of boosting iterations",
|
|
200
|
+
"max_depth": "Maximum tree depth",
|
|
201
|
+
"learning_rate": "Boosting learning rate",
|
|
202
|
+
"subsample": "Sample rate for bagging",
|
|
203
|
+
"l2_leaf_reg": "L2 regularization coefficient",
|
|
204
|
+
"random_strength": "Amount of randomness used when scoring splits",
|
|
205
|
+
"one_hot_max_size": "Maximum categorical cardinality for one-hot encoding",
|
|
206
|
+
"scale_pos_weight": "Class 1 weight multiplier for binary classification",
|
|
207
|
+
"allow_writing_files": "Whether CatBoost may write training artifacts",
|
|
208
|
+
"train_dir": "Directory for CatBoost training artifacts",
|
|
209
|
+
"seed": "Random seed for reproducibility",
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def default_params(self) -> Dict[str, Any]:
|
|
214
|
+
"""Get default parameter values."""
|
|
215
|
+
return {
|
|
216
|
+
"n_estimators": 100,
|
|
217
|
+
"max_depth": 6,
|
|
218
|
+
"learning_rate": 0.03,
|
|
219
|
+
"subsample": None,
|
|
220
|
+
"l2_leaf_reg": 3.0,
|
|
221
|
+
"random_strength": 1.0,
|
|
222
|
+
"one_hot_max_size": None,
|
|
223
|
+
"scale_pos_weight": None,
|
|
224
|
+
"allow_writing_files": False,
|
|
225
|
+
"train_dir": None,
|
|
226
|
+
"seed": None,
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
def _catboost_params(
|
|
230
|
+
self,
|
|
231
|
+
label_col: Optional[str] = None,
|
|
232
|
+
train_dir: Optional[str] = None,
|
|
233
|
+
) -> Dict[str, Any]:
|
|
234
|
+
"""Translate smallaxe parameter names to CatBoost Spark parameter names."""
|
|
235
|
+
loss_function = "Logloss" if self.task == "binary" else "MultiClass"
|
|
236
|
+
params = {
|
|
237
|
+
"iterations": self.get_param("n_estimators"),
|
|
238
|
+
"depth": self.get_param("max_depth"),
|
|
239
|
+
"learningRate": self.get_param("learning_rate"),
|
|
240
|
+
"l2LeafReg": self.get_param("l2_leaf_reg"),
|
|
241
|
+
"randomStrength": self.get_param("random_strength"),
|
|
242
|
+
"lossFunction": loss_function,
|
|
243
|
+
"allowWritingFiles": self.get_param("allow_writing_files"),
|
|
244
|
+
"featuresCol": self.FEATURES_COL,
|
|
245
|
+
"predictionCol": self.PREDICTION_COL,
|
|
246
|
+
"probabilityCol": self.PROBABILITY_COL,
|
|
247
|
+
"rawPredictionCol": self.RAW_PREDICTION_COL,
|
|
248
|
+
}
|
|
249
|
+
if label_col is not None:
|
|
250
|
+
params["labelCol"] = label_col
|
|
251
|
+
|
|
252
|
+
configured_train_dir = train_dir or self.get_param("train_dir")
|
|
253
|
+
if configured_train_dir is not None:
|
|
254
|
+
params["trainDir"] = configured_train_dir
|
|
255
|
+
|
|
256
|
+
optional_params = {
|
|
257
|
+
"subsample": self.get_param("subsample"),
|
|
258
|
+
"oneHotMaxSize": self.get_param("one_hot_max_size"),
|
|
259
|
+
"scalePosWeight": self.get_param("scale_pos_weight"),
|
|
260
|
+
"randomSeed": self.get_param("seed"),
|
|
261
|
+
}
|
|
262
|
+
params.update({name: value for name, value in optional_params.items() if value is not None})
|
|
263
|
+
return params
|
|
264
|
+
|
|
265
|
+
def _create_spark_estimator(self) -> Any:
|
|
266
|
+
"""Create the underlying CatBoost Spark classifier."""
|
|
267
|
+
return SparkCatBoostClassifier(**self._catboost_params())
|
|
268
|
+
|
|
269
|
+
def _fit_spark_model(
|
|
270
|
+
self,
|
|
271
|
+
df: DataFrame,
|
|
272
|
+
label_col: str,
|
|
273
|
+
feature_cols: List[str],
|
|
274
|
+
) -> Any:
|
|
275
|
+
"""Fit the CatBoost Spark classifier."""
|
|
276
|
+
df_with_features = self._assemble_features(df, feature_cols)
|
|
277
|
+
temp_train_dir = None
|
|
278
|
+
if self.get_param("train_dir") is None:
|
|
279
|
+
temp_train_dir = tempfile.mkdtemp(prefix="smallaxe_catboost_")
|
|
280
|
+
estimator = SparkCatBoostClassifier(
|
|
281
|
+
**self._catboost_params(label_col, train_dir=temp_train_dir)
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
self._feature_cols = feature_cols
|
|
285
|
+
self._label_col = label_col
|
|
286
|
+
try:
|
|
287
|
+
self._spark_model = estimator.fit(df_with_features)
|
|
288
|
+
finally:
|
|
289
|
+
if temp_train_dir is not None:
|
|
290
|
+
shutil.rmtree(temp_train_dir, ignore_errors=True)
|
|
291
|
+
|
|
292
|
+
return self._spark_model
|
|
293
|
+
|
|
294
|
+
def _load_artifacts(self, path: str) -> None:
|
|
295
|
+
"""Load the CatBoost Spark model from disk."""
|
|
296
|
+
self._load_spark_model(path, SparkCatBoostClassificationModel)
|