zero-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentz/agent/base.py +262 -0
- agentz/artifacts/__init__.py +5 -0
- agentz/artifacts/artifact_writer.py +538 -0
- agentz/artifacts/reporter.py +235 -0
- agentz/artifacts/terminal_writer.py +100 -0
- agentz/context/__init__.py +6 -0
- agentz/context/context.py +91 -0
- agentz/context/conversation.py +205 -0
- agentz/context/data_store.py +208 -0
- agentz/llm/llm_setup.py +156 -0
- agentz/mcp/manager.py +142 -0
- agentz/mcp/patches.py +88 -0
- agentz/mcp/servers/chrome_devtools/server.py +14 -0
- agentz/profiles/base.py +108 -0
- agentz/profiles/data/data_analysis.py +38 -0
- agentz/profiles/data/data_loader.py +35 -0
- agentz/profiles/data/evaluation.py +43 -0
- agentz/profiles/data/model_training.py +47 -0
- agentz/profiles/data/preprocessing.py +47 -0
- agentz/profiles/data/visualization.py +47 -0
- agentz/profiles/manager/evaluate.py +51 -0
- agentz/profiles/manager/memory.py +62 -0
- agentz/profiles/manager/observe.py +48 -0
- agentz/profiles/manager/routing.py +66 -0
- agentz/profiles/manager/writer.py +51 -0
- agentz/profiles/mcp/browser.py +21 -0
- agentz/profiles/mcp/chrome.py +21 -0
- agentz/profiles/mcp/notion.py +21 -0
- agentz/runner/__init__.py +74 -0
- agentz/runner/base.py +28 -0
- agentz/runner/executor.py +320 -0
- agentz/runner/hooks.py +110 -0
- agentz/runner/iteration.py +142 -0
- agentz/runner/patterns.py +215 -0
- agentz/runner/tracker.py +188 -0
- agentz/runner/utils.py +45 -0
- agentz/runner/workflow.py +250 -0
- agentz/tools/__init__.py +20 -0
- agentz/tools/data_tools/__init__.py +17 -0
- agentz/tools/data_tools/data_analysis.py +152 -0
- agentz/tools/data_tools/data_loading.py +92 -0
- agentz/tools/data_tools/evaluation.py +175 -0
- agentz/tools/data_tools/helpers.py +120 -0
- agentz/tools/data_tools/model_training.py +192 -0
- agentz/tools/data_tools/preprocessing.py +229 -0
- agentz/tools/data_tools/visualization.py +281 -0
- agentz/utils/__init__.py +69 -0
- agentz/utils/config.py +708 -0
- agentz/utils/helpers.py +10 -0
- agentz/utils/parsers.py +142 -0
- agentz/utils/printer.py +539 -0
- pipelines/base.py +972 -0
- pipelines/data_scientist.py +97 -0
- pipelines/data_scientist_memory.py +151 -0
- pipelines/experience_learner.py +0 -0
- pipelines/prompt_generator.py +0 -0
- pipelines/simple.py +78 -0
- pipelines/simple_browser.py +145 -0
- pipelines/simple_chrome.py +75 -0
- pipelines/simple_notion.py +103 -0
- pipelines/tool_builder.py +0 -0
- zero_agent-0.1.0.dist-info/METADATA +269 -0
- zero_agent-0.1.0.dist-info/RECORD +66 -0
- zero_agent-0.1.0.dist-info/WHEEL +5 -0
- zero_agent-0.1.0.dist-info/licenses/LICENSE +21 -0
- zero_agent-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,250 @@
|
|
1
|
+
"""Workflow control flow helpers for pipeline execution."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
|
5
|
+
|
6
|
+
|
7
|
+
class WorkflowHelpers:
|
8
|
+
"""Collection of workflow control flow helpers.
|
9
|
+
|
10
|
+
These helpers provide common patterns for pipeline execution:
|
11
|
+
- Iterative loops
|
12
|
+
- Custom groups
|
13
|
+
- Parallel execution
|
14
|
+
- Conditional execution
|
15
|
+
- Loop-until patterns
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
iteration_manager: Any,
|
21
|
+
hook_registry: Any,
|
22
|
+
start_group_callback: Optional[Callable] = None,
|
23
|
+
end_group_callback: Optional[Callable] = None,
|
24
|
+
update_printer_callback: Optional[Callable] = None,
|
25
|
+
):
|
26
|
+
"""Initialize workflow helpers.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
iteration_manager: IterationManager instance
|
30
|
+
hook_registry: HookRegistry instance
|
31
|
+
start_group_callback: Optional callback(group_id, title, border_style)
|
32
|
+
end_group_callback: Optional callback(group_id, is_done)
|
33
|
+
update_printer_callback: Optional callback(key, message, is_done, group_id)
|
34
|
+
"""
|
35
|
+
self.iteration_manager = iteration_manager
|
36
|
+
self.hook_registry = hook_registry
|
37
|
+
self.start_group = start_group_callback
|
38
|
+
self.end_group = end_group_callback
|
39
|
+
self.update_printer = update_printer_callback
|
40
|
+
|
41
|
+
async def run_iterative_loop(
|
42
|
+
self,
|
43
|
+
iteration_body: Callable[[Any, str], Awaitable[Any]],
|
44
|
+
final_body: Optional[Callable[[str], Awaitable[Any]]] = None,
|
45
|
+
should_continue: Optional[Callable[[], bool]] = None,
|
46
|
+
) -> Any:
|
47
|
+
"""Execute standard iterative loop pattern.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
iteration_body: Async function(iteration, group_id) -> result
|
51
|
+
final_body: Optional async function(final_group_id) -> result
|
52
|
+
should_continue: Optional custom condition (default: iteration_manager.should_continue_iteration)
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
Result from final_body if provided, else None
|
56
|
+
|
57
|
+
Example:
|
58
|
+
async def my_iteration(iteration, group):
|
59
|
+
observations = await observe_agent(...)
|
60
|
+
evaluations = await evaluate_agent(...)
|
61
|
+
await route_and_execute(evaluations, group)
|
62
|
+
|
63
|
+
async def my_final(group):
|
64
|
+
return await writer_agent(...)
|
65
|
+
|
66
|
+
result = await helpers.run_iterative_loop(
|
67
|
+
iteration_body=my_iteration,
|
68
|
+
final_body=my_final
|
69
|
+
)
|
70
|
+
"""
|
71
|
+
should_continue_fn = should_continue or self.iteration_manager.should_continue_iteration
|
72
|
+
|
73
|
+
while should_continue_fn():
|
74
|
+
iteration, group_id = self.iteration_manager.begin_iteration()
|
75
|
+
|
76
|
+
await self.hook_registry.trigger(
|
77
|
+
"before_iteration",
|
78
|
+
context=self.iteration_manager.context,
|
79
|
+
iteration=iteration,
|
80
|
+
group_id=group_id
|
81
|
+
)
|
82
|
+
|
83
|
+
try:
|
84
|
+
await iteration_body(iteration, group_id)
|
85
|
+
finally:
|
86
|
+
await self.hook_registry.trigger(
|
87
|
+
"after_iteration",
|
88
|
+
context=self.iteration_manager.context,
|
89
|
+
iteration=iteration,
|
90
|
+
group_id=group_id
|
91
|
+
)
|
92
|
+
self.iteration_manager.end_iteration(group_id)
|
93
|
+
|
94
|
+
# Check if state indicates completion
|
95
|
+
context = self.iteration_manager.context
|
96
|
+
if hasattr(context, 'state') and context.state and context.state.complete:
|
97
|
+
break
|
98
|
+
|
99
|
+
result = None
|
100
|
+
if final_body:
|
101
|
+
final_group = self.iteration_manager.start_final_group()
|
102
|
+
result = await final_body(final_group)
|
103
|
+
self.iteration_manager.end_final_group(final_group)
|
104
|
+
|
105
|
+
return result
|
106
|
+
|
107
|
+
async def run_custom_group(
|
108
|
+
self,
|
109
|
+
group_id: str,
|
110
|
+
title: str,
|
111
|
+
body: Callable[[], Awaitable[Any]],
|
112
|
+
border_style: str = "white",
|
113
|
+
) -> Any:
|
114
|
+
"""Execute code within a custom printer group.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
group_id: Unique group identifier
|
118
|
+
title: Display title for the group
|
119
|
+
body: Async function to execute within group
|
120
|
+
border_style: Border color for printer
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
Result from body()
|
124
|
+
|
125
|
+
Example:
|
126
|
+
exploration = await helpers.run_custom_group(
|
127
|
+
"exploration",
|
128
|
+
"Exploration Phase",
|
129
|
+
self._explore
|
130
|
+
)
|
131
|
+
|
132
|
+
analysis = await helpers.run_custom_group(
|
133
|
+
"analysis",
|
134
|
+
"Deep Analysis",
|
135
|
+
lambda: self._analyze(exploration)
|
136
|
+
)
|
137
|
+
"""
|
138
|
+
if self.start_group:
|
139
|
+
self.start_group(group_id, title=title, border_style=border_style)
|
140
|
+
try:
|
141
|
+
result = await body()
|
142
|
+
return result
|
143
|
+
finally:
|
144
|
+
if self.end_group:
|
145
|
+
self.end_group(group_id, is_done=True)
|
146
|
+
|
147
|
+
async def run_parallel_steps(
|
148
|
+
self,
|
149
|
+
steps: Dict[str, Callable[[], Awaitable[Any]]],
|
150
|
+
group_id: Optional[str] = None,
|
151
|
+
) -> Dict[str, Any]:
|
152
|
+
"""Execute multiple steps in parallel.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
steps: Dict mapping step_name -> async callable
|
156
|
+
group_id: Optional group to nest steps in
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
Dict mapping step_name -> result
|
160
|
+
|
161
|
+
Example:
|
162
|
+
results = await helpers.run_parallel_steps({
|
163
|
+
"data_loading": self.load_data,
|
164
|
+
"validation": self.validate_inputs,
|
165
|
+
"model_init": self.initialize_models,
|
166
|
+
})
|
167
|
+
|
168
|
+
data = results["data_loading"]
|
169
|
+
"""
|
170
|
+
async def run_step(name: str, fn: Callable):
|
171
|
+
key = f"{group_id}:{name}" if group_id else name
|
172
|
+
if self.update_printer:
|
173
|
+
self.update_printer(key, f"Running {name}...", group_id=group_id)
|
174
|
+
result = await fn()
|
175
|
+
if self.update_printer:
|
176
|
+
self.update_printer(key, f"Completed {name}", is_done=True, group_id=group_id)
|
177
|
+
return name, result
|
178
|
+
|
179
|
+
tasks = [run_step(name, fn) for name, fn in steps.items()]
|
180
|
+
completed = await asyncio.gather(*tasks)
|
181
|
+
return dict(completed)
|
182
|
+
|
183
|
+
async def run_if(
|
184
|
+
self,
|
185
|
+
condition: Union[bool, Callable[[], bool]],
|
186
|
+
body: Callable[[], Awaitable[Any]],
|
187
|
+
else_body: Optional[Callable[[], Awaitable[Any]]] = None,
|
188
|
+
) -> Any:
|
189
|
+
"""Conditional execution helper.
|
190
|
+
|
191
|
+
Args:
|
192
|
+
condition: Boolean or callable returning bool
|
193
|
+
body: Execute if condition is True
|
194
|
+
else_body: Optional execute if condition is False
|
195
|
+
|
196
|
+
Returns:
|
197
|
+
Result from executed body
|
198
|
+
|
199
|
+
Example:
|
200
|
+
initial = await quick_check()
|
201
|
+
|
202
|
+
return await helpers.run_if(
|
203
|
+
condition=initial.needs_deep_analysis,
|
204
|
+
body=lambda: deep_analysis(initial),
|
205
|
+
else_body=lambda: simple_report(initial)
|
206
|
+
)
|
207
|
+
"""
|
208
|
+
cond_result = condition() if callable(condition) else condition
|
209
|
+
if cond_result:
|
210
|
+
return await body()
|
211
|
+
elif else_body:
|
212
|
+
return await else_body()
|
213
|
+
return None
|
214
|
+
|
215
|
+
async def run_until(
|
216
|
+
self,
|
217
|
+
condition: Callable[[], bool],
|
218
|
+
body: Callable[[int], Awaitable[Any]],
|
219
|
+
max_iterations: Optional[int] = None,
|
220
|
+
) -> List[Any]:
|
221
|
+
"""Execute body repeatedly until condition is met.
|
222
|
+
|
223
|
+
Args:
|
224
|
+
condition: Callable returning True to stop
|
225
|
+
body: Async function(iteration_number) -> result
|
226
|
+
max_iterations: Optional max iterations (default: unlimited)
|
227
|
+
|
228
|
+
Returns:
|
229
|
+
List of results from each iteration
|
230
|
+
|
231
|
+
Example:
|
232
|
+
results = await helpers.run_until(
|
233
|
+
condition=lambda: context.state.complete,
|
234
|
+
body=self._exploration_step,
|
235
|
+
max_iterations=10
|
236
|
+
)
|
237
|
+
return aggregate(results)
|
238
|
+
"""
|
239
|
+
results = []
|
240
|
+
iteration = 0
|
241
|
+
|
242
|
+
while not condition():
|
243
|
+
if max_iterations and iteration >= max_iterations:
|
244
|
+
break
|
245
|
+
|
246
|
+
result = await body(iteration)
|
247
|
+
results.append(result)
|
248
|
+
iteration += 1
|
249
|
+
|
250
|
+
return results
|
agentz/tools/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
"""Tools for agent workflows."""
|
2
|
+
|
3
|
+
# Re-export data tools for backward compatibility
|
4
|
+
from .data_tools import (
|
5
|
+
load_dataset,
|
6
|
+
analyze_data,
|
7
|
+
preprocess_data,
|
8
|
+
train_model,
|
9
|
+
evaluate_model,
|
10
|
+
create_visualization,
|
11
|
+
)
|
12
|
+
|
13
|
+
__all__ = [
|
14
|
+
"load_dataset",
|
15
|
+
"analyze_data",
|
16
|
+
"preprocess_data",
|
17
|
+
"train_model",
|
18
|
+
"evaluate_model",
|
19
|
+
"create_visualization",
|
20
|
+
]
|
@@ -0,0 +1,17 @@
|
|
1
|
+
"""Data science tools for data analysis, preprocessing, modeling, and visualization."""
|
2
|
+
|
3
|
+
from .data_loading import load_dataset
|
4
|
+
from .data_analysis import analyze_data
|
5
|
+
from .preprocessing import preprocess_data
|
6
|
+
from .model_training import train_model
|
7
|
+
from .evaluation import evaluate_model
|
8
|
+
from .visualization import create_visualization
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"load_dataset",
|
12
|
+
"analyze_data",
|
13
|
+
"preprocess_data",
|
14
|
+
"train_model",
|
15
|
+
"evaluate_model",
|
16
|
+
"create_visualization",
|
17
|
+
]
|
@@ -0,0 +1,152 @@
|
|
1
|
+
"""Data analysis tool for exploratory data analysis and statistical analysis."""
|
2
|
+
|
3
|
+
from typing import Union, Dict, Any, Optional
|
4
|
+
from pathlib import Path
|
5
|
+
import pandas as pd
|
6
|
+
import numpy as np
|
7
|
+
from agents import function_tool
|
8
|
+
from agents.run_context import RunContextWrapper
|
9
|
+
from agentz.context.data_store import DataStore
|
10
|
+
from .helpers import load_or_get_dataframe
|
11
|
+
from loguru import logger
|
12
|
+
|
13
|
+
|
14
|
+
@function_tool
|
15
|
+
async def analyze_data(ctx: RunContextWrapper[DataStore], file_path: Optional[str] = None, target_column: str = None) -> Union[Dict[str, Any], str]:
|
16
|
+
"""Performs comprehensive exploratory data analysis on a dataset.
|
17
|
+
|
18
|
+
This tool automatically uses the current dataset from the pipeline context.
|
19
|
+
A file_path can optionally be provided to analyze a different dataset.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
ctx: Pipeline context wrapper for accessing the data store
|
23
|
+
file_path: Optional path to dataset file. If not provided, uses current dataset.
|
24
|
+
target_column: Optional target column for correlation analysis
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
Dictionary containing:
|
28
|
+
- distributions: Distribution statistics for each column
|
29
|
+
- correlations: Correlation matrix for numerical columns
|
30
|
+
- outliers: Outlier detection results using IQR method
|
31
|
+
- patterns: Identified patterns and insights
|
32
|
+
- recommendations: Data quality and preprocessing recommendations
|
33
|
+
Or error message string if analysis fails
|
34
|
+
"""
|
35
|
+
try:
|
36
|
+
# Get DataFrame - either from file_path or current dataset
|
37
|
+
data_store = ctx.context
|
38
|
+
if file_path is None:
|
39
|
+
if data_store and data_store.has("current_dataset"):
|
40
|
+
df = data_store.get("current_dataset")
|
41
|
+
logger.info("Analyzing current dataset from pipeline context")
|
42
|
+
else:
|
43
|
+
return "Error: No dataset loaded. Please load a dataset first using the load_dataset tool."
|
44
|
+
else:
|
45
|
+
df = load_or_get_dataframe(file_path, prefer_preprocessed=False, data_store=data_store)
|
46
|
+
logger.info(f"Analyzing dataset from: {file_path}")
|
47
|
+
|
48
|
+
result = {}
|
49
|
+
|
50
|
+
# Distribution analysis
|
51
|
+
distributions = {}
|
52
|
+
for col in df.columns:
|
53
|
+
if pd.api.types.is_numeric_dtype(df[col]):
|
54
|
+
distributions[col] = {
|
55
|
+
"mean": float(df[col].mean()),
|
56
|
+
"median": float(df[col].median()),
|
57
|
+
"std": float(df[col].std()),
|
58
|
+
"min": float(df[col].min()),
|
59
|
+
"max": float(df[col].max()),
|
60
|
+
"q25": float(df[col].quantile(0.25)),
|
61
|
+
"q75": float(df[col].quantile(0.75)),
|
62
|
+
"skewness": float(df[col].skew()),
|
63
|
+
"kurtosis": float(df[col].kurtosis()),
|
64
|
+
}
|
65
|
+
else:
|
66
|
+
distributions[col] = {
|
67
|
+
"unique_values": int(df[col].nunique()),
|
68
|
+
"top_value": str(df[col].mode()[0]) if not df[col].mode().empty else None,
|
69
|
+
"top_frequency": int(df[col].value_counts().iloc[0]) if len(df[col].value_counts()) > 0 else 0,
|
70
|
+
}
|
71
|
+
result["distributions"] = distributions
|
72
|
+
|
73
|
+
# Correlation analysis
|
74
|
+
numeric_df = df.select_dtypes(include=['number'])
|
75
|
+
if not numeric_df.empty:
|
76
|
+
corr_matrix = numeric_df.corr()
|
77
|
+
result["correlations"] = corr_matrix.to_dict()
|
78
|
+
|
79
|
+
if target_column and target_column in corr_matrix.columns:
|
80
|
+
target_corr = corr_matrix[target_column].drop(target_column).sort_values(ascending=False)
|
81
|
+
result["target_correlations"] = target_corr.to_dict()
|
82
|
+
|
83
|
+
# Outlier detection using IQR method
|
84
|
+
outliers = {}
|
85
|
+
for col in numeric_df.columns:
|
86
|
+
Q1 = df[col].quantile(0.25)
|
87
|
+
Q3 = df[col].quantile(0.75)
|
88
|
+
IQR = Q3 - Q1
|
89
|
+
lower_bound = Q1 - 1.5 * IQR
|
90
|
+
upper_bound = Q3 + 1.5 * IQR
|
91
|
+
outlier_count = ((df[col] < lower_bound) | (df[col] > upper_bound)).sum()
|
92
|
+
outliers[col] = {
|
93
|
+
"count": int(outlier_count),
|
94
|
+
"percentage": float(outlier_count / len(df) * 100),
|
95
|
+
"lower_bound": float(lower_bound),
|
96
|
+
"upper_bound": float(upper_bound),
|
97
|
+
}
|
98
|
+
result["outliers"] = outliers
|
99
|
+
|
100
|
+
# Pattern identification
|
101
|
+
patterns = []
|
102
|
+
|
103
|
+
# High correlation patterns
|
104
|
+
if "correlations" in result:
|
105
|
+
for col1 in corr_matrix.columns:
|
106
|
+
for col2 in corr_matrix.columns:
|
107
|
+
if col1 < col2: # Avoid duplicates
|
108
|
+
corr_val = corr_matrix.loc[col1, col2]
|
109
|
+
if abs(corr_val) > 0.7:
|
110
|
+
patterns.append(f"Strong correlation ({corr_val:.2f}) between {col1} and {col2}")
|
111
|
+
|
112
|
+
# Missing data patterns
|
113
|
+
missing_cols = [col for col in df.columns if df[col].isnull().sum() > 0]
|
114
|
+
if missing_cols:
|
115
|
+
patterns.append(f"Missing data detected in {len(missing_cols)} columns: {', '.join(missing_cols[:5])}")
|
116
|
+
|
117
|
+
# Outlier patterns
|
118
|
+
high_outlier_cols = [col for col, info in outliers.items() if info['percentage'] > 5]
|
119
|
+
if high_outlier_cols:
|
120
|
+
patterns.append(f"High outlier percentage (>5%) in columns: {', '.join(high_outlier_cols)}")
|
121
|
+
|
122
|
+
result["patterns"] = patterns
|
123
|
+
|
124
|
+
# Recommendations
|
125
|
+
recommendations = []
|
126
|
+
|
127
|
+
if missing_cols:
|
128
|
+
recommendations.append("Consider imputation strategies for missing values")
|
129
|
+
|
130
|
+
if high_outlier_cols:
|
131
|
+
recommendations.append("Review and handle outliers before modeling")
|
132
|
+
|
133
|
+
# Check for imbalanced categorical columns
|
134
|
+
for col in df.select_dtypes(include=['object']).columns:
|
135
|
+
value_counts = df[col].value_counts()
|
136
|
+
if len(value_counts) > 1:
|
137
|
+
imbalance_ratio = value_counts.iloc[0] / value_counts.iloc[-1]
|
138
|
+
if imbalance_ratio > 10:
|
139
|
+
recommendations.append(f"Column '{col}' shows class imbalance (ratio: {imbalance_ratio:.1f})")
|
140
|
+
|
141
|
+
# Check for constant or near-constant columns
|
142
|
+
for col in df.columns:
|
143
|
+
unique_ratio = df[col].nunique() / len(df)
|
144
|
+
if unique_ratio < 0.01 and df[col].nunique() > 1:
|
145
|
+
recommendations.append(f"Column '{col}' has very low variance, consider removing")
|
146
|
+
|
147
|
+
result["recommendations"] = recommendations
|
148
|
+
|
149
|
+
return result
|
150
|
+
|
151
|
+
except Exception as e:
|
152
|
+
return f"Error analyzing dataset: {str(e)}"
|
@@ -0,0 +1,92 @@
|
|
1
|
+
"""Data loading tool for loading and inspecting datasets."""
|
2
|
+
|
3
|
+
from typing import Union, Dict, Any
|
4
|
+
from pathlib import Path
|
5
|
+
import pandas as pd
|
6
|
+
from agents import function_tool
|
7
|
+
from agents.run_context import RunContextWrapper
|
8
|
+
from agentz.context.data_store import DataStore
|
9
|
+
from loguru import logger
|
10
|
+
|
11
|
+
|
12
|
+
@function_tool
|
13
|
+
async def load_dataset(ctx: RunContextWrapper[DataStore], file_path: str) -> Union[Dict[str, Any], str]:
|
14
|
+
"""Loads a dataset and provides comprehensive inspection information.
|
15
|
+
|
16
|
+
This tool caches the loaded DataFrame in the pipeline data store so other
|
17
|
+
tools can reuse it without reloading from disk.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
ctx: Pipeline context wrapper for accessing the data store
|
21
|
+
file_path: Path to the dataset file (CSV, JSON, Excel, etc.)
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
Dictionary containing:
|
25
|
+
- shape: Tuple of (rows, columns)
|
26
|
+
- columns: List of column names
|
27
|
+
- dtypes: Dictionary of column data types
|
28
|
+
- missing_values: Dictionary of missing value counts per column
|
29
|
+
- sample_data: First 5 rows as dictionary
|
30
|
+
- summary_stats: Statistical summary for numerical columns
|
31
|
+
- memory_usage: Memory usage information
|
32
|
+
Or error message string if loading fails
|
33
|
+
"""
|
34
|
+
try:
|
35
|
+
file_path = Path(file_path)
|
36
|
+
|
37
|
+
if not file_path.exists():
|
38
|
+
return f"File not found: {file_path}"
|
39
|
+
|
40
|
+
# Load based on file extension
|
41
|
+
if file_path.suffix.lower() == '.csv':
|
42
|
+
df = pd.read_csv(file_path)
|
43
|
+
elif file_path.suffix.lower() in ['.xlsx', '.xls']:
|
44
|
+
df = pd.read_excel(file_path)
|
45
|
+
elif file_path.suffix.lower() == '.json':
|
46
|
+
df = pd.read_json(file_path)
|
47
|
+
elif file_path.suffix.lower() == '.parquet':
|
48
|
+
df = pd.read_parquet(file_path)
|
49
|
+
else:
|
50
|
+
return f"Unsupported file format: {file_path.suffix}"
|
51
|
+
|
52
|
+
# Store DataFrame in data store for reuse by other tools
|
53
|
+
data_store = ctx.context
|
54
|
+
if data_store is not None:
|
55
|
+
# Store with file path key for backward compatibility
|
56
|
+
cache_key = f"dataframe:{file_path.resolve()}"
|
57
|
+
data_store.set(
|
58
|
+
cache_key,
|
59
|
+
df,
|
60
|
+
data_type="dataframe",
|
61
|
+
metadata={"file_path": str(file_path), "shape": df.shape}
|
62
|
+
)
|
63
|
+
logger.info(f"Cached DataFrame from {file_path} with key: {cache_key}")
|
64
|
+
|
65
|
+
# Also set as the current active dataset
|
66
|
+
data_store.set(
|
67
|
+
"current_dataset",
|
68
|
+
df,
|
69
|
+
data_type="dataframe",
|
70
|
+
metadata={"file_path": str(file_path), "shape": df.shape, "source": "loaded"}
|
71
|
+
)
|
72
|
+
logger.info(f"Set as current dataset for pipeline")
|
73
|
+
|
74
|
+
# Gather comprehensive information
|
75
|
+
result = {
|
76
|
+
"file_path": str(file_path),
|
77
|
+
"shape": df.shape,
|
78
|
+
"columns": df.columns.tolist(),
|
79
|
+
"dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()},
|
80
|
+
"missing_values": df.isnull().sum().to_dict(),
|
81
|
+
"missing_percentage": (df.isnull().sum() / len(df) * 100).to_dict(),
|
82
|
+
"sample_data": df.head(5).to_dict(orient='records'),
|
83
|
+
"summary_stats": df.describe().to_dict() if not df.select_dtypes(include=['number']).empty else {},
|
84
|
+
"memory_usage": df.memory_usage(deep=True).to_dict(),
|
85
|
+
"total_memory_mb": df.memory_usage(deep=True).sum() / 1024 / 1024,
|
86
|
+
"duplicate_rows": int(df.duplicated().sum()),
|
87
|
+
}
|
88
|
+
|
89
|
+
return result
|
90
|
+
|
91
|
+
except Exception as e:
|
92
|
+
return f"Error loading dataset: {str(e)}"
|