smallaxe 0.6.2__tar.gz → 0.6.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {smallaxe-0.6.2 → smallaxe-0.6.3}/PKG-INFO +1 -1
  2. {smallaxe-0.6.2 → smallaxe-0.6.3}/TODO.md +0 -8
  3. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/training/__init__.py +13 -0
  4. smallaxe-0.6.3/smallaxe/training/lightgbm.py +441 -0
  5. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe.egg-info/PKG-INFO +1 -1
  6. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe.egg-info/SOURCES.txt +2 -0
  7. smallaxe-0.6.3/tests/test_lightgbm.py +837 -0
  8. {smallaxe-0.6.2 → smallaxe-0.6.3}/.github/workflows/ci.yml +0 -0
  9. {smallaxe-0.6.2 → smallaxe-0.6.3}/.github/workflows/publish.yml +0 -0
  10. {smallaxe-0.6.2 → smallaxe-0.6.3}/.gitignore +0 -0
  11. {smallaxe-0.6.2 → smallaxe-0.6.3}/LICENSE +0 -0
  12. {smallaxe-0.6.2 → smallaxe-0.6.3}/README.md +0 -0
  13. {smallaxe-0.6.2 → smallaxe-0.6.3}/pyproject.toml +0 -0
  14. {smallaxe-0.6.2 → smallaxe-0.6.3}/requirements-dev.txt +0 -0
  15. {smallaxe-0.6.2 → smallaxe-0.6.3}/setup.cfg +0 -0
  16. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/__init__.py +0 -0
  17. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/_config.py +0 -0
  18. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/auto/__init__.py +0 -0
  19. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/datasets/__init__.py +0 -0
  20. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/datasets/_data.py +0 -0
  21. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/exceptions/__init__.py +0 -0
  22. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/metrics/__init__.py +0 -0
  23. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/metrics/classification.py +0 -0
  24. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/metrics/regression.py +0 -0
  25. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/pipeline/__init__.py +0 -0
  26. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/pipeline/pipeline.py +0 -0
  27. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/preprocessing/__init__.py +0 -0
  28. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/preprocessing/encoder.py +0 -0
  29. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/preprocessing/imputer.py +0 -0
  30. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/preprocessing/scaler.py +0 -0
  31. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/search/__init__.py +0 -0
  32. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/training/base.py +0 -0
  33. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/training/classifiers.py +0 -0
  34. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/training/mixins/__init__.py +0 -0
  35. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/training/mixins/metadata_mixin.py +0 -0
  36. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/training/mixins/param_mixin.py +0 -0
  37. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/training/mixins/persistence_mixin.py +0 -0
  38. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/training/mixins/spark_model_mixin.py +0 -0
  39. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/training/mixins/validation_mixin.py +0 -0
  40. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/training/random_forest.py +0 -0
  41. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/training/regressors.py +0 -0
  42. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/training/xgboost.py +0 -0
  43. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe/viz/__init__.py +0 -0
  44. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe.egg-info/dependency_links.txt +0 -0
  45. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe.egg-info/requires.txt +0 -0
  46. {smallaxe-0.6.2 → smallaxe-0.6.3}/smallaxe.egg-info/top_level.txt +0 -0
  47. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/__init__.py +0 -0
  48. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/conftest.py +0 -0
  49. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_config.py +0 -0
  50. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_datasets.py +0 -0
  51. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_encoder.py +0 -0
  52. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_exceptions.py +0 -0
  53. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_factories.py +0 -0
  54. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_imputer.py +0 -0
  55. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_metrics_classification.py +0 -0
  56. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_metrics_regression.py +0 -0
  57. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_mixins.py +0 -0
  58. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_pipeline.py +0 -0
  59. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_random_forest.py +0 -0
  60. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_scaler.py +0 -0
  61. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_smoke.py +0 -0
  62. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_training_base.py +0 -0
  63. {smallaxe-0.6.2 → smallaxe-0.6.3}/tests/test_xgboost.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: smallaxe
3
- Version: 0.6.2
3
+ Version: 0.6.3
4
4
  Summary: A PySpark MLOps library for simplified model training and optimization
5
5
  Author: Henok Yemam
6
6
  License: MIT
@@ -2,14 +2,6 @@
2
2
 
3
3
  ### Phase 8: Training Module - External Algorithms (v0.7.0)
4
4
 
5
- #### Step 8.2: LightGBM
6
- - [ ] Create `smallaxe/training/lightgbm.py`
7
- - [ ] Implement `LightGBMRegressor` and `LightGBMClassifier`
8
- - [ ] Handle optional dependency
9
- - [ ] Create `tests/test_lightgbm.py`
10
- - [ ] Commit: "Add LightGBM support"
11
- - [ ] PR → main
12
-
13
5
  #### Step 8.3: CatBoost
14
6
  - [ ] Create `smallaxe/training/catboost.py`
15
7
  - [ ] Implement `CatBoostRegressor` and `CatBoostClassifier`
@@ -27,3 +27,16 @@ try:
27
27
  __all__.extend(["XGBoostRegressor", "XGBoostClassifier"])
28
28
  except ImportError:
29
29
  pass
30
+
31
+ # Import LightGBM classes if available (optional dependency)
32
+ try:
33
+ from smallaxe.training.lightgbm import (
34
+ LightGBMClassifier as LightGBMClassifier,
35
+ )
36
+ from smallaxe.training.lightgbm import (
37
+ LightGBMRegressor as LightGBMRegressor,
38
+ )
39
+
40
+ __all__.extend(["LightGBMRegressor", "LightGBMClassifier"])
41
+ except ImportError:
42
+ pass
@@ -0,0 +1,441 @@
1
+ """LightGBM models for regression and classification."""
2
+
3
+ from typing import Any, Dict, List
4
+
5
+ from pyspark.sql import DataFrame
6
+
7
+ from smallaxe.exceptions import DependencyError
8
+ from smallaxe.training.base import BaseClassifier, BaseRegressor
9
+
10
+ # Check for LightGBM availability (via SynapseML)
11
+ try:
12
+ from synapse.ml.lightgbm import (
13
+ LightGBMClassifier as SparkLightGBMClassifier,
14
+ )
15
+ from synapse.ml.lightgbm import (
16
+ LightGBMClassifierModel as SparkLightGBMClassifierModel,
17
+ )
18
+ from synapse.ml.lightgbm import (
19
+ LightGBMRegressor as SparkLightGBMRegressor,
20
+ )
21
+ from synapse.ml.lightgbm import (
22
+ LightGBMRegressorModel as SparkLightGBMRegressorModel,
23
+ )
24
+
25
+ LIGHTGBM_AVAILABLE = True
26
+ except ImportError:
27
+ LIGHTGBM_AVAILABLE = False
28
+ SparkLightGBMRegressor = None
29
+ SparkLightGBMRegressorModel = None
30
+ SparkLightGBMClassifier = None
31
+ SparkLightGBMClassifierModel = None
32
+
33
+
34
+ def _check_lightgbm_available() -> None:
35
+ """Check if LightGBM is available and raise DependencyError if not."""
36
+ if not LIGHTGBM_AVAILABLE:
37
+ raise DependencyError(
38
+ package="synapseml",
39
+ install_command=(
40
+ "pyspark --packages com.microsoft.azure:synapseml_2.12:1.1.0 "
41
+ "--repositories https://mmlspark.azureedge.net/maven"
42
+ ),
43
+ )
44
+
45
+
46
+ class LightGBMRegressor(BaseRegressor):
47
+ """LightGBM Regressor for regression tasks.
48
+
49
+ This class wraps SynapseML's LightGBMRegressor to provide
50
+ a scikit-learn-like interface with support for train/test and k-fold
51
+ cross-validation.
52
+
53
+ Note:
54
+ This requires SynapseML (v1.1.0+) which provides LightGBM integration
55
+ for Spark. Requires Scala 2.12, Spark 3.2+, and Python 3.8+.
56
+
57
+ **Standalone Spark (Maven)**::
58
+
59
+ pyspark --packages com.microsoft.azure:synapseml_2.12:1.1.0 \\
60
+ --repositories https://mmlspark.azureedge.net/maven
61
+
62
+ **Databricks**: Add Maven library to cluster with coordinates
63
+ ``com.microsoft.azure:synapseml_2.12:1.1.0`` and repository
64
+ ``https://mmlspark.azureedge.net/maven``
65
+
66
+ Args:
67
+ task: The regression task type. Default is 'simple_regression'.
68
+
69
+ Example:
70
+ >>> from smallaxe.training import LightGBMRegressor
71
+ >>> model = LightGBMRegressor()
72
+ >>> model.set_param({"n_estimators": 100, "max_depth": 6})
73
+ >>> model.fit(df, label_col='target', feature_cols=['f1', 'f2'])
74
+ >>> predictions = model.predict(df)
75
+
76
+ Raises:
77
+ DependencyError: If synapseml is not installed.
78
+ """
79
+
80
+ def __init__(self, task: str = "simple_regression") -> None:
81
+ """Initialize the LightGBM regressor.
82
+
83
+ Args:
84
+ task: The regression task type.
85
+
86
+ Raises:
87
+ DependencyError: If synapseml is not installed.
88
+ ValidationError: If task is not a valid regression task.
89
+ """
90
+ _check_lightgbm_available()
91
+ super().__init__(task)
92
+
93
+ @property
94
+ def params(self) -> Dict[str, str]:
95
+ """Get parameter descriptions.
96
+
97
+ Returns:
98
+ Dictionary mapping parameter names to their descriptions.
99
+ """
100
+ return {
101
+ "n_estimators": "Number of boosting iterations",
102
+ "max_depth": "Maximum depth of each tree (-1 for no limit)",
103
+ "learning_rate": "Boosting learning rate",
104
+ "num_leaves": "Maximum number of leaves in one tree",
105
+ "min_data_in_leaf": "Minimum number of data points in a leaf",
106
+ "feature_fraction": "Fraction of features used for training each tree",
107
+ "bagging_fraction": "Fraction of data used for training each tree",
108
+ "bagging_freq": "Frequency for bagging (0 means disable bagging)",
109
+ "lambda_l1": "L1 regularization term on weights",
110
+ "lambda_l2": "L2 regularization term on weights",
111
+ "seed": "Random seed for reproducibility",
112
+ }
113
+
114
+ @property
115
+ def default_params(self) -> Dict[str, Any]:
116
+ """Get default parameter values.
117
+
118
+ Returns:
119
+ Dictionary mapping parameter names to their default values.
120
+ """
121
+ return {
122
+ "n_estimators": 100,
123
+ "max_depth": -1,
124
+ "learning_rate": 0.1,
125
+ "num_leaves": 31,
126
+ "min_data_in_leaf": 20,
127
+ "feature_fraction": 1.0,
128
+ "bagging_fraction": 1.0,
129
+ "bagging_freq": 0,
130
+ "lambda_l1": 0.0,
131
+ "lambda_l2": 0.0,
132
+ "seed": None,
133
+ }
134
+
135
+ def _create_spark_estimator(self) -> Any:
136
+ """Create the underlying SparkLightGBMRegressor.
137
+
138
+ Returns:
139
+ Configured SparkLightGBMRegressor instance.
140
+ """
141
+ n_estimators = self.get_param("n_estimators")
142
+ max_depth = self.get_param("max_depth")
143
+ learning_rate = self.get_param("learning_rate")
144
+ num_leaves = self.get_param("num_leaves")
145
+ min_data_in_leaf = self.get_param("min_data_in_leaf")
146
+ feature_fraction = self.get_param("feature_fraction")
147
+ bagging_fraction = self.get_param("bagging_fraction")
148
+ bagging_freq = self.get_param("bagging_freq")
149
+ lambda_l1 = self.get_param("lambda_l1")
150
+ lambda_l2 = self.get_param("lambda_l2")
151
+ seed = self.get_param("seed")
152
+
153
+ estimator = SparkLightGBMRegressor(
154
+ numIterations=n_estimators,
155
+ maxDepth=max_depth,
156
+ learningRate=learning_rate,
157
+ numLeaves=num_leaves,
158
+ minDataInLeaf=min_data_in_leaf,
159
+ featureFraction=feature_fraction,
160
+ baggingFraction=bagging_fraction,
161
+ baggingFreq=bagging_freq,
162
+ lambdaL1=lambda_l1,
163
+ lambdaL2=lambda_l2,
164
+ )
165
+
166
+ if seed is not None:
167
+ estimator.setSeed(seed)
168
+
169
+ return estimator
170
+
171
+ def _fit_spark_model(
172
+ self,
173
+ df: DataFrame,
174
+ label_col: str,
175
+ feature_cols: List[str],
176
+ ) -> Any:
177
+ """Fit the LightGBM model.
178
+
179
+ Override base class method to handle LightGBM's API.
180
+
181
+ Args:
182
+ df: PySpark DataFrame with training data.
183
+ label_col: Name of the label column.
184
+ feature_cols: List of feature column names.
185
+
186
+ Returns:
187
+ Fitted LightGBM model.
188
+ """
189
+ # Assemble features
190
+ df_with_features = self._assemble_features(df, feature_cols)
191
+
192
+ # Get parameters
193
+ n_estimators = self.get_param("n_estimators")
194
+ max_depth = self.get_param("max_depth")
195
+ learning_rate = self.get_param("learning_rate")
196
+ num_leaves = self.get_param("num_leaves")
197
+ min_data_in_leaf = self.get_param("min_data_in_leaf")
198
+ feature_fraction = self.get_param("feature_fraction")
199
+ bagging_fraction = self.get_param("bagging_fraction")
200
+ bagging_freq = self.get_param("bagging_freq")
201
+ lambda_l1 = self.get_param("lambda_l1")
202
+ lambda_l2 = self.get_param("lambda_l2")
203
+ seed = self.get_param("seed")
204
+
205
+ # Create estimator with all params including column names
206
+ estimator = SparkLightGBMRegressor(
207
+ numIterations=n_estimators,
208
+ maxDepth=max_depth,
209
+ learningRate=learning_rate,
210
+ numLeaves=num_leaves,
211
+ minDataInLeaf=min_data_in_leaf,
212
+ featureFraction=feature_fraction,
213
+ baggingFraction=bagging_fraction,
214
+ baggingFreq=bagging_freq,
215
+ lambdaL1=lambda_l1,
216
+ lambdaL2=lambda_l2,
217
+ featuresCol=self.FEATURES_COL,
218
+ labelCol=label_col,
219
+ predictionCol=self.PREDICTION_COL,
220
+ )
221
+
222
+ if seed is not None:
223
+ estimator.setSeed(seed)
224
+
225
+ # Store feature columns for prediction
226
+ self._feature_cols = feature_cols
227
+ self._label_col = label_col
228
+
229
+ # Fit the model
230
+ self._spark_model = estimator.fit(df_with_features)
231
+
232
+ return self._spark_model
233
+
234
+ def _load_artifacts(self, path: str) -> None:
235
+ """Load the Spark model from disk.
236
+
237
+ Args:
238
+ path: Directory path where the model is saved.
239
+ """
240
+ self._load_spark_model(path, SparkLightGBMRegressorModel)
241
+
242
+
243
+ class LightGBMClassifier(BaseClassifier):
244
+ """LightGBM Classifier for classification tasks.
245
+
246
+ This class wraps SynapseML's LightGBMClassifier to provide
247
+ a scikit-learn-like interface with support for train/test and k-fold
248
+ cross-validation, including stratified sampling for classification.
249
+
250
+ Note:
251
+ This requires SynapseML (v1.1.0+) which provides LightGBM integration
252
+ for Spark. Requires Scala 2.12, Spark 3.2+, and Python 3.8+.
253
+
254
+ **Standalone Spark (Maven)**::
255
+
256
+ pyspark --packages com.microsoft.azure:synapseml_2.12:1.1.0 \\
257
+ --repositories https://mmlspark.azureedge.net/maven
258
+
259
+ **Databricks**: Add Maven library to cluster with coordinates
260
+ ``com.microsoft.azure:synapseml_2.12:1.1.0`` and repository
261
+ ``https://mmlspark.azureedge.net/maven``
262
+
263
+ Args:
264
+ task: The classification task type. Options are 'binary' or 'multiclass'.
265
+ Default is 'binary'.
266
+
267
+ Example:
268
+ >>> from smallaxe.training import LightGBMClassifier
269
+ >>> model = LightGBMClassifier(task='binary')
270
+ >>> model.set_param({"n_estimators": 100, "max_depth": 6})
271
+ >>> model.fit(df, label_col='label', feature_cols=['f1', 'f2'])
272
+ >>> predictions = model.predict(df)
273
+ >>> probabilities = model.predict_proba(df)
274
+
275
+ Raises:
276
+ DependencyError: If synapseml is not installed.
277
+ """
278
+
279
+ def __init__(self, task: str = "binary") -> None:
280
+ """Initialize the LightGBM classifier.
281
+
282
+ Args:
283
+ task: The classification task type.
284
+
285
+ Raises:
286
+ DependencyError: If synapseml is not installed.
287
+ ValidationError: If task is not a valid classification task.
288
+ """
289
+ _check_lightgbm_available()
290
+ super().__init__(task)
291
+
292
+ @property
293
+ def params(self) -> Dict[str, str]:
294
+ """Get parameter descriptions.
295
+
296
+ Returns:
297
+ Dictionary mapping parameter names to their descriptions.
298
+ """
299
+ return {
300
+ "n_estimators": "Number of boosting iterations",
301
+ "max_depth": "Maximum depth of each tree (-1 for no limit)",
302
+ "learning_rate": "Boosting learning rate",
303
+ "num_leaves": "Maximum number of leaves in one tree",
304
+ "min_data_in_leaf": "Minimum number of data points in a leaf",
305
+ "feature_fraction": "Fraction of features used for training each tree",
306
+ "bagging_fraction": "Fraction of data used for training each tree",
307
+ "bagging_freq": "Frequency for bagging (0 means disable bagging)",
308
+ "lambda_l1": "L1 regularization term on weights",
309
+ "lambda_l2": "L2 regularization term on weights",
310
+ "seed": "Random seed for reproducibility",
311
+ }
312
+
313
+ @property
314
+ def default_params(self) -> Dict[str, Any]:
315
+ """Get default parameter values.
316
+
317
+ Returns:
318
+ Dictionary mapping parameter names to their default values.
319
+ """
320
+ return {
321
+ "n_estimators": 100,
322
+ "max_depth": -1,
323
+ "learning_rate": 0.1,
324
+ "num_leaves": 31,
325
+ "min_data_in_leaf": 20,
326
+ "feature_fraction": 1.0,
327
+ "bagging_fraction": 1.0,
328
+ "bagging_freq": 0,
329
+ "lambda_l1": 0.0,
330
+ "lambda_l2": 0.0,
331
+ "seed": None,
332
+ }
333
+
334
+ def _create_spark_estimator(self) -> Any:
335
+ """Create the underlying SparkLightGBMClassifier.
336
+
337
+ Returns:
338
+ Configured SparkLightGBMClassifier instance.
339
+ """
340
+ n_estimators = self.get_param("n_estimators")
341
+ max_depth = self.get_param("max_depth")
342
+ learning_rate = self.get_param("learning_rate")
343
+ num_leaves = self.get_param("num_leaves")
344
+ min_data_in_leaf = self.get_param("min_data_in_leaf")
345
+ feature_fraction = self.get_param("feature_fraction")
346
+ bagging_fraction = self.get_param("bagging_fraction")
347
+ bagging_freq = self.get_param("bagging_freq")
348
+ lambda_l1 = self.get_param("lambda_l1")
349
+ lambda_l2 = self.get_param("lambda_l2")
350
+ seed = self.get_param("seed")
351
+
352
+ estimator = SparkLightGBMClassifier(
353
+ numIterations=n_estimators,
354
+ maxDepth=max_depth,
355
+ learningRate=learning_rate,
356
+ numLeaves=num_leaves,
357
+ minDataInLeaf=min_data_in_leaf,
358
+ featureFraction=feature_fraction,
359
+ baggingFraction=bagging_fraction,
360
+ baggingFreq=bagging_freq,
361
+ lambdaL1=lambda_l1,
362
+ lambdaL2=lambda_l2,
363
+ )
364
+
365
+ if seed is not None:
366
+ estimator.setSeed(seed)
367
+
368
+ return estimator
369
+
370
+ def _fit_spark_model(
371
+ self,
372
+ df: DataFrame,
373
+ label_col: str,
374
+ feature_cols: List[str],
375
+ ) -> Any:
376
+ """Fit the LightGBM classifier.
377
+
378
+ Override base class method to handle LightGBM's API.
379
+
380
+ Args:
381
+ df: PySpark DataFrame with training data.
382
+ label_col: Name of the label column.
383
+ feature_cols: List of feature column names.
384
+
385
+ Returns:
386
+ Fitted LightGBM model.
387
+ """
388
+ # Assemble features
389
+ df_with_features = self._assemble_features(df, feature_cols)
390
+
391
+ # Get parameters
392
+ n_estimators = self.get_param("n_estimators")
393
+ max_depth = self.get_param("max_depth")
394
+ learning_rate = self.get_param("learning_rate")
395
+ num_leaves = self.get_param("num_leaves")
396
+ min_data_in_leaf = self.get_param("min_data_in_leaf")
397
+ feature_fraction = self.get_param("feature_fraction")
398
+ bagging_fraction = self.get_param("bagging_fraction")
399
+ bagging_freq = self.get_param("bagging_freq")
400
+ lambda_l1 = self.get_param("lambda_l1")
401
+ lambda_l2 = self.get_param("lambda_l2")
402
+ seed = self.get_param("seed")
403
+
404
+ # Create estimator with all params including column names
405
+ estimator = SparkLightGBMClassifier(
406
+ numIterations=n_estimators,
407
+ maxDepth=max_depth,
408
+ learningRate=learning_rate,
409
+ numLeaves=num_leaves,
410
+ minDataInLeaf=min_data_in_leaf,
411
+ featureFraction=feature_fraction,
412
+ baggingFraction=bagging_fraction,
413
+ baggingFreq=bagging_freq,
414
+ lambdaL1=lambda_l1,
415
+ lambdaL2=lambda_l2,
416
+ featuresCol=self.FEATURES_COL,
417
+ labelCol=label_col,
418
+ predictionCol=self.PREDICTION_COL,
419
+ probabilityCol=self.PROBABILITY_COL,
420
+ rawPredictionCol=self.RAW_PREDICTION_COL,
421
+ )
422
+
423
+ if seed is not None:
424
+ estimator.setSeed(seed)
425
+
426
+ # Store feature columns for prediction
427
+ self._feature_cols = feature_cols
428
+ self._label_col = label_col
429
+
430
+ # Fit the model
431
+ self._spark_model = estimator.fit(df_with_features)
432
+
433
+ return self._spark_model
434
+
435
+ def _load_artifacts(self, path: str) -> None:
436
+ """Load the Spark model from disk.
437
+
438
+ Args:
439
+ path: Directory path where the model is saved.
440
+ """
441
+ self._load_spark_model(path, SparkLightGBMClassifierModel)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: smallaxe
3
- Version: 0.6.2
3
+ Version: 0.6.3
4
4
  Summary: A PySpark MLOps library for simplified model training and optimization
5
5
  Author: Henok Yemam
6
6
  License: MIT
@@ -30,6 +30,7 @@ smallaxe/search/__init__.py
30
30
  smallaxe/training/__init__.py
31
31
  smallaxe/training/base.py
32
32
  smallaxe/training/classifiers.py
33
+ smallaxe/training/lightgbm.py
33
34
  smallaxe/training/random_forest.py
34
35
  smallaxe/training/regressors.py
35
36
  smallaxe/training/xgboost.py
@@ -48,6 +49,7 @@ tests/test_encoder.py
48
49
  tests/test_exceptions.py
49
50
  tests/test_factories.py
50
51
  tests/test_imputer.py
52
+ tests/test_lightgbm.py
51
53
  tests/test_metrics_classification.py
52
54
  tests/test_metrics_regression.py
53
55
  tests/test_mixins.py