aek-auto-mlbuilder 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,4 +2,5 @@ from .base import BaseModel
2
2
  from .utils import split_data
3
3
  from .linear_regression import LinearRegressor
4
4
  from .logistic_regression import LogisticClassifier
5
- from .decision_tree import DecisionTreeModel
5
+ from .decision_tree import DecisionTreeModel
6
+ from. knn import KNNModel
@@ -0,0 +1,48 @@
1
+ from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
2
+ from sklearn.preprocessing import StandardScaler
3
+ from sklearn.pipeline import make_pipeline
4
+ from .base import BaseModel
5
+
6
+
7
+ class KNNModel(BaseModel):
8
+ """
9
+ KNN model for classification for classification or regression
10
+ Use "task" for "classification" or "regression"
11
+ Brute-force search is using
12
+ """
13
+ def __init__(self, task="classification", param_grid=None):
14
+ super().__init__()
15
+ self.task = task
16
+ self.param_grid = param_grid or {
17
+ "n_neighbors": [3, 5, 7, 9],
18
+ "weights": ["uniform", "distance"],
19
+ "p": [1, 2] #1: manhattan, 2: euclidean
20
+ }
21
+
22
+ def train(self, X, y):
23
+ best_score = -float("inf")
24
+ best_model = None
25
+
26
+ if self.task.lower() == "classification":
27
+ ModelClass = KNeighborsClassifier
28
+ elif self.task.lower() == "regression":
29
+ ModelClass = KNeighborsRegressor
30
+ else:
31
+ raise ValueError("task must be 'classification' or 'regression'")
32
+
33
+ for n in self.param_grid["n_neighbors"]:
34
+ for weights in self.param_grid["weights"]:
35
+ for p in self.param_grid["p"]:
36
+ model = make_pipeline(
37
+ StandardScaler(),
38
+ ModelClass(n_neighbors=n, weights=weights, p=p)
39
+ )
40
+ model.fit(X, y)
41
+ score = model.score(X, y)
42
+ if score > best_score:
43
+ best_score = score
44
+ best_model = model
45
+
46
+ self.best_model = best_model
47
+ self.best_score = best_score
48
+ return self.best_model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aek-auto-mlbuilder
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Summary: Automatic ML model builder in Python
5
5
  Home-page: https://github.com/alpemre8/aek-auto-mlbuilder
6
6
  Author: Alp Emre Karaahmet
@@ -1,11 +1,12 @@
1
- aek_auto_mlbuilder/__init__.py,sha256=ZvU-pWfwDlqQbfdgnrn04PoMv3Qs969YkRKGPl406h8,201
1
+ aek_auto_mlbuilder/__init__.py,sha256=qB67FHc_lJeKMQltNz087hseQUeIvd15njtNrGAqXRk,227
2
2
  aek_auto_mlbuilder/base.py,sha256=GgMdAoceRjwz3i9rVQ0RAjvn5ZdRS-sAkLWjymbE8s0,385
3
3
  aek_auto_mlbuilder/decision_tree.py,sha256=OImOHaREz2jWcANXVC6VKcatFJfzrtyCHJKXNA5-hoI,1606
4
+ aek_auto_mlbuilder/knn.py,sha256=bNADSq2Ce6stmiWoumRgJbyCmp5SjsbvtTDq3cimKAk,1722
4
5
  aek_auto_mlbuilder/linear_regression.py,sha256=MtOSRiXDIJPd3abnz4yNT4DBtrkmvEy00Kbx4AFk4Kg,1259
5
6
  aek_auto_mlbuilder/logistic_regression.py,sha256=lp9-e9p9QrqL20DmhIJaSnra1SwyMiOdTfOMlgYsNQA,1707
6
7
  aek_auto_mlbuilder/utils.py,sha256=NcoM3b4Ng1Ogk3iKuz9DcMVwppGRqOLRp5g9jBCkWxY,190
7
- aek_auto_mlbuilder-0.3.2.dist-info/licenses/LICENSE,sha256=eSVo2jJj9FB1xvr0zZ9U1fXkyjjnT6-WM3O4HSFKJOc,133
8
- aek_auto_mlbuilder-0.3.2.dist-info/METADATA,sha256=-uiHEiwzLRLRnqjQdmv7dIpdocp55eGXmXYlw_6XQr8,1400
9
- aek_auto_mlbuilder-0.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
- aek_auto_mlbuilder-0.3.2.dist-info/top_level.txt,sha256=2ZY5rMRnVvrAH2GRGUbd6n9ey8cg_uk5iJwke0hQzFE,19
11
- aek_auto_mlbuilder-0.3.2.dist-info/RECORD,,
8
+ aek_auto_mlbuilder-0.4.0.dist-info/licenses/LICENSE,sha256=eSVo2jJj9FB1xvr0zZ9U1fXkyjjnT6-WM3O4HSFKJOc,133
9
+ aek_auto_mlbuilder-0.4.0.dist-info/METADATA,sha256=dNYzCfMNcAIpg0CVk5CE2BxjoWn3TYhhb3D5NNxhhPc,1400
10
+ aek_auto_mlbuilder-0.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
+ aek_auto_mlbuilder-0.4.0.dist-info/top_level.txt,sha256=2ZY5rMRnVvrAH2GRGUbd6n9ey8cg_uk5iJwke0hQzFE,19
12
+ aek_auto_mlbuilder-0.4.0.dist-info/RECORD,,