aek-auto-mlbuilder 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
1
  from .base import BaseModel
2
2
  from .utils import split_data
3
3
  from .linear_regression import LinearRegressor
4
- from .logistic_regression import LogisticClassifier
4
+ from .logistic_regression import LogisticClassifier
5
+ from .decision_tree import DecisionTreeModel
@@ -0,0 +1,45 @@
1
+ from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
2
+ from .base import BaseModel
3
+
4
+
5
+ class DecisionTreeModel(BaseModel):
6
+ """
7
+ Decision Tree for classification or regression
8
+ Use "task" param to specify "classification" or "regression"
9
+ Brute force search included
10
+ """
11
+
12
+ def __init__(self, task="classification", param_grid=None):
13
+ super().__init__()
14
+ self.task = task
15
+ self.param_grid = param_grid or {
16
+ "max_depth": [None, 3, 5, 10],
17
+ "min_samples_split": [2, 5, 10],
18
+ "min_samples_leaf": [1, 2, 4]
19
+ }
20
+
21
+ def train(self, X, y):
22
+ best_score = -float("inf")
23
+ best_model = None
24
+
25
+ if self.task.lower() == "classification":
26
+ ModelClass = DecisionTreeClassifier
27
+ else:
28
+ ModelClass = DecisionTreeRegressor
29
+
30
+ for max_depth in self.param_grid["max_depth"]:
31
+ for min_samples_split in self.param_grid["min_samples_split"]:
32
+ for min_samples_leaf in self.param_grid["min_samples_leaf"]:
33
+ model = ModelClass(
34
+ max_depth=max_depth,
35
+ min_samples_split=min_samples_split,
36
+ min_samples_leaf=min_samples_leaf
37
+ )
38
+ model.fit(X, y)
39
+ score = model.score(X, y)
40
+ if score > best_score:
41
+ best_score = score
42
+ best_model = model
43
+ self.best_model = best_model
44
+ self.best_score = best_score
45
+ return self.best_model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aek-auto-mlbuilder
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: Automatic ML model builder in Python
5
5
  Home-page: https://github.com/alpemre8/aek-auto-mlbuilder
6
6
  Author: Alp Emre Karaahmet
@@ -0,0 +1,11 @@
1
+ aek_auto_mlbuilder/__init__.py,sha256=ZvU-pWfwDlqQbfdgnrn04PoMv3Qs969YkRKGPl406h8,201
2
+ aek_auto_mlbuilder/base.py,sha256=GgMdAoceRjwz3i9rVQ0RAjvn5ZdRS-sAkLWjymbE8s0,385
3
+ aek_auto_mlbuilder/decision_tree.py,sha256=OImOHaREz2jWcANXVC6VKcatFJfzrtyCHJKXNA5-hoI,1606
4
+ aek_auto_mlbuilder/linear_regression.py,sha256=MtOSRiXDIJPd3abnz4yNT4DBtrkmvEy00Kbx4AFk4Kg,1259
5
+ aek_auto_mlbuilder/logistic_regression.py,sha256=lp9-e9p9QrqL20DmhIJaSnra1SwyMiOdTfOMlgYsNQA,1707
6
+ aek_auto_mlbuilder/utils.py,sha256=NcoM3b4Ng1Ogk3iKuz9DcMVwppGRqOLRp5g9jBCkWxY,190
7
+ aek_auto_mlbuilder-0.3.0.dist-info/licenses/LICENSE,sha256=eSVo2jJj9FB1xvr0zZ9U1fXkyjjnT6-WM3O4HSFKJOc,133
8
+ aek_auto_mlbuilder-0.3.0.dist-info/METADATA,sha256=TSWjsFxMYsNRT3MrSjC6D4zJcAt2UQBrxi9ohyR0WmE,1400
9
+ aek_auto_mlbuilder-0.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
+ aek_auto_mlbuilder-0.3.0.dist-info/top_level.txt,sha256=2ZY5rMRnVvrAH2GRGUbd6n9ey8cg_uk5iJwke0hQzFE,19
11
+ aek_auto_mlbuilder-0.3.0.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- aek_auto_mlbuilder/__init__.py,sha256=SXkgDc9Wk7algCZBmE_855M6Kmqh0OauWZ_ZCDPI_-w,156
2
- aek_auto_mlbuilder/base.py,sha256=GgMdAoceRjwz3i9rVQ0RAjvn5ZdRS-sAkLWjymbE8s0,385
3
- aek_auto_mlbuilder/linear_regression.py,sha256=MtOSRiXDIJPd3abnz4yNT4DBtrkmvEy00Kbx4AFk4Kg,1259
4
- aek_auto_mlbuilder/logistic_regression.py,sha256=lp9-e9p9QrqL20DmhIJaSnra1SwyMiOdTfOMlgYsNQA,1707
5
- aek_auto_mlbuilder/utils.py,sha256=NcoM3b4Ng1Ogk3iKuz9DcMVwppGRqOLRp5g9jBCkWxY,190
6
- aek_auto_mlbuilder-0.2.0.dist-info/licenses/LICENSE,sha256=eSVo2jJj9FB1xvr0zZ9U1fXkyjjnT6-WM3O4HSFKJOc,133
7
- aek_auto_mlbuilder-0.2.0.dist-info/METADATA,sha256=jM2zfl1x7SNJUcduNuO4sDEEj84MZCW-2kg3XMMqmps,1400
8
- aek_auto_mlbuilder-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
- aek_auto_mlbuilder-0.2.0.dist-info/top_level.txt,sha256=2ZY5rMRnVvrAH2GRGUbd6n9ey8cg_uk5iJwke0hQzFE,19
10
- aek_auto_mlbuilder-0.2.0.dist-info/RECORD,,