aek-auto-mlbuilder 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aek_auto_mlbuilder/__init__.py +2 -1
- aek_auto_mlbuilder/logistic_regression.py +48 -0
- {aek_auto_mlbuilder-0.1.1.dist-info → aek_auto_mlbuilder-0.2.0.dist-info}/METADATA +1 -1
- aek_auto_mlbuilder-0.2.0.dist-info/RECORD +10 -0
- aek_auto_mlbuilder-0.1.1.dist-info/RECORD +0 -9
- {aek_auto_mlbuilder-0.1.1.dist-info → aek_auto_mlbuilder-0.2.0.dist-info}/WHEEL +0 -0
- {aek_auto_mlbuilder-0.1.1.dist-info → aek_auto_mlbuilder-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {aek_auto_mlbuilder-0.1.1.dist-info → aek_auto_mlbuilder-0.2.0.dist-info}/top_level.txt +0 -0
aek_auto_mlbuilder/__init__.py
CHANGED
@@ -0,0 +1,48 @@
|
|
1
|
+
from sklearn.linear_model import LogisticRegression
|
2
|
+
from sklearn.preprocessing import StandardScaler
|
3
|
+
from sklearn.pipeline import make_pipeline
|
4
|
+
from .base import BaseModel
|
5
|
+
|
6
|
+
|
7
|
+
|
8
|
+
class LogisticClassifier(BaseModel):
|
9
|
+
"""
|
10
|
+
Basic Logistic Regression class for binary/multi-class classification.
|
11
|
+
Brute-force parameter search included.
|
12
|
+
"""
|
13
|
+
|
14
|
+
def __init__(self, param_grid=None):
|
15
|
+
super().__init__()
|
16
|
+
self.param_grid = param_grid or {
|
17
|
+
"C": [0.01, 0.1, 1, 10],
|
18
|
+
"penalty": ["l2"],
|
19
|
+
"solver": ["lbfgs"],
|
20
|
+
"fit_intercept": [True, False]
|
21
|
+
}
|
22
|
+
def train(self, X, y):
|
23
|
+
best_score = -float("inf")
|
24
|
+
best_model = None
|
25
|
+
|
26
|
+
for C in self.param_grid["C"]:
|
27
|
+
for penalty in self.param_grid["penalty"]:
|
28
|
+
for solver in self.param_grid["solver"]:
|
29
|
+
for fit_intercept in self.param_grid["fit_intercept"]:
|
30
|
+
model = make_pipeline(
|
31
|
+
StandardScaler(),
|
32
|
+
LogisticRegression(
|
33
|
+
C=C,
|
34
|
+
penalty=penalty,
|
35
|
+
solver=solver,
|
36
|
+
fit_intercept=fit_intercept,
|
37
|
+
max_iter=1000
|
38
|
+
)
|
39
|
+
)
|
40
|
+
model.fit(X, y)
|
41
|
+
score = model.score(X, y)
|
42
|
+
if score > best_score:
|
43
|
+
best_score = score
|
44
|
+
best_model = model
|
45
|
+
|
46
|
+
self.best_model = best_model
|
47
|
+
self.best_score = best_score
|
48
|
+
return self.best_model
|
@@ -0,0 +1,10 @@
|
|
1
|
+
aek_auto_mlbuilder/__init__.py,sha256=SXkgDc9Wk7algCZBmE_855M6Kmqh0OauWZ_ZCDPI_-w,156
|
2
|
+
aek_auto_mlbuilder/base.py,sha256=GgMdAoceRjwz3i9rVQ0RAjvn5ZdRS-sAkLWjymbE8s0,385
|
3
|
+
aek_auto_mlbuilder/linear_regression.py,sha256=MtOSRiXDIJPd3abnz4yNT4DBtrkmvEy00Kbx4AFk4Kg,1259
|
4
|
+
aek_auto_mlbuilder/logistic_regression.py,sha256=lp9-e9p9QrqL20DmhIJaSnra1SwyMiOdTfOMlgYsNQA,1707
|
5
|
+
aek_auto_mlbuilder/utils.py,sha256=NcoM3b4Ng1Ogk3iKuz9DcMVwppGRqOLRp5g9jBCkWxY,190
|
6
|
+
aek_auto_mlbuilder-0.2.0.dist-info/licenses/LICENSE,sha256=eSVo2jJj9FB1xvr0zZ9U1fXkyjjnT6-WM3O4HSFKJOc,133
|
7
|
+
aek_auto_mlbuilder-0.2.0.dist-info/METADATA,sha256=jM2zfl1x7SNJUcduNuO4sDEEj84MZCW-2kg3XMMqmps,1400
|
8
|
+
aek_auto_mlbuilder-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
9
|
+
aek_auto_mlbuilder-0.2.0.dist-info/top_level.txt,sha256=2ZY5rMRnVvrAH2GRGUbd6n9ey8cg_uk5iJwke0hQzFE,19
|
10
|
+
aek_auto_mlbuilder-0.2.0.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
aek_auto_mlbuilder/__init__.py,sha256=h1Y0NFNG_7TJ37NbYaoKWvO-2FrqdjGlP2Zui6677hM,104
|
2
|
-
aek_auto_mlbuilder/base.py,sha256=GgMdAoceRjwz3i9rVQ0RAjvn5ZdRS-sAkLWjymbE8s0,385
|
3
|
-
aek_auto_mlbuilder/linear_regression.py,sha256=MtOSRiXDIJPd3abnz4yNT4DBtrkmvEy00Kbx4AFk4Kg,1259
|
4
|
-
aek_auto_mlbuilder/utils.py,sha256=NcoM3b4Ng1Ogk3iKuz9DcMVwppGRqOLRp5g9jBCkWxY,190
|
5
|
-
aek_auto_mlbuilder-0.1.1.dist-info/licenses/LICENSE,sha256=eSVo2jJj9FB1xvr0zZ9U1fXkyjjnT6-WM3O4HSFKJOc,133
|
6
|
-
aek_auto_mlbuilder-0.1.1.dist-info/METADATA,sha256=S0_iXPkKXa5kwz0CcTUBym5xCjaenkGrZCStUrFBwII,1400
|
7
|
-
aek_auto_mlbuilder-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
8
|
-
aek_auto_mlbuilder-0.1.1.dist-info/top_level.txt,sha256=2ZY5rMRnVvrAH2GRGUbd6n9ey8cg_uk5iJwke0hQzFE,19
|
9
|
-
aek_auto_mlbuilder-0.1.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|