ai-data-science-team 0.0.0.9012__py3-none-any.whl → 0.0.0.9014__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/__init__.py +22 -0
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/data_cleaning_agent.py +17 -3
- ai_data_science_team/agents/data_loader_tools_agent.py +24 -1
- ai_data_science_team/agents/data_visualization_agent.py +17 -3
- ai_data_science_team/agents/data_wrangling_agent.py +30 -10
- ai_data_science_team/agents/feature_engineering_agent.py +17 -4
- ai_data_science_team/agents/sql_database_agent.py +15 -2
- ai_data_science_team/ds_agents/eda_tools_agent.py +28 -6
- ai_data_science_team/ml_agents/h2o_ml_agent.py +15 -3
- ai_data_science_team/ml_agents/mlflow_tools_agent.py +23 -1
- ai_data_science_team/multiagents/__init__.py +2 -1
- ai_data_science_team/multiagents/pandas_data_analyst.py +305 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +119 -30
- ai_data_science_team/templates/agent_templates.py +41 -5
- ai_data_science_team/tools/dataframe.py +6 -1
- ai_data_science_team/tools/eda.py +75 -16
- ai_data_science_team/utils/messages.py +27 -0
- {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/METADATA +7 -3
- {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/RECORD +23 -21
- {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,8 @@
|
|
1
1
|
|
2
2
|
|
3
|
-
from typing import Any, Optional, Annotated, Sequence,
|
3
|
+
from typing import Any, Optional, Annotated, Sequence, Dict
|
4
4
|
import operator
|
5
5
|
import pandas as pd
|
6
|
-
import os
|
7
|
-
from io import StringIO, BytesIO
|
8
|
-
import base64
|
9
|
-
import matplotlib.pyplot as plt
|
10
6
|
|
11
7
|
from IPython.display import Markdown
|
12
8
|
|
@@ -14,22 +10,26 @@ from langchain_core.messages import BaseMessage, AIMessage
|
|
14
10
|
from langgraph.prebuilt import create_react_agent, ToolNode
|
15
11
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
16
12
|
from langgraph.graph import START, END, StateGraph
|
13
|
+
from langgraph.types import Checkpointer
|
17
14
|
|
18
15
|
from ai_data_science_team.templates import BaseAgent
|
19
16
|
from ai_data_science_team.utils.regex import format_agent_name
|
20
17
|
|
21
18
|
from ai_data_science_team.tools.eda import (
|
19
|
+
explain_data,
|
22
20
|
describe_dataset,
|
23
21
|
visualize_missing,
|
24
22
|
correlation_funnel,
|
25
23
|
generate_sweetviz_report,
|
26
24
|
)
|
25
|
+
from ai_data_science_team.utils.messages import get_tool_call_names
|
27
26
|
|
28
27
|
|
29
28
|
AGENT_NAME = "exploratory_data_analyst_agent"
|
30
29
|
|
31
30
|
# Updated tool list for EDA
|
32
31
|
EDA_TOOLS = [
|
32
|
+
explain_data,
|
33
33
|
describe_dataset,
|
34
34
|
visualize_missing,
|
35
35
|
correlation_funnel,
|
@@ -49,6 +49,8 @@ class EDAToolsAgent(BaseAgent):
|
|
49
49
|
Additional kwargs for create_react_agent.
|
50
50
|
invoke_react_agent_kwargs : dict
|
51
51
|
Additional kwargs for agent invocation.
|
52
|
+
checkpointer : Checkpointer, optional
|
53
|
+
The checkpointer for the agent.
|
52
54
|
"""
|
53
55
|
|
54
56
|
def __init__(
|
@@ -56,11 +58,13 @@ class EDAToolsAgent(BaseAgent):
|
|
56
58
|
model: Any,
|
57
59
|
create_react_agent_kwargs: Optional[Dict] = {},
|
58
60
|
invoke_react_agent_kwargs: Optional[Dict] = {},
|
61
|
+
checkpointer: Optional[Checkpointer] = None,
|
59
62
|
):
|
60
63
|
self._params = {
|
61
64
|
"model": model,
|
62
65
|
"create_react_agent_kwargs": create_react_agent_kwargs,
|
63
66
|
"invoke_react_agent_kwargs": invoke_react_agent_kwargs,
|
67
|
+
"checkpointer": checkpointer
|
64
68
|
}
|
65
69
|
self._compiled_graph = self._make_compiled_graph()
|
66
70
|
self.response = None
|
@@ -162,11 +166,18 @@ class EDAToolsAgent(BaseAgent):
|
|
162
166
|
return Markdown(self.response["messages"][0].content)
|
163
167
|
else:
|
164
168
|
return self.response["messages"][0].content
|
169
|
+
|
170
|
+
def get_tool_calls(self):
|
171
|
+
"""
|
172
|
+
Returns the tool calls made by the agent.
|
173
|
+
"""
|
174
|
+
return self.response["tool_calls"]
|
165
175
|
|
166
176
|
def make_eda_tools_agent(
|
167
177
|
model: Any,
|
168
178
|
create_react_agent_kwargs: Optional[Dict] = {},
|
169
179
|
invoke_react_agent_kwargs: Optional[Dict] = {},
|
180
|
+
checkpointer: Optional[Checkpointer] = None,
|
170
181
|
):
|
171
182
|
"""
|
172
183
|
Creates an Exploratory Data Analyst Agent that can interact with EDA tools.
|
@@ -179,6 +190,8 @@ def make_eda_tools_agent(
|
|
179
190
|
Additional kwargs for create_react_agent.
|
180
191
|
invoke_react_agent_kwargs : dict
|
181
192
|
Additional kwargs for agent invocation.
|
193
|
+
checkpointer : Checkpointer, optional
|
194
|
+
The checkpointer for the agent.
|
182
195
|
|
183
196
|
Returns:
|
184
197
|
-------
|
@@ -191,6 +204,7 @@ def make_eda_tools_agent(
|
|
191
204
|
user_instructions: str
|
192
205
|
data_raw: dict
|
193
206
|
eda_artifacts: dict
|
207
|
+
tool_calls: list
|
194
208
|
|
195
209
|
def exploratory_agent(state):
|
196
210
|
print(format_agent_name(AGENT_NAME))
|
@@ -205,6 +219,7 @@ def make_eda_tools_agent(
|
|
205
219
|
tools=tool_node,
|
206
220
|
state_schema=GraphState,
|
207
221
|
**create_react_agent_kwargs,
|
222
|
+
checkpointer=checkpointer,
|
208
223
|
)
|
209
224
|
|
210
225
|
response = eda_agent.invoke(
|
@@ -229,11 +244,14 @@ def make_eda_tools_agent(
|
|
229
244
|
last_tool_artifact = last_message.artifact
|
230
245
|
elif isinstance(last_message, dict) and "artifact" in last_message:
|
231
246
|
last_tool_artifact = last_message["artifact"]
|
247
|
+
|
248
|
+
tool_calls = get_tool_call_names(internal_messages)
|
232
249
|
|
233
250
|
return {
|
234
251
|
"messages": [last_ai_message],
|
235
252
|
"internal_messages": internal_messages,
|
236
253
|
"eda_artifacts": last_tool_artifact,
|
254
|
+
"tool_calls": tool_calls,
|
237
255
|
}
|
238
256
|
|
239
257
|
workflow = StateGraph(GraphState)
|
@@ -241,5 +259,9 @@ def make_eda_tools_agent(
|
|
241
259
|
workflow.add_edge(START, "exploratory_agent")
|
242
260
|
workflow.add_edge("exploratory_agent", END)
|
243
261
|
|
244
|
-
app = workflow.compile(
|
262
|
+
app = workflow.compile(
|
263
|
+
checkpointer=checkpointer,
|
264
|
+
name=AGENT_NAME,
|
265
|
+
)
|
266
|
+
|
245
267
|
return app
|
@@ -5,7 +5,7 @@
|
|
5
5
|
|
6
6
|
import os
|
7
7
|
import json
|
8
|
-
from typing import TypedDict, Annotated, Sequence, Literal
|
8
|
+
from typing import TypedDict, Annotated, Sequence, Literal, Optional
|
9
9
|
import operator
|
10
10
|
|
11
11
|
import pandas as pd
|
@@ -14,7 +14,7 @@ from IPython.display import Markdown
|
|
14
14
|
from langchain.prompts import PromptTemplate
|
15
15
|
from langchain_core.messages import BaseMessage
|
16
16
|
|
17
|
-
from langgraph.types import Command
|
17
|
+
from langgraph.types import Command, Checkpointer
|
18
18
|
from langgraph.checkpoint.memory import MemorySaver
|
19
19
|
|
20
20
|
from ai_data_science_team.templates import(
|
@@ -79,6 +79,8 @@ class H2OMLAgent(BaseAgent):
|
|
79
79
|
Name of the MLflow experiment (created if doesn't exist).
|
80
80
|
mlflow_run_name : str, default None
|
81
81
|
A custom name for the MLflow run.
|
82
|
+
checkpointer : langgraph.checkpoint.memory.MemorySaver, optional
|
83
|
+
A checkpointer object for saving the agent's state. Defaults to None.
|
82
84
|
|
83
85
|
|
84
86
|
Methods
|
@@ -176,6 +178,7 @@ class H2OMLAgent(BaseAgent):
|
|
176
178
|
mlflow_tracking_uri=None,
|
177
179
|
mlflow_experiment_name="H2O AutoML",
|
178
180
|
mlflow_run_name=None,
|
181
|
+
checkpointer: Optional[Checkpointer]=None,
|
179
182
|
):
|
180
183
|
self._params = {
|
181
184
|
"model": model,
|
@@ -193,6 +196,7 @@ class H2OMLAgent(BaseAgent):
|
|
193
196
|
"mlflow_tracking_uri": mlflow_tracking_uri,
|
194
197
|
"mlflow_experiment_name": mlflow_experiment_name,
|
195
198
|
"mlflow_run_name": mlflow_run_name,
|
199
|
+
"checkpointer": checkpointer,
|
196
200
|
}
|
197
201
|
self._compiled_graph = self._make_compiled_graph()
|
198
202
|
self.response = None
|
@@ -350,6 +354,7 @@ def make_h2o_ml_agent(
|
|
350
354
|
mlflow_tracking_uri=None,
|
351
355
|
mlflow_experiment_name="H2O AutoML",
|
352
356
|
mlflow_run_name=None,
|
357
|
+
checkpointer=None,
|
353
358
|
):
|
354
359
|
"""
|
355
360
|
Creates a machine learning agent that uses H2O for AutoML.
|
@@ -384,6 +389,12 @@ def make_h2o_ml_agent(
|
|
384
389
|
" pip install h2o\n\n"
|
385
390
|
"Visit https://docs.h2o.ai/h2o/latest-stable/h2o-docs/downloading.html for details."
|
386
391
|
) from e
|
392
|
+
|
393
|
+
if human_in_the_loop:
|
394
|
+
if checkpointer is None:
|
395
|
+
print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
|
396
|
+
checkpointer = MemorySaver()
|
397
|
+
|
387
398
|
|
388
399
|
# Define GraphState
|
389
400
|
class GraphState(TypedDict):
|
@@ -844,9 +855,10 @@ def make_h2o_ml_agent(
|
|
844
855
|
retry_count_key="retry_count",
|
845
856
|
human_in_the_loop=human_in_the_loop,
|
846
857
|
human_review_node_name="human_review",
|
847
|
-
checkpointer=
|
858
|
+
checkpointer=checkpointer,
|
848
859
|
bypass_recommended_steps=bypass_recommended_steps,
|
849
860
|
bypass_explain_code=bypass_explain_code,
|
861
|
+
agent_name=AGENT_NAME,
|
850
862
|
)
|
851
863
|
|
852
864
|
return app
|
@@ -10,6 +10,7 @@ from langchain_core.messages import BaseMessage, AIMessage
|
|
10
10
|
|
11
11
|
from langgraph.prebuilt import create_react_agent, ToolNode
|
12
12
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
13
|
+
from langgraph.types import Checkpointer
|
13
14
|
from langgraph.graph import START, END, StateGraph
|
14
15
|
|
15
16
|
from ai_data_science_team.templates import BaseAgent
|
@@ -27,6 +28,7 @@ from ai_data_science_team.tools.mlflow import (
|
|
27
28
|
mlflow_search_registered_models,
|
28
29
|
mlflow_get_model_version_details,
|
29
30
|
)
|
31
|
+
from ai_data_science_team.utils.messages import get_tool_call_names
|
30
32
|
|
31
33
|
AGENT_NAME = "mlflow_tools_agent"
|
32
34
|
|
@@ -67,6 +69,8 @@ class MLflowToolsAgent(BaseAgent):
|
|
67
69
|
Additional keyword arguments to pass to the create_react_agent function.
|
68
70
|
invoke_react_agent_kwargs : dict
|
69
71
|
Additional keyword arguments to pass to the invoke method of the react agent.
|
72
|
+
checkpointer : langchain.checkpointing.Checkpointer, optional
|
73
|
+
A checkpointer to use for saving and loading the agent's state. Defaults to None.
|
70
74
|
|
71
75
|
Methods:
|
72
76
|
--------
|
@@ -118,6 +122,7 @@ class MLflowToolsAgent(BaseAgent):
|
|
118
122
|
mlflow_registry_uri: Optional[str]=None,
|
119
123
|
create_react_agent_kwargs: Optional[Dict]={},
|
120
124
|
invoke_react_agent_kwargs: Optional[Dict]={},
|
125
|
+
checkpointer: Optional[Checkpointer]=None,
|
121
126
|
):
|
122
127
|
self._params = {
|
123
128
|
"model": model,
|
@@ -125,6 +130,7 @@ class MLflowToolsAgent(BaseAgent):
|
|
125
130
|
"mlflow_registry_uri": mlflow_registry_uri,
|
126
131
|
"create_react_agent_kwargs": create_react_agent_kwargs,
|
127
132
|
"invoke_react_agent_kwargs": invoke_react_agent_kwargs,
|
133
|
+
"checkpointer": checkpointer,
|
128
134
|
}
|
129
135
|
self._compiled_graph = self._make_compiled_graph()
|
130
136
|
self.response = None
|
@@ -228,6 +234,12 @@ class MLflowToolsAgent(BaseAgent):
|
|
228
234
|
return Markdown(self.response["messages"][0].content)
|
229
235
|
else:
|
230
236
|
return self.response["messages"][0].content
|
237
|
+
|
238
|
+
def get_tool_calls(self):
|
239
|
+
"""
|
240
|
+
Returns the tool calls made by the agent.
|
241
|
+
"""
|
242
|
+
return self.response["tool_calls"]
|
231
243
|
|
232
244
|
|
233
245
|
|
@@ -238,6 +250,7 @@ def make_mlflow_tools_agent(
|
|
238
250
|
mlflow_registry_uri: str=None,
|
239
251
|
create_react_agent_kwargs: Optional[Dict]={},
|
240
252
|
invoke_react_agent_kwargs: Optional[Dict]={},
|
253
|
+
checkpointer: Optional[Checkpointer]=None,
|
241
254
|
):
|
242
255
|
"""
|
243
256
|
MLflow Tool Calling Agent
|
@@ -254,6 +267,8 @@ def make_mlflow_tools_agent(
|
|
254
267
|
Additional keyword arguments to pass to the agent's create_react_agent method.
|
255
268
|
invoke_react_agent_kwargs : dict, optional
|
256
269
|
Additional keyword arguments to pass to the agent's invoke method.
|
270
|
+
checkpointer : langchain.checkpointing.Checkpointer, optional
|
271
|
+
A checkpointer to use for saving and loading the agent's state. Defaults to None.
|
257
272
|
|
258
273
|
Returns
|
259
274
|
-------
|
@@ -296,6 +311,7 @@ def make_mlflow_tools_agent(
|
|
296
311
|
model,
|
297
312
|
tools=tool_node,
|
298
313
|
state_schema=GraphState,
|
314
|
+
checkpointer=checkpointer,
|
299
315
|
**create_react_agent_kwargs,
|
300
316
|
)
|
301
317
|
|
@@ -330,10 +346,13 @@ def make_mlflow_tools_agent(
|
|
330
346
|
elif isinstance(last_message, dict) and "artifact" in last_message:
|
331
347
|
last_tool_artifact = last_message["artifact"]
|
332
348
|
|
349
|
+
tool_calls = get_tool_call_names(internal_messages)
|
350
|
+
|
333
351
|
return {
|
334
352
|
"messages": [last_ai_message],
|
335
353
|
"internal_messages": internal_messages,
|
336
354
|
"mlflow_artifacts": last_tool_artifact,
|
355
|
+
"tool_calls": tool_calls,
|
337
356
|
}
|
338
357
|
|
339
358
|
|
@@ -344,7 +363,10 @@ def make_mlflow_tools_agent(
|
|
344
363
|
workflow.add_edge(START, "mlflow_tools_agent")
|
345
364
|
workflow.add_edge("mlflow_tools_agent", END)
|
346
365
|
|
347
|
-
app = workflow.compile(
|
366
|
+
app = workflow.compile(
|
367
|
+
checkpointer=checkpointer,
|
368
|
+
name=AGENT_NAME,
|
369
|
+
)
|
348
370
|
|
349
371
|
return app
|
350
372
|
|
@@ -1 +1,2 @@
|
|
1
|
-
from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
|
1
|
+
from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
|
2
|
+
from ai_data_science_team.multiagents.pandas_data_analyst import PandasDataAnalyst, make_pandas_data_analyst
|
@@ -0,0 +1,305 @@
|
|
1
|
+
from langchain_core.messages import BaseMessage
|
2
|
+
from langchain.prompts import PromptTemplate
|
3
|
+
from langchain_core.output_parsers import JsonOutputParser
|
4
|
+
from langgraph.types import Checkpointer
|
5
|
+
from langgraph.graph import START, END, StateGraph
|
6
|
+
from langgraph.graph.state import CompiledStateGraph
|
7
|
+
|
8
|
+
from typing import TypedDict, Annotated, Sequence, Union
|
9
|
+
import operator
|
10
|
+
|
11
|
+
import pandas as pd
|
12
|
+
import json
|
13
|
+
from IPython.display import Markdown
|
14
|
+
|
15
|
+
from ai_data_science_team.templates import BaseAgent
|
16
|
+
from ai_data_science_team.agents import DataWranglingAgent, DataVisualizationAgent
|
17
|
+
from ai_data_science_team.utils.plotly import plotly_from_dict
|
18
|
+
from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
|
19
|
+
|
20
|
+
AGENT_NAME = "pandas_data_analyst"
|
21
|
+
|
22
|
+
class PandasDataAnalyst(BaseAgent):
|
23
|
+
"""
|
24
|
+
PandasDataAnalyst is a multi-agent class that combines data wrangling and visualization capabilities.
|
25
|
+
|
26
|
+
Parameters:
|
27
|
+
-----------
|
28
|
+
model:
|
29
|
+
The language model to be used for the agents.
|
30
|
+
data_wrangling_agent: DataWranglingAgent
|
31
|
+
The Data Wrangling Agent for transforming raw data.
|
32
|
+
data_visualization_agent: DataVisualizationAgent
|
33
|
+
The Data Visualization Agent for generating plots.
|
34
|
+
checkpointer: Checkpointer (optional)
|
35
|
+
The checkpointer to save the state of the multi-agent system.
|
36
|
+
|
37
|
+
Methods:
|
38
|
+
--------
|
39
|
+
ainvoke_agent(user_instructions, data_raw, **kwargs)
|
40
|
+
Asynchronously invokes the multi-agent with user instructions and raw data.
|
41
|
+
invoke_agent(user_instructions, data_raw, **kwargs)
|
42
|
+
Synchronously invokes the multi-agent with user instructions and raw data.
|
43
|
+
get_data_wrangled()
|
44
|
+
Returns the wrangled data as a Pandas DataFrame.
|
45
|
+
get_plotly_graph()
|
46
|
+
Returns the Plotly graph as a Plotly object.
|
47
|
+
get_data_wrangler_function(markdown=False)
|
48
|
+
Returns the data wrangling function as a string, optionally in Markdown.
|
49
|
+
get_data_visualization_function(markdown=False)
|
50
|
+
Returns the data visualization function as a string, optionally in Markdown.
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
model,
|
56
|
+
data_wrangling_agent: DataWranglingAgent,
|
57
|
+
data_visualization_agent: DataVisualizationAgent,
|
58
|
+
checkpointer: Checkpointer = None,
|
59
|
+
):
|
60
|
+
self._params = {
|
61
|
+
"model": model,
|
62
|
+
"data_wrangling_agent": data_wrangling_agent,
|
63
|
+
"data_visualization_agent": data_visualization_agent,
|
64
|
+
"checkpointer": checkpointer,
|
65
|
+
}
|
66
|
+
self._compiled_graph = self._make_compiled_graph()
|
67
|
+
self.response = None
|
68
|
+
|
69
|
+
def _make_compiled_graph(self):
|
70
|
+
"""Create or rebuild the compiled graph. Resets response to None."""
|
71
|
+
self.response = None
|
72
|
+
return make_pandas_data_analyst(
|
73
|
+
model=self._params["model"],
|
74
|
+
data_wrangling_agent=self._params["data_wrangling_agent"]._compiled_graph,
|
75
|
+
data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
|
76
|
+
checkpointer=self._params["checkpointer"],
|
77
|
+
)
|
78
|
+
|
79
|
+
def update_params(self, **kwargs):
|
80
|
+
"""Updates parameters and rebuilds the compiled graph."""
|
81
|
+
for k, v in kwargs.items():
|
82
|
+
self._params[k] = v
|
83
|
+
self._compiled_graph = self._make_compiled_graph()
|
84
|
+
|
85
|
+
async def ainvoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
|
86
|
+
"""Asynchronously invokes the multi-agent."""
|
87
|
+
response = await self._compiled_graph.ainvoke({
|
88
|
+
"user_instructions": user_instructions,
|
89
|
+
"data_raw": self._convert_data_input(data_raw),
|
90
|
+
"max_retries": max_retries,
|
91
|
+
"retry_count": retry_count,
|
92
|
+
}, **kwargs)
|
93
|
+
if response.get("messages"):
|
94
|
+
response["messages"] = remove_consecutive_duplicates(response["messages"])
|
95
|
+
self.response = response
|
96
|
+
|
97
|
+
def invoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
|
98
|
+
"""Synchronously invokes the multi-agent."""
|
99
|
+
response = self._compiled_graph.invoke({
|
100
|
+
"user_instructions": user_instructions,
|
101
|
+
"data_raw": self._convert_data_input(data_raw),
|
102
|
+
"max_retries": max_retries,
|
103
|
+
"retry_count": retry_count,
|
104
|
+
}, **kwargs)
|
105
|
+
if response.get("messages"):
|
106
|
+
response["messages"] = remove_consecutive_duplicates(response["messages"])
|
107
|
+
self.response = response
|
108
|
+
|
109
|
+
def get_data_wrangled(self):
|
110
|
+
"""Returns the wrangled data as a Pandas DataFrame."""
|
111
|
+
if self.response and self.response.get("data_wrangled"):
|
112
|
+
return pd.DataFrame(self.response.get("data_wrangled"))
|
113
|
+
|
114
|
+
def get_plotly_graph(self):
|
115
|
+
"""Returns the Plotly graph as a Plotly object."""
|
116
|
+
if self.response and self.response.get("plotly_graph"):
|
117
|
+
return plotly_from_dict(self.response.get("plotly_graph"))
|
118
|
+
|
119
|
+
def get_data_wrangler_function(self, markdown=False):
|
120
|
+
"""Returns the data wrangling function as a string."""
|
121
|
+
if self.response and self.response.get("data_wrangler_function"):
|
122
|
+
code = self.response.get("data_wrangler_function")
|
123
|
+
return Markdown(f"```python\n{code}\n```") if markdown else code
|
124
|
+
|
125
|
+
def get_data_visualization_function(self, markdown=False):
|
126
|
+
"""Returns the data visualization function as a string."""
|
127
|
+
if self.response and self.response.get("data_visualization_function"):
|
128
|
+
code = self.response.get("data_visualization_function")
|
129
|
+
return Markdown(f"```python\n{code}\n```") if markdown else code
|
130
|
+
|
131
|
+
def get_workflow_summary(self, markdown=False):
|
132
|
+
"""Returns a summary of the workflow."""
|
133
|
+
if self.response and self.response.get("messages"):
|
134
|
+
agents = [msg.role for msg in self.response["messages"]]
|
135
|
+
agent_labels = [f"- **Agent {i+1}:** {role}" for i, role in enumerate(agents)]
|
136
|
+
header = f"# Pandas Data Analyst Workflow Summary\n\nThis workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
|
137
|
+
reports = [get_generic_summary(json.loads(msg.content)) for msg in self.response["messages"]]
|
138
|
+
summary = "\n" +header + "\n\n".join(reports)
|
139
|
+
return Markdown(summary) if markdown else summary
|
140
|
+
|
141
|
+
@staticmethod
|
142
|
+
def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
|
143
|
+
"""Converts input data to the expected format (dict or list of dicts)."""
|
144
|
+
if isinstance(data_raw, pd.DataFrame):
|
145
|
+
return data_raw.to_dict()
|
146
|
+
if isinstance(data_raw, dict):
|
147
|
+
return data_raw
|
148
|
+
if isinstance(data_raw, list):
|
149
|
+
return [item.to_dict() if isinstance(item, pd.DataFrame) else item for item in data_raw]
|
150
|
+
raise ValueError("data_raw must be a DataFrame, dict, or list of DataFrames/dicts")
|
151
|
+
|
152
|
+
def make_pandas_data_analyst(
|
153
|
+
model,
|
154
|
+
data_wrangling_agent: CompiledStateGraph,
|
155
|
+
data_visualization_agent: CompiledStateGraph,
|
156
|
+
checkpointer: Checkpointer = None
|
157
|
+
):
|
158
|
+
"""
|
159
|
+
Creates a multi-agent system that wrangles data and optionally visualizes it.
|
160
|
+
|
161
|
+
Parameters:
|
162
|
+
-----------
|
163
|
+
model: The language model to be used.
|
164
|
+
data_wrangling_agent: CompiledStateGraph
|
165
|
+
The Data Wrangling Agent.
|
166
|
+
data_visualization_agent: CompiledStateGraph
|
167
|
+
The Data Visualization Agent.
|
168
|
+
checkpointer: Checkpointer (optional)
|
169
|
+
The checkpointer to save the state.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
--------
|
173
|
+
CompiledStateGraph: The compiled multi-agent system.
|
174
|
+
"""
|
175
|
+
|
176
|
+
llm = model
|
177
|
+
|
178
|
+
routing_preprocessor_prompt = PromptTemplate(
|
179
|
+
template="""
|
180
|
+
You are an expert in routing decisions for a Pandas Data Manipulation Wrangling Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to:
|
181
|
+
|
182
|
+
1. Determine what the correct format for a Users Question should be for use with a Pandas Data Wrangling Agent based on the incoming user question. Anything related to data wrangling and manipulation should be passed along.
|
183
|
+
2. Determine whether or not a chart should be generated or a table should be returned based on the users question.
|
184
|
+
3. If a chart is requested, determine the correct format of a Users Question should be used with a Data Visualization Agent. Anything related to plotting and visualization should be passed along.
|
185
|
+
|
186
|
+
Use the following criteria on how to route the the initial user question:
|
187
|
+
|
188
|
+
From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the SQL generator agent. This will be the 'user_instructions_data_wrangling'. If 'None' is found, return the original user question.
|
189
|
+
|
190
|
+
Next, determine if the user would like a data visualization ('chart') or a 'table' returned with the results of the Data Wrangling Agent. If unknown, not specified or 'None' is found, then select 'table'.
|
191
|
+
|
192
|
+
If a 'chart' is requested, return the 'user_instructions_data_visualization'. If 'None' is found, return None.
|
193
|
+
|
194
|
+
Return JSON with 'user_instructions_data_wrangling', 'user_instructions_data_visualization' and 'routing_preprocessor_decision'.
|
195
|
+
|
196
|
+
INITIAL_USER_QUESTION: {user_instructions}
|
197
|
+
""",
|
198
|
+
input_variables=["user_instructions"]
|
199
|
+
)
|
200
|
+
|
201
|
+
routing_preprocessor = routing_preprocessor_prompt | llm | JsonOutputParser()
|
202
|
+
|
203
|
+
class PrimaryState(TypedDict):
|
204
|
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
205
|
+
user_instructions: str
|
206
|
+
user_instructions_data_wrangling: str
|
207
|
+
user_instructions_data_visualization: str
|
208
|
+
routing_preprocessor_decision: str
|
209
|
+
data_raw: Union[dict, list]
|
210
|
+
data_wrangled: dict
|
211
|
+
data_wrangler_function: str
|
212
|
+
data_visualization_function: str
|
213
|
+
plotly_graph: dict
|
214
|
+
plotly_error: str
|
215
|
+
max_retries: int
|
216
|
+
retry_count: int
|
217
|
+
|
218
|
+
|
219
|
+
def preprocess_routing(state: PrimaryState):
|
220
|
+
print("---PANDAS DATA ANALYST---")
|
221
|
+
print("*************************")
|
222
|
+
print("---PREPROCESS ROUTER---")
|
223
|
+
question = state.get("user_instructions")
|
224
|
+
|
225
|
+
# Chart Routing and SQL Prep
|
226
|
+
response = routing_preprocessor.invoke({"user_instructions": question})
|
227
|
+
|
228
|
+
return {
|
229
|
+
"user_instructions_data_wrangling": response.get('user_instructions_data_wrangling'),
|
230
|
+
"user_instructions_data_visualization": response.get('user_instructions_data_visualization'),
|
231
|
+
"routing_preprocessor_decision": response.get('routing_preprocessor_decision'),
|
232
|
+
}
|
233
|
+
|
234
|
+
def router_chart_or_table(state: PrimaryState):
|
235
|
+
print("---ROUTER: CHART OR TABLE---")
|
236
|
+
return "chart" if state.get('routing_preprocessor_decision') == "chart" else "table"
|
237
|
+
|
238
|
+
|
239
|
+
def invoke_data_wrangling_agent(state: PrimaryState):
|
240
|
+
|
241
|
+
response = data_wrangling_agent.invoke({
|
242
|
+
"user_instructions": state.get("user_instructions_data_wrangling"),
|
243
|
+
"data_raw": state.get("data_raw"),
|
244
|
+
"max_retries": state.get("max_retries"),
|
245
|
+
"retry_count": state.get("retry_count"),
|
246
|
+
})
|
247
|
+
|
248
|
+
return {
|
249
|
+
"messages": response.get("messages"),
|
250
|
+
"data_wrangled": response.get("data_wrangled"),
|
251
|
+
"data_wrangler_function": response.get("data_wrangler_function"),
|
252
|
+
"plotly_error": response.get("data_visualization_error"),
|
253
|
+
|
254
|
+
}
|
255
|
+
|
256
|
+
def invoke_data_visualization_agent(state: PrimaryState):
|
257
|
+
|
258
|
+
response = data_visualization_agent.invoke({
|
259
|
+
"user_instructions": state.get("user_instructions_data_visualization"),
|
260
|
+
"data_raw": state.get("data_wrangled") if state.get("data_wrangled") else state.get("data_raw"),
|
261
|
+
"max_retries": state.get("max_retries"),
|
262
|
+
"retry_count": state.get("retry_count"),
|
263
|
+
})
|
264
|
+
|
265
|
+
return {
|
266
|
+
"messages": response.get("messages"),
|
267
|
+
"data_visualization_function": response.get("data_visualization_function"),
|
268
|
+
"plotly_graph": response.get("plotly_graph"),
|
269
|
+
"plotly_error": response.get("data_visualization_error"),
|
270
|
+
}
|
271
|
+
|
272
|
+
def route_printer(state: PrimaryState):
|
273
|
+
print("---ROUTE PRINTER---")
|
274
|
+
print(f" Route: {state.get('routing_preprocessor_decision')}")
|
275
|
+
print("---END---")
|
276
|
+
return {}
|
277
|
+
|
278
|
+
workflow = StateGraph(PrimaryState)
|
279
|
+
|
280
|
+
workflow.add_node("routing_preprocessor", preprocess_routing)
|
281
|
+
workflow.add_node("data_wrangling_agent", invoke_data_wrangling_agent)
|
282
|
+
workflow.add_node("data_visualization_agent", invoke_data_visualization_agent)
|
283
|
+
workflow.add_node("route_printer", route_printer)
|
284
|
+
|
285
|
+
workflow.add_edge(START, "routing_preprocessor")
|
286
|
+
workflow.add_edge("routing_preprocessor", "data_wrangling_agent")
|
287
|
+
|
288
|
+
workflow.add_conditional_edges(
|
289
|
+
"data_wrangling_agent",
|
290
|
+
router_chart_or_table,
|
291
|
+
{
|
292
|
+
"chart": "data_visualization_agent",
|
293
|
+
"table": "route_printer"
|
294
|
+
}
|
295
|
+
)
|
296
|
+
|
297
|
+
workflow.add_edge("data_visualization_agent", "route_printer")
|
298
|
+
workflow.add_edge("route_printer", END)
|
299
|
+
|
300
|
+
app = workflow.compile(
|
301
|
+
checkpointer=checkpointer,
|
302
|
+
name=AGENT_NAME
|
303
|
+
)
|
304
|
+
|
305
|
+
return app
|