ai-data-science-team 0.0.0.9014__py3-none-any.whl → 0.0.0.9016__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/data_visualization_agent.py +172 -129
- ai_data_science_team/agents/data_wrangling_agent.py +1 -0
- ai_data_science_team/ds_agents/eda_tools_agent.py +46 -50
- ai_data_science_team/multiagents/pandas_data_analyst.py +5 -5
- ai_data_science_team/multiagents/sql_data_analyst.py +7 -18
- ai_data_science_team/tools/eda.py +123 -60
- {ai_data_science_team-0.0.0.9014.dist-info → ai_data_science_team-0.0.0.9016.dist-info}/METADATA +64 -57
- {ai_data_science_team-0.0.0.9014.dist-info → ai_data_science_team-0.0.0.9016.dist-info}/RECORD +12 -12
- {ai_data_science_team-0.0.0.9014.dist-info → ai_data_science_team-0.0.0.9016.dist-info}/WHEEL +1 -1
- {ai_data_science_team-0.0.0.9014.dist-info → ai_data_science_team-0.0.0.9016.dist-info/licenses}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9014.dist-info → ai_data_science_team-0.0.0.9016.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,3 @@
|
|
1
|
-
|
2
|
-
|
3
1
|
from typing import Any, Optional, Annotated, Sequence, Dict
|
4
2
|
import operator
|
5
3
|
import pandas as pd
|
@@ -17,10 +15,11 @@ from ai_data_science_team.utils.regex import format_agent_name
|
|
17
15
|
|
18
16
|
from ai_data_science_team.tools.eda import (
|
19
17
|
explain_data,
|
20
|
-
describe_dataset,
|
21
|
-
visualize_missing,
|
22
|
-
|
18
|
+
describe_dataset,
|
19
|
+
visualize_missing,
|
20
|
+
generate_correlation_funnel,
|
23
21
|
generate_sweetviz_report,
|
22
|
+
generate_dtale_report,
|
24
23
|
)
|
25
24
|
from ai_data_science_team.utils.messages import get_tool_call_names
|
26
25
|
|
@@ -32,15 +31,17 @@ EDA_TOOLS = [
|
|
32
31
|
explain_data,
|
33
32
|
describe_dataset,
|
34
33
|
visualize_missing,
|
35
|
-
|
34
|
+
generate_correlation_funnel,
|
36
35
|
generate_sweetviz_report,
|
36
|
+
generate_dtale_report,
|
37
37
|
]
|
38
38
|
|
39
|
+
|
39
40
|
class EDAToolsAgent(BaseAgent):
|
40
41
|
"""
|
41
42
|
An Exploratory Data Analysis Tools Agent that interacts with EDA tools to generate summary statistics,
|
42
43
|
missing data visualizations, correlation funnels, EDA reports, etc.
|
43
|
-
|
44
|
+
|
44
45
|
Parameters:
|
45
46
|
----------
|
46
47
|
model : langchain.llms.base.LLM
|
@@ -52,9 +53,9 @@ class EDAToolsAgent(BaseAgent):
|
|
52
53
|
checkpointer : Checkpointer, optional
|
53
54
|
The checkpointer for the agent.
|
54
55
|
"""
|
55
|
-
|
56
|
+
|
56
57
|
def __init__(
|
57
|
-
self,
|
58
|
+
self,
|
58
59
|
model: Any,
|
59
60
|
create_react_agent_kwargs: Optional[Dict] = {},
|
60
61
|
invoke_react_agent_kwargs: Optional[Dict] = {},
|
@@ -64,18 +65,18 @@ class EDAToolsAgent(BaseAgent):
|
|
64
65
|
"model": model,
|
65
66
|
"create_react_agent_kwargs": create_react_agent_kwargs,
|
66
67
|
"invoke_react_agent_kwargs": invoke_react_agent_kwargs,
|
67
|
-
"checkpointer": checkpointer
|
68
|
+
"checkpointer": checkpointer,
|
68
69
|
}
|
69
70
|
self._compiled_graph = self._make_compiled_graph()
|
70
71
|
self.response = None
|
71
|
-
|
72
|
+
|
72
73
|
def _make_compiled_graph(self):
|
73
74
|
"""
|
74
75
|
Creates the compiled state graph for the EDA agent.
|
75
76
|
"""
|
76
77
|
self.response = None
|
77
78
|
return make_eda_tools_agent(**self._params)
|
78
|
-
|
79
|
+
|
79
80
|
def update_params(self, **kwargs):
|
80
81
|
"""
|
81
82
|
Updates the agent's parameters and rebuilds the compiled graph.
|
@@ -83,16 +84,13 @@ class EDAToolsAgent(BaseAgent):
|
|
83
84
|
for k, v in kwargs.items():
|
84
85
|
self._params[k] = v
|
85
86
|
self._compiled_graph = self._make_compiled_graph()
|
86
|
-
|
87
|
+
|
87
88
|
async def ainvoke_agent(
|
88
|
-
self,
|
89
|
-
user_instructions: str = None,
|
90
|
-
data_raw: pd.DataFrame = None,
|
91
|
-
**kwargs
|
89
|
+
self, user_instructions: str = None, data_raw: pd.DataFrame = None, **kwargs
|
92
90
|
):
|
93
91
|
"""
|
94
92
|
Asynchronously runs the agent with user instructions and data.
|
95
|
-
|
93
|
+
|
96
94
|
Parameters:
|
97
95
|
----------
|
98
96
|
user_instructions : str, optional
|
@@ -105,20 +103,17 @@ class EDAToolsAgent(BaseAgent):
|
|
105
103
|
"user_instructions": user_instructions,
|
106
104
|
"data_raw": data_raw.to_dict() if data_raw is not None else None,
|
107
105
|
},
|
108
|
-
**kwargs
|
106
|
+
**kwargs,
|
109
107
|
)
|
110
108
|
self.response = response
|
111
109
|
return None
|
112
|
-
|
110
|
+
|
113
111
|
def invoke_agent(
|
114
|
-
self,
|
115
|
-
user_instructions: str = None,
|
116
|
-
data_raw: pd.DataFrame = None,
|
117
|
-
**kwargs
|
112
|
+
self, user_instructions: str = None, data_raw: pd.DataFrame = None, **kwargs
|
118
113
|
):
|
119
114
|
"""
|
120
115
|
Synchronously runs the agent with user instructions and data.
|
121
|
-
|
116
|
+
|
122
117
|
Parameters:
|
123
118
|
----------
|
124
119
|
user_instructions : str, optional
|
@@ -131,24 +126,26 @@ class EDAToolsAgent(BaseAgent):
|
|
131
126
|
"user_instructions": user_instructions,
|
132
127
|
"data_raw": data_raw.to_dict() if data_raw is not None else None,
|
133
128
|
},
|
134
|
-
**kwargs
|
129
|
+
**kwargs,
|
135
130
|
)
|
136
131
|
self.response = response
|
137
132
|
return None
|
138
|
-
|
133
|
+
|
139
134
|
def get_internal_messages(self, markdown: bool = False):
|
140
135
|
"""
|
141
136
|
Returns internal messages from the agent response.
|
142
137
|
"""
|
143
138
|
pretty_print = "\n\n".join(
|
144
|
-
[
|
145
|
-
|
139
|
+
[
|
140
|
+
f"### {msg.type.upper()}\n\nID: {msg.id}\n\nContent:\n\n{msg.content}"
|
141
|
+
for msg in self.response["internal_messages"]
|
142
|
+
]
|
146
143
|
)
|
147
144
|
if markdown:
|
148
145
|
return Markdown(pretty_print)
|
149
146
|
else:
|
150
147
|
return self.response["internal_messages"]
|
151
|
-
|
148
|
+
|
152
149
|
def get_artifacts(self, as_dataframe: bool = False):
|
153
150
|
"""
|
154
151
|
Returns the EDA artifacts from the agent response.
|
@@ -157,7 +154,7 @@ class EDAToolsAgent(BaseAgent):
|
|
157
154
|
return pd.DataFrame(self.response["eda_artifacts"])
|
158
155
|
else:
|
159
156
|
return self.response["eda_artifacts"]
|
160
|
-
|
157
|
+
|
161
158
|
def get_ai_message(self, markdown: bool = False):
|
162
159
|
"""
|
163
160
|
Returns the AI message from the agent response.
|
@@ -166,13 +163,14 @@ class EDAToolsAgent(BaseAgent):
|
|
166
163
|
return Markdown(self.response["messages"][0].content)
|
167
164
|
else:
|
168
165
|
return self.response["messages"][0].content
|
169
|
-
|
166
|
+
|
170
167
|
def get_tool_calls(self):
|
171
168
|
"""
|
172
169
|
Returns the tool calls made by the agent.
|
173
170
|
"""
|
174
171
|
return self.response["tool_calls"]
|
175
172
|
|
173
|
+
|
176
174
|
def make_eda_tools_agent(
|
177
175
|
model: Any,
|
178
176
|
create_react_agent_kwargs: Optional[Dict] = {},
|
@@ -181,7 +179,7 @@ def make_eda_tools_agent(
|
|
181
179
|
):
|
182
180
|
"""
|
183
181
|
Creates an Exploratory Data Analyst Agent that can interact with EDA tools.
|
184
|
-
|
182
|
+
|
185
183
|
Parameters:
|
186
184
|
----------
|
187
185
|
model : Any
|
@@ -192,13 +190,13 @@ def make_eda_tools_agent(
|
|
192
190
|
Additional kwargs for agent invocation.
|
193
191
|
checkpointer : Checkpointer, optional
|
194
192
|
The checkpointer for the agent.
|
195
|
-
|
193
|
+
|
196
194
|
Returns:
|
197
195
|
-------
|
198
196
|
app : langgraph.graph.CompiledStateGraph
|
199
197
|
The compiled state graph for the EDA agent.
|
200
198
|
"""
|
201
|
-
|
199
|
+
|
202
200
|
class GraphState(AgentState):
|
203
201
|
internal_messages: Annotated[Sequence[BaseMessage], operator.add]
|
204
202
|
user_instructions: str
|
@@ -209,11 +207,9 @@ def make_eda_tools_agent(
|
|
209
207
|
def exploratory_agent(state):
|
210
208
|
print(format_agent_name(AGENT_NAME))
|
211
209
|
print(" * RUN REACT TOOL-CALLING AGENT FOR EDA")
|
212
|
-
|
213
|
-
tool_node = ToolNode(
|
214
|
-
|
215
|
-
)
|
216
|
-
|
210
|
+
|
211
|
+
tool_node = ToolNode(tools=EDA_TOOLS)
|
212
|
+
|
217
213
|
eda_agent = create_react_agent(
|
218
214
|
model,
|
219
215
|
tools=tool_node,
|
@@ -221,7 +217,7 @@ def make_eda_tools_agent(
|
|
221
217
|
**create_react_agent_kwargs,
|
222
218
|
checkpointer=checkpointer,
|
223
219
|
)
|
224
|
-
|
220
|
+
|
225
221
|
response = eda_agent.invoke(
|
226
222
|
{
|
227
223
|
"messages": [("user", state["user_instructions"])],
|
@@ -229,13 +225,13 @@ def make_eda_tools_agent(
|
|
229
225
|
},
|
230
226
|
invoke_react_agent_kwargs,
|
231
227
|
)
|
232
|
-
|
228
|
+
|
233
229
|
print(" * POST-PROCESSING EDA RESULTS")
|
234
|
-
|
235
|
-
internal_messages = response[
|
230
|
+
|
231
|
+
internal_messages = response["messages"]
|
236
232
|
if not internal_messages:
|
237
233
|
return {"internal_messages": [], "eda_artifacts": None}
|
238
|
-
|
234
|
+
|
239
235
|
last_ai_message = AIMessage(internal_messages[-1].content, role=AGENT_NAME)
|
240
236
|
last_tool_artifact = None
|
241
237
|
if len(internal_messages) > 1:
|
@@ -244,24 +240,24 @@ def make_eda_tools_agent(
|
|
244
240
|
last_tool_artifact = last_message.artifact
|
245
241
|
elif isinstance(last_message, dict) and "artifact" in last_message:
|
246
242
|
last_tool_artifact = last_message["artifact"]
|
247
|
-
|
243
|
+
|
248
244
|
tool_calls = get_tool_call_names(internal_messages)
|
249
|
-
|
245
|
+
|
250
246
|
return {
|
251
247
|
"messages": [last_ai_message],
|
252
248
|
"internal_messages": internal_messages,
|
253
249
|
"eda_artifacts": last_tool_artifact,
|
254
250
|
"tool_calls": tool_calls,
|
255
251
|
}
|
256
|
-
|
252
|
+
|
257
253
|
workflow = StateGraph(GraphState)
|
258
254
|
workflow.add_node("exploratory_agent", exploratory_agent)
|
259
255
|
workflow.add_edge(START, "exploratory_agent")
|
260
256
|
workflow.add_edge("exploratory_agent", END)
|
261
|
-
|
257
|
+
|
262
258
|
app = workflow.compile(
|
263
259
|
checkpointer=checkpointer,
|
264
260
|
name=AGENT_NAME,
|
265
261
|
)
|
266
|
-
|
262
|
+
|
267
263
|
return app
|
@@ -132,10 +132,10 @@ class PandasDataAnalyst(BaseAgent):
|
|
132
132
|
"""Returns a summary of the workflow."""
|
133
133
|
if self.response and self.response.get("messages"):
|
134
134
|
agents = [msg.role for msg in self.response["messages"]]
|
135
|
-
agent_labels = [f"- **Agent {i+1}:** {role}" for i, role in enumerate(agents)]
|
135
|
+
agent_labels = [f"- **Agent {i+1}:** {role}\n" for i, role in enumerate(agents)]
|
136
136
|
header = f"# Pandas Data Analyst Workflow Summary\n\nThis workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
|
137
137
|
reports = [get_generic_summary(json.loads(msg.content)) for msg in self.response["messages"]]
|
138
|
-
summary = "\n" +header + "\n\n".join(reports)
|
138
|
+
summary = "\n\n" + header + "\n\n".join(reports)
|
139
139
|
return Markdown(summary) if markdown else summary
|
140
140
|
|
141
141
|
@staticmethod
|
@@ -177,15 +177,15 @@ def make_pandas_data_analyst(
|
|
177
177
|
|
178
178
|
routing_preprocessor_prompt = PromptTemplate(
|
179
179
|
template="""
|
180
|
-
You are an expert in routing decisions for a Pandas Data Manipulation Wrangling Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to:
|
180
|
+
You are an expert in routing decisions for a Pandas Data Manipulation Wrangling Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to tell the agents which actions to perform and determine the correct routing for the incoming user question:
|
181
181
|
|
182
|
-
1. Determine what the correct format for a Users Question should be for use with a Pandas Data Wrangling Agent based on the incoming user question. Anything related to data wrangling and manipulation should be passed along.
|
182
|
+
1. Determine what the correct format for a Users Question should be for use with a Pandas Data Wrangling Agent based on the incoming user question. Anything related to data wrangling and manipulation should be passed along. Anything related to data analysis can be handled by the Pandas Agent. Anything that uses Pandas can be passed along. Tables can be returned from this agent. Don't pass along anything about plotting or visualization.
|
183
183
|
2. Determine whether or not a chart should be generated or a table should be returned based on the users question.
|
184
184
|
3. If a chart is requested, determine the correct format of a Users Question should be used with a Data Visualization Agent. Anything related to plotting and visualization should be passed along.
|
185
185
|
|
186
186
|
Use the following criteria on how to route the the initial user question:
|
187
187
|
|
188
|
-
From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the
|
188
|
+
From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the Pandas Data Wrangling and Transformation agent. This will be the 'user_instructions_data_wrangling'. If 'None' is found, return the original user question.
|
189
189
|
|
190
190
|
Next, determine if the user would like a data visualization ('chart') or a 'table' returned with the results of the Data Wrangling Agent. If unknown, not specified or 'None' is found, then select 'table'.
|
191
191
|
|
@@ -301,24 +301,13 @@ class SQLDataAnalyst(BaseAgent):
|
|
301
301
|
markdown: bool
|
302
302
|
If True, returns the summary as a Markdown-formatted string.
|
303
303
|
"""
|
304
|
-
if self.response and self.
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
# Construct header
|
313
|
-
header = f"# SQL Data Analyst Workflow Summary Report\n\nThis agentic workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
|
314
|
-
|
315
|
-
reports = []
|
316
|
-
for msg in self.get_response()['messages']:
|
317
|
-
reports.append(get_generic_summary(json.loads(msg.content)))
|
318
|
-
|
319
|
-
if markdown:
|
320
|
-
return Markdown(header + "\n\n".join(reports))
|
321
|
-
return "\n\n".join(reports)
|
304
|
+
if self.response and self.response.get("messages"):
|
305
|
+
agents = [msg.role for msg in self.response["messages"]]
|
306
|
+
agent_labels = [f"- **Agent {i+1}:** {role}\n" for i, role in enumerate(agents)]
|
307
|
+
header = f"# SQL Data Analyst Workflow Summary\n\nThis workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
|
308
|
+
reports = [get_generic_summary(json.loads(msg.content)) for msg in self.response["messages"]]
|
309
|
+
summary = "\n\n" + header + "\n\n".join(reports)
|
310
|
+
return Markdown(summary) if markdown else summary
|
322
311
|
|
323
312
|
|
324
313
|
|