ai-data-science-team 0.0.0.9013__py3-none-any.whl → 0.0.0.9014__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_data_science_team/__init__.py +22 -0
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/data_cleaning_agent.py +17 -3
- ai_data_science_team/agents/data_loader_tools_agent.py +13 -1
- ai_data_science_team/agents/data_visualization_agent.py +17 -3
- ai_data_science_team/agents/data_wrangling_agent.py +30 -10
- ai_data_science_team/agents/feature_engineering_agent.py +17 -4
- ai_data_science_team/agents/sql_database_agent.py +15 -2
- ai_data_science_team/ds_agents/eda_tools_agent.py +15 -6
- ai_data_science_team/ml_agents/h2o_ml_agent.py +15 -3
- ai_data_science_team/ml_agents/mlflow_tools_agent.py +13 -1
- ai_data_science_team/multiagents/__init__.py +2 -1
- ai_data_science_team/multiagents/pandas_data_analyst.py +305 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +119 -30
- ai_data_science_team/templates/agent_templates.py +41 -5
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/METADATA +2 -2
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/RECORD +20 -19
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,7 @@ from langchain_core.messages import BaseMessage, AIMessage
|
|
10
10
|
|
11
11
|
from langgraph.prebuilt import create_react_agent, ToolNode
|
12
12
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
13
|
+
from langgraph.types import Checkpointer
|
13
14
|
from langgraph.graph import START, END, StateGraph
|
14
15
|
|
15
16
|
from ai_data_science_team.templates import BaseAgent
|
@@ -68,6 +69,8 @@ class MLflowToolsAgent(BaseAgent):
|
|
68
69
|
Additional keyword arguments to pass to the create_react_agent function.
|
69
70
|
invoke_react_agent_kwargs : dict
|
70
71
|
Additional keyword arguments to pass to the invoke method of the react agent.
|
72
|
+
checkpointer : langchain.checkpointing.Checkpointer, optional
|
73
|
+
A checkpointer to use for saving and loading the agent's state. Defaults to None.
|
71
74
|
|
72
75
|
Methods:
|
73
76
|
--------
|
@@ -119,6 +122,7 @@ class MLflowToolsAgent(BaseAgent):
|
|
119
122
|
mlflow_registry_uri: Optional[str]=None,
|
120
123
|
create_react_agent_kwargs: Optional[Dict]={},
|
121
124
|
invoke_react_agent_kwargs: Optional[Dict]={},
|
125
|
+
checkpointer: Optional[Checkpointer]=None,
|
122
126
|
):
|
123
127
|
self._params = {
|
124
128
|
"model": model,
|
@@ -126,6 +130,7 @@ class MLflowToolsAgent(BaseAgent):
|
|
126
130
|
"mlflow_registry_uri": mlflow_registry_uri,
|
127
131
|
"create_react_agent_kwargs": create_react_agent_kwargs,
|
128
132
|
"invoke_react_agent_kwargs": invoke_react_agent_kwargs,
|
133
|
+
"checkpointer": checkpointer,
|
129
134
|
}
|
130
135
|
self._compiled_graph = self._make_compiled_graph()
|
131
136
|
self.response = None
|
@@ -245,6 +250,7 @@ def make_mlflow_tools_agent(
|
|
245
250
|
mlflow_registry_uri: str=None,
|
246
251
|
create_react_agent_kwargs: Optional[Dict]={},
|
247
252
|
invoke_react_agent_kwargs: Optional[Dict]={},
|
253
|
+
checkpointer: Optional[Checkpointer]=None,
|
248
254
|
):
|
249
255
|
"""
|
250
256
|
MLflow Tool Calling Agent
|
@@ -261,6 +267,8 @@ def make_mlflow_tools_agent(
|
|
261
267
|
Additional keyword arguments to pass to the agent's create_react_agent method.
|
262
268
|
invoke_react_agent_kwargs : dict, optional
|
263
269
|
Additional keyword arguments to pass to the agent's invoke method.
|
270
|
+
checkpointer : langchain.checkpointing.Checkpointer, optional
|
271
|
+
A checkpointer to use for saving and loading the agent's state. Defaults to None.
|
264
272
|
|
265
273
|
Returns
|
266
274
|
-------
|
@@ -303,6 +311,7 @@ def make_mlflow_tools_agent(
|
|
303
311
|
model,
|
304
312
|
tools=tool_node,
|
305
313
|
state_schema=GraphState,
|
314
|
+
checkpointer=checkpointer,
|
306
315
|
**create_react_agent_kwargs,
|
307
316
|
)
|
308
317
|
|
@@ -354,7 +363,10 @@ def make_mlflow_tools_agent(
|
|
354
363
|
workflow.add_edge(START, "mlflow_tools_agent")
|
355
364
|
workflow.add_edge("mlflow_tools_agent", END)
|
356
365
|
|
357
|
-
app = workflow.compile(
|
366
|
+
app = workflow.compile(
|
367
|
+
checkpointer=checkpointer,
|
368
|
+
name=AGENT_NAME,
|
369
|
+
)
|
358
370
|
|
359
371
|
return app
|
360
372
|
|
@@ -1 +1,2 @@
|
|
1
|
-
from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
|
1
|
+
from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
|
2
|
+
from ai_data_science_team.multiagents.pandas_data_analyst import PandasDataAnalyst, make_pandas_data_analyst
|
@@ -0,0 +1,305 @@
|
|
1
|
+
from langchain_core.messages import BaseMessage
|
2
|
+
from langchain.prompts import PromptTemplate
|
3
|
+
from langchain_core.output_parsers import JsonOutputParser
|
4
|
+
from langgraph.types import Checkpointer
|
5
|
+
from langgraph.graph import START, END, StateGraph
|
6
|
+
from langgraph.graph.state import CompiledStateGraph
|
7
|
+
|
8
|
+
from typing import TypedDict, Annotated, Sequence, Union
|
9
|
+
import operator
|
10
|
+
|
11
|
+
import pandas as pd
|
12
|
+
import json
|
13
|
+
from IPython.display import Markdown
|
14
|
+
|
15
|
+
from ai_data_science_team.templates import BaseAgent
|
16
|
+
from ai_data_science_team.agents import DataWranglingAgent, DataVisualizationAgent
|
17
|
+
from ai_data_science_team.utils.plotly import plotly_from_dict
|
18
|
+
from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
|
19
|
+
|
20
|
+
AGENT_NAME = "pandas_data_analyst"
|
21
|
+
|
22
|
+
class PandasDataAnalyst(BaseAgent):
|
23
|
+
"""
|
24
|
+
PandasDataAnalyst is a multi-agent class that combines data wrangling and visualization capabilities.
|
25
|
+
|
26
|
+
Parameters:
|
27
|
+
-----------
|
28
|
+
model:
|
29
|
+
The language model to be used for the agents.
|
30
|
+
data_wrangling_agent: DataWranglingAgent
|
31
|
+
The Data Wrangling Agent for transforming raw data.
|
32
|
+
data_visualization_agent: DataVisualizationAgent
|
33
|
+
The Data Visualization Agent for generating plots.
|
34
|
+
checkpointer: Checkpointer (optional)
|
35
|
+
The checkpointer to save the state of the multi-agent system.
|
36
|
+
|
37
|
+
Methods:
|
38
|
+
--------
|
39
|
+
ainvoke_agent(user_instructions, data_raw, **kwargs)
|
40
|
+
Asynchronously invokes the multi-agent with user instructions and raw data.
|
41
|
+
invoke_agent(user_instructions, data_raw, **kwargs)
|
42
|
+
Synchronously invokes the multi-agent with user instructions and raw data.
|
43
|
+
get_data_wrangled()
|
44
|
+
Returns the wrangled data as a Pandas DataFrame.
|
45
|
+
get_plotly_graph()
|
46
|
+
Returns the Plotly graph as a Plotly object.
|
47
|
+
get_data_wrangler_function(markdown=False)
|
48
|
+
Returns the data wrangling function as a string, optionally in Markdown.
|
49
|
+
get_data_visualization_function(markdown=False)
|
50
|
+
Returns the data visualization function as a string, optionally in Markdown.
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
model,
|
56
|
+
data_wrangling_agent: DataWranglingAgent,
|
57
|
+
data_visualization_agent: DataVisualizationAgent,
|
58
|
+
checkpointer: Checkpointer = None,
|
59
|
+
):
|
60
|
+
self._params = {
|
61
|
+
"model": model,
|
62
|
+
"data_wrangling_agent": data_wrangling_agent,
|
63
|
+
"data_visualization_agent": data_visualization_agent,
|
64
|
+
"checkpointer": checkpointer,
|
65
|
+
}
|
66
|
+
self._compiled_graph = self._make_compiled_graph()
|
67
|
+
self.response = None
|
68
|
+
|
69
|
+
def _make_compiled_graph(self):
|
70
|
+
"""Create or rebuild the compiled graph. Resets response to None."""
|
71
|
+
self.response = None
|
72
|
+
return make_pandas_data_analyst(
|
73
|
+
model=self._params["model"],
|
74
|
+
data_wrangling_agent=self._params["data_wrangling_agent"]._compiled_graph,
|
75
|
+
data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
|
76
|
+
checkpointer=self._params["checkpointer"],
|
77
|
+
)
|
78
|
+
|
79
|
+
def update_params(self, **kwargs):
|
80
|
+
"""Updates parameters and rebuilds the compiled graph."""
|
81
|
+
for k, v in kwargs.items():
|
82
|
+
self._params[k] = v
|
83
|
+
self._compiled_graph = self._make_compiled_graph()
|
84
|
+
|
85
|
+
async def ainvoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
|
86
|
+
"""Asynchronously invokes the multi-agent."""
|
87
|
+
response = await self._compiled_graph.ainvoke({
|
88
|
+
"user_instructions": user_instructions,
|
89
|
+
"data_raw": self._convert_data_input(data_raw),
|
90
|
+
"max_retries": max_retries,
|
91
|
+
"retry_count": retry_count,
|
92
|
+
}, **kwargs)
|
93
|
+
if response.get("messages"):
|
94
|
+
response["messages"] = remove_consecutive_duplicates(response["messages"])
|
95
|
+
self.response = response
|
96
|
+
|
97
|
+
def invoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
|
98
|
+
"""Synchronously invokes the multi-agent."""
|
99
|
+
response = self._compiled_graph.invoke({
|
100
|
+
"user_instructions": user_instructions,
|
101
|
+
"data_raw": self._convert_data_input(data_raw),
|
102
|
+
"max_retries": max_retries,
|
103
|
+
"retry_count": retry_count,
|
104
|
+
}, **kwargs)
|
105
|
+
if response.get("messages"):
|
106
|
+
response["messages"] = remove_consecutive_duplicates(response["messages"])
|
107
|
+
self.response = response
|
108
|
+
|
109
|
+
def get_data_wrangled(self):
|
110
|
+
"""Returns the wrangled data as a Pandas DataFrame."""
|
111
|
+
if self.response and self.response.get("data_wrangled"):
|
112
|
+
return pd.DataFrame(self.response.get("data_wrangled"))
|
113
|
+
|
114
|
+
def get_plotly_graph(self):
|
115
|
+
"""Returns the Plotly graph as a Plotly object."""
|
116
|
+
if self.response and self.response.get("plotly_graph"):
|
117
|
+
return plotly_from_dict(self.response.get("plotly_graph"))
|
118
|
+
|
119
|
+
def get_data_wrangler_function(self, markdown=False):
|
120
|
+
"""Returns the data wrangling function as a string."""
|
121
|
+
if self.response and self.response.get("data_wrangler_function"):
|
122
|
+
code = self.response.get("data_wrangler_function")
|
123
|
+
return Markdown(f"```python\n{code}\n```") if markdown else code
|
124
|
+
|
125
|
+
def get_data_visualization_function(self, markdown=False):
|
126
|
+
"""Returns the data visualization function as a string."""
|
127
|
+
if self.response and self.response.get("data_visualization_function"):
|
128
|
+
code = self.response.get("data_visualization_function")
|
129
|
+
return Markdown(f"```python\n{code}\n```") if markdown else code
|
130
|
+
|
131
|
+
def get_workflow_summary(self, markdown=False):
|
132
|
+
"""Returns a summary of the workflow."""
|
133
|
+
if self.response and self.response.get("messages"):
|
134
|
+
agents = [msg.role for msg in self.response["messages"]]
|
135
|
+
agent_labels = [f"- **Agent {i+1}:** {role}" for i, role in enumerate(agents)]
|
136
|
+
header = f"# Pandas Data Analyst Workflow Summary\n\nThis workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
|
137
|
+
reports = [get_generic_summary(json.loads(msg.content)) for msg in self.response["messages"]]
|
138
|
+
summary = "\n" +header + "\n\n".join(reports)
|
139
|
+
return Markdown(summary) if markdown else summary
|
140
|
+
|
141
|
+
@staticmethod
|
142
|
+
def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
|
143
|
+
"""Converts input data to the expected format (dict or list of dicts)."""
|
144
|
+
if isinstance(data_raw, pd.DataFrame):
|
145
|
+
return data_raw.to_dict()
|
146
|
+
if isinstance(data_raw, dict):
|
147
|
+
return data_raw
|
148
|
+
if isinstance(data_raw, list):
|
149
|
+
return [item.to_dict() if isinstance(item, pd.DataFrame) else item for item in data_raw]
|
150
|
+
raise ValueError("data_raw must be a DataFrame, dict, or list of DataFrames/dicts")
|
151
|
+
|
152
|
+
def make_pandas_data_analyst(
|
153
|
+
model,
|
154
|
+
data_wrangling_agent: CompiledStateGraph,
|
155
|
+
data_visualization_agent: CompiledStateGraph,
|
156
|
+
checkpointer: Checkpointer = None
|
157
|
+
):
|
158
|
+
"""
|
159
|
+
Creates a multi-agent system that wrangles data and optionally visualizes it.
|
160
|
+
|
161
|
+
Parameters:
|
162
|
+
-----------
|
163
|
+
model: The language model to be used.
|
164
|
+
data_wrangling_agent: CompiledStateGraph
|
165
|
+
The Data Wrangling Agent.
|
166
|
+
data_visualization_agent: CompiledStateGraph
|
167
|
+
The Data Visualization Agent.
|
168
|
+
checkpointer: Checkpointer (optional)
|
169
|
+
The checkpointer to save the state.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
--------
|
173
|
+
CompiledStateGraph: The compiled multi-agent system.
|
174
|
+
"""
|
175
|
+
|
176
|
+
llm = model
|
177
|
+
|
178
|
+
routing_preprocessor_prompt = PromptTemplate(
|
179
|
+
template="""
|
180
|
+
You are an expert in routing decisions for a Pandas Data Manipulation Wrangling Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to:
|
181
|
+
|
182
|
+
1. Determine what the correct format for a Users Question should be for use with a Pandas Data Wrangling Agent based on the incoming user question. Anything related to data wrangling and manipulation should be passed along.
|
183
|
+
2. Determine whether or not a chart should be generated or a table should be returned based on the users question.
|
184
|
+
3. If a chart is requested, determine the correct format of a Users Question should be used with a Data Visualization Agent. Anything related to plotting and visualization should be passed along.
|
185
|
+
|
186
|
+
Use the following criteria on how to route the the initial user question:
|
187
|
+
|
188
|
+
From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the SQL generator agent. This will be the 'user_instructions_data_wrangling'. If 'None' is found, return the original user question.
|
189
|
+
|
190
|
+
Next, determine if the user would like a data visualization ('chart') or a 'table' returned with the results of the Data Wrangling Agent. If unknown, not specified or 'None' is found, then select 'table'.
|
191
|
+
|
192
|
+
If a 'chart' is requested, return the 'user_instructions_data_visualization'. If 'None' is found, return None.
|
193
|
+
|
194
|
+
Return JSON with 'user_instructions_data_wrangling', 'user_instructions_data_visualization' and 'routing_preprocessor_decision'.
|
195
|
+
|
196
|
+
INITIAL_USER_QUESTION: {user_instructions}
|
197
|
+
""",
|
198
|
+
input_variables=["user_instructions"]
|
199
|
+
)
|
200
|
+
|
201
|
+
routing_preprocessor = routing_preprocessor_prompt | llm | JsonOutputParser()
|
202
|
+
|
203
|
+
class PrimaryState(TypedDict):
|
204
|
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
205
|
+
user_instructions: str
|
206
|
+
user_instructions_data_wrangling: str
|
207
|
+
user_instructions_data_visualization: str
|
208
|
+
routing_preprocessor_decision: str
|
209
|
+
data_raw: Union[dict, list]
|
210
|
+
data_wrangled: dict
|
211
|
+
data_wrangler_function: str
|
212
|
+
data_visualization_function: str
|
213
|
+
plotly_graph: dict
|
214
|
+
plotly_error: str
|
215
|
+
max_retries: int
|
216
|
+
retry_count: int
|
217
|
+
|
218
|
+
|
219
|
+
def preprocess_routing(state: PrimaryState):
|
220
|
+
print("---PANDAS DATA ANALYST---")
|
221
|
+
print("*************************")
|
222
|
+
print("---PREPROCESS ROUTER---")
|
223
|
+
question = state.get("user_instructions")
|
224
|
+
|
225
|
+
# Chart Routing and SQL Prep
|
226
|
+
response = routing_preprocessor.invoke({"user_instructions": question})
|
227
|
+
|
228
|
+
return {
|
229
|
+
"user_instructions_data_wrangling": response.get('user_instructions_data_wrangling'),
|
230
|
+
"user_instructions_data_visualization": response.get('user_instructions_data_visualization'),
|
231
|
+
"routing_preprocessor_decision": response.get('routing_preprocessor_decision'),
|
232
|
+
}
|
233
|
+
|
234
|
+
def router_chart_or_table(state: PrimaryState):
|
235
|
+
print("---ROUTER: CHART OR TABLE---")
|
236
|
+
return "chart" if state.get('routing_preprocessor_decision') == "chart" else "table"
|
237
|
+
|
238
|
+
|
239
|
+
def invoke_data_wrangling_agent(state: PrimaryState):
|
240
|
+
|
241
|
+
response = data_wrangling_agent.invoke({
|
242
|
+
"user_instructions": state.get("user_instructions_data_wrangling"),
|
243
|
+
"data_raw": state.get("data_raw"),
|
244
|
+
"max_retries": state.get("max_retries"),
|
245
|
+
"retry_count": state.get("retry_count"),
|
246
|
+
})
|
247
|
+
|
248
|
+
return {
|
249
|
+
"messages": response.get("messages"),
|
250
|
+
"data_wrangled": response.get("data_wrangled"),
|
251
|
+
"data_wrangler_function": response.get("data_wrangler_function"),
|
252
|
+
"plotly_error": response.get("data_visualization_error"),
|
253
|
+
|
254
|
+
}
|
255
|
+
|
256
|
+
def invoke_data_visualization_agent(state: PrimaryState):
|
257
|
+
|
258
|
+
response = data_visualization_agent.invoke({
|
259
|
+
"user_instructions": state.get("user_instructions_data_visualization"),
|
260
|
+
"data_raw": state.get("data_wrangled") if state.get("data_wrangled") else state.get("data_raw"),
|
261
|
+
"max_retries": state.get("max_retries"),
|
262
|
+
"retry_count": state.get("retry_count"),
|
263
|
+
})
|
264
|
+
|
265
|
+
return {
|
266
|
+
"messages": response.get("messages"),
|
267
|
+
"data_visualization_function": response.get("data_visualization_function"),
|
268
|
+
"plotly_graph": response.get("plotly_graph"),
|
269
|
+
"plotly_error": response.get("data_visualization_error"),
|
270
|
+
}
|
271
|
+
|
272
|
+
def route_printer(state: PrimaryState):
|
273
|
+
print("---ROUTE PRINTER---")
|
274
|
+
print(f" Route: {state.get('routing_preprocessor_decision')}")
|
275
|
+
print("---END---")
|
276
|
+
return {}
|
277
|
+
|
278
|
+
workflow = StateGraph(PrimaryState)
|
279
|
+
|
280
|
+
workflow.add_node("routing_preprocessor", preprocess_routing)
|
281
|
+
workflow.add_node("data_wrangling_agent", invoke_data_wrangling_agent)
|
282
|
+
workflow.add_node("data_visualization_agent", invoke_data_visualization_agent)
|
283
|
+
workflow.add_node("route_printer", route_printer)
|
284
|
+
|
285
|
+
workflow.add_edge(START, "routing_preprocessor")
|
286
|
+
workflow.add_edge("routing_preprocessor", "data_wrangling_agent")
|
287
|
+
|
288
|
+
workflow.add_conditional_edges(
|
289
|
+
"data_wrangling_agent",
|
290
|
+
router_chart_or_table,
|
291
|
+
{
|
292
|
+
"chart": "data_visualization_agent",
|
293
|
+
"table": "route_printer"
|
294
|
+
}
|
295
|
+
)
|
296
|
+
|
297
|
+
workflow.add_edge("data_visualization_agent", "route_printer")
|
298
|
+
workflow.add_edge("route_printer", END)
|
299
|
+
|
300
|
+
app = workflow.compile(
|
301
|
+
checkpointer=checkpointer,
|
302
|
+
name=AGENT_NAME
|
303
|
+
)
|
304
|
+
|
305
|
+
return app
|
@@ -1,12 +1,14 @@
|
|
1
1
|
|
2
2
|
from langchain_core.messages import BaseMessage
|
3
|
-
|
3
|
+
|
4
|
+
from langchain.prompts import PromptTemplate
|
5
|
+
from langchain_core.output_parsers import JsonOutputParser
|
4
6
|
|
5
7
|
from langgraph.graph import START, END, StateGraph
|
6
8
|
from langgraph.graph.state import CompiledStateGraph
|
7
|
-
from langgraph.types import
|
9
|
+
from langgraph.types import Checkpointer
|
8
10
|
|
9
|
-
from typing import TypedDict, Annotated, Sequence
|
11
|
+
from typing import TypedDict, Annotated, Sequence
|
10
12
|
import operator
|
11
13
|
|
12
14
|
from typing_extensions import TypedDict
|
@@ -20,6 +22,7 @@ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
|
20
22
|
from ai_data_science_team.utils.plotly import plotly_from_dict
|
21
23
|
from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
|
22
24
|
|
25
|
+
AGENT_NAME = "sql_data_analyst"
|
23
26
|
|
24
27
|
class SQLDataAnalyst(BaseAgent):
|
25
28
|
"""
|
@@ -33,6 +36,8 @@ class SQLDataAnalyst(BaseAgent):
|
|
33
36
|
The SQL Database Agent.
|
34
37
|
data_visualization_agent: DataVisualizationAgent
|
35
38
|
The Data Visualization Agent.
|
39
|
+
checkpointer: Checkpointer (optional)
|
40
|
+
The checkpointer to save the state of the multi-agent system.
|
36
41
|
|
37
42
|
Methods:
|
38
43
|
--------
|
@@ -326,17 +331,17 @@ def make_sql_data_analyst(
|
|
326
331
|
"""
|
327
332
|
Creates a multi-agent system that takes in a SQL query and returns a plot or table.
|
328
333
|
|
329
|
-
- Agent 1: SQL Database Agent made with `
|
330
|
-
- Agent 2: Data Visualization Agent made with `
|
334
|
+
- Agent 1: SQL Database Agent made with `SQLDatabaseAgent()`
|
335
|
+
- Agent 2: Data Visualization Agent made with `DataVisualizationAgent()`
|
331
336
|
|
332
337
|
Parameters:
|
333
338
|
----------
|
334
339
|
model:
|
335
340
|
The language model to be used for the agents.
|
336
341
|
sql_database_agent: CompiledStateGraph
|
337
|
-
The SQL Database Agent made with `
|
342
|
+
The SQL Database Agent made with `SQLDatabaseAgent()`.
|
338
343
|
data_visualization_agent: CompiledStateGraph
|
339
|
-
The Data Visualization Agent made with `
|
344
|
+
The Data Visualization Agent made with `DataVisualizationAgent()`.
|
340
345
|
checkpointer: Checkpointer (optional)
|
341
346
|
The checkpointer to save the state of the multi-agent system.
|
342
347
|
Default: None
|
@@ -348,10 +353,39 @@ def make_sql_data_analyst(
|
|
348
353
|
"""
|
349
354
|
|
350
355
|
llm = model
|
356
|
+
|
357
|
+
|
358
|
+
routing_preprocessor_prompt = PromptTemplate(
|
359
|
+
template="""
|
360
|
+
You are an expert in routing decisions for a SQL Database Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to:
|
361
|
+
|
362
|
+
1. Determine what the correct format for a Users Question should be for use with a SQL Database Agent based on the incoming user question. Anything related to database and data manipulation should be passed along.
|
363
|
+
2. Determine whether or not a chart should be generated or a table should be returned based on the users question.
|
364
|
+
3. If a chart is requested, determine the correct format of a Users Question should be used with a Data Visualization Agent. Anything related to plotting and visualization should be passed along.
|
365
|
+
|
366
|
+
Use the following criteria on how to route the the initial user question:
|
367
|
+
|
368
|
+
From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the SQL generator agent. This will be the 'user_instructions_sql_database'. If 'None' is found, return the original user question.
|
369
|
+
|
370
|
+
Next, determine if the user would like a data visualization ('chart') or a 'table' returned with the results of the Data Wrangling Agent. If unknown, not specified or 'None' is found, then select 'table'.
|
371
|
+
|
372
|
+
If a 'chart' is requested, return the 'user_instructions_data_visualization'. If 'None' is found, return None.
|
373
|
+
|
374
|
+
Return JSON with 'user_instructions_sql_database', 'user_instructions_data_visualization' and 'routing_preprocessor_decision'.
|
375
|
+
|
376
|
+
INITIAL_USER_QUESTION: {user_instructions}
|
377
|
+
""",
|
378
|
+
input_variables=["user_instructions"]
|
379
|
+
)
|
380
|
+
|
381
|
+
routing_preprocessor = routing_preprocessor_prompt | llm | JsonOutputParser()
|
351
382
|
|
352
383
|
class PrimaryState(TypedDict):
|
353
384
|
messages: Annotated[Sequence[BaseMessage], operator.add]
|
354
385
|
user_instructions: str
|
386
|
+
user_instructions_sql_database: str
|
387
|
+
user_instructions_data_visualization: str
|
388
|
+
routing_preprocessor_decision: str
|
355
389
|
sql_query_code: str
|
356
390
|
sql_database_function: str
|
357
391
|
data_sql: dict
|
@@ -359,39 +393,94 @@ def make_sql_data_analyst(
|
|
359
393
|
plot_required: bool
|
360
394
|
data_visualization_function: str
|
361
395
|
plotly_graph: dict
|
396
|
+
plotly_error: str
|
362
397
|
max_retries: int
|
363
398
|
retry_count: int
|
364
399
|
|
365
|
-
def
|
400
|
+
def preprocess_routing(state: PrimaryState):
|
401
|
+
print("---SQL DATA ANALYST---")
|
402
|
+
print("*************************")
|
403
|
+
print("---PREPROCESS ROUTER---")
|
404
|
+
question = state.get("user_instructions")
|
366
405
|
|
367
|
-
|
406
|
+
# Chart Routing and SQL Prep
|
407
|
+
response = routing_preprocessor.invoke({"user_instructions": question})
|
368
408
|
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
409
|
+
return {
|
410
|
+
"user_instructions_sql_database": response.get('user_instructions_sql_database'),
|
411
|
+
"user_instructions_data_visualization": response.get('user_instructions_data_visualization'),
|
412
|
+
"routing_preprocessor_decision": response.get('routing_preprocessor_decision'),
|
413
|
+
}
|
414
|
+
|
415
|
+
def router_chart_or_table(state: PrimaryState):
|
416
|
+
print("---ROUTER: CHART OR TABLE---")
|
417
|
+
return "chart" if state.get('routing_preprocessor_decision') == "chart" else "table"
|
418
|
+
|
419
|
+
|
420
|
+
def invoke_sql_database_agent(state: PrimaryState):
|
375
421
|
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
goto=goto
|
382
|
-
)
|
422
|
+
response = sql_database_agent.invoke({
|
423
|
+
"user_instructions": state.get("user_instructions_sql_database"),
|
424
|
+
"max_retries": state.get("max_retries"),
|
425
|
+
"retry_count": state.get("retry_count"),
|
426
|
+
})
|
383
427
|
|
384
|
-
|
428
|
+
return {
|
429
|
+
"messages": response.get("messages"),
|
430
|
+
"data_sql": response.get("data_sql"),
|
431
|
+
"sql_query_code": response.get("sql_query_code"),
|
432
|
+
"sql_database_function": response.get("sql_database_function"),
|
433
|
+
|
434
|
+
}
|
435
|
+
|
436
|
+
def invoke_data_visualization_agent(state: PrimaryState):
|
437
|
+
|
438
|
+
response = data_visualization_agent.invoke({
|
439
|
+
"user_instructions": state.get("user_instructions_data_visualization"),
|
440
|
+
"data_raw": state.get("data_sql"),
|
441
|
+
"max_retries": state.get("max_retries"),
|
442
|
+
"retry_count": state.get("retry_count"),
|
443
|
+
})
|
444
|
+
|
445
|
+
return {
|
446
|
+
"messages": response.get("messages"),
|
447
|
+
"data_visualization_function": response.get("data_visualization_function"),
|
448
|
+
"plotly_graph": response.get("plotly_graph"),
|
449
|
+
"plotly_error": response.get("data_visualization_error"),
|
450
|
+
}
|
385
451
|
|
386
|
-
|
387
|
-
|
388
|
-
|
452
|
+
def route_printer(state: PrimaryState):
|
453
|
+
print("---ROUTE PRINTER---")
|
454
|
+
print(f" Route: {state.get('routing_preprocessor_decision')}")
|
455
|
+
print("---END---")
|
456
|
+
return {}
|
457
|
+
|
458
|
+
workflow = StateGraph(PrimaryState)
|
459
|
+
|
460
|
+
workflow.add_node("routing_preprocessor", preprocess_routing)
|
461
|
+
workflow.add_node("sql_database_agent", invoke_sql_database_agent)
|
462
|
+
workflow.add_node("data_visualization_agent", invoke_data_visualization_agent)
|
463
|
+
workflow.add_node("route_printer", route_printer)
|
389
464
|
|
390
|
-
workflow.add_edge(START, "
|
391
|
-
workflow.add_edge("
|
392
|
-
|
465
|
+
workflow.add_edge(START, "routing_preprocessor")
|
466
|
+
workflow.add_edge("routing_preprocessor", "sql_database_agent")
|
467
|
+
|
468
|
+
workflow.add_conditional_edges(
|
469
|
+
"sql_database_agent",
|
470
|
+
router_chart_or_table,
|
471
|
+
{
|
472
|
+
"chart": "data_visualization_agent",
|
473
|
+
"table": "route_printer"
|
474
|
+
}
|
475
|
+
)
|
476
|
+
|
477
|
+
workflow.add_edge("data_visualization_agent", "route_printer")
|
478
|
+
workflow.add_edge("route_printer", END)
|
393
479
|
|
394
|
-
app = workflow.compile(
|
480
|
+
app = workflow.compile(
|
481
|
+
checkpointer=checkpointer,
|
482
|
+
name=AGENT_NAME
|
483
|
+
)
|
395
484
|
|
396
485
|
return app
|
397
486
|
|
@@ -40,6 +40,21 @@ class BaseAgent(CompiledStateGraph):
|
|
40
40
|
self._params = params
|
41
41
|
self._compiled_graph = self._make_compiled_graph()
|
42
42
|
self.response = None
|
43
|
+
self.name = self._compiled_graph.name
|
44
|
+
self.checkpointer = self._compiled_graph.checkpointer
|
45
|
+
self.store = self._compiled_graph.store
|
46
|
+
self.output_channels = self._compiled_graph.output_channels
|
47
|
+
self.nodes = self._compiled_graph.nodes
|
48
|
+
self.stream_mode = self._compiled_graph.stream_mode
|
49
|
+
self.builder = self._compiled_graph.builder
|
50
|
+
self.channels = self._compiled_graph.channels
|
51
|
+
self.input_channels = self._compiled_graph.input_channels
|
52
|
+
self.input_schema = self._compiled_graph.input_schema
|
53
|
+
self.output_schema = self._compiled_graph.output_schema
|
54
|
+
self.debug = self._compiled_graph.debug
|
55
|
+
self.interrupt_after_nodes = self._compiled_graph.interrupt_after_nodes
|
56
|
+
self.interrupt_before_nodes = self._compiled_graph.interrupt_before_nodes
|
57
|
+
self.config = self._compiled_graph.config
|
43
58
|
|
44
59
|
def _make_compiled_graph(self):
|
45
60
|
"""
|
@@ -197,6 +212,24 @@ class BaseAgent(CompiledStateGraph):
|
|
197
212
|
"""
|
198
213
|
return self.get_output_jsonschema()['properties']
|
199
214
|
|
215
|
+
def get_state(self, config, *, subgraphs = False):
|
216
|
+
"""
|
217
|
+
Returns the state of the agent.
|
218
|
+
"""
|
219
|
+
return self._compiled_graph.get_state(config, subgraphs=subgraphs)
|
220
|
+
|
221
|
+
def get_state_history(self, config, *, filter = None, before = None, limit = None):
|
222
|
+
"""
|
223
|
+
Returns the state history of the agent.
|
224
|
+
"""
|
225
|
+
return self._compiled_graph.get_state_history(config, filter=filter, before=before, limit=limit)
|
226
|
+
|
227
|
+
def update_state(self, config, values, as_node = None):
|
228
|
+
"""
|
229
|
+
Updates the state of the agent.
|
230
|
+
"""
|
231
|
+
return self._compiled_graph.update_state(config, values, as_node)
|
232
|
+
|
200
233
|
def get_response(self):
|
201
234
|
"""
|
202
235
|
Returns the response generated by the agent.
|
@@ -237,6 +270,7 @@ def create_coding_agent_graph(
|
|
237
270
|
checkpointer: Optional[Callable] = None,
|
238
271
|
bypass_recommended_steps: bool = False,
|
239
272
|
bypass_explain_code: bool = False,
|
273
|
+
agent_name: str = "coding_agent"
|
240
274
|
):
|
241
275
|
"""
|
242
276
|
Creates a generic agent graph using the provided node functions and node names.
|
@@ -281,6 +315,8 @@ def create_coding_agent_graph(
|
|
281
315
|
Whether to skip the recommended steps node.
|
282
316
|
bypass_explain_code : bool, optional
|
283
317
|
Whether to skip the final explain code node.
|
318
|
+
name : str, optional
|
319
|
+
The name of the agent graph.
|
284
320
|
|
285
321
|
Returns
|
286
322
|
-------
|
@@ -366,10 +402,10 @@ def create_coding_agent_graph(
|
|
366
402
|
workflow.add_edge(explain_code_node_name, END)
|
367
403
|
|
368
404
|
# Finally, compile
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
405
|
+
app = workflow.compile(
|
406
|
+
checkpointer=checkpointer,
|
407
|
+
name=agent_name,
|
408
|
+
)
|
373
409
|
|
374
410
|
return app
|
375
411
|
|
@@ -574,7 +610,7 @@ def node_func_execute_agent_from_sql_connection(
|
|
574
610
|
|
575
611
|
# Retrieve SQLAlchemy connection and code snippet from the state
|
576
612
|
is_engine = isinstance(connection, sql.engine.base.Engine)
|
577
|
-
|
613
|
+
connection = connection.connect() if is_engine else connection
|
578
614
|
agent_code = state.get(code_snippet_key)
|
579
615
|
|
580
616
|
# Ensure the connection object is provided
|