zero-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentz/agent/base.py +262 -0
- agentz/artifacts/__init__.py +5 -0
- agentz/artifacts/artifact_writer.py +538 -0
- agentz/artifacts/reporter.py +235 -0
- agentz/artifacts/terminal_writer.py +100 -0
- agentz/context/__init__.py +6 -0
- agentz/context/context.py +91 -0
- agentz/context/conversation.py +205 -0
- agentz/context/data_store.py +208 -0
- agentz/llm/llm_setup.py +156 -0
- agentz/mcp/manager.py +142 -0
- agentz/mcp/patches.py +88 -0
- agentz/mcp/servers/chrome_devtools/server.py +14 -0
- agentz/profiles/base.py +108 -0
- agentz/profiles/data/data_analysis.py +38 -0
- agentz/profiles/data/data_loader.py +35 -0
- agentz/profiles/data/evaluation.py +43 -0
- agentz/profiles/data/model_training.py +47 -0
- agentz/profiles/data/preprocessing.py +47 -0
- agentz/profiles/data/visualization.py +47 -0
- agentz/profiles/manager/evaluate.py +51 -0
- agentz/profiles/manager/memory.py +62 -0
- agentz/profiles/manager/observe.py +48 -0
- agentz/profiles/manager/routing.py +66 -0
- agentz/profiles/manager/writer.py +51 -0
- agentz/profiles/mcp/browser.py +21 -0
- agentz/profiles/mcp/chrome.py +21 -0
- agentz/profiles/mcp/notion.py +21 -0
- agentz/runner/__init__.py +74 -0
- agentz/runner/base.py +28 -0
- agentz/runner/executor.py +320 -0
- agentz/runner/hooks.py +110 -0
- agentz/runner/iteration.py +142 -0
- agentz/runner/patterns.py +215 -0
- agentz/runner/tracker.py +188 -0
- agentz/runner/utils.py +45 -0
- agentz/runner/workflow.py +250 -0
- agentz/tools/__init__.py +20 -0
- agentz/tools/data_tools/__init__.py +17 -0
- agentz/tools/data_tools/data_analysis.py +152 -0
- agentz/tools/data_tools/data_loading.py +92 -0
- agentz/tools/data_tools/evaluation.py +175 -0
- agentz/tools/data_tools/helpers.py +120 -0
- agentz/tools/data_tools/model_training.py +192 -0
- agentz/tools/data_tools/preprocessing.py +229 -0
- agentz/tools/data_tools/visualization.py +281 -0
- agentz/utils/__init__.py +69 -0
- agentz/utils/config.py +708 -0
- agentz/utils/helpers.py +10 -0
- agentz/utils/parsers.py +142 -0
- agentz/utils/printer.py +539 -0
- pipelines/base.py +972 -0
- pipelines/data_scientist.py +97 -0
- pipelines/data_scientist_memory.py +151 -0
- pipelines/experience_learner.py +0 -0
- pipelines/prompt_generator.py +0 -0
- pipelines/simple.py +78 -0
- pipelines/simple_browser.py +145 -0
- pipelines/simple_chrome.py +75 -0
- pipelines/simple_notion.py +103 -0
- pipelines/tool_builder.py +0 -0
- zero_agent-0.1.0.dist-info/METADATA +269 -0
- zero_agent-0.1.0.dist-info/RECORD +66 -0
- zero_agent-0.1.0.dist-info/WHEEL +5 -0
- zero_agent-0.1.0.dist-info/licenses/LICENSE +21 -0
- zero_agent-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,175 @@
|
|
1
|
+
"""Model evaluation tool for assessing model performance."""
|
2
|
+
|
3
|
+
from typing import Union, Dict, Any, Optional
|
4
|
+
from pathlib import Path
|
5
|
+
import pandas as pd
|
6
|
+
import numpy as np
|
7
|
+
from sklearn.model_selection import train_test_split, cross_val_score
|
8
|
+
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
9
|
+
from sklearn.metrics import (
|
10
|
+
accuracy_score, precision_score, recall_score, f1_score,
|
11
|
+
confusion_matrix, classification_report,
|
12
|
+
mean_squared_error, mean_absolute_error, r2_score
|
13
|
+
)
|
14
|
+
from agents import function_tool
|
15
|
+
from agents.run_context import RunContextWrapper
|
16
|
+
from agentz.context.data_store import DataStore
|
17
|
+
from .helpers import load_or_get_dataframe
|
18
|
+
from loguru import logger
|
19
|
+
|
20
|
+
|
21
|
+
@function_tool
|
22
|
+
async def evaluate_model(
|
23
|
+
ctx: RunContextWrapper[DataStore],
|
24
|
+
target_column: str,
|
25
|
+
file_path: Optional[str] = None,
|
26
|
+
model_type: str = "random_forest",
|
27
|
+
test_size: float = 0.2,
|
28
|
+
random_state: int = 42
|
29
|
+
) -> Union[Dict[str, Any], str]:
|
30
|
+
"""Evaluates machine learning model performance with comprehensive metrics.
|
31
|
+
|
32
|
+
This tool automatically uses the current dataset from the pipeline context.
|
33
|
+
A file_path can optionally be provided to evaluate on a different dataset.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
ctx: Pipeline context wrapper for accessing the data store
|
37
|
+
target_column: Name of the target column to predict
|
38
|
+
file_path: Optional path to dataset file. If not provided, uses current dataset.
|
39
|
+
model_type: Type of model to evaluate (random_forest, decision_tree, etc.)
|
40
|
+
test_size: Proportion of data to use for testing (default: 0.2)
|
41
|
+
random_state: Random seed for reproducibility (default: 42)
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
Dictionary containing:
|
45
|
+
- problem_type: "classification" or "regression"
|
46
|
+
- metrics: Performance metrics
|
47
|
+
- confusion_matrix: Confusion matrix (for classification)
|
48
|
+
- classification_report: Detailed classification report
|
49
|
+
- cross_validation: Cross-validation results
|
50
|
+
- error_analysis: Error distribution analysis
|
51
|
+
Or error message string if evaluation fails
|
52
|
+
"""
|
53
|
+
try:
|
54
|
+
# Get DataFrame - either from file_path or current dataset
|
55
|
+
data_store = ctx.context
|
56
|
+
if file_path is None:
|
57
|
+
if data_store and data_store.has("current_dataset"):
|
58
|
+
df = data_store.get("current_dataset")
|
59
|
+
logger.info("Evaluating model on current dataset from pipeline context")
|
60
|
+
else:
|
61
|
+
return "Error: No dataset loaded. Please load a dataset first using the load_dataset tool."
|
62
|
+
else:
|
63
|
+
df = load_or_get_dataframe(file_path, prefer_preprocessed=True, data_store=data_store)
|
64
|
+
logger.info(f"Evaluating model on dataset from: {file_path}")
|
65
|
+
|
66
|
+
if target_column not in df.columns:
|
67
|
+
return f"Target column '{target_column}' not found in dataset"
|
68
|
+
|
69
|
+
# Separate features and target
|
70
|
+
X = df.drop(columns=[target_column])
|
71
|
+
y = df[target_column]
|
72
|
+
|
73
|
+
# Handle categorical features
|
74
|
+
X = pd.get_dummies(X, drop_first=True)
|
75
|
+
|
76
|
+
# Determine problem type
|
77
|
+
is_classification = y.dtype == 'object' or y.nunique() < 20
|
78
|
+
|
79
|
+
# Encode target if categorical
|
80
|
+
original_labels = None
|
81
|
+
if is_classification and y.dtype == 'object':
|
82
|
+
from sklearn.preprocessing import LabelEncoder
|
83
|
+
le = LabelEncoder()
|
84
|
+
original_labels = le.classes_
|
85
|
+
y = le.fit_transform(y)
|
86
|
+
|
87
|
+
# Split data
|
88
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
89
|
+
X, y, test_size=test_size, random_state=random_state
|
90
|
+
)
|
91
|
+
|
92
|
+
# Train model
|
93
|
+
if is_classification:
|
94
|
+
model = RandomForestClassifier(random_state=random_state, n_estimators=100)
|
95
|
+
else:
|
96
|
+
model = RandomForestRegressor(random_state=random_state, n_estimators=100)
|
97
|
+
|
98
|
+
model.fit(X_train, y_train)
|
99
|
+
y_pred = model.predict(X_test)
|
100
|
+
|
101
|
+
result = {
|
102
|
+
"problem_type": "classification" if is_classification else "regression",
|
103
|
+
}
|
104
|
+
|
105
|
+
if is_classification:
|
106
|
+
# Classification metrics
|
107
|
+
metrics = {
|
108
|
+
"accuracy": float(accuracy_score(y_test, y_pred)),
|
109
|
+
"precision": float(precision_score(y_test, y_pred, average='weighted', zero_division=0)),
|
110
|
+
"recall": float(recall_score(y_test, y_pred, average='weighted', zero_division=0)),
|
111
|
+
"f1_score": float(f1_score(y_test, y_pred, average='weighted', zero_division=0)),
|
112
|
+
}
|
113
|
+
result["metrics"] = metrics
|
114
|
+
|
115
|
+
# Confusion matrix
|
116
|
+
cm = confusion_matrix(y_test, y_pred)
|
117
|
+
result["confusion_matrix"] = cm.tolist()
|
118
|
+
|
119
|
+
# Classification report
|
120
|
+
if original_labels is not None:
|
121
|
+
target_names = [str(label) for label in original_labels]
|
122
|
+
else:
|
123
|
+
target_names = [str(i) for i in sorted(np.unique(y))]
|
124
|
+
|
125
|
+
class_report = classification_report(y_test, y_pred, target_names=target_names, output_dict=True, zero_division=0)
|
126
|
+
result["classification_report"] = class_report
|
127
|
+
|
128
|
+
# Per-class accuracy
|
129
|
+
per_class_accuracy = {}
|
130
|
+
for i, label in enumerate(target_names):
|
131
|
+
mask = y_test == i
|
132
|
+
if mask.sum() > 0:
|
133
|
+
per_class_accuracy[label] = float(accuracy_score(y_test[mask], y_pred[mask]))
|
134
|
+
result["per_class_accuracy"] = per_class_accuracy
|
135
|
+
|
136
|
+
else:
|
137
|
+
# Regression metrics
|
138
|
+
metrics = {
|
139
|
+
"r2_score": float(r2_score(y_test, y_pred)),
|
140
|
+
"mean_squared_error": float(mean_squared_error(y_test, y_pred)),
|
141
|
+
"root_mean_squared_error": float(np.sqrt(mean_squared_error(y_test, y_pred))),
|
142
|
+
"mean_absolute_error": float(mean_absolute_error(y_test, y_pred)),
|
143
|
+
"mean_absolute_percentage_error": float(np.mean(np.abs((y_test - y_pred) / y_test)) * 100),
|
144
|
+
}
|
145
|
+
result["metrics"] = metrics
|
146
|
+
|
147
|
+
# Error analysis
|
148
|
+
errors = y_test - y_pred
|
149
|
+
error_analysis = {
|
150
|
+
"mean_error": float(np.mean(errors)),
|
151
|
+
"std_error": float(np.std(errors)),
|
152
|
+
"min_error": float(np.min(errors)),
|
153
|
+
"max_error": float(np.max(errors)),
|
154
|
+
"median_error": float(np.median(errors)),
|
155
|
+
}
|
156
|
+
result["error_analysis"] = error_analysis
|
157
|
+
|
158
|
+
# Cross-validation
|
159
|
+
cv_scores = cross_val_score(model, X, y, cv=5)
|
160
|
+
result["cross_validation"] = {
|
161
|
+
"scores": cv_scores.tolist(),
|
162
|
+
"mean": float(cv_scores.mean()),
|
163
|
+
"std": float(cv_scores.std()),
|
164
|
+
}
|
165
|
+
|
166
|
+
# Feature importance
|
167
|
+
if hasattr(model, 'feature_importances_'):
|
168
|
+
importance_dict = dict(zip(X.columns, model.feature_importances_))
|
169
|
+
sorted_importance = dict(sorted(importance_dict.items(), key=lambda x: x[1], reverse=True))
|
170
|
+
result["feature_importance"] = {k: float(v) for k, v in list(sorted_importance.items())[:10]}
|
171
|
+
|
172
|
+
return result
|
173
|
+
|
174
|
+
except Exception as e:
|
175
|
+
return f"Error evaluating model: {str(e)}"
|
@@ -0,0 +1,120 @@
|
|
1
|
+
"""Helper utilities for data tools to access cached DataFrames."""
|
2
|
+
|
3
|
+
from typing import Optional
|
4
|
+
from pathlib import Path
|
5
|
+
import pandas as pd
|
6
|
+
from agentz.context.data_store import DataStore
|
7
|
+
from loguru import logger
|
8
|
+
|
9
|
+
|
10
|
+
def get_dataframe(file_path: str, prefer_preprocessed: bool = False, data_store: Optional[DataStore] = None) -> Optional[pd.DataFrame]:
|
11
|
+
"""Get a DataFrame from cache or load from file.
|
12
|
+
|
13
|
+
Args:
|
14
|
+
file_path: Path to the dataset file
|
15
|
+
prefer_preprocessed: If True, try to get preprocessed version first
|
16
|
+
data_store: Pipeline data store instance
|
17
|
+
|
18
|
+
Returns:
|
19
|
+
DataFrame or None if not found
|
20
|
+
"""
|
21
|
+
file_path = Path(file_path)
|
22
|
+
|
23
|
+
if not data_store:
|
24
|
+
return None
|
25
|
+
|
26
|
+
# Try preprocessed first if requested
|
27
|
+
if prefer_preprocessed:
|
28
|
+
preprocessed_key = f"preprocessed:{file_path.resolve()}"
|
29
|
+
if data_store.has(preprocessed_key):
|
30
|
+
logger.info(f"Using cached preprocessed DataFrame for {file_path}")
|
31
|
+
return data_store.get(preprocessed_key)
|
32
|
+
|
33
|
+
# Try regular DataFrame
|
34
|
+
cache_key = f"dataframe:{file_path.resolve()}"
|
35
|
+
if data_store.has(cache_key):
|
36
|
+
logger.info(f"Using cached DataFrame for {file_path}")
|
37
|
+
return data_store.get(cache_key)
|
38
|
+
|
39
|
+
return None
|
40
|
+
|
41
|
+
|
42
|
+
def load_or_get_dataframe(file_path: str, prefer_preprocessed: bool = False, data_store: Optional[DataStore] = None) -> pd.DataFrame:
|
43
|
+
"""Get DataFrame from cache or load from file, with fallback loading.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
file_path: Path to the dataset file
|
47
|
+
prefer_preprocessed: If True, try to get preprocessed version first
|
48
|
+
data_store: Pipeline data store instance
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
DataFrame
|
52
|
+
|
53
|
+
Raises:
|
54
|
+
FileNotFoundError: If file doesn't exist and not in cache
|
55
|
+
ValueError: If file format is not supported
|
56
|
+
"""
|
57
|
+
# Try cache first
|
58
|
+
df = get_dataframe(file_path, prefer_preprocessed, data_store)
|
59
|
+
if df is not None:
|
60
|
+
return df
|
61
|
+
|
62
|
+
# Load from file
|
63
|
+
file_path = Path(file_path)
|
64
|
+
if not file_path.exists():
|
65
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
66
|
+
|
67
|
+
if file_path.suffix.lower() == '.csv':
|
68
|
+
df = pd.read_csv(file_path)
|
69
|
+
elif file_path.suffix.lower() in ['.xlsx', '.xls']:
|
70
|
+
df = pd.read_excel(file_path)
|
71
|
+
elif file_path.suffix.lower() == '.json':
|
72
|
+
df = pd.read_json(file_path)
|
73
|
+
elif file_path.suffix.lower() == '.parquet':
|
74
|
+
df = pd.read_parquet(file_path)
|
75
|
+
else:
|
76
|
+
raise ValueError(f"Unsupported file format: {file_path.suffix}")
|
77
|
+
|
78
|
+
# Cache it for future use
|
79
|
+
if data_store:
|
80
|
+
cache_key = f"dataframe:{file_path.resolve()}"
|
81
|
+
data_store.set(
|
82
|
+
cache_key,
|
83
|
+
df,
|
84
|
+
data_type="dataframe",
|
85
|
+
metadata={"file_path": str(file_path), "shape": df.shape}
|
86
|
+
)
|
87
|
+
logger.info(f"Cached DataFrame from {file_path}")
|
88
|
+
|
89
|
+
return df
|
90
|
+
|
91
|
+
|
92
|
+
def cache_object(key: str, obj: any, data_type: str = None, metadata: dict = None, data_store: Optional[DataStore] = None) -> None:
|
93
|
+
"""Cache an object in the pipeline data store.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
key: Cache key
|
97
|
+
obj: Object to cache
|
98
|
+
data_type: Type descriptor (e.g., 'model', 'scaler')
|
99
|
+
metadata: Optional metadata
|
100
|
+
data_store: Pipeline data store instance
|
101
|
+
"""
|
102
|
+
if data_store:
|
103
|
+
data_store.set(key, obj, data_type=data_type, metadata=metadata)
|
104
|
+
logger.info(f"Cached {data_type or 'object'} with key: {key}")
|
105
|
+
|
106
|
+
|
107
|
+
def get_cached_object(key: str, data_store: Optional[DataStore] = None) -> Optional[any]:
|
108
|
+
"""Get a cached object from the pipeline data store.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
key: Cache key
|
112
|
+
data_store: Pipeline data store instance
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
Cached object or None
|
116
|
+
"""
|
117
|
+
if data_store and data_store.has(key):
|
118
|
+
logger.info(f"Retrieved cached object with key: {key}")
|
119
|
+
return data_store.get(key)
|
120
|
+
return None
|
@@ -0,0 +1,192 @@
|
|
1
|
+
"""Model training tool for training machine learning models."""
|
2
|
+
|
3
|
+
from typing import Union, Dict, Any, Optional, List
|
4
|
+
from pathlib import Path
|
5
|
+
import pandas as pd
|
6
|
+
import numpy as np
|
7
|
+
from sklearn.model_selection import train_test_split, cross_val_score
|
8
|
+
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
9
|
+
from sklearn.linear_model import LogisticRegression, LinearRegression
|
10
|
+
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
11
|
+
from sklearn.metrics import accuracy_score, mean_squared_error, r2_score
|
12
|
+
from agents import function_tool
|
13
|
+
from agents.run_context import RunContextWrapper
|
14
|
+
from agentz.context.data_store import DataStore
|
15
|
+
from .helpers import load_or_get_dataframe, cache_object
|
16
|
+
from loguru import logger
|
17
|
+
|
18
|
+
|
19
|
+
@function_tool
|
20
|
+
async def train_model(
|
21
|
+
ctx: RunContextWrapper[DataStore],
|
22
|
+
target_column: str,
|
23
|
+
file_path: Optional[str] = None,
|
24
|
+
model_type: str = "auto",
|
25
|
+
test_size: float = 0.2,
|
26
|
+
random_state: int = 42
|
27
|
+
) -> Union[Dict[str, Any], str]:
|
28
|
+
"""Trains machine learning models on a dataset.
|
29
|
+
|
30
|
+
This tool automatically uses the current dataset from the pipeline context.
|
31
|
+
A file_path can optionally be provided to train on a different dataset.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
ctx: Pipeline context wrapper for accessing the data store
|
35
|
+
target_column: Name of the target column to predict
|
36
|
+
file_path: Optional path to dataset file. If not provided, uses current dataset.
|
37
|
+
model_type: Type of model to train. Options:
|
38
|
+
- "auto": Automatically detect and use best model
|
39
|
+
- "random_forest": Random Forest
|
40
|
+
- "logistic_regression": Logistic Regression (classification)
|
41
|
+
- "linear_regression": Linear Regression (regression)
|
42
|
+
- "decision_tree": Decision Tree
|
43
|
+
test_size: Proportion of data to use for testing (default: 0.2)
|
44
|
+
random_state: Random seed for reproducibility (default: 42)
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
Dictionary containing:
|
48
|
+
- model_type: Type of model trained
|
49
|
+
- problem_type: "classification" or "regression"
|
50
|
+
- train_score: Training score
|
51
|
+
- test_score: Testing score
|
52
|
+
- cross_val_scores: Cross-validation scores (mean and std)
|
53
|
+
- feature_importance: Feature importance scores (if available)
|
54
|
+
- predictions_sample: Sample of predictions vs actual values
|
55
|
+
Or error message string if training fails
|
56
|
+
"""
|
57
|
+
try:
|
58
|
+
# Get DataFrame - either from file_path or current dataset
|
59
|
+
data_store = ctx.context
|
60
|
+
if file_path is None:
|
61
|
+
if data_store and data_store.has("current_dataset"):
|
62
|
+
df = data_store.get("current_dataset")
|
63
|
+
logger.info("Training model on current dataset from pipeline context")
|
64
|
+
else:
|
65
|
+
return "Error: No dataset loaded. Please load a dataset first using the load_dataset tool."
|
66
|
+
else:
|
67
|
+
df = load_or_get_dataframe(file_path, prefer_preprocessed=True, data_store=data_store)
|
68
|
+
logger.info(f"Training model on dataset from: {file_path}")
|
69
|
+
|
70
|
+
if target_column not in df.columns:
|
71
|
+
return f"Target column '{target_column}' not found in dataset"
|
72
|
+
|
73
|
+
# Separate features and target
|
74
|
+
X = df.drop(columns=[target_column])
|
75
|
+
y = df[target_column]
|
76
|
+
|
77
|
+
# Handle categorical features in X
|
78
|
+
X = pd.get_dummies(X, drop_first=True)
|
79
|
+
|
80
|
+
# Determine problem type
|
81
|
+
is_classification = y.dtype == 'object' or y.nunique() < 20
|
82
|
+
|
83
|
+
# Encode target if categorical
|
84
|
+
if is_classification and y.dtype == 'object':
|
85
|
+
from sklearn.preprocessing import LabelEncoder
|
86
|
+
le = LabelEncoder()
|
87
|
+
y = le.fit_transform(y)
|
88
|
+
|
89
|
+
# Split data
|
90
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
91
|
+
X, y, test_size=test_size, random_state=random_state
|
92
|
+
)
|
93
|
+
|
94
|
+
# Select model
|
95
|
+
if model_type == "auto":
|
96
|
+
if is_classification:
|
97
|
+
model = RandomForestClassifier(random_state=random_state, n_estimators=100)
|
98
|
+
model_name = "Random Forest Classifier"
|
99
|
+
else:
|
100
|
+
model = RandomForestRegressor(random_state=random_state, n_estimators=100)
|
101
|
+
model_name = "Random Forest Regressor"
|
102
|
+
elif model_type == "random_forest":
|
103
|
+
if is_classification:
|
104
|
+
model = RandomForestClassifier(random_state=random_state, n_estimators=100)
|
105
|
+
model_name = "Random Forest Classifier"
|
106
|
+
else:
|
107
|
+
model = RandomForestRegressor(random_state=random_state, n_estimators=100)
|
108
|
+
model_name = "Random Forest Regressor"
|
109
|
+
elif model_type == "logistic_regression":
|
110
|
+
model = LogisticRegression(random_state=random_state, max_iter=1000)
|
111
|
+
model_name = "Logistic Regression"
|
112
|
+
elif model_type == "linear_regression":
|
113
|
+
model = LinearRegression()
|
114
|
+
model_name = "Linear Regression"
|
115
|
+
elif model_type == "decision_tree":
|
116
|
+
if is_classification:
|
117
|
+
model = DecisionTreeClassifier(random_state=random_state)
|
118
|
+
model_name = "Decision Tree Classifier"
|
119
|
+
else:
|
120
|
+
model = DecisionTreeRegressor(random_state=random_state)
|
121
|
+
model_name = "Decision Tree Regressor"
|
122
|
+
else:
|
123
|
+
return f"Unknown model type: {model_type}"
|
124
|
+
|
125
|
+
# Train model
|
126
|
+
model.fit(X_train, y_train)
|
127
|
+
|
128
|
+
# Evaluate
|
129
|
+
train_pred = model.predict(X_train)
|
130
|
+
test_pred = model.predict(X_test)
|
131
|
+
|
132
|
+
if is_classification:
|
133
|
+
train_score = accuracy_score(y_train, train_pred)
|
134
|
+
test_score = accuracy_score(y_test, test_pred)
|
135
|
+
metric_name = "accuracy"
|
136
|
+
else:
|
137
|
+
train_score = r2_score(y_train, train_pred)
|
138
|
+
test_score = r2_score(y_test, test_pred)
|
139
|
+
metric_name = "r2_score"
|
140
|
+
|
141
|
+
# Cross-validation
|
142
|
+
cv_scores = cross_val_score(model, X, y, cv=5)
|
143
|
+
|
144
|
+
# Feature importance
|
145
|
+
feature_importance = {}
|
146
|
+
if hasattr(model, 'feature_importances_'):
|
147
|
+
importance_dict = dict(zip(X.columns, model.feature_importances_))
|
148
|
+
# Sort by importance
|
149
|
+
feature_importance = dict(sorted(importance_dict.items(), key=lambda x: x[1], reverse=True))
|
150
|
+
|
151
|
+
# Sample predictions
|
152
|
+
predictions_sample = []
|
153
|
+
for i in range(min(10, len(y_test))):
|
154
|
+
predictions_sample.append({
|
155
|
+
"actual": float(y_test.iloc[i]) if hasattr(y_test, 'iloc') else float(y_test[i]),
|
156
|
+
"predicted": float(test_pred[i]),
|
157
|
+
})
|
158
|
+
|
159
|
+
# Cache the trained model for reuse
|
160
|
+
file_path_obj = Path(file_path) if file_path else Path("model")
|
161
|
+
model_key = f"model:{file_path_obj.resolve()}"
|
162
|
+
cache_object(
|
163
|
+
model_key,
|
164
|
+
model,
|
165
|
+
data_type="model",
|
166
|
+
data_store=data_store,
|
167
|
+
metadata={
|
168
|
+
"file_path": str(file_path),
|
169
|
+
"model_type": model_name,
|
170
|
+
"target_column": target_column,
|
171
|
+
"test_score": float(test_score)
|
172
|
+
}
|
173
|
+
)
|
174
|
+
|
175
|
+
result = {
|
176
|
+
"model_type": model_name,
|
177
|
+
"problem_type": "classification" if is_classification else "regression",
|
178
|
+
"train_score": float(train_score),
|
179
|
+
"test_score": float(test_score),
|
180
|
+
"metric": metric_name,
|
181
|
+
"cross_val_mean": float(cv_scores.mean()),
|
182
|
+
"cross_val_std": float(cv_scores.std()),
|
183
|
+
"feature_importance": {k: float(v) for k, v in list(feature_importance.items())[:10]}, # Top 10
|
184
|
+
"predictions_sample": predictions_sample,
|
185
|
+
"train_size": len(X_train),
|
186
|
+
"test_size": len(X_test),
|
187
|
+
}
|
188
|
+
|
189
|
+
return result
|
190
|
+
|
191
|
+
except Exception as e:
|
192
|
+
return f"Error training model: {str(e)}"
|