ai-data-science-team 0.0.0.9013__py3-none-any.whl → 0.0.0.9015__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,305 @@
1
+ from langchain_core.messages import BaseMessage
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain_core.output_parsers import JsonOutputParser
4
+ from langgraph.types import Checkpointer
5
+ from langgraph.graph import START, END, StateGraph
6
+ from langgraph.graph.state import CompiledStateGraph
7
+
8
+ from typing import TypedDict, Annotated, Sequence, Union
9
+ import operator
10
+
11
+ import pandas as pd
12
+ import json
13
+ from IPython.display import Markdown
14
+
15
+ from ai_data_science_team.templates import BaseAgent
16
+ from ai_data_science_team.agents import DataWranglingAgent, DataVisualizationAgent
17
+ from ai_data_science_team.utils.plotly import plotly_from_dict
18
+ from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
19
+
20
+ AGENT_NAME = "pandas_data_analyst"
21
+
22
+ class PandasDataAnalyst(BaseAgent):
23
+ """
24
+ PandasDataAnalyst is a multi-agent class that combines data wrangling and visualization capabilities.
25
+
26
+ Parameters:
27
+ -----------
28
+ model:
29
+ The language model to be used for the agents.
30
+ data_wrangling_agent: DataWranglingAgent
31
+ The Data Wrangling Agent for transforming raw data.
32
+ data_visualization_agent: DataVisualizationAgent
33
+ The Data Visualization Agent for generating plots.
34
+ checkpointer: Checkpointer (optional)
35
+ The checkpointer to save the state of the multi-agent system.
36
+
37
+ Methods:
38
+ --------
39
+ ainvoke_agent(user_instructions, data_raw, **kwargs)
40
+ Asynchronously invokes the multi-agent with user instructions and raw data.
41
+ invoke_agent(user_instructions, data_raw, **kwargs)
42
+ Synchronously invokes the multi-agent with user instructions and raw data.
43
+ get_data_wrangled()
44
+ Returns the wrangled data as a Pandas DataFrame.
45
+ get_plotly_graph()
46
+ Returns the Plotly graph as a Plotly object.
47
+ get_data_wrangler_function(markdown=False)
48
+ Returns the data wrangling function as a string, optionally in Markdown.
49
+ get_data_visualization_function(markdown=False)
50
+ Returns the data visualization function as a string, optionally in Markdown.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ model,
56
+ data_wrangling_agent: DataWranglingAgent,
57
+ data_visualization_agent: DataVisualizationAgent,
58
+ checkpointer: Checkpointer = None,
59
+ ):
60
+ self._params = {
61
+ "model": model,
62
+ "data_wrangling_agent": data_wrangling_agent,
63
+ "data_visualization_agent": data_visualization_agent,
64
+ "checkpointer": checkpointer,
65
+ }
66
+ self._compiled_graph = self._make_compiled_graph()
67
+ self.response = None
68
+
69
+ def _make_compiled_graph(self):
70
+ """Create or rebuild the compiled graph. Resets response to None."""
71
+ self.response = None
72
+ return make_pandas_data_analyst(
73
+ model=self._params["model"],
74
+ data_wrangling_agent=self._params["data_wrangling_agent"]._compiled_graph,
75
+ data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
76
+ checkpointer=self._params["checkpointer"],
77
+ )
78
+
79
+ def update_params(self, **kwargs):
80
+ """Updates parameters and rebuilds the compiled graph."""
81
+ for k, v in kwargs.items():
82
+ self._params[k] = v
83
+ self._compiled_graph = self._make_compiled_graph()
84
+
85
+ async def ainvoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
86
+ """Asynchronously invokes the multi-agent."""
87
+ response = await self._compiled_graph.ainvoke({
88
+ "user_instructions": user_instructions,
89
+ "data_raw": self._convert_data_input(data_raw),
90
+ "max_retries": max_retries,
91
+ "retry_count": retry_count,
92
+ }, **kwargs)
93
+ if response.get("messages"):
94
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
95
+ self.response = response
96
+
97
+ def invoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
98
+ """Synchronously invokes the multi-agent."""
99
+ response = self._compiled_graph.invoke({
100
+ "user_instructions": user_instructions,
101
+ "data_raw": self._convert_data_input(data_raw),
102
+ "max_retries": max_retries,
103
+ "retry_count": retry_count,
104
+ }, **kwargs)
105
+ if response.get("messages"):
106
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
107
+ self.response = response
108
+
109
+ def get_data_wrangled(self):
110
+ """Returns the wrangled data as a Pandas DataFrame."""
111
+ if self.response and self.response.get("data_wrangled"):
112
+ return pd.DataFrame(self.response.get("data_wrangled"))
113
+
114
+ def get_plotly_graph(self):
115
+ """Returns the Plotly graph as a Plotly object."""
116
+ if self.response and self.response.get("plotly_graph"):
117
+ return plotly_from_dict(self.response.get("plotly_graph"))
118
+
119
+ def get_data_wrangler_function(self, markdown=False):
120
+ """Returns the data wrangling function as a string."""
121
+ if self.response and self.response.get("data_wrangler_function"):
122
+ code = self.response.get("data_wrangler_function")
123
+ return Markdown(f"```python\n{code}\n```") if markdown else code
124
+
125
+ def get_data_visualization_function(self, markdown=False):
126
+ """Returns the data visualization function as a string."""
127
+ if self.response and self.response.get("data_visualization_function"):
128
+ code = self.response.get("data_visualization_function")
129
+ return Markdown(f"```python\n{code}\n```") if markdown else code
130
+
131
+ def get_workflow_summary(self, markdown=False):
132
+ """Returns a summary of the workflow."""
133
+ if self.response and self.response.get("messages"):
134
+ agents = [msg.role for msg in self.response["messages"]]
135
+ agent_labels = [f"- **Agent {i+1}:** {role}\n" for i, role in enumerate(agents)]
136
+ header = f"# Pandas Data Analyst Workflow Summary\n\nThis workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
137
+ reports = [get_generic_summary(json.loads(msg.content)) for msg in self.response["messages"]]
138
+ summary = "\n\n" + header + "\n\n".join(reports)
139
+ return Markdown(summary) if markdown else summary
140
+
141
+ @staticmethod
142
+ def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
143
+ """Converts input data to the expected format (dict or list of dicts)."""
144
+ if isinstance(data_raw, pd.DataFrame):
145
+ return data_raw.to_dict()
146
+ if isinstance(data_raw, dict):
147
+ return data_raw
148
+ if isinstance(data_raw, list):
149
+ return [item.to_dict() if isinstance(item, pd.DataFrame) else item for item in data_raw]
150
+ raise ValueError("data_raw must be a DataFrame, dict, or list of DataFrames/dicts")
151
+
152
+ def make_pandas_data_analyst(
153
+ model,
154
+ data_wrangling_agent: CompiledStateGraph,
155
+ data_visualization_agent: CompiledStateGraph,
156
+ checkpointer: Checkpointer = None
157
+ ):
158
+ """
159
+ Creates a multi-agent system that wrangles data and optionally visualizes it.
160
+
161
+ Parameters:
162
+ -----------
163
+ model: The language model to be used.
164
+ data_wrangling_agent: CompiledStateGraph
165
+ The Data Wrangling Agent.
166
+ data_visualization_agent: CompiledStateGraph
167
+ The Data Visualization Agent.
168
+ checkpointer: Checkpointer (optional)
169
+ The checkpointer to save the state.
170
+
171
+ Returns:
172
+ --------
173
+ CompiledStateGraph: The compiled multi-agent system.
174
+ """
175
+
176
+ llm = model
177
+
178
+ routing_preprocessor_prompt = PromptTemplate(
179
+ template="""
180
+ You are an expert in routing decisions for a Pandas Data Manipulation Wrangling Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to tell the agents which actions to perform and determine the correct routing for the incoming user question:
181
+
182
+ 1. Determine what the correct format for a Users Question should be for use with a Pandas Data Wrangling Agent based on the incoming user question. Anything related to data wrangling and manipulation should be passed along. Anything related to data analysis can be handled by the Pandas Agent. Anything that uses Pandas can be passed along. Tables can be returned from this agent. Don't pass along anything about plotting or visualization.
183
+ 2. Determine whether or not a chart should be generated or a table should be returned based on the users question.
184
+ 3. If a chart is requested, determine the correct format of a Users Question should be used with a Data Visualization Agent. Anything related to plotting and visualization should be passed along.
185
+
186
+ Use the following criteria on how to route the the initial user question:
187
+
188
+ From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the Pandas Data Wrangling and Transformation agent. This will be the 'user_instructions_data_wrangling'. If 'None' is found, return the original user question.
189
+
190
+ Next, determine if the user would like a data visualization ('chart') or a 'table' returned with the results of the Data Wrangling Agent. If unknown, not specified or 'None' is found, then select 'table'.
191
+
192
+ If a 'chart' is requested, return the 'user_instructions_data_visualization'. If 'None' is found, return None.
193
+
194
+ Return JSON with 'user_instructions_data_wrangling', 'user_instructions_data_visualization' and 'routing_preprocessor_decision'.
195
+
196
+ INITIAL_USER_QUESTION: {user_instructions}
197
+ """,
198
+ input_variables=["user_instructions"]
199
+ )
200
+
201
+ routing_preprocessor = routing_preprocessor_prompt | llm | JsonOutputParser()
202
+
203
+ class PrimaryState(TypedDict):
204
+ messages: Annotated[Sequence[BaseMessage], operator.add]
205
+ user_instructions: str
206
+ user_instructions_data_wrangling: str
207
+ user_instructions_data_visualization: str
208
+ routing_preprocessor_decision: str
209
+ data_raw: Union[dict, list]
210
+ data_wrangled: dict
211
+ data_wrangler_function: str
212
+ data_visualization_function: str
213
+ plotly_graph: dict
214
+ plotly_error: str
215
+ max_retries: int
216
+ retry_count: int
217
+
218
+
219
+ def preprocess_routing(state: PrimaryState):
220
+ print("---PANDAS DATA ANALYST---")
221
+ print("*************************")
222
+ print("---PREPROCESS ROUTER---")
223
+ question = state.get("user_instructions")
224
+
225
+ # Chart Routing and SQL Prep
226
+ response = routing_preprocessor.invoke({"user_instructions": question})
227
+
228
+ return {
229
+ "user_instructions_data_wrangling": response.get('user_instructions_data_wrangling'),
230
+ "user_instructions_data_visualization": response.get('user_instructions_data_visualization'),
231
+ "routing_preprocessor_decision": response.get('routing_preprocessor_decision'),
232
+ }
233
+
234
+ def router_chart_or_table(state: PrimaryState):
235
+ print("---ROUTER: CHART OR TABLE---")
236
+ return "chart" if state.get('routing_preprocessor_decision') == "chart" else "table"
237
+
238
+
239
+ def invoke_data_wrangling_agent(state: PrimaryState):
240
+
241
+ response = data_wrangling_agent.invoke({
242
+ "user_instructions": state.get("user_instructions_data_wrangling"),
243
+ "data_raw": state.get("data_raw"),
244
+ "max_retries": state.get("max_retries"),
245
+ "retry_count": state.get("retry_count"),
246
+ })
247
+
248
+ return {
249
+ "messages": response.get("messages"),
250
+ "data_wrangled": response.get("data_wrangled"),
251
+ "data_wrangler_function": response.get("data_wrangler_function"),
252
+ "plotly_error": response.get("data_visualization_error"),
253
+
254
+ }
255
+
256
+ def invoke_data_visualization_agent(state: PrimaryState):
257
+
258
+ response = data_visualization_agent.invoke({
259
+ "user_instructions": state.get("user_instructions_data_visualization"),
260
+ "data_raw": state.get("data_wrangled") if state.get("data_wrangled") else state.get("data_raw"),
261
+ "max_retries": state.get("max_retries"),
262
+ "retry_count": state.get("retry_count"),
263
+ })
264
+
265
+ return {
266
+ "messages": response.get("messages"),
267
+ "data_visualization_function": response.get("data_visualization_function"),
268
+ "plotly_graph": response.get("plotly_graph"),
269
+ "plotly_error": response.get("data_visualization_error"),
270
+ }
271
+
272
+ def route_printer(state: PrimaryState):
273
+ print("---ROUTE PRINTER---")
274
+ print(f" Route: {state.get('routing_preprocessor_decision')}")
275
+ print("---END---")
276
+ return {}
277
+
278
+ workflow = StateGraph(PrimaryState)
279
+
280
+ workflow.add_node("routing_preprocessor", preprocess_routing)
281
+ workflow.add_node("data_wrangling_agent", invoke_data_wrangling_agent)
282
+ workflow.add_node("data_visualization_agent", invoke_data_visualization_agent)
283
+ workflow.add_node("route_printer", route_printer)
284
+
285
+ workflow.add_edge(START, "routing_preprocessor")
286
+ workflow.add_edge("routing_preprocessor", "data_wrangling_agent")
287
+
288
+ workflow.add_conditional_edges(
289
+ "data_wrangling_agent",
290
+ router_chart_or_table,
291
+ {
292
+ "chart": "data_visualization_agent",
293
+ "table": "route_printer"
294
+ }
295
+ )
296
+
297
+ workflow.add_edge("data_visualization_agent", "route_printer")
298
+ workflow.add_edge("route_printer", END)
299
+
300
+ app = workflow.compile(
301
+ checkpointer=checkpointer,
302
+ name=AGENT_NAME
303
+ )
304
+
305
+ return app
@@ -1,12 +1,14 @@
1
1
 
2
2
  from langchain_core.messages import BaseMessage
3
- from langgraph.types import Checkpointer
3
+
4
+ from langchain.prompts import PromptTemplate
5
+ from langchain_core.output_parsers import JsonOutputParser
4
6
 
5
7
  from langgraph.graph import START, END, StateGraph
6
8
  from langgraph.graph.state import CompiledStateGraph
7
- from langgraph.types import Command
9
+ from langgraph.types import Checkpointer
8
10
 
9
- from typing import TypedDict, Annotated, Sequence, Literal
11
+ from typing import TypedDict, Annotated, Sequence
10
12
  import operator
11
13
 
12
14
  from typing_extensions import TypedDict
@@ -20,6 +22,7 @@ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
20
22
  from ai_data_science_team.utils.plotly import plotly_from_dict
21
23
  from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
22
24
 
25
+ AGENT_NAME = "sql_data_analyst"
23
26
 
24
27
  class SQLDataAnalyst(BaseAgent):
25
28
  """
@@ -33,6 +36,8 @@ class SQLDataAnalyst(BaseAgent):
33
36
  The SQL Database Agent.
34
37
  data_visualization_agent: DataVisualizationAgent
35
38
  The Data Visualization Agent.
39
+ checkpointer: Checkpointer (optional)
40
+ The checkpointer to save the state of the multi-agent system.
36
41
 
37
42
  Methods:
38
43
  --------
@@ -296,24 +301,13 @@ class SQLDataAnalyst(BaseAgent):
296
301
  markdown: bool
297
302
  If True, returns the summary as a Markdown-formatted string.
298
303
  """
299
- if self.response and self.get_response()['messages']:
300
-
301
- agents = [self.get_response()['messages'][i].role for i in range(len(self.get_response()['messages']))]
302
-
303
- agent_labels = []
304
- for i in range(len(agents)):
305
- agent_labels.append(f"- **Agent {i+1}:** {agents[i]}")
306
-
307
- # Construct header
308
- header = f"# SQL Data Analyst Workflow Summary Report\n\nThis agentic workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
309
-
310
- reports = []
311
- for msg in self.get_response()['messages']:
312
- reports.append(get_generic_summary(json.loads(msg.content)))
313
-
314
- if markdown:
315
- return Markdown(header + "\n\n".join(reports))
316
- return "\n\n".join(reports)
304
+ if self.response and self.response.get("messages"):
305
+ agents = [msg.role for msg in self.response["messages"]]
306
+ agent_labels = [f"- **Agent {i+1}:** {role}\n" for i, role in enumerate(agents)]
307
+ header = f"# SQL Data Analyst Workflow Summary\n\nThis workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
308
+ reports = [get_generic_summary(json.loads(msg.content)) for msg in self.response["messages"]]
309
+ summary = "\n\n" + header + "\n\n".join(reports)
310
+ return Markdown(summary) if markdown else summary
317
311
 
318
312
 
319
313
 
@@ -326,17 +320,17 @@ def make_sql_data_analyst(
326
320
  """
327
321
  Creates a multi-agent system that takes in a SQL query and returns a plot or table.
328
322
 
329
- - Agent 1: SQL Database Agent made with `make_sql_database_agent()`
330
- - Agent 2: Data Visualization Agent made with `make_data_visualization_agent()`
323
+ - Agent 1: SQL Database Agent made with `SQLDatabaseAgent()`
324
+ - Agent 2: Data Visualization Agent made with `DataVisualizationAgent()`
331
325
 
332
326
  Parameters:
333
327
  ----------
334
328
  model:
335
329
  The language model to be used for the agents.
336
330
  sql_database_agent: CompiledStateGraph
337
- The SQL Database Agent made with `make_sql_database_agent()`.
331
+ The SQL Database Agent made with `SQLDatabaseAgent()`.
338
332
  data_visualization_agent: CompiledStateGraph
339
- The Data Visualization Agent made with `make_data_visualization_agent()`.
333
+ The Data Visualization Agent made with `DataVisualizationAgent()`.
340
334
  checkpointer: Checkpointer (optional)
341
335
  The checkpointer to save the state of the multi-agent system.
342
336
  Default: None
@@ -348,10 +342,39 @@ def make_sql_data_analyst(
348
342
  """
349
343
 
350
344
  llm = model
345
+
346
+
347
+ routing_preprocessor_prompt = PromptTemplate(
348
+ template="""
349
+ You are an expert in routing decisions for a SQL Database Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to:
350
+
351
+ 1. Determine what the correct format for a Users Question should be for use with a SQL Database Agent based on the incoming user question. Anything related to database and data manipulation should be passed along.
352
+ 2. Determine whether or not a chart should be generated or a table should be returned based on the users question.
353
+ 3. If a chart is requested, determine the correct format of a Users Question should be used with a Data Visualization Agent. Anything related to plotting and visualization should be passed along.
354
+
355
+ Use the following criteria on how to route the the initial user question:
356
+
357
+ From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the SQL generator agent. This will be the 'user_instructions_sql_database'. If 'None' is found, return the original user question.
358
+
359
+ Next, determine if the user would like a data visualization ('chart') or a 'table' returned with the results of the Data Wrangling Agent. If unknown, not specified or 'None' is found, then select 'table'.
360
+
361
+ If a 'chart' is requested, return the 'user_instructions_data_visualization'. If 'None' is found, return None.
362
+
363
+ Return JSON with 'user_instructions_sql_database', 'user_instructions_data_visualization' and 'routing_preprocessor_decision'.
364
+
365
+ INITIAL_USER_QUESTION: {user_instructions}
366
+ """,
367
+ input_variables=["user_instructions"]
368
+ )
369
+
370
+ routing_preprocessor = routing_preprocessor_prompt | llm | JsonOutputParser()
351
371
 
352
372
  class PrimaryState(TypedDict):
353
373
  messages: Annotated[Sequence[BaseMessage], operator.add]
354
374
  user_instructions: str
375
+ user_instructions_sql_database: str
376
+ user_instructions_data_visualization: str
377
+ routing_preprocessor_decision: str
355
378
  sql_query_code: str
356
379
  sql_database_function: str
357
380
  data_sql: dict
@@ -359,39 +382,94 @@ def make_sql_data_analyst(
359
382
  plot_required: bool
360
383
  data_visualization_function: str
361
384
  plotly_graph: dict
385
+ plotly_error: str
362
386
  max_retries: int
363
387
  retry_count: int
364
388
 
365
- def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
389
+ def preprocess_routing(state: PrimaryState):
390
+ print("---SQL DATA ANALYST---")
391
+ print("*************************")
392
+ print("---PREPROCESS ROUTER---")
393
+ question = state.get("user_instructions")
366
394
 
367
- response = llm.invoke(f"Respond in 1 word ('plot' or 'table'). Is the user requesting a plot? If unknown, select 'table'. \n\n User Instructions:\n{state.get('user_instructions')}")
395
+ # Chart Routing and SQL Prep
396
+ response = routing_preprocessor.invoke({"user_instructions": question})
368
397
 
369
- if response.content == 'plot':
370
- plot_required = True
371
- goto="data_visualization_agent"
372
- else:
373
- plot_required = False
374
- goto="__end__"
398
+ return {
399
+ "user_instructions_sql_database": response.get('user_instructions_sql_database'),
400
+ "user_instructions_data_visualization": response.get('user_instructions_data_visualization'),
401
+ "routing_preprocessor_decision": response.get('routing_preprocessor_decision'),
402
+ }
403
+
404
+ def router_chart_or_table(state: PrimaryState):
405
+ print("---ROUTER: CHART OR TABLE---")
406
+ return "chart" if state.get('routing_preprocessor_decision') == "chart" else "table"
407
+
408
+
409
+ def invoke_sql_database_agent(state: PrimaryState):
375
410
 
376
- return Command(
377
- update={
378
- 'data_raw': state.get("data_sql"),
379
- 'plot_required': plot_required,
380
- },
381
- goto=goto
382
- )
411
+ response = sql_database_agent.invoke({
412
+ "user_instructions": state.get("user_instructions_sql_database"),
413
+ "max_retries": state.get("max_retries"),
414
+ "retry_count": state.get("retry_count"),
415
+ })
383
416
 
384
- workflow = StateGraph(PrimaryState)
417
+ return {
418
+ "messages": response.get("messages"),
419
+ "data_sql": response.get("data_sql"),
420
+ "sql_query_code": response.get("sql_query_code"),
421
+ "sql_database_function": response.get("sql_database_function"),
422
+
423
+ }
424
+
425
+ def invoke_data_visualization_agent(state: PrimaryState):
426
+
427
+ response = data_visualization_agent.invoke({
428
+ "user_instructions": state.get("user_instructions_data_visualization"),
429
+ "data_raw": state.get("data_sql"),
430
+ "max_retries": state.get("max_retries"),
431
+ "retry_count": state.get("retry_count"),
432
+ })
433
+
434
+ return {
435
+ "messages": response.get("messages"),
436
+ "data_visualization_function": response.get("data_visualization_function"),
437
+ "plotly_graph": response.get("plotly_graph"),
438
+ "plotly_error": response.get("data_visualization_error"),
439
+ }
385
440
 
386
- workflow.add_node("sql_database_agent", sql_database_agent)
387
- workflow.add_node("route_to_visualization", route_to_visualization)
388
- workflow.add_node("data_visualization_agent", data_visualization_agent)
441
+ def route_printer(state: PrimaryState):
442
+ print("---ROUTE PRINTER---")
443
+ print(f" Route: {state.get('routing_preprocessor_decision')}")
444
+ print("---END---")
445
+ return {}
389
446
 
390
- workflow.add_edge(START, "sql_database_agent")
391
- workflow.add_edge("sql_database_agent", "route_to_visualization")
392
- workflow.add_edge("data_visualization_agent", END)
447
+ workflow = StateGraph(PrimaryState)
448
+
449
+ workflow.add_node("routing_preprocessor", preprocess_routing)
450
+ workflow.add_node("sql_database_agent", invoke_sql_database_agent)
451
+ workflow.add_node("data_visualization_agent", invoke_data_visualization_agent)
452
+ workflow.add_node("route_printer", route_printer)
453
+
454
+ workflow.add_edge(START, "routing_preprocessor")
455
+ workflow.add_edge("routing_preprocessor", "sql_database_agent")
456
+
457
+ workflow.add_conditional_edges(
458
+ "sql_database_agent",
459
+ router_chart_or_table,
460
+ {
461
+ "chart": "data_visualization_agent",
462
+ "table": "route_printer"
463
+ }
464
+ )
465
+
466
+ workflow.add_edge("data_visualization_agent", "route_printer")
467
+ workflow.add_edge("route_printer", END)
393
468
 
394
- app = workflow.compile(checkpointer=checkpointer)
469
+ app = workflow.compile(
470
+ checkpointer=checkpointer,
471
+ name=AGENT_NAME
472
+ )
395
473
 
396
474
  return app
397
475
 
@@ -40,6 +40,21 @@ class BaseAgent(CompiledStateGraph):
40
40
  self._params = params
41
41
  self._compiled_graph = self._make_compiled_graph()
42
42
  self.response = None
43
+ self.name = self._compiled_graph.name
44
+ self.checkpointer = self._compiled_graph.checkpointer
45
+ self.store = self._compiled_graph.store
46
+ self.output_channels = self._compiled_graph.output_channels
47
+ self.nodes = self._compiled_graph.nodes
48
+ self.stream_mode = self._compiled_graph.stream_mode
49
+ self.builder = self._compiled_graph.builder
50
+ self.channels = self._compiled_graph.channels
51
+ self.input_channels = self._compiled_graph.input_channels
52
+ self.input_schema = self._compiled_graph.input_schema
53
+ self.output_schema = self._compiled_graph.output_schema
54
+ self.debug = self._compiled_graph.debug
55
+ self.interrupt_after_nodes = self._compiled_graph.interrupt_after_nodes
56
+ self.interrupt_before_nodes = self._compiled_graph.interrupt_before_nodes
57
+ self.config = self._compiled_graph.config
43
58
 
44
59
  def _make_compiled_graph(self):
45
60
  """
@@ -197,6 +212,24 @@ class BaseAgent(CompiledStateGraph):
197
212
  """
198
213
  return self.get_output_jsonschema()['properties']
199
214
 
215
+ def get_state(self, config, *, subgraphs = False):
216
+ """
217
+ Returns the state of the agent.
218
+ """
219
+ return self._compiled_graph.get_state(config, subgraphs=subgraphs)
220
+
221
+ def get_state_history(self, config, *, filter = None, before = None, limit = None):
222
+ """
223
+ Returns the state history of the agent.
224
+ """
225
+ return self._compiled_graph.get_state_history(config, filter=filter, before=before, limit=limit)
226
+
227
+ def update_state(self, config, values, as_node = None):
228
+ """
229
+ Updates the state of the agent.
230
+ """
231
+ return self._compiled_graph.update_state(config, values, as_node)
232
+
200
233
  def get_response(self):
201
234
  """
202
235
  Returns the response generated by the agent.
@@ -237,6 +270,7 @@ def create_coding_agent_graph(
237
270
  checkpointer: Optional[Callable] = None,
238
271
  bypass_recommended_steps: bool = False,
239
272
  bypass_explain_code: bool = False,
273
+ agent_name: str = "coding_agent"
240
274
  ):
241
275
  """
242
276
  Creates a generic agent graph using the provided node functions and node names.
@@ -281,6 +315,8 @@ def create_coding_agent_graph(
281
315
  Whether to skip the recommended steps node.
282
316
  bypass_explain_code : bool, optional
283
317
  Whether to skip the final explain code node.
318
+ name : str, optional
319
+ The name of the agent graph.
284
320
 
285
321
  Returns
286
322
  -------
@@ -366,10 +402,10 @@ def create_coding_agent_graph(
366
402
  workflow.add_edge(explain_code_node_name, END)
367
403
 
368
404
  # Finally, compile
369
- if human_in_the_loop:
370
- app = workflow.compile(checkpointer=checkpointer)
371
- else:
372
- app = workflow.compile()
405
+ app = workflow.compile(
406
+ checkpointer=checkpointer,
407
+ name=agent_name,
408
+ )
373
409
 
374
410
  return app
375
411
 
@@ -574,7 +610,7 @@ def node_func_execute_agent_from_sql_connection(
574
610
 
575
611
  # Retrieve SQLAlchemy connection and code snippet from the state
576
612
  is_engine = isinstance(connection, sql.engine.base.Engine)
577
- conn = connection.connect() if is_engine else connection
613
+ connection = connection.connect() if is_engine else connection
578
614
  agent_code = state.get(code_snippet_key)
579
615
 
580
616
  # Ensure the connection object is provided
@@ -235,6 +235,7 @@ def correlation_funnel(
235
235
  df_correlated = df_binarized.correlate(target=full_target, method=corr_method)
236
236
 
237
237
  # Attempt to generate a static plot.
238
+ encoded = None
238
239
  try:
239
240
  # Here we assume that your DataFrame has a method plot_correlation_funnel.
240
241
  fig = df_correlated.plot_correlation_funnel(engine='plotnine', height=600)
@@ -248,6 +249,7 @@ def correlation_funnel(
248
249
  encoded = {"error": str(e)}
249
250
 
250
251
  # Attempt to generate a Plotly plot.
252
+ fig_dict = None
251
253
  try:
252
254
  fig = df_correlated.plot_correlation_funnel(engine='plotly')
253
255
  fig_json = pio.to_json(fig)