ai-data-science-team 0.0.0.9013__py3-none-any.whl → 0.0.0.9014__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,6 +10,7 @@ from langchain_core.messages import BaseMessage, AIMessage
10
10
 
11
11
  from langgraph.prebuilt import create_react_agent, ToolNode
12
12
  from langgraph.prebuilt.chat_agent_executor import AgentState
13
+ from langgraph.types import Checkpointer
13
14
  from langgraph.graph import START, END, StateGraph
14
15
 
15
16
  from ai_data_science_team.templates import BaseAgent
@@ -68,6 +69,8 @@ class MLflowToolsAgent(BaseAgent):
68
69
  Additional keyword arguments to pass to the create_react_agent function.
69
70
  invoke_react_agent_kwargs : dict
70
71
  Additional keyword arguments to pass to the invoke method of the react agent.
72
+ checkpointer : langchain.checkpointing.Checkpointer, optional
73
+ A checkpointer to use for saving and loading the agent's state. Defaults to None.
71
74
 
72
75
  Methods:
73
76
  --------
@@ -119,6 +122,7 @@ class MLflowToolsAgent(BaseAgent):
119
122
  mlflow_registry_uri: Optional[str]=None,
120
123
  create_react_agent_kwargs: Optional[Dict]={},
121
124
  invoke_react_agent_kwargs: Optional[Dict]={},
125
+ checkpointer: Optional[Checkpointer]=None,
122
126
  ):
123
127
  self._params = {
124
128
  "model": model,
@@ -126,6 +130,7 @@ class MLflowToolsAgent(BaseAgent):
126
130
  "mlflow_registry_uri": mlflow_registry_uri,
127
131
  "create_react_agent_kwargs": create_react_agent_kwargs,
128
132
  "invoke_react_agent_kwargs": invoke_react_agent_kwargs,
133
+ "checkpointer": checkpointer,
129
134
  }
130
135
  self._compiled_graph = self._make_compiled_graph()
131
136
  self.response = None
@@ -245,6 +250,7 @@ def make_mlflow_tools_agent(
245
250
  mlflow_registry_uri: str=None,
246
251
  create_react_agent_kwargs: Optional[Dict]={},
247
252
  invoke_react_agent_kwargs: Optional[Dict]={},
253
+ checkpointer: Optional[Checkpointer]=None,
248
254
  ):
249
255
  """
250
256
  MLflow Tool Calling Agent
@@ -261,6 +267,8 @@ def make_mlflow_tools_agent(
261
267
  Additional keyword arguments to pass to the agent's create_react_agent method.
262
268
  invoke_react_agent_kwargs : dict, optional
263
269
  Additional keyword arguments to pass to the agent's invoke method.
270
+ checkpointer : langchain.checkpointing.Checkpointer, optional
271
+ A checkpointer to use for saving and loading the agent's state. Defaults to None.
264
272
 
265
273
  Returns
266
274
  -------
@@ -303,6 +311,7 @@ def make_mlflow_tools_agent(
303
311
  model,
304
312
  tools=tool_node,
305
313
  state_schema=GraphState,
314
+ checkpointer=checkpointer,
306
315
  **create_react_agent_kwargs,
307
316
  )
308
317
 
@@ -354,7 +363,10 @@ def make_mlflow_tools_agent(
354
363
  workflow.add_edge(START, "mlflow_tools_agent")
355
364
  workflow.add_edge("mlflow_tools_agent", END)
356
365
 
357
- app = workflow.compile()
366
+ app = workflow.compile(
367
+ checkpointer=checkpointer,
368
+ name=AGENT_NAME,
369
+ )
358
370
 
359
371
  return app
360
372
 
@@ -1 +1,2 @@
1
- from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
1
+ from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
2
+ from ai_data_science_team.multiagents.pandas_data_analyst import PandasDataAnalyst, make_pandas_data_analyst
@@ -0,0 +1,305 @@
1
+ from langchain_core.messages import BaseMessage
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain_core.output_parsers import JsonOutputParser
4
+ from langgraph.types import Checkpointer
5
+ from langgraph.graph import START, END, StateGraph
6
+ from langgraph.graph.state import CompiledStateGraph
7
+
8
+ from typing import TypedDict, Annotated, Sequence, Union
9
+ import operator
10
+
11
+ import pandas as pd
12
+ import json
13
+ from IPython.display import Markdown
14
+
15
+ from ai_data_science_team.templates import BaseAgent
16
+ from ai_data_science_team.agents import DataWranglingAgent, DataVisualizationAgent
17
+ from ai_data_science_team.utils.plotly import plotly_from_dict
18
+ from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
19
+
20
+ AGENT_NAME = "pandas_data_analyst"
21
+
22
+ class PandasDataAnalyst(BaseAgent):
23
+ """
24
+ PandasDataAnalyst is a multi-agent class that combines data wrangling and visualization capabilities.
25
+
26
+ Parameters:
27
+ -----------
28
+ model:
29
+ The language model to be used for the agents.
30
+ data_wrangling_agent: DataWranglingAgent
31
+ The Data Wrangling Agent for transforming raw data.
32
+ data_visualization_agent: DataVisualizationAgent
33
+ The Data Visualization Agent for generating plots.
34
+ checkpointer: Checkpointer (optional)
35
+ The checkpointer to save the state of the multi-agent system.
36
+
37
+ Methods:
38
+ --------
39
+ ainvoke_agent(user_instructions, data_raw, **kwargs)
40
+ Asynchronously invokes the multi-agent with user instructions and raw data.
41
+ invoke_agent(user_instructions, data_raw, **kwargs)
42
+ Synchronously invokes the multi-agent with user instructions and raw data.
43
+ get_data_wrangled()
44
+ Returns the wrangled data as a Pandas DataFrame.
45
+ get_plotly_graph()
46
+ Returns the Plotly graph as a Plotly object.
47
+ get_data_wrangler_function(markdown=False)
48
+ Returns the data wrangling function as a string, optionally in Markdown.
49
+ get_data_visualization_function(markdown=False)
50
+ Returns the data visualization function as a string, optionally in Markdown.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ model,
56
+ data_wrangling_agent: DataWranglingAgent,
57
+ data_visualization_agent: DataVisualizationAgent,
58
+ checkpointer: Checkpointer = None,
59
+ ):
60
+ self._params = {
61
+ "model": model,
62
+ "data_wrangling_agent": data_wrangling_agent,
63
+ "data_visualization_agent": data_visualization_agent,
64
+ "checkpointer": checkpointer,
65
+ }
66
+ self._compiled_graph = self._make_compiled_graph()
67
+ self.response = None
68
+
69
+ def _make_compiled_graph(self):
70
+ """Create or rebuild the compiled graph. Resets response to None."""
71
+ self.response = None
72
+ return make_pandas_data_analyst(
73
+ model=self._params["model"],
74
+ data_wrangling_agent=self._params["data_wrangling_agent"]._compiled_graph,
75
+ data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
76
+ checkpointer=self._params["checkpointer"],
77
+ )
78
+
79
+ def update_params(self, **kwargs):
80
+ """Updates parameters and rebuilds the compiled graph."""
81
+ for k, v in kwargs.items():
82
+ self._params[k] = v
83
+ self._compiled_graph = self._make_compiled_graph()
84
+
85
+ async def ainvoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
86
+ """Asynchronously invokes the multi-agent."""
87
+ response = await self._compiled_graph.ainvoke({
88
+ "user_instructions": user_instructions,
89
+ "data_raw": self._convert_data_input(data_raw),
90
+ "max_retries": max_retries,
91
+ "retry_count": retry_count,
92
+ }, **kwargs)
93
+ if response.get("messages"):
94
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
95
+ self.response = response
96
+
97
+ def invoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
98
+ """Synchronously invokes the multi-agent."""
99
+ response = self._compiled_graph.invoke({
100
+ "user_instructions": user_instructions,
101
+ "data_raw": self._convert_data_input(data_raw),
102
+ "max_retries": max_retries,
103
+ "retry_count": retry_count,
104
+ }, **kwargs)
105
+ if response.get("messages"):
106
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
107
+ self.response = response
108
+
109
+ def get_data_wrangled(self):
110
+ """Returns the wrangled data as a Pandas DataFrame."""
111
+ if self.response and self.response.get("data_wrangled"):
112
+ return pd.DataFrame(self.response.get("data_wrangled"))
113
+
114
+ def get_plotly_graph(self):
115
+ """Returns the Plotly graph as a Plotly object."""
116
+ if self.response and self.response.get("plotly_graph"):
117
+ return plotly_from_dict(self.response.get("plotly_graph"))
118
+
119
+ def get_data_wrangler_function(self, markdown=False):
120
+ """Returns the data wrangling function as a string."""
121
+ if self.response and self.response.get("data_wrangler_function"):
122
+ code = self.response.get("data_wrangler_function")
123
+ return Markdown(f"```python\n{code}\n```") if markdown else code
124
+
125
+ def get_data_visualization_function(self, markdown=False):
126
+ """Returns the data visualization function as a string."""
127
+ if self.response and self.response.get("data_visualization_function"):
128
+ code = self.response.get("data_visualization_function")
129
+ return Markdown(f"```python\n{code}\n```") if markdown else code
130
+
131
+ def get_workflow_summary(self, markdown=False):
132
+ """Returns a summary of the workflow."""
133
+ if self.response and self.response.get("messages"):
134
+ agents = [msg.role for msg in self.response["messages"]]
135
+ agent_labels = [f"- **Agent {i+1}:** {role}" for i, role in enumerate(agents)]
136
+ header = f"# Pandas Data Analyst Workflow Summary\n\nThis workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
137
+ reports = [get_generic_summary(json.loads(msg.content)) for msg in self.response["messages"]]
138
+ summary = "\n" +header + "\n\n".join(reports)
139
+ return Markdown(summary) if markdown else summary
140
+
141
+ @staticmethod
142
+ def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
143
+ """Converts input data to the expected format (dict or list of dicts)."""
144
+ if isinstance(data_raw, pd.DataFrame):
145
+ return data_raw.to_dict()
146
+ if isinstance(data_raw, dict):
147
+ return data_raw
148
+ if isinstance(data_raw, list):
149
+ return [item.to_dict() if isinstance(item, pd.DataFrame) else item for item in data_raw]
150
+ raise ValueError("data_raw must be a DataFrame, dict, or list of DataFrames/dicts")
151
+
152
+ def make_pandas_data_analyst(
153
+ model,
154
+ data_wrangling_agent: CompiledStateGraph,
155
+ data_visualization_agent: CompiledStateGraph,
156
+ checkpointer: Checkpointer = None
157
+ ):
158
+ """
159
+ Creates a multi-agent system that wrangles data and optionally visualizes it.
160
+
161
+ Parameters:
162
+ -----------
163
+ model: The language model to be used.
164
+ data_wrangling_agent: CompiledStateGraph
165
+ The Data Wrangling Agent.
166
+ data_visualization_agent: CompiledStateGraph
167
+ The Data Visualization Agent.
168
+ checkpointer: Checkpointer (optional)
169
+ The checkpointer to save the state.
170
+
171
+ Returns:
172
+ --------
173
+ CompiledStateGraph: The compiled multi-agent system.
174
+ """
175
+
176
+ llm = model
177
+
178
+ routing_preprocessor_prompt = PromptTemplate(
179
+ template="""
180
+ You are an expert in routing decisions for a Pandas Data Manipulation Wrangling Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to:
181
+
182
+ 1. Determine what the correct format for a Users Question should be for use with a Pandas Data Wrangling Agent based on the incoming user question. Anything related to data wrangling and manipulation should be passed along.
183
+ 2. Determine whether or not a chart should be generated or a table should be returned based on the users question.
184
+ 3. If a chart is requested, determine the correct format of a Users Question should be used with a Data Visualization Agent. Anything related to plotting and visualization should be passed along.
185
+
186
+ Use the following criteria on how to route the the initial user question:
187
+
188
+ From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the SQL generator agent. This will be the 'user_instructions_data_wrangling'. If 'None' is found, return the original user question.
189
+
190
+ Next, determine if the user would like a data visualization ('chart') or a 'table' returned with the results of the Data Wrangling Agent. If unknown, not specified or 'None' is found, then select 'table'.
191
+
192
+ If a 'chart' is requested, return the 'user_instructions_data_visualization'. If 'None' is found, return None.
193
+
194
+ Return JSON with 'user_instructions_data_wrangling', 'user_instructions_data_visualization' and 'routing_preprocessor_decision'.
195
+
196
+ INITIAL_USER_QUESTION: {user_instructions}
197
+ """,
198
+ input_variables=["user_instructions"]
199
+ )
200
+
201
+ routing_preprocessor = routing_preprocessor_prompt | llm | JsonOutputParser()
202
+
203
+ class PrimaryState(TypedDict):
204
+ messages: Annotated[Sequence[BaseMessage], operator.add]
205
+ user_instructions: str
206
+ user_instructions_data_wrangling: str
207
+ user_instructions_data_visualization: str
208
+ routing_preprocessor_decision: str
209
+ data_raw: Union[dict, list]
210
+ data_wrangled: dict
211
+ data_wrangler_function: str
212
+ data_visualization_function: str
213
+ plotly_graph: dict
214
+ plotly_error: str
215
+ max_retries: int
216
+ retry_count: int
217
+
218
+
219
+ def preprocess_routing(state: PrimaryState):
220
+ print("---PANDAS DATA ANALYST---")
221
+ print("*************************")
222
+ print("---PREPROCESS ROUTER---")
223
+ question = state.get("user_instructions")
224
+
225
+ # Chart Routing and SQL Prep
226
+ response = routing_preprocessor.invoke({"user_instructions": question})
227
+
228
+ return {
229
+ "user_instructions_data_wrangling": response.get('user_instructions_data_wrangling'),
230
+ "user_instructions_data_visualization": response.get('user_instructions_data_visualization'),
231
+ "routing_preprocessor_decision": response.get('routing_preprocessor_decision'),
232
+ }
233
+
234
+ def router_chart_or_table(state: PrimaryState):
235
+ print("---ROUTER: CHART OR TABLE---")
236
+ return "chart" if state.get('routing_preprocessor_decision') == "chart" else "table"
237
+
238
+
239
+ def invoke_data_wrangling_agent(state: PrimaryState):
240
+
241
+ response = data_wrangling_agent.invoke({
242
+ "user_instructions": state.get("user_instructions_data_wrangling"),
243
+ "data_raw": state.get("data_raw"),
244
+ "max_retries": state.get("max_retries"),
245
+ "retry_count": state.get("retry_count"),
246
+ })
247
+
248
+ return {
249
+ "messages": response.get("messages"),
250
+ "data_wrangled": response.get("data_wrangled"),
251
+ "data_wrangler_function": response.get("data_wrangler_function"),
252
+ "plotly_error": response.get("data_visualization_error"),
253
+
254
+ }
255
+
256
+ def invoke_data_visualization_agent(state: PrimaryState):
257
+
258
+ response = data_visualization_agent.invoke({
259
+ "user_instructions": state.get("user_instructions_data_visualization"),
260
+ "data_raw": state.get("data_wrangled") if state.get("data_wrangled") else state.get("data_raw"),
261
+ "max_retries": state.get("max_retries"),
262
+ "retry_count": state.get("retry_count"),
263
+ })
264
+
265
+ return {
266
+ "messages": response.get("messages"),
267
+ "data_visualization_function": response.get("data_visualization_function"),
268
+ "plotly_graph": response.get("plotly_graph"),
269
+ "plotly_error": response.get("data_visualization_error"),
270
+ }
271
+
272
+ def route_printer(state: PrimaryState):
273
+ print("---ROUTE PRINTER---")
274
+ print(f" Route: {state.get('routing_preprocessor_decision')}")
275
+ print("---END---")
276
+ return {}
277
+
278
+ workflow = StateGraph(PrimaryState)
279
+
280
+ workflow.add_node("routing_preprocessor", preprocess_routing)
281
+ workflow.add_node("data_wrangling_agent", invoke_data_wrangling_agent)
282
+ workflow.add_node("data_visualization_agent", invoke_data_visualization_agent)
283
+ workflow.add_node("route_printer", route_printer)
284
+
285
+ workflow.add_edge(START, "routing_preprocessor")
286
+ workflow.add_edge("routing_preprocessor", "data_wrangling_agent")
287
+
288
+ workflow.add_conditional_edges(
289
+ "data_wrangling_agent",
290
+ router_chart_or_table,
291
+ {
292
+ "chart": "data_visualization_agent",
293
+ "table": "route_printer"
294
+ }
295
+ )
296
+
297
+ workflow.add_edge("data_visualization_agent", "route_printer")
298
+ workflow.add_edge("route_printer", END)
299
+
300
+ app = workflow.compile(
301
+ checkpointer=checkpointer,
302
+ name=AGENT_NAME
303
+ )
304
+
305
+ return app
@@ -1,12 +1,14 @@
1
1
 
2
2
  from langchain_core.messages import BaseMessage
3
- from langgraph.types import Checkpointer
3
+
4
+ from langchain.prompts import PromptTemplate
5
+ from langchain_core.output_parsers import JsonOutputParser
4
6
 
5
7
  from langgraph.graph import START, END, StateGraph
6
8
  from langgraph.graph.state import CompiledStateGraph
7
- from langgraph.types import Command
9
+ from langgraph.types import Checkpointer
8
10
 
9
- from typing import TypedDict, Annotated, Sequence, Literal
11
+ from typing import TypedDict, Annotated, Sequence
10
12
  import operator
11
13
 
12
14
  from typing_extensions import TypedDict
@@ -20,6 +22,7 @@ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
20
22
  from ai_data_science_team.utils.plotly import plotly_from_dict
21
23
  from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
22
24
 
25
+ AGENT_NAME = "sql_data_analyst"
23
26
 
24
27
  class SQLDataAnalyst(BaseAgent):
25
28
  """
@@ -33,6 +36,8 @@ class SQLDataAnalyst(BaseAgent):
33
36
  The SQL Database Agent.
34
37
  data_visualization_agent: DataVisualizationAgent
35
38
  The Data Visualization Agent.
39
+ checkpointer: Checkpointer (optional)
40
+ The checkpointer to save the state of the multi-agent system.
36
41
 
37
42
  Methods:
38
43
  --------
@@ -326,17 +331,17 @@ def make_sql_data_analyst(
326
331
  """
327
332
  Creates a multi-agent system that takes in a SQL query and returns a plot or table.
328
333
 
329
- - Agent 1: SQL Database Agent made with `make_sql_database_agent()`
330
- - Agent 2: Data Visualization Agent made with `make_data_visualization_agent()`
334
+ - Agent 1: SQL Database Agent made with `SQLDatabaseAgent()`
335
+ - Agent 2: Data Visualization Agent made with `DataVisualizationAgent()`
331
336
 
332
337
  Parameters:
333
338
  ----------
334
339
  model:
335
340
  The language model to be used for the agents.
336
341
  sql_database_agent: CompiledStateGraph
337
- The SQL Database Agent made with `make_sql_database_agent()`.
342
+ The SQL Database Agent made with `SQLDatabaseAgent()`.
338
343
  data_visualization_agent: CompiledStateGraph
339
- The Data Visualization Agent made with `make_data_visualization_agent()`.
344
+ The Data Visualization Agent made with `DataVisualizationAgent()`.
340
345
  checkpointer: Checkpointer (optional)
341
346
  The checkpointer to save the state of the multi-agent system.
342
347
  Default: None
@@ -348,10 +353,39 @@ def make_sql_data_analyst(
348
353
  """
349
354
 
350
355
  llm = model
356
+
357
+
358
+ routing_preprocessor_prompt = PromptTemplate(
359
+ template="""
360
+ You are an expert in routing decisions for a SQL Database Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to:
361
+
362
+ 1. Determine what the correct format for a Users Question should be for use with a SQL Database Agent based on the incoming user question. Anything related to database and data manipulation should be passed along.
363
+ 2. Determine whether or not a chart should be generated or a table should be returned based on the users question.
364
+ 3. If a chart is requested, determine the correct format of a Users Question should be used with a Data Visualization Agent. Anything related to plotting and visualization should be passed along.
365
+
366
+ Use the following criteria on how to route the the initial user question:
367
+
368
+ From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the SQL generator agent. This will be the 'user_instructions_sql_database'. If 'None' is found, return the original user question.
369
+
370
+ Next, determine if the user would like a data visualization ('chart') or a 'table' returned with the results of the Data Wrangling Agent. If unknown, not specified or 'None' is found, then select 'table'.
371
+
372
+ If a 'chart' is requested, return the 'user_instructions_data_visualization'. If 'None' is found, return None.
373
+
374
+ Return JSON with 'user_instructions_sql_database', 'user_instructions_data_visualization' and 'routing_preprocessor_decision'.
375
+
376
+ INITIAL_USER_QUESTION: {user_instructions}
377
+ """,
378
+ input_variables=["user_instructions"]
379
+ )
380
+
381
+ routing_preprocessor = routing_preprocessor_prompt | llm | JsonOutputParser()
351
382
 
352
383
  class PrimaryState(TypedDict):
353
384
  messages: Annotated[Sequence[BaseMessage], operator.add]
354
385
  user_instructions: str
386
+ user_instructions_sql_database: str
387
+ user_instructions_data_visualization: str
388
+ routing_preprocessor_decision: str
355
389
  sql_query_code: str
356
390
  sql_database_function: str
357
391
  data_sql: dict
@@ -359,39 +393,94 @@ def make_sql_data_analyst(
359
393
  plot_required: bool
360
394
  data_visualization_function: str
361
395
  plotly_graph: dict
396
+ plotly_error: str
362
397
  max_retries: int
363
398
  retry_count: int
364
399
 
365
- def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
400
+ def preprocess_routing(state: PrimaryState):
401
+ print("---SQL DATA ANALYST---")
402
+ print("*************************")
403
+ print("---PREPROCESS ROUTER---")
404
+ question = state.get("user_instructions")
366
405
 
367
- response = llm.invoke(f"Respond in 1 word ('plot' or 'table'). Is the user requesting a plot? If unknown, select 'table'. \n\n User Instructions:\n{state.get('user_instructions')}")
406
+ # Chart Routing and SQL Prep
407
+ response = routing_preprocessor.invoke({"user_instructions": question})
368
408
 
369
- if response.content == 'plot':
370
- plot_required = True
371
- goto="data_visualization_agent"
372
- else:
373
- plot_required = False
374
- goto="__end__"
409
+ return {
410
+ "user_instructions_sql_database": response.get('user_instructions_sql_database'),
411
+ "user_instructions_data_visualization": response.get('user_instructions_data_visualization'),
412
+ "routing_preprocessor_decision": response.get('routing_preprocessor_decision'),
413
+ }
414
+
415
+ def router_chart_or_table(state: PrimaryState):
416
+ print("---ROUTER: CHART OR TABLE---")
417
+ return "chart" if state.get('routing_preprocessor_decision') == "chart" else "table"
418
+
419
+
420
+ def invoke_sql_database_agent(state: PrimaryState):
375
421
 
376
- return Command(
377
- update={
378
- 'data_raw': state.get("data_sql"),
379
- 'plot_required': plot_required,
380
- },
381
- goto=goto
382
- )
422
+ response = sql_database_agent.invoke({
423
+ "user_instructions": state.get("user_instructions_sql_database"),
424
+ "max_retries": state.get("max_retries"),
425
+ "retry_count": state.get("retry_count"),
426
+ })
383
427
 
384
- workflow = StateGraph(PrimaryState)
428
+ return {
429
+ "messages": response.get("messages"),
430
+ "data_sql": response.get("data_sql"),
431
+ "sql_query_code": response.get("sql_query_code"),
432
+ "sql_database_function": response.get("sql_database_function"),
433
+
434
+ }
435
+
436
+ def invoke_data_visualization_agent(state: PrimaryState):
437
+
438
+ response = data_visualization_agent.invoke({
439
+ "user_instructions": state.get("user_instructions_data_visualization"),
440
+ "data_raw": state.get("data_sql"),
441
+ "max_retries": state.get("max_retries"),
442
+ "retry_count": state.get("retry_count"),
443
+ })
444
+
445
+ return {
446
+ "messages": response.get("messages"),
447
+ "data_visualization_function": response.get("data_visualization_function"),
448
+ "plotly_graph": response.get("plotly_graph"),
449
+ "plotly_error": response.get("data_visualization_error"),
450
+ }
385
451
 
386
- workflow.add_node("sql_database_agent", sql_database_agent)
387
- workflow.add_node("route_to_visualization", route_to_visualization)
388
- workflow.add_node("data_visualization_agent", data_visualization_agent)
452
+ def route_printer(state: PrimaryState):
453
+ print("---ROUTE PRINTER---")
454
+ print(f" Route: {state.get('routing_preprocessor_decision')}")
455
+ print("---END---")
456
+ return {}
457
+
458
+ workflow = StateGraph(PrimaryState)
459
+
460
+ workflow.add_node("routing_preprocessor", preprocess_routing)
461
+ workflow.add_node("sql_database_agent", invoke_sql_database_agent)
462
+ workflow.add_node("data_visualization_agent", invoke_data_visualization_agent)
463
+ workflow.add_node("route_printer", route_printer)
389
464
 
390
- workflow.add_edge(START, "sql_database_agent")
391
- workflow.add_edge("sql_database_agent", "route_to_visualization")
392
- workflow.add_edge("data_visualization_agent", END)
465
+ workflow.add_edge(START, "routing_preprocessor")
466
+ workflow.add_edge("routing_preprocessor", "sql_database_agent")
467
+
468
+ workflow.add_conditional_edges(
469
+ "sql_database_agent",
470
+ router_chart_or_table,
471
+ {
472
+ "chart": "data_visualization_agent",
473
+ "table": "route_printer"
474
+ }
475
+ )
476
+
477
+ workflow.add_edge("data_visualization_agent", "route_printer")
478
+ workflow.add_edge("route_printer", END)
393
479
 
394
- app = workflow.compile(checkpointer=checkpointer)
480
+ app = workflow.compile(
481
+ checkpointer=checkpointer,
482
+ name=AGENT_NAME
483
+ )
395
484
 
396
485
  return app
397
486
 
@@ -40,6 +40,21 @@ class BaseAgent(CompiledStateGraph):
40
40
  self._params = params
41
41
  self._compiled_graph = self._make_compiled_graph()
42
42
  self.response = None
43
+ self.name = self._compiled_graph.name
44
+ self.checkpointer = self._compiled_graph.checkpointer
45
+ self.store = self._compiled_graph.store
46
+ self.output_channels = self._compiled_graph.output_channels
47
+ self.nodes = self._compiled_graph.nodes
48
+ self.stream_mode = self._compiled_graph.stream_mode
49
+ self.builder = self._compiled_graph.builder
50
+ self.channels = self._compiled_graph.channels
51
+ self.input_channels = self._compiled_graph.input_channels
52
+ self.input_schema = self._compiled_graph.input_schema
53
+ self.output_schema = self._compiled_graph.output_schema
54
+ self.debug = self._compiled_graph.debug
55
+ self.interrupt_after_nodes = self._compiled_graph.interrupt_after_nodes
56
+ self.interrupt_before_nodes = self._compiled_graph.interrupt_before_nodes
57
+ self.config = self._compiled_graph.config
43
58
 
44
59
  def _make_compiled_graph(self):
45
60
  """
@@ -197,6 +212,24 @@ class BaseAgent(CompiledStateGraph):
197
212
  """
198
213
  return self.get_output_jsonschema()['properties']
199
214
 
215
+ def get_state(self, config, *, subgraphs = False):
216
+ """
217
+ Returns the state of the agent.
218
+ """
219
+ return self._compiled_graph.get_state(config, subgraphs=subgraphs)
220
+
221
+ def get_state_history(self, config, *, filter = None, before = None, limit = None):
222
+ """
223
+ Returns the state history of the agent.
224
+ """
225
+ return self._compiled_graph.get_state_history(config, filter=filter, before=before, limit=limit)
226
+
227
+ def update_state(self, config, values, as_node = None):
228
+ """
229
+ Updates the state of the agent.
230
+ """
231
+ return self._compiled_graph.update_state(config, values, as_node)
232
+
200
233
  def get_response(self):
201
234
  """
202
235
  Returns the response generated by the agent.
@@ -237,6 +270,7 @@ def create_coding_agent_graph(
237
270
  checkpointer: Optional[Callable] = None,
238
271
  bypass_recommended_steps: bool = False,
239
272
  bypass_explain_code: bool = False,
273
+ agent_name: str = "coding_agent"
240
274
  ):
241
275
  """
242
276
  Creates a generic agent graph using the provided node functions and node names.
@@ -281,6 +315,8 @@ def create_coding_agent_graph(
281
315
  Whether to skip the recommended steps node.
282
316
  bypass_explain_code : bool, optional
283
317
  Whether to skip the final explain code node.
318
+ name : str, optional
319
+ The name of the agent graph.
284
320
 
285
321
  Returns
286
322
  -------
@@ -366,10 +402,10 @@ def create_coding_agent_graph(
366
402
  workflow.add_edge(explain_code_node_name, END)
367
403
 
368
404
  # Finally, compile
369
- if human_in_the_loop:
370
- app = workflow.compile(checkpointer=checkpointer)
371
- else:
372
- app = workflow.compile()
405
+ app = workflow.compile(
406
+ checkpointer=checkpointer,
407
+ name=agent_name,
408
+ )
373
409
 
374
410
  return app
375
411
 
@@ -574,7 +610,7 @@ def node_func_execute_agent_from_sql_connection(
574
610
 
575
611
  # Retrieve SQLAlchemy connection and code snippet from the state
576
612
  is_engine = isinstance(connection, sql.engine.base.Engine)
577
- conn = connection.connect() if is_engine else connection
613
+ connection = connection.connect() if is_engine else connection
578
614
  agent_code = state.get(code_snippet_key)
579
615
 
580
616
  # Ensure the connection object is provided