aek-auto-mlbuilder 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aek_auto_mlbuilder/__init__.py +3 -0
- aek_auto_mlbuilder/base.py +12 -0
- aek_auto_mlbuilder/linear_regression.py +39 -0
- aek_auto_mlbuilder/utils.py +4 -0
- {aek_auto_mlbuilder-0.0.1.dist-info → aek_auto_mlbuilder-0.1.0.dist-info}/METADATA +1 -1
- aek_auto_mlbuilder-0.1.0.dist-info/RECORD +9 -0
- aek_auto_mlbuilder-0.1.0.dist-info/licenses/LICENSE +6 -0
- aek_auto_mlbuilder-0.0.1.dist-info/RECORD +0 -8
- aek_auto_mlbuilder-0.0.1.dist-info/licenses/LICENSE +0 -0
- {aek_auto_mlbuilder-0.0.1.dist-info → aek_auto_mlbuilder-0.1.0.dist-info}/WHEEL +0 -0
- {aek_auto_mlbuilder-0.0.1.dist-info → aek_auto_mlbuilder-0.1.0.dist-info}/top_level.txt +0 -0
aek_auto_mlbuilder/__init__.py
CHANGED
aek_auto_mlbuilder/base.py
CHANGED
@@ -0,0 +1,12 @@
|
|
1
|
+
class BaseModel:
|
2
|
+
def __init__(self):
|
3
|
+
self.best_model = None
|
4
|
+
self.best_score = None
|
5
|
+
|
6
|
+
def train(self, X, y):
|
7
|
+
raise NotImplemented("Train method must be implemented by subclass.")
|
8
|
+
|
9
|
+
def evaluate(self, X, y):
|
10
|
+
if self.best_model is None:
|
11
|
+
raise Exception("Model has not been trained yet!")
|
12
|
+
return self.best_model.score(X, y)
|
@@ -0,0 +1,39 @@
|
|
1
|
+
from sklearn.linear_model import LinearRegression
|
2
|
+
from .base import BaseModel
|
3
|
+
from sklearn.preprocessing import StandardScaler
|
4
|
+
from sklearn.pipeline import make_pipeline
|
5
|
+
|
6
|
+
class LinearRegressor(BaseModel):
|
7
|
+
"""
|
8
|
+
Basic Linear Regression class
|
9
|
+
Try parameters with brute-force.
|
10
|
+
"""
|
11
|
+
|
12
|
+
def __init__(self, param_grid=None):
|
13
|
+
super().__init__()
|
14
|
+
self.param_grid = param_grid or {
|
15
|
+
"fit_intercept": [True, False],
|
16
|
+
"normalize": [True, False]
|
17
|
+
}
|
18
|
+
|
19
|
+
def train(self, X, y):
|
20
|
+
best_score = -float("inf")
|
21
|
+
best_model = None
|
22
|
+
|
23
|
+
for fit_intercept in self.param_grid["fit_intercept"]:
|
24
|
+
for normalize in self.param_grid["normalize"]:
|
25
|
+
|
26
|
+
if normalize:
|
27
|
+
model = make_pipeline(StandardScaler(), LinearRegression(fit_intercept=fit_intercept))
|
28
|
+
else:
|
29
|
+
model = LinearRegression(fit_intercept=fit_intercept)
|
30
|
+
|
31
|
+
model.fit(X, y)
|
32
|
+
score = model.score(X, y)
|
33
|
+
if score > best_score:
|
34
|
+
best_score = score
|
35
|
+
best_model = model
|
36
|
+
|
37
|
+
self.best_model = best_model
|
38
|
+
self.best_score = best_score
|
39
|
+
return self.best_model
|
aek_auto_mlbuilder/utils.py
CHANGED
@@ -0,0 +1,9 @@
|
|
1
|
+
aek_auto_mlbuilder/__init__.py,sha256=h1Y0NFNG_7TJ37NbYaoKWvO-2FrqdjGlP2Zui6677hM,104
|
2
|
+
aek_auto_mlbuilder/base.py,sha256=GgMdAoceRjwz3i9rVQ0RAjvn5ZdRS-sAkLWjymbE8s0,385
|
3
|
+
aek_auto_mlbuilder/linear_regression.py,sha256=MtOSRiXDIJPd3abnz4yNT4DBtrkmvEy00Kbx4AFk4Kg,1259
|
4
|
+
aek_auto_mlbuilder/utils.py,sha256=NcoM3b4Ng1Ogk3iKuz9DcMVwppGRqOLRp5g9jBCkWxY,190
|
5
|
+
aek_auto_mlbuilder-0.1.0.dist-info/licenses/LICENSE,sha256=eSVo2jJj9FB1xvr0zZ9U1fXkyjjnT6-WM3O4HSFKJOc,133
|
6
|
+
aek_auto_mlbuilder-0.1.0.dist-info/METADATA,sha256=2dRUDj917EQGVM_fZorH1vs0Pg0reDN1KKGCKk3EPlE,504
|
7
|
+
aek_auto_mlbuilder-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
8
|
+
aek_auto_mlbuilder-0.1.0.dist-info/top_level.txt,sha256=2ZY5rMRnVvrAH2GRGUbd6n9ey8cg_uk5iJwke0hQzFE,19
|
9
|
+
aek_auto_mlbuilder-0.1.0.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
aek_auto_mlbuilder/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
aek_auto_mlbuilder/base.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
aek_auto_mlbuilder/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
aek_auto_mlbuilder-0.0.1.dist-info/licenses/LICENSE,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
-
aek_auto_mlbuilder-0.0.1.dist-info/METADATA,sha256=G_P3Y4T6GCNMyR6A_eOPwflcTu4u8OJhwoZCqhYlddc,504
|
6
|
-
aek_auto_mlbuilder-0.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
7
|
-
aek_auto_mlbuilder-0.0.1.dist-info/top_level.txt,sha256=2ZY5rMRnVvrAH2GRGUbd6n9ey8cg_uk5iJwke0hQzFE,19
|
8
|
-
aek_auto_mlbuilder-0.0.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|