ai-data-science-team 0.0.0.9013__py3-none-any.whl → 0.0.0.9014__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/__init__.py +22 -0
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/data_cleaning_agent.py +17 -3
- ai_data_science_team/agents/data_loader_tools_agent.py +13 -1
- ai_data_science_team/agents/data_visualization_agent.py +17 -3
- ai_data_science_team/agents/data_wrangling_agent.py +30 -10
- ai_data_science_team/agents/feature_engineering_agent.py +17 -4
- ai_data_science_team/agents/sql_database_agent.py +15 -2
- ai_data_science_team/ds_agents/eda_tools_agent.py +15 -6
- ai_data_science_team/ml_agents/h2o_ml_agent.py +15 -3
- ai_data_science_team/ml_agents/mlflow_tools_agent.py +13 -1
- ai_data_science_team/multiagents/__init__.py +2 -1
- ai_data_science_team/multiagents/pandas_data_analyst.py +305 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +119 -30
- ai_data_science_team/templates/agent_templates.py +41 -5
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/METADATA +2 -2
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/RECORD +20 -19
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,7 @@ from langchain_core.messages import BaseMessage, AIMessage
|
|
10
10
|
|
11
11
|
from langgraph.prebuilt import create_react_agent, ToolNode
|
12
12
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
13
|
+
from langgraph.types import Checkpointer
|
13
14
|
from langgraph.graph import START, END, StateGraph
|
14
15
|
|
15
16
|
from ai_data_science_team.templates import BaseAgent
|
@@ -68,6 +69,8 @@ class MLflowToolsAgent(BaseAgent):
|
|
68
69
|
Additional keyword arguments to pass to the create_react_agent function.
|
69
70
|
invoke_react_agent_kwargs : dict
|
70
71
|
Additional keyword arguments to pass to the invoke method of the react agent.
|
72
|
+
checkpointer : langchain.checkpointing.Checkpointer, optional
|
73
|
+
A checkpointer to use for saving and loading the agent's state. Defaults to None.
|
71
74
|
|
72
75
|
Methods:
|
73
76
|
--------
|
@@ -119,6 +122,7 @@ class MLflowToolsAgent(BaseAgent):
|
|
119
122
|
mlflow_registry_uri: Optional[str]=None,
|
120
123
|
create_react_agent_kwargs: Optional[Dict]={},
|
121
124
|
invoke_react_agent_kwargs: Optional[Dict]={},
|
125
|
+
checkpointer: Optional[Checkpointer]=None,
|
122
126
|
):
|
123
127
|
self._params = {
|
124
128
|
"model": model,
|
@@ -126,6 +130,7 @@ class MLflowToolsAgent(BaseAgent):
|
|
126
130
|
"mlflow_registry_uri": mlflow_registry_uri,
|
127
131
|
"create_react_agent_kwargs": create_react_agent_kwargs,
|
128
132
|
"invoke_react_agent_kwargs": invoke_react_agent_kwargs,
|
133
|
+
"checkpointer": checkpointer,
|
129
134
|
}
|
130
135
|
self._compiled_graph = self._make_compiled_graph()
|
131
136
|
self.response = None
|
@@ -245,6 +250,7 @@ def make_mlflow_tools_agent(
|
|
245
250
|
mlflow_registry_uri: str=None,
|
246
251
|
create_react_agent_kwargs: Optional[Dict]={},
|
247
252
|
invoke_react_agent_kwargs: Optional[Dict]={},
|
253
|
+
checkpointer: Optional[Checkpointer]=None,
|
248
254
|
):
|
249
255
|
"""
|
250
256
|
MLflow Tool Calling Agent
|
@@ -261,6 +267,8 @@ def make_mlflow_tools_agent(
|
|
261
267
|
Additional keyword arguments to pass to the agent's create_react_agent method.
|
262
268
|
invoke_react_agent_kwargs : dict, optional
|
263
269
|
Additional keyword arguments to pass to the agent's invoke method.
|
270
|
+
checkpointer : langchain.checkpointing.Checkpointer, optional
|
271
|
+
A checkpointer to use for saving and loading the agent's state. Defaults to None.
|
264
272
|
|
265
273
|
Returns
|
266
274
|
-------
|
@@ -303,6 +311,7 @@ def make_mlflow_tools_agent(
|
|
303
311
|
model,
|
304
312
|
tools=tool_node,
|
305
313
|
state_schema=GraphState,
|
314
|
+
checkpointer=checkpointer,
|
306
315
|
**create_react_agent_kwargs,
|
307
316
|
)
|
308
317
|
|
@@ -354,7 +363,10 @@ def make_mlflow_tools_agent(
|
|
354
363
|
workflow.add_edge(START, "mlflow_tools_agent")
|
355
364
|
workflow.add_edge("mlflow_tools_agent", END)
|
356
365
|
|
357
|
-
app = workflow.compile(
|
366
|
+
app = workflow.compile(
|
367
|
+
checkpointer=checkpointer,
|
368
|
+
name=AGENT_NAME,
|
369
|
+
)
|
358
370
|
|
359
371
|
return app
|
360
372
|
|
@@ -1 +1,2 @@
|
|
1
|
-
from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
|
1
|
+
from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
|
2
|
+
from ai_data_science_team.multiagents.pandas_data_analyst import PandasDataAnalyst, make_pandas_data_analyst
|
@@ -0,0 +1,305 @@
|
|
1
|
+
from langchain_core.messages import BaseMessage
|
2
|
+
from langchain.prompts import PromptTemplate
|
3
|
+
from langchain_core.output_parsers import JsonOutputParser
|
4
|
+
from langgraph.types import Checkpointer
|
5
|
+
from langgraph.graph import START, END, StateGraph
|
6
|
+
from langgraph.graph.state import CompiledStateGraph
|
7
|
+
|
8
|
+
from typing import TypedDict, Annotated, Sequence, Union
|
9
|
+
import operator
|
10
|
+
|
11
|
+
import pandas as pd
|
12
|
+
import json
|
13
|
+
from IPython.display import Markdown
|
14
|
+
|
15
|
+
from ai_data_science_team.templates import BaseAgent
|
16
|
+
from ai_data_science_team.agents import DataWranglingAgent, DataVisualizationAgent
|
17
|
+
from ai_data_science_team.utils.plotly import plotly_from_dict
|
18
|
+
from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
|
19
|
+
|
20
|
+
AGENT_NAME = "pandas_data_analyst"
|
21
|
+
|
22
|
+
class PandasDataAnalyst(BaseAgent):
|
23
|
+
"""
|
24
|
+
PandasDataAnalyst is a multi-agent class that combines data wrangling and visualization capabilities.
|
25
|
+
|
26
|
+
Parameters:
|
27
|
+
-----------
|
28
|
+
model:
|
29
|
+
The language model to be used for the agents.
|
30
|
+
data_wrangling_agent: DataWranglingAgent
|
31
|
+
The Data Wrangling Agent for transforming raw data.
|
32
|
+
data_visualization_agent: DataVisualizationAgent
|
33
|
+
The Data Visualization Agent for generating plots.
|
34
|
+
checkpointer: Checkpointer (optional)
|
35
|
+
The checkpointer to save the state of the multi-agent system.
|
36
|
+
|
37
|
+
Methods:
|
38
|
+
--------
|
39
|
+
ainvoke_agent(user_instructions, data_raw, **kwargs)
|
40
|
+
Asynchronously invokes the multi-agent with user instructions and raw data.
|
41
|
+
invoke_agent(user_instructions, data_raw, **kwargs)
|
42
|
+
Synchronously invokes the multi-agent with user instructions and raw data.
|
43
|
+
get_data_wrangled()
|
44
|
+
Returns the wrangled data as a Pandas DataFrame.
|
45
|
+
get_plotly_graph()
|
46
|
+
Returns the Plotly graph as a Plotly object.
|
47
|
+
get_data_wrangler_function(markdown=False)
|
48
|
+
Returns the data wrangling function as a string, optionally in Markdown.
|
49
|
+
get_data_visualization_function(markdown=False)
|
50
|
+
Returns the data visualization function as a string, optionally in Markdown.
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
model,
|
56
|
+
data_wrangling_agent: DataWranglingAgent,
|
57
|
+
data_visualization_agent: DataVisualizationAgent,
|
58
|
+
checkpointer: Checkpointer = None,
|
59
|
+
):
|
60
|
+
self._params = {
|
61
|
+
"model": model,
|
62
|
+
"data_wrangling_agent": data_wrangling_agent,
|
63
|
+
"data_visualization_agent": data_visualization_agent,
|
64
|
+
"checkpointer": checkpointer,
|
65
|
+
}
|
66
|
+
self._compiled_graph = self._make_compiled_graph()
|
67
|
+
self.response = None
|
68
|
+
|
69
|
+
def _make_compiled_graph(self):
|
70
|
+
"""Create or rebuild the compiled graph. Resets response to None."""
|
71
|
+
self.response = None
|
72
|
+
return make_pandas_data_analyst(
|
73
|
+
model=self._params["model"],
|
74
|
+
data_wrangling_agent=self._params["data_wrangling_agent"]._compiled_graph,
|
75
|
+
data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
|
76
|
+
checkpointer=self._params["checkpointer"],
|
77
|
+
)
|
78
|
+
|
79
|
+
def update_params(self, **kwargs):
|
80
|
+
"""Updates parameters and rebuilds the compiled graph."""
|
81
|
+
for k, v in kwargs.items():
|
82
|
+
self._params[k] = v
|
83
|
+
self._compiled_graph = self._make_compiled_graph()
|
84
|
+
|
85
|
+
async def ainvoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
|
86
|
+
"""Asynchronously invokes the multi-agent."""
|
87
|
+
response = await self._compiled_graph.ainvoke({
|
88
|
+
"user_instructions": user_instructions,
|
89
|
+
"data_raw": self._convert_data_input(data_raw),
|
90
|
+
"max_retries": max_retries,
|
91
|
+
"retry_count": retry_count,
|
92
|
+
}, **kwargs)
|
93
|
+
if response.get("messages"):
|
94
|
+
response["messages"] = remove_consecutive_duplicates(response["messages"])
|
95
|
+
self.response = response
|
96
|
+
|
97
|
+
def invoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
|
98
|
+
"""Synchronously invokes the multi-agent."""
|
99
|
+
response = self._compiled_graph.invoke({
|
100
|
+
"user_instructions": user_instructions,
|
101
|
+
"data_raw": self._convert_data_input(data_raw),
|
102
|
+
"max_retries": max_retries,
|
103
|
+
"retry_count": retry_count,
|
104
|
+
}, **kwargs)
|
105
|
+
if response.get("messages"):
|
106
|
+
response["messages"] = remove_consecutive_duplicates(response["messages"])
|
107
|
+
self.response = response
|
108
|
+
|
109
|
+
def get_data_wrangled(self):
|
110
|
+
"""Returns the wrangled data as a Pandas DataFrame."""
|
111
|
+
if self.response and self.response.get("data_wrangled"):
|
112
|
+
return pd.DataFrame(self.response.get("data_wrangled"))
|
113
|
+
|
114
|
+
def get_plotly_graph(self):
|
115
|
+
"""Returns the Plotly graph as a Plotly object."""
|
116
|
+
if self.response and self.response.get("plotly_graph"):
|
117
|
+
return plotly_from_dict(self.response.get("plotly_graph"))
|
118
|
+
|
119
|
+
def get_data_wrangler_function(self, markdown=False):
|
120
|
+
"""Returns the data wrangling function as a string."""
|
121
|
+
if self.response and self.response.get("data_wrangler_function"):
|
122
|
+
code = self.response.get("data_wrangler_function")
|
123
|
+
return Markdown(f"```python\n{code}\n```") if markdown else code
|
124
|
+
|
125
|
+
def get_data_visualization_function(self, markdown=False):
|
126
|
+
"""Returns the data visualization function as a string."""
|
127
|
+
if self.response and self.response.get("data_visualization_function"):
|
128
|
+
code = self.response.get("data_visualization_function")
|
129
|
+
return Markdown(f"```python\n{code}\n```") if markdown else code
|
130
|
+
|
131
|
+
def get_workflow_summary(self, markdown=False):
|
132
|
+
"""Returns a summary of the workflow."""
|
133
|
+
if self.response and self.response.get("messages"):
|
134
|
+
agents = [msg.role for msg in self.response["messages"]]
|
135
|
+
agent_labels = [f"- **Agent {i+1}:** {role}" for i, role in enumerate(agents)]
|
136
|
+
header = f"# Pandas Data Analyst Workflow Summary\n\nThis workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
|
137
|
+
reports = [get_generic_summary(json.loads(msg.content)) for msg in self.response["messages"]]
|
138
|
+
summary = "\n" +header + "\n\n".join(reports)
|
139
|
+
return Markdown(summary) if markdown else summary
|
140
|
+
|
141
|
+
@staticmethod
|
142
|
+
def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
|
143
|
+
"""Converts input data to the expected format (dict or list of dicts)."""
|
144
|
+
if isinstance(data_raw, pd.DataFrame):
|
145
|
+
return data_raw.to_dict()
|
146
|
+
if isinstance(data_raw, dict):
|
147
|
+
return data_raw
|
148
|
+
if isinstance(data_raw, list):
|
149
|
+
return [item.to_dict() if isinstance(item, pd.DataFrame) else item for item in data_raw]
|
150
|
+
raise ValueError("data_raw must be a DataFrame, dict, or list of DataFrames/dicts")
|
151
|
+
|
152
|
+
def make_pandas_data_analyst(
|
153
|
+
model,
|
154
|
+
data_wrangling_agent: CompiledStateGraph,
|
155
|
+
data_visualization_agent: CompiledStateGraph,
|
156
|
+
checkpointer: Checkpointer = None
|
157
|
+
):
|
158
|
+
"""
|
159
|
+
Creates a multi-agent system that wrangles data and optionally visualizes it.
|
160
|
+
|
161
|
+
Parameters:
|
162
|
+
-----------
|
163
|
+
model: The language model to be used.
|
164
|
+
data_wrangling_agent: CompiledStateGraph
|
165
|
+
The Data Wrangling Agent.
|
166
|
+
data_visualization_agent: CompiledStateGraph
|
167
|
+
The Data Visualization Agent.
|
168
|
+
checkpointer: Checkpointer (optional)
|
169
|
+
The checkpointer to save the state.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
--------
|
173
|
+
CompiledStateGraph: The compiled multi-agent system.
|
174
|
+
"""
|
175
|
+
|
176
|
+
llm = model
|
177
|
+
|
178
|
+
routing_preprocessor_prompt = PromptTemplate(
|
179
|
+
template="""
|
180
|
+
You are an expert in routing decisions for a Pandas Data Manipulation Wrangling Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to:
|
181
|
+
|
182
|
+
1. Determine what the correct format for a Users Question should be for use with a Pandas Data Wrangling Agent based on the incoming user question. Anything related to data wrangling and manipulation should be passed along.
|
183
|
+
2. Determine whether or not a chart should be generated or a table should be returned based on the users question.
|
184
|
+
3. If a chart is requested, determine the correct format of a Users Question should be used with a Data Visualization Agent. Anything related to plotting and visualization should be passed along.
|
185
|
+
|
186
|
+
Use the following criteria on how to route the the initial user question:
|
187
|
+
|
188
|
+
From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the SQL generator agent. This will be the 'user_instructions_data_wrangling'. If 'None' is found, return the original user question.
|
189
|
+
|
190
|
+
Next, determine if the user would like a data visualization ('chart') or a 'table' returned with the results of the Data Wrangling Agent. If unknown, not specified or 'None' is found, then select 'table'.
|
191
|
+
|
192
|
+
If a 'chart' is requested, return the 'user_instructions_data_visualization'. If 'None' is found, return None.
|
193
|
+
|
194
|
+
Return JSON with 'user_instructions_data_wrangling', 'user_instructions_data_visualization' and 'routing_preprocessor_decision'.
|
195
|
+
|
196
|
+
INITIAL_USER_QUESTION: {user_instructions}
|
197
|
+
""",
|
198
|
+
input_variables=["user_instructions"]
|
199
|
+
)
|
200
|
+
|
201
|
+
routing_preprocessor = routing_preprocessor_prompt | llm | JsonOutputParser()
|
202
|
+
|
203
|
+
class PrimaryState(TypedDict):
|
204
|
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
205
|
+
user_instructions: str
|
206
|
+
user_instructions_data_wrangling: str
|
207
|
+
user_instructions_data_visualization: str
|
208
|
+
routing_preprocessor_decision: str
|
209
|
+
data_raw: Union[dict, list]
|
210
|
+
data_wrangled: dict
|
211
|
+
data_wrangler_function: str
|
212
|
+
data_visualization_function: str
|
213
|
+
plotly_graph: dict
|
214
|
+
plotly_error: str
|
215
|
+
max_retries: int
|
216
|
+
retry_count: int
|
217
|
+
|
218
|
+
|
219
|
+
def preprocess_routing(state: PrimaryState):
|
220
|
+
print("---PANDAS DATA ANALYST---")
|
221
|
+
print("*************************")
|
222
|
+
print("---PREPROCESS ROUTER---")
|
223
|
+
question = state.get("user_instructions")
|
224
|
+
|
225
|
+
# Chart Routing and SQL Prep
|
226
|
+
response = routing_preprocessor.invoke({"user_instructions": question})
|
227
|
+
|
228
|
+
return {
|
229
|
+
"user_instructions_data_wrangling": response.get('user_instructions_data_wrangling'),
|
230
|
+
"user_instructions_data_visualization": response.get('user_instructions_data_visualization'),
|
231
|
+
"routing_preprocessor_decision": response.get('routing_preprocessor_decision'),
|
232
|
+
}
|
233
|
+
|
234
|
+
def router_chart_or_table(state: PrimaryState):
|
235
|
+
print("---ROUTER: CHART OR TABLE---")
|
236
|
+
return "chart" if state.get('routing_preprocessor_decision') == "chart" else "table"
|
237
|
+
|
238
|
+
|
239
|
+
def invoke_data_wrangling_agent(state: PrimaryState):
|
240
|
+
|
241
|
+
response = data_wrangling_agent.invoke({
|
242
|
+
"user_instructions": state.get("user_instructions_data_wrangling"),
|
243
|
+
"data_raw": state.get("data_raw"),
|
244
|
+
"max_retries": state.get("max_retries"),
|
245
|
+
"retry_count": state.get("retry_count"),
|
246
|
+
})
|
247
|
+
|
248
|
+
return {
|
249
|
+
"messages": response.get("messages"),
|
250
|
+
"data_wrangled": response.get("data_wrangled"),
|
251
|
+
"data_wrangler_function": response.get("data_wrangler_function"),
|
252
|
+
"plotly_error": response.get("data_visualization_error"),
|
253
|
+
|
254
|
+
}
|
255
|
+
|
256
|
+
def invoke_data_visualization_agent(state: PrimaryState):
|
257
|
+
|
258
|
+
response = data_visualization_agent.invoke({
|
259
|
+
"user_instructions": state.get("user_instructions_data_visualization"),
|
260
|
+
"data_raw": state.get("data_wrangled") if state.get("data_wrangled") else state.get("data_raw"),
|
261
|
+
"max_retries": state.get("max_retries"),
|
262
|
+
"retry_count": state.get("retry_count"),
|
263
|
+
})
|
264
|
+
|
265
|
+
return {
|
266
|
+
"messages": response.get("messages"),
|
267
|
+
"data_visualization_function": response.get("data_visualization_function"),
|
268
|
+
"plotly_graph": response.get("plotly_graph"),
|
269
|
+
"plotly_error": response.get("data_visualization_error"),
|
270
|
+
}
|
271
|
+
|
272
|
+
def route_printer(state: PrimaryState):
|
273
|
+
print("---ROUTE PRINTER---")
|
274
|
+
print(f" Route: {state.get('routing_preprocessor_decision')}")
|
275
|
+
print("---END---")
|
276
|
+
return {}
|
277
|
+
|
278
|
+
workflow = StateGraph(PrimaryState)
|
279
|
+
|
280
|
+
workflow.add_node("routing_preprocessor", preprocess_routing)
|
281
|
+
workflow.add_node("data_wrangling_agent", invoke_data_wrangling_agent)
|
282
|
+
workflow.add_node("data_visualization_agent", invoke_data_visualization_agent)
|
283
|
+
workflow.add_node("route_printer", route_printer)
|
284
|
+
|
285
|
+
workflow.add_edge(START, "routing_preprocessor")
|
286
|
+
workflow.add_edge("routing_preprocessor", "data_wrangling_agent")
|
287
|
+
|
288
|
+
workflow.add_conditional_edges(
|
289
|
+
"data_wrangling_agent",
|
290
|
+
router_chart_or_table,
|
291
|
+
{
|
292
|
+
"chart": "data_visualization_agent",
|
293
|
+
"table": "route_printer"
|
294
|
+
}
|
295
|
+
)
|
296
|
+
|
297
|
+
workflow.add_edge("data_visualization_agent", "route_printer")
|
298
|
+
workflow.add_edge("route_printer", END)
|
299
|
+
|
300
|
+
app = workflow.compile(
|
301
|
+
checkpointer=checkpointer,
|
302
|
+
name=AGENT_NAME
|
303
|
+
)
|
304
|
+
|
305
|
+
return app
|
@@ -1,12 +1,14 @@
|
|
1
1
|
|
2
2
|
from langchain_core.messages import BaseMessage
|
3
|
-
|
3
|
+
|
4
|
+
from langchain.prompts import PromptTemplate
|
5
|
+
from langchain_core.output_parsers import JsonOutputParser
|
4
6
|
|
5
7
|
from langgraph.graph import START, END, StateGraph
|
6
8
|
from langgraph.graph.state import CompiledStateGraph
|
7
|
-
from langgraph.types import
|
9
|
+
from langgraph.types import Checkpointer
|
8
10
|
|
9
|
-
from typing import TypedDict, Annotated, Sequence
|
11
|
+
from typing import TypedDict, Annotated, Sequence
|
10
12
|
import operator
|
11
13
|
|
12
14
|
from typing_extensions import TypedDict
|
@@ -20,6 +22,7 @@ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
|
20
22
|
from ai_data_science_team.utils.plotly import plotly_from_dict
|
21
23
|
from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
|
22
24
|
|
25
|
+
AGENT_NAME = "sql_data_analyst"
|
23
26
|
|
24
27
|
class SQLDataAnalyst(BaseAgent):
|
25
28
|
"""
|
@@ -33,6 +36,8 @@ class SQLDataAnalyst(BaseAgent):
|
|
33
36
|
The SQL Database Agent.
|
34
37
|
data_visualization_agent: DataVisualizationAgent
|
35
38
|
The Data Visualization Agent.
|
39
|
+
checkpointer: Checkpointer (optional)
|
40
|
+
The checkpointer to save the state of the multi-agent system.
|
36
41
|
|
37
42
|
Methods:
|
38
43
|
--------
|
@@ -326,17 +331,17 @@ def make_sql_data_analyst(
|
|
326
331
|
"""
|
327
332
|
Creates a multi-agent system that takes in a SQL query and returns a plot or table.
|
328
333
|
|
329
|
-
- Agent 1: SQL Database Agent made with `
|
330
|
-
- Agent 2: Data Visualization Agent made with `
|
334
|
+
- Agent 1: SQL Database Agent made with `SQLDatabaseAgent()`
|
335
|
+
- Agent 2: Data Visualization Agent made with `DataVisualizationAgent()`
|
331
336
|
|
332
337
|
Parameters:
|
333
338
|
----------
|
334
339
|
model:
|
335
340
|
The language model to be used for the agents.
|
336
341
|
sql_database_agent: CompiledStateGraph
|
337
|
-
The SQL Database Agent made with `
|
342
|
+
The SQL Database Agent made with `SQLDatabaseAgent()`.
|
338
343
|
data_visualization_agent: CompiledStateGraph
|
339
|
-
The Data Visualization Agent made with `
|
344
|
+
The Data Visualization Agent made with `DataVisualizationAgent()`.
|
340
345
|
checkpointer: Checkpointer (optional)
|
341
346
|
The checkpointer to save the state of the multi-agent system.
|
342
347
|
Default: None
|
@@ -348,10 +353,39 @@ def make_sql_data_analyst(
|
|
348
353
|
"""
|
349
354
|
|
350
355
|
llm = model
|
356
|
+
|
357
|
+
|
358
|
+
routing_preprocessor_prompt = PromptTemplate(
|
359
|
+
template="""
|
360
|
+
You are an expert in routing decisions for a SQL Database Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to:
|
361
|
+
|
362
|
+
1. Determine what the correct format for a Users Question should be for use with a SQL Database Agent based on the incoming user question. Anything related to database and data manipulation should be passed along.
|
363
|
+
2. Determine whether or not a chart should be generated or a table should be returned based on the users question.
|
364
|
+
3. If a chart is requested, determine the correct format of a Users Question should be used with a Data Visualization Agent. Anything related to plotting and visualization should be passed along.
|
365
|
+
|
366
|
+
Use the following criteria on how to route the the initial user question:
|
367
|
+
|
368
|
+
From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the SQL generator agent. This will be the 'user_instructions_sql_database'. If 'None' is found, return the original user question.
|
369
|
+
|
370
|
+
Next, determine if the user would like a data visualization ('chart') or a 'table' returned with the results of the Data Wrangling Agent. If unknown, not specified or 'None' is found, then select 'table'.
|
371
|
+
|
372
|
+
If a 'chart' is requested, return the 'user_instructions_data_visualization'. If 'None' is found, return None.
|
373
|
+
|
374
|
+
Return JSON with 'user_instructions_sql_database', 'user_instructions_data_visualization' and 'routing_preprocessor_decision'.
|
375
|
+
|
376
|
+
INITIAL_USER_QUESTION: {user_instructions}
|
377
|
+
""",
|
378
|
+
input_variables=["user_instructions"]
|
379
|
+
)
|
380
|
+
|
381
|
+
routing_preprocessor = routing_preprocessor_prompt | llm | JsonOutputParser()
|
351
382
|
|
352
383
|
class PrimaryState(TypedDict):
|
353
384
|
messages: Annotated[Sequence[BaseMessage], operator.add]
|
354
385
|
user_instructions: str
|
386
|
+
user_instructions_sql_database: str
|
387
|
+
user_instructions_data_visualization: str
|
388
|
+
routing_preprocessor_decision: str
|
355
389
|
sql_query_code: str
|
356
390
|
sql_database_function: str
|
357
391
|
data_sql: dict
|
@@ -359,39 +393,94 @@ def make_sql_data_analyst(
|
|
359
393
|
plot_required: bool
|
360
394
|
data_visualization_function: str
|
361
395
|
plotly_graph: dict
|
396
|
+
plotly_error: str
|
362
397
|
max_retries: int
|
363
398
|
retry_count: int
|
364
399
|
|
365
|
-
def
|
400
|
+
def preprocess_routing(state: PrimaryState):
|
401
|
+
print("---SQL DATA ANALYST---")
|
402
|
+
print("*************************")
|
403
|
+
print("---PREPROCESS ROUTER---")
|
404
|
+
question = state.get("user_instructions")
|
366
405
|
|
367
|
-
|
406
|
+
# Chart Routing and SQL Prep
|
407
|
+
response = routing_preprocessor.invoke({"user_instructions": question})
|
368
408
|
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
409
|
+
return {
|
410
|
+
"user_instructions_sql_database": response.get('user_instructions_sql_database'),
|
411
|
+
"user_instructions_data_visualization": response.get('user_instructions_data_visualization'),
|
412
|
+
"routing_preprocessor_decision": response.get('routing_preprocessor_decision'),
|
413
|
+
}
|
414
|
+
|
415
|
+
def router_chart_or_table(state: PrimaryState):
|
416
|
+
print("---ROUTER: CHART OR TABLE---")
|
417
|
+
return "chart" if state.get('routing_preprocessor_decision') == "chart" else "table"
|
418
|
+
|
419
|
+
|
420
|
+
def invoke_sql_database_agent(state: PrimaryState):
|
375
421
|
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
goto=goto
|
382
|
-
)
|
422
|
+
response = sql_database_agent.invoke({
|
423
|
+
"user_instructions": state.get("user_instructions_sql_database"),
|
424
|
+
"max_retries": state.get("max_retries"),
|
425
|
+
"retry_count": state.get("retry_count"),
|
426
|
+
})
|
383
427
|
|
384
|
-
|
428
|
+
return {
|
429
|
+
"messages": response.get("messages"),
|
430
|
+
"data_sql": response.get("data_sql"),
|
431
|
+
"sql_query_code": response.get("sql_query_code"),
|
432
|
+
"sql_database_function": response.get("sql_database_function"),
|
433
|
+
|
434
|
+
}
|
435
|
+
|
436
|
+
def invoke_data_visualization_agent(state: PrimaryState):
|
437
|
+
|
438
|
+
response = data_visualization_agent.invoke({
|
439
|
+
"user_instructions": state.get("user_instructions_data_visualization"),
|
440
|
+
"data_raw": state.get("data_sql"),
|
441
|
+
"max_retries": state.get("max_retries"),
|
442
|
+
"retry_count": state.get("retry_count"),
|
443
|
+
})
|
444
|
+
|
445
|
+
return {
|
446
|
+
"messages": response.get("messages"),
|
447
|
+
"data_visualization_function": response.get("data_visualization_function"),
|
448
|
+
"plotly_graph": response.get("plotly_graph"),
|
449
|
+
"plotly_error": response.get("data_visualization_error"),
|
450
|
+
}
|
385
451
|
|
386
|
-
|
387
|
-
|
388
|
-
|
452
|
+
def route_printer(state: PrimaryState):
|
453
|
+
print("---ROUTE PRINTER---")
|
454
|
+
print(f" Route: {state.get('routing_preprocessor_decision')}")
|
455
|
+
print("---END---")
|
456
|
+
return {}
|
457
|
+
|
458
|
+
workflow = StateGraph(PrimaryState)
|
459
|
+
|
460
|
+
workflow.add_node("routing_preprocessor", preprocess_routing)
|
461
|
+
workflow.add_node("sql_database_agent", invoke_sql_database_agent)
|
462
|
+
workflow.add_node("data_visualization_agent", invoke_data_visualization_agent)
|
463
|
+
workflow.add_node("route_printer", route_printer)
|
389
464
|
|
390
|
-
workflow.add_edge(START, "
|
391
|
-
workflow.add_edge("
|
392
|
-
|
465
|
+
workflow.add_edge(START, "routing_preprocessor")
|
466
|
+
workflow.add_edge("routing_preprocessor", "sql_database_agent")
|
467
|
+
|
468
|
+
workflow.add_conditional_edges(
|
469
|
+
"sql_database_agent",
|
470
|
+
router_chart_or_table,
|
471
|
+
{
|
472
|
+
"chart": "data_visualization_agent",
|
473
|
+
"table": "route_printer"
|
474
|
+
}
|
475
|
+
)
|
476
|
+
|
477
|
+
workflow.add_edge("data_visualization_agent", "route_printer")
|
478
|
+
workflow.add_edge("route_printer", END)
|
393
479
|
|
394
|
-
app = workflow.compile(
|
480
|
+
app = workflow.compile(
|
481
|
+
checkpointer=checkpointer,
|
482
|
+
name=AGENT_NAME
|
483
|
+
)
|
395
484
|
|
396
485
|
return app
|
397
486
|
|
@@ -40,6 +40,21 @@ class BaseAgent(CompiledStateGraph):
|
|
40
40
|
self._params = params
|
41
41
|
self._compiled_graph = self._make_compiled_graph()
|
42
42
|
self.response = None
|
43
|
+
self.name = self._compiled_graph.name
|
44
|
+
self.checkpointer = self._compiled_graph.checkpointer
|
45
|
+
self.store = self._compiled_graph.store
|
46
|
+
self.output_channels = self._compiled_graph.output_channels
|
47
|
+
self.nodes = self._compiled_graph.nodes
|
48
|
+
self.stream_mode = self._compiled_graph.stream_mode
|
49
|
+
self.builder = self._compiled_graph.builder
|
50
|
+
self.channels = self._compiled_graph.channels
|
51
|
+
self.input_channels = self._compiled_graph.input_channels
|
52
|
+
self.input_schema = self._compiled_graph.input_schema
|
53
|
+
self.output_schema = self._compiled_graph.output_schema
|
54
|
+
self.debug = self._compiled_graph.debug
|
55
|
+
self.interrupt_after_nodes = self._compiled_graph.interrupt_after_nodes
|
56
|
+
self.interrupt_before_nodes = self._compiled_graph.interrupt_before_nodes
|
57
|
+
self.config = self._compiled_graph.config
|
43
58
|
|
44
59
|
def _make_compiled_graph(self):
|
45
60
|
"""
|
@@ -197,6 +212,24 @@ class BaseAgent(CompiledStateGraph):
|
|
197
212
|
"""
|
198
213
|
return self.get_output_jsonschema()['properties']
|
199
214
|
|
215
|
+
def get_state(self, config, *, subgraphs = False):
|
216
|
+
"""
|
217
|
+
Returns the state of the agent.
|
218
|
+
"""
|
219
|
+
return self._compiled_graph.get_state(config, subgraphs=subgraphs)
|
220
|
+
|
221
|
+
def get_state_history(self, config, *, filter = None, before = None, limit = None):
|
222
|
+
"""
|
223
|
+
Returns the state history of the agent.
|
224
|
+
"""
|
225
|
+
return self._compiled_graph.get_state_history(config, filter=filter, before=before, limit=limit)
|
226
|
+
|
227
|
+
def update_state(self, config, values, as_node = None):
|
228
|
+
"""
|
229
|
+
Updates the state of the agent.
|
230
|
+
"""
|
231
|
+
return self._compiled_graph.update_state(config, values, as_node)
|
232
|
+
|
200
233
|
def get_response(self):
|
201
234
|
"""
|
202
235
|
Returns the response generated by the agent.
|
@@ -237,6 +270,7 @@ def create_coding_agent_graph(
|
|
237
270
|
checkpointer: Optional[Callable] = None,
|
238
271
|
bypass_recommended_steps: bool = False,
|
239
272
|
bypass_explain_code: bool = False,
|
273
|
+
agent_name: str = "coding_agent"
|
240
274
|
):
|
241
275
|
"""
|
242
276
|
Creates a generic agent graph using the provided node functions and node names.
|
@@ -281,6 +315,8 @@ def create_coding_agent_graph(
|
|
281
315
|
Whether to skip the recommended steps node.
|
282
316
|
bypass_explain_code : bool, optional
|
283
317
|
Whether to skip the final explain code node.
|
318
|
+
name : str, optional
|
319
|
+
The name of the agent graph.
|
284
320
|
|
285
321
|
Returns
|
286
322
|
-------
|
@@ -366,10 +402,10 @@ def create_coding_agent_graph(
|
|
366
402
|
workflow.add_edge(explain_code_node_name, END)
|
367
403
|
|
368
404
|
# Finally, compile
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
405
|
+
app = workflow.compile(
|
406
|
+
checkpointer=checkpointer,
|
407
|
+
name=agent_name,
|
408
|
+
)
|
373
409
|
|
374
410
|
return app
|
375
411
|
|
@@ -574,7 +610,7 @@ def node_func_execute_agent_from_sql_connection(
|
|
574
610
|
|
575
611
|
# Retrieve SQLAlchemy connection and code snippet from the state
|
576
612
|
is_engine = isinstance(connection, sql.engine.base.Engine)
|
577
|
-
|
613
|
+
connection = connection.connect() if is_engine else connection
|
578
614
|
agent_code = state.get(code_snippet_key)
|
579
615
|
|
580
616
|
# Ensure the connection object is provided
|