ai-data-science-team 0.0.0.9013__py3-none-any.whl → 0.0.0.9014__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -10,6 +10,7 @@ from langchain_core.messages import BaseMessage, AIMessage
10
10
 
11
11
  from langgraph.prebuilt import create_react_agent, ToolNode
12
12
  from langgraph.prebuilt.chat_agent_executor import AgentState
13
+ from langgraph.types import Checkpointer
13
14
  from langgraph.graph import START, END, StateGraph
14
15
 
15
16
  from ai_data_science_team.templates import BaseAgent
@@ -68,6 +69,8 @@ class MLflowToolsAgent(BaseAgent):
68
69
  Additional keyword arguments to pass to the create_react_agent function.
69
70
  invoke_react_agent_kwargs : dict
70
71
  Additional keyword arguments to pass to the invoke method of the react agent.
72
+ checkpointer : langchain.checkpointing.Checkpointer, optional
73
+ A checkpointer to use for saving and loading the agent's state. Defaults to None.
71
74
 
72
75
  Methods:
73
76
  --------
@@ -119,6 +122,7 @@ class MLflowToolsAgent(BaseAgent):
119
122
  mlflow_registry_uri: Optional[str]=None,
120
123
  create_react_agent_kwargs: Optional[Dict]={},
121
124
  invoke_react_agent_kwargs: Optional[Dict]={},
125
+ checkpointer: Optional[Checkpointer]=None,
122
126
  ):
123
127
  self._params = {
124
128
  "model": model,
@@ -126,6 +130,7 @@ class MLflowToolsAgent(BaseAgent):
126
130
  "mlflow_registry_uri": mlflow_registry_uri,
127
131
  "create_react_agent_kwargs": create_react_agent_kwargs,
128
132
  "invoke_react_agent_kwargs": invoke_react_agent_kwargs,
133
+ "checkpointer": checkpointer,
129
134
  }
130
135
  self._compiled_graph = self._make_compiled_graph()
131
136
  self.response = None
@@ -245,6 +250,7 @@ def make_mlflow_tools_agent(
245
250
  mlflow_registry_uri: str=None,
246
251
  create_react_agent_kwargs: Optional[Dict]={},
247
252
  invoke_react_agent_kwargs: Optional[Dict]={},
253
+ checkpointer: Optional[Checkpointer]=None,
248
254
  ):
249
255
  """
250
256
  MLflow Tool Calling Agent
@@ -261,6 +267,8 @@ def make_mlflow_tools_agent(
261
267
  Additional keyword arguments to pass to the agent's create_react_agent method.
262
268
  invoke_react_agent_kwargs : dict, optional
263
269
  Additional keyword arguments to pass to the agent's invoke method.
270
+ checkpointer : langchain.checkpointing.Checkpointer, optional
271
+ A checkpointer to use for saving and loading the agent's state. Defaults to None.
264
272
 
265
273
  Returns
266
274
  -------
@@ -303,6 +311,7 @@ def make_mlflow_tools_agent(
303
311
  model,
304
312
  tools=tool_node,
305
313
  state_schema=GraphState,
314
+ checkpointer=checkpointer,
306
315
  **create_react_agent_kwargs,
307
316
  )
308
317
 
@@ -354,7 +363,10 @@ def make_mlflow_tools_agent(
354
363
  workflow.add_edge(START, "mlflow_tools_agent")
355
364
  workflow.add_edge("mlflow_tools_agent", END)
356
365
 
357
- app = workflow.compile()
366
+ app = workflow.compile(
367
+ checkpointer=checkpointer,
368
+ name=AGENT_NAME,
369
+ )
358
370
 
359
371
  return app
360
372
 
@@ -1 +1,2 @@
1
- from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
1
+ from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
2
+ from ai_data_science_team.multiagents.pandas_data_analyst import PandasDataAnalyst, make_pandas_data_analyst
@@ -0,0 +1,305 @@
1
+ from langchain_core.messages import BaseMessage
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain_core.output_parsers import JsonOutputParser
4
+ from langgraph.types import Checkpointer
5
+ from langgraph.graph import START, END, StateGraph
6
+ from langgraph.graph.state import CompiledStateGraph
7
+
8
+ from typing import TypedDict, Annotated, Sequence, Union
9
+ import operator
10
+
11
+ import pandas as pd
12
+ import json
13
+ from IPython.display import Markdown
14
+
15
+ from ai_data_science_team.templates import BaseAgent
16
+ from ai_data_science_team.agents import DataWranglingAgent, DataVisualizationAgent
17
+ from ai_data_science_team.utils.plotly import plotly_from_dict
18
+ from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
19
+
20
+ AGENT_NAME = "pandas_data_analyst"
21
+
22
+ class PandasDataAnalyst(BaseAgent):
23
+ """
24
+ PandasDataAnalyst is a multi-agent class that combines data wrangling and visualization capabilities.
25
+
26
+ Parameters:
27
+ -----------
28
+ model:
29
+ The language model to be used for the agents.
30
+ data_wrangling_agent: DataWranglingAgent
31
+ The Data Wrangling Agent for transforming raw data.
32
+ data_visualization_agent: DataVisualizationAgent
33
+ The Data Visualization Agent for generating plots.
34
+ checkpointer: Checkpointer (optional)
35
+ The checkpointer to save the state of the multi-agent system.
36
+
37
+ Methods:
38
+ --------
39
+ ainvoke_agent(user_instructions, data_raw, **kwargs)
40
+ Asynchronously invokes the multi-agent with user instructions and raw data.
41
+ invoke_agent(user_instructions, data_raw, **kwargs)
42
+ Synchronously invokes the multi-agent with user instructions and raw data.
43
+ get_data_wrangled()
44
+ Returns the wrangled data as a Pandas DataFrame.
45
+ get_plotly_graph()
46
+ Returns the Plotly graph as a Plotly object.
47
+ get_data_wrangler_function(markdown=False)
48
+ Returns the data wrangling function as a string, optionally in Markdown.
49
+ get_data_visualization_function(markdown=False)
50
+ Returns the data visualization function as a string, optionally in Markdown.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ model,
56
+ data_wrangling_agent: DataWranglingAgent,
57
+ data_visualization_agent: DataVisualizationAgent,
58
+ checkpointer: Checkpointer = None,
59
+ ):
60
+ self._params = {
61
+ "model": model,
62
+ "data_wrangling_agent": data_wrangling_agent,
63
+ "data_visualization_agent": data_visualization_agent,
64
+ "checkpointer": checkpointer,
65
+ }
66
+ self._compiled_graph = self._make_compiled_graph()
67
+ self.response = None
68
+
69
+ def _make_compiled_graph(self):
70
+ """Create or rebuild the compiled graph. Resets response to None."""
71
+ self.response = None
72
+ return make_pandas_data_analyst(
73
+ model=self._params["model"],
74
+ data_wrangling_agent=self._params["data_wrangling_agent"]._compiled_graph,
75
+ data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
76
+ checkpointer=self._params["checkpointer"],
77
+ )
78
+
79
+ def update_params(self, **kwargs):
80
+ """Updates parameters and rebuilds the compiled graph."""
81
+ for k, v in kwargs.items():
82
+ self._params[k] = v
83
+ self._compiled_graph = self._make_compiled_graph()
84
+
85
+ async def ainvoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
86
+ """Asynchronously invokes the multi-agent."""
87
+ response = await self._compiled_graph.ainvoke({
88
+ "user_instructions": user_instructions,
89
+ "data_raw": self._convert_data_input(data_raw),
90
+ "max_retries": max_retries,
91
+ "retry_count": retry_count,
92
+ }, **kwargs)
93
+ if response.get("messages"):
94
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
95
+ self.response = response
96
+
97
+ def invoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
98
+ """Synchronously invokes the multi-agent."""
99
+ response = self._compiled_graph.invoke({
100
+ "user_instructions": user_instructions,
101
+ "data_raw": self._convert_data_input(data_raw),
102
+ "max_retries": max_retries,
103
+ "retry_count": retry_count,
104
+ }, **kwargs)
105
+ if response.get("messages"):
106
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
107
+ self.response = response
108
+
109
+ def get_data_wrangled(self):
110
+ """Returns the wrangled data as a Pandas DataFrame."""
111
+ if self.response and self.response.get("data_wrangled"):
112
+ return pd.DataFrame(self.response.get("data_wrangled"))
113
+
114
+ def get_plotly_graph(self):
115
+ """Returns the Plotly graph as a Plotly object."""
116
+ if self.response and self.response.get("plotly_graph"):
117
+ return plotly_from_dict(self.response.get("plotly_graph"))
118
+
119
+ def get_data_wrangler_function(self, markdown=False):
120
+ """Returns the data wrangling function as a string."""
121
+ if self.response and self.response.get("data_wrangler_function"):
122
+ code = self.response.get("data_wrangler_function")
123
+ return Markdown(f"```python\n{code}\n```") if markdown else code
124
+
125
+ def get_data_visualization_function(self, markdown=False):
126
+ """Returns the data visualization function as a string."""
127
+ if self.response and self.response.get("data_visualization_function"):
128
+ code = self.response.get("data_visualization_function")
129
+ return Markdown(f"```python\n{code}\n```") if markdown else code
130
+
131
+ def get_workflow_summary(self, markdown=False):
132
+ """Returns a summary of the workflow."""
133
+ if self.response and self.response.get("messages"):
134
+ agents = [msg.role for msg in self.response["messages"]]
135
+ agent_labels = [f"- **Agent {i+1}:** {role}" for i, role in enumerate(agents)]
136
+ header = f"# Pandas Data Analyst Workflow Summary\n\nThis workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
137
+ reports = [get_generic_summary(json.loads(msg.content)) for msg in self.response["messages"]]
138
+ summary = "\n" +header + "\n\n".join(reports)
139
+ return Markdown(summary) if markdown else summary
140
+
141
+ @staticmethod
142
+ def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
143
+ """Converts input data to the expected format (dict or list of dicts)."""
144
+ if isinstance(data_raw, pd.DataFrame):
145
+ return data_raw.to_dict()
146
+ if isinstance(data_raw, dict):
147
+ return data_raw
148
+ if isinstance(data_raw, list):
149
+ return [item.to_dict() if isinstance(item, pd.DataFrame) else item for item in data_raw]
150
+ raise ValueError("data_raw must be a DataFrame, dict, or list of DataFrames/dicts")
151
+
152
+ def make_pandas_data_analyst(
153
+ model,
154
+ data_wrangling_agent: CompiledStateGraph,
155
+ data_visualization_agent: CompiledStateGraph,
156
+ checkpointer: Checkpointer = None
157
+ ):
158
+ """
159
+ Creates a multi-agent system that wrangles data and optionally visualizes it.
160
+
161
+ Parameters:
162
+ -----------
163
+ model: The language model to be used.
164
+ data_wrangling_agent: CompiledStateGraph
165
+ The Data Wrangling Agent.
166
+ data_visualization_agent: CompiledStateGraph
167
+ The Data Visualization Agent.
168
+ checkpointer: Checkpointer (optional)
169
+ The checkpointer to save the state.
170
+
171
+ Returns:
172
+ --------
173
+ CompiledStateGraph: The compiled multi-agent system.
174
+ """
175
+
176
+ llm = model
177
+
178
+ routing_preprocessor_prompt = PromptTemplate(
179
+ template="""
180
+ You are an expert in routing decisions for a Pandas Data Manipulation Wrangling Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to:
181
+
182
+ 1. Determine what the correct format for a Users Question should be for use with a Pandas Data Wrangling Agent based on the incoming user question. Anything related to data wrangling and manipulation should be passed along.
183
+ 2. Determine whether or not a chart should be generated or a table should be returned based on the users question.
184
+ 3. If a chart is requested, determine the correct format of a Users Question should be used with a Data Visualization Agent. Anything related to plotting and visualization should be passed along.
185
+
186
+ Use the following criteria on how to route the the initial user question:
187
+
188
+ From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the SQL generator agent. This will be the 'user_instructions_data_wrangling'. If 'None' is found, return the original user question.
189
+
190
+ Next, determine if the user would like a data visualization ('chart') or a 'table' returned with the results of the Data Wrangling Agent. If unknown, not specified or 'None' is found, then select 'table'.
191
+
192
+ If a 'chart' is requested, return the 'user_instructions_data_visualization'. If 'None' is found, return None.
193
+
194
+ Return JSON with 'user_instructions_data_wrangling', 'user_instructions_data_visualization' and 'routing_preprocessor_decision'.
195
+
196
+ INITIAL_USER_QUESTION: {user_instructions}
197
+ """,
198
+ input_variables=["user_instructions"]
199
+ )
200
+
201
+ routing_preprocessor = routing_preprocessor_prompt | llm | JsonOutputParser()
202
+
203
+ class PrimaryState(TypedDict):
204
+ messages: Annotated[Sequence[BaseMessage], operator.add]
205
+ user_instructions: str
206
+ user_instructions_data_wrangling: str
207
+ user_instructions_data_visualization: str
208
+ routing_preprocessor_decision: str
209
+ data_raw: Union[dict, list]
210
+ data_wrangled: dict
211
+ data_wrangler_function: str
212
+ data_visualization_function: str
213
+ plotly_graph: dict
214
+ plotly_error: str
215
+ max_retries: int
216
+ retry_count: int
217
+
218
+
219
+ def preprocess_routing(state: PrimaryState):
220
+ print("---PANDAS DATA ANALYST---")
221
+ print("*************************")
222
+ print("---PREPROCESS ROUTER---")
223
+ question = state.get("user_instructions")
224
+
225
+ # Chart Routing and SQL Prep
226
+ response = routing_preprocessor.invoke({"user_instructions": question})
227
+
228
+ return {
229
+ "user_instructions_data_wrangling": response.get('user_instructions_data_wrangling'),
230
+ "user_instructions_data_visualization": response.get('user_instructions_data_visualization'),
231
+ "routing_preprocessor_decision": response.get('routing_preprocessor_decision'),
232
+ }
233
+
234
+ def router_chart_or_table(state: PrimaryState):
235
+ print("---ROUTER: CHART OR TABLE---")
236
+ return "chart" if state.get('routing_preprocessor_decision') == "chart" else "table"
237
+
238
+
239
+ def invoke_data_wrangling_agent(state: PrimaryState):
240
+
241
+ response = data_wrangling_agent.invoke({
242
+ "user_instructions": state.get("user_instructions_data_wrangling"),
243
+ "data_raw": state.get("data_raw"),
244
+ "max_retries": state.get("max_retries"),
245
+ "retry_count": state.get("retry_count"),
246
+ })
247
+
248
+ return {
249
+ "messages": response.get("messages"),
250
+ "data_wrangled": response.get("data_wrangled"),
251
+ "data_wrangler_function": response.get("data_wrangler_function"),
252
+ "plotly_error": response.get("data_visualization_error"),
253
+
254
+ }
255
+
256
+ def invoke_data_visualization_agent(state: PrimaryState):
257
+
258
+ response = data_visualization_agent.invoke({
259
+ "user_instructions": state.get("user_instructions_data_visualization"),
260
+ "data_raw": state.get("data_wrangled") if state.get("data_wrangled") else state.get("data_raw"),
261
+ "max_retries": state.get("max_retries"),
262
+ "retry_count": state.get("retry_count"),
263
+ })
264
+
265
+ return {
266
+ "messages": response.get("messages"),
267
+ "data_visualization_function": response.get("data_visualization_function"),
268
+ "plotly_graph": response.get("plotly_graph"),
269
+ "plotly_error": response.get("data_visualization_error"),
270
+ }
271
+
272
+ def route_printer(state: PrimaryState):
273
+ print("---ROUTE PRINTER---")
274
+ print(f" Route: {state.get('routing_preprocessor_decision')}")
275
+ print("---END---")
276
+ return {}
277
+
278
+ workflow = StateGraph(PrimaryState)
279
+
280
+ workflow.add_node("routing_preprocessor", preprocess_routing)
281
+ workflow.add_node("data_wrangling_agent", invoke_data_wrangling_agent)
282
+ workflow.add_node("data_visualization_agent", invoke_data_visualization_agent)
283
+ workflow.add_node("route_printer", route_printer)
284
+
285
+ workflow.add_edge(START, "routing_preprocessor")
286
+ workflow.add_edge("routing_preprocessor", "data_wrangling_agent")
287
+
288
+ workflow.add_conditional_edges(
289
+ "data_wrangling_agent",
290
+ router_chart_or_table,
291
+ {
292
+ "chart": "data_visualization_agent",
293
+ "table": "route_printer"
294
+ }
295
+ )
296
+
297
+ workflow.add_edge("data_visualization_agent", "route_printer")
298
+ workflow.add_edge("route_printer", END)
299
+
300
+ app = workflow.compile(
301
+ checkpointer=checkpointer,
302
+ name=AGENT_NAME
303
+ )
304
+
305
+ return app
@@ -1,12 +1,14 @@
1
1
 
2
2
  from langchain_core.messages import BaseMessage
3
- from langgraph.types import Checkpointer
3
+
4
+ from langchain.prompts import PromptTemplate
5
+ from langchain_core.output_parsers import JsonOutputParser
4
6
 
5
7
  from langgraph.graph import START, END, StateGraph
6
8
  from langgraph.graph.state import CompiledStateGraph
7
- from langgraph.types import Command
9
+ from langgraph.types import Checkpointer
8
10
 
9
- from typing import TypedDict, Annotated, Sequence, Literal
11
+ from typing import TypedDict, Annotated, Sequence
10
12
  import operator
11
13
 
12
14
  from typing_extensions import TypedDict
@@ -20,6 +22,7 @@ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
20
22
  from ai_data_science_team.utils.plotly import plotly_from_dict
21
23
  from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
22
24
 
25
+ AGENT_NAME = "sql_data_analyst"
23
26
 
24
27
  class SQLDataAnalyst(BaseAgent):
25
28
  """
@@ -33,6 +36,8 @@ class SQLDataAnalyst(BaseAgent):
33
36
  The SQL Database Agent.
34
37
  data_visualization_agent: DataVisualizationAgent
35
38
  The Data Visualization Agent.
39
+ checkpointer: Checkpointer (optional)
40
+ The checkpointer to save the state of the multi-agent system.
36
41
 
37
42
  Methods:
38
43
  --------
@@ -326,17 +331,17 @@ def make_sql_data_analyst(
326
331
  """
327
332
  Creates a multi-agent system that takes in a SQL query and returns a plot or table.
328
333
 
329
- - Agent 1: SQL Database Agent made with `make_sql_database_agent()`
330
- - Agent 2: Data Visualization Agent made with `make_data_visualization_agent()`
334
+ - Agent 1: SQL Database Agent made with `SQLDatabaseAgent()`
335
+ - Agent 2: Data Visualization Agent made with `DataVisualizationAgent()`
331
336
 
332
337
  Parameters:
333
338
  ----------
334
339
  model:
335
340
  The language model to be used for the agents.
336
341
  sql_database_agent: CompiledStateGraph
337
- The SQL Database Agent made with `make_sql_database_agent()`.
342
+ The SQL Database Agent made with `SQLDatabaseAgent()`.
338
343
  data_visualization_agent: CompiledStateGraph
339
- The Data Visualization Agent made with `make_data_visualization_agent()`.
344
+ The Data Visualization Agent made with `DataVisualizationAgent()`.
340
345
  checkpointer: Checkpointer (optional)
341
346
  The checkpointer to save the state of the multi-agent system.
342
347
  Default: None
@@ -348,10 +353,39 @@ def make_sql_data_analyst(
348
353
  """
349
354
 
350
355
  llm = model
356
+
357
+
358
+ routing_preprocessor_prompt = PromptTemplate(
359
+ template="""
360
+ You are an expert in routing decisions for a SQL Database Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to:
361
+
362
+ 1. Determine what the correct format for a Users Question should be for use with a SQL Database Agent based on the incoming user question. Anything related to database and data manipulation should be passed along.
363
+ 2. Determine whether or not a chart should be generated or a table should be returned based on the users question.
364
+ 3. If a chart is requested, determine the correct format of a Users Question should be used with a Data Visualization Agent. Anything related to plotting and visualization should be passed along.
365
+
366
+ Use the following criteria on how to route the the initial user question:
367
+
368
+ From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the SQL generator agent. This will be the 'user_instructions_sql_database'. If 'None' is found, return the original user question.
369
+
370
+ Next, determine if the user would like a data visualization ('chart') or a 'table' returned with the results of the Data Wrangling Agent. If unknown, not specified or 'None' is found, then select 'table'.
371
+
372
+ If a 'chart' is requested, return the 'user_instructions_data_visualization'. If 'None' is found, return None.
373
+
374
+ Return JSON with 'user_instructions_sql_database', 'user_instructions_data_visualization' and 'routing_preprocessor_decision'.
375
+
376
+ INITIAL_USER_QUESTION: {user_instructions}
377
+ """,
378
+ input_variables=["user_instructions"]
379
+ )
380
+
381
+ routing_preprocessor = routing_preprocessor_prompt | llm | JsonOutputParser()
351
382
 
352
383
  class PrimaryState(TypedDict):
353
384
  messages: Annotated[Sequence[BaseMessage], operator.add]
354
385
  user_instructions: str
386
+ user_instructions_sql_database: str
387
+ user_instructions_data_visualization: str
388
+ routing_preprocessor_decision: str
355
389
  sql_query_code: str
356
390
  sql_database_function: str
357
391
  data_sql: dict
@@ -359,39 +393,94 @@ def make_sql_data_analyst(
359
393
  plot_required: bool
360
394
  data_visualization_function: str
361
395
  plotly_graph: dict
396
+ plotly_error: str
362
397
  max_retries: int
363
398
  retry_count: int
364
399
 
365
- def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
400
+ def preprocess_routing(state: PrimaryState):
401
+ print("---SQL DATA ANALYST---")
402
+ print("*************************")
403
+ print("---PREPROCESS ROUTER---")
404
+ question = state.get("user_instructions")
366
405
 
367
- response = llm.invoke(f"Respond in 1 word ('plot' or 'table'). Is the user requesting a plot? If unknown, select 'table'. \n\n User Instructions:\n{state.get('user_instructions')}")
406
+ # Chart Routing and SQL Prep
407
+ response = routing_preprocessor.invoke({"user_instructions": question})
368
408
 
369
- if response.content == 'plot':
370
- plot_required = True
371
- goto="data_visualization_agent"
372
- else:
373
- plot_required = False
374
- goto="__end__"
409
+ return {
410
+ "user_instructions_sql_database": response.get('user_instructions_sql_database'),
411
+ "user_instructions_data_visualization": response.get('user_instructions_data_visualization'),
412
+ "routing_preprocessor_decision": response.get('routing_preprocessor_decision'),
413
+ }
414
+
415
+ def router_chart_or_table(state: PrimaryState):
416
+ print("---ROUTER: CHART OR TABLE---")
417
+ return "chart" if state.get('routing_preprocessor_decision') == "chart" else "table"
418
+
419
+
420
+ def invoke_sql_database_agent(state: PrimaryState):
375
421
 
376
- return Command(
377
- update={
378
- 'data_raw': state.get("data_sql"),
379
- 'plot_required': plot_required,
380
- },
381
- goto=goto
382
- )
422
+ response = sql_database_agent.invoke({
423
+ "user_instructions": state.get("user_instructions_sql_database"),
424
+ "max_retries": state.get("max_retries"),
425
+ "retry_count": state.get("retry_count"),
426
+ })
383
427
 
384
- workflow = StateGraph(PrimaryState)
428
+ return {
429
+ "messages": response.get("messages"),
430
+ "data_sql": response.get("data_sql"),
431
+ "sql_query_code": response.get("sql_query_code"),
432
+ "sql_database_function": response.get("sql_database_function"),
433
+
434
+ }
435
+
436
+ def invoke_data_visualization_agent(state: PrimaryState):
437
+
438
+ response = data_visualization_agent.invoke({
439
+ "user_instructions": state.get("user_instructions_data_visualization"),
440
+ "data_raw": state.get("data_sql"),
441
+ "max_retries": state.get("max_retries"),
442
+ "retry_count": state.get("retry_count"),
443
+ })
444
+
445
+ return {
446
+ "messages": response.get("messages"),
447
+ "data_visualization_function": response.get("data_visualization_function"),
448
+ "plotly_graph": response.get("plotly_graph"),
449
+ "plotly_error": response.get("data_visualization_error"),
450
+ }
385
451
 
386
- workflow.add_node("sql_database_agent", sql_database_agent)
387
- workflow.add_node("route_to_visualization", route_to_visualization)
388
- workflow.add_node("data_visualization_agent", data_visualization_agent)
452
+ def route_printer(state: PrimaryState):
453
+ print("---ROUTE PRINTER---")
454
+ print(f" Route: {state.get('routing_preprocessor_decision')}")
455
+ print("---END---")
456
+ return {}
457
+
458
+ workflow = StateGraph(PrimaryState)
459
+
460
+ workflow.add_node("routing_preprocessor", preprocess_routing)
461
+ workflow.add_node("sql_database_agent", invoke_sql_database_agent)
462
+ workflow.add_node("data_visualization_agent", invoke_data_visualization_agent)
463
+ workflow.add_node("route_printer", route_printer)
389
464
 
390
- workflow.add_edge(START, "sql_database_agent")
391
- workflow.add_edge("sql_database_agent", "route_to_visualization")
392
- workflow.add_edge("data_visualization_agent", END)
465
+ workflow.add_edge(START, "routing_preprocessor")
466
+ workflow.add_edge("routing_preprocessor", "sql_database_agent")
467
+
468
+ workflow.add_conditional_edges(
469
+ "sql_database_agent",
470
+ router_chart_or_table,
471
+ {
472
+ "chart": "data_visualization_agent",
473
+ "table": "route_printer"
474
+ }
475
+ )
476
+
477
+ workflow.add_edge("data_visualization_agent", "route_printer")
478
+ workflow.add_edge("route_printer", END)
393
479
 
394
- app = workflow.compile(checkpointer=checkpointer)
480
+ app = workflow.compile(
481
+ checkpointer=checkpointer,
482
+ name=AGENT_NAME
483
+ )
395
484
 
396
485
  return app
397
486
 
@@ -40,6 +40,21 @@ class BaseAgent(CompiledStateGraph):
40
40
  self._params = params
41
41
  self._compiled_graph = self._make_compiled_graph()
42
42
  self.response = None
43
+ self.name = self._compiled_graph.name
44
+ self.checkpointer = self._compiled_graph.checkpointer
45
+ self.store = self._compiled_graph.store
46
+ self.output_channels = self._compiled_graph.output_channels
47
+ self.nodes = self._compiled_graph.nodes
48
+ self.stream_mode = self._compiled_graph.stream_mode
49
+ self.builder = self._compiled_graph.builder
50
+ self.channels = self._compiled_graph.channels
51
+ self.input_channels = self._compiled_graph.input_channels
52
+ self.input_schema = self._compiled_graph.input_schema
53
+ self.output_schema = self._compiled_graph.output_schema
54
+ self.debug = self._compiled_graph.debug
55
+ self.interrupt_after_nodes = self._compiled_graph.interrupt_after_nodes
56
+ self.interrupt_before_nodes = self._compiled_graph.interrupt_before_nodes
57
+ self.config = self._compiled_graph.config
43
58
 
44
59
  def _make_compiled_graph(self):
45
60
  """
@@ -197,6 +212,24 @@ class BaseAgent(CompiledStateGraph):
197
212
  """
198
213
  return self.get_output_jsonschema()['properties']
199
214
 
215
+ def get_state(self, config, *, subgraphs = False):
216
+ """
217
+ Returns the state of the agent.
218
+ """
219
+ return self._compiled_graph.get_state(config, subgraphs=subgraphs)
220
+
221
+ def get_state_history(self, config, *, filter = None, before = None, limit = None):
222
+ """
223
+ Returns the state history of the agent.
224
+ """
225
+ return self._compiled_graph.get_state_history(config, filter=filter, before=before, limit=limit)
226
+
227
+ def update_state(self, config, values, as_node = None):
228
+ """
229
+ Updates the state of the agent.
230
+ """
231
+ return self._compiled_graph.update_state(config, values, as_node)
232
+
200
233
  def get_response(self):
201
234
  """
202
235
  Returns the response generated by the agent.
@@ -237,6 +270,7 @@ def create_coding_agent_graph(
237
270
  checkpointer: Optional[Callable] = None,
238
271
  bypass_recommended_steps: bool = False,
239
272
  bypass_explain_code: bool = False,
273
+ agent_name: str = "coding_agent"
240
274
  ):
241
275
  """
242
276
  Creates a generic agent graph using the provided node functions and node names.
@@ -281,6 +315,8 @@ def create_coding_agent_graph(
281
315
  Whether to skip the recommended steps node.
282
316
  bypass_explain_code : bool, optional
283
317
  Whether to skip the final explain code node.
318
+ name : str, optional
319
+ The name of the agent graph.
284
320
 
285
321
  Returns
286
322
  -------
@@ -366,10 +402,10 @@ def create_coding_agent_graph(
366
402
  workflow.add_edge(explain_code_node_name, END)
367
403
 
368
404
  # Finally, compile
369
- if human_in_the_loop:
370
- app = workflow.compile(checkpointer=checkpointer)
371
- else:
372
- app = workflow.compile()
405
+ app = workflow.compile(
406
+ checkpointer=checkpointer,
407
+ name=agent_name,
408
+ )
373
409
 
374
410
  return app
375
411
 
@@ -574,7 +610,7 @@ def node_func_execute_agent_from_sql_connection(
574
610
 
575
611
  # Retrieve SQLAlchemy connection and code snippet from the state
576
612
  is_engine = isinstance(connection, sql.engine.base.Engine)
577
- conn = connection.connect() if is_engine else connection
613
+ connection = connection.connect() if is_engine else connection
578
614
  agent_code = state.get(code_snippet_key)
579
615
 
580
616
  # Ensure the connection object is provided