ai-data-science-team 0.0.0.9012__py3-none-any.whl → 0.0.0.9014__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. ai_data_science_team/__init__.py +22 -0
  2. ai_data_science_team/_version.py +1 -1
  3. ai_data_science_team/agents/data_cleaning_agent.py +17 -3
  4. ai_data_science_team/agents/data_loader_tools_agent.py +24 -1
  5. ai_data_science_team/agents/data_visualization_agent.py +17 -3
  6. ai_data_science_team/agents/data_wrangling_agent.py +30 -10
  7. ai_data_science_team/agents/feature_engineering_agent.py +17 -4
  8. ai_data_science_team/agents/sql_database_agent.py +15 -2
  9. ai_data_science_team/ds_agents/eda_tools_agent.py +28 -6
  10. ai_data_science_team/ml_agents/h2o_ml_agent.py +15 -3
  11. ai_data_science_team/ml_agents/mlflow_tools_agent.py +23 -1
  12. ai_data_science_team/multiagents/__init__.py +2 -1
  13. ai_data_science_team/multiagents/pandas_data_analyst.py +305 -0
  14. ai_data_science_team/multiagents/sql_data_analyst.py +119 -30
  15. ai_data_science_team/templates/agent_templates.py +41 -5
  16. ai_data_science_team/tools/dataframe.py +6 -1
  17. ai_data_science_team/tools/eda.py +75 -16
  18. ai_data_science_team/utils/messages.py +27 -0
  19. {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/METADATA +7 -3
  20. {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/RECORD +23 -21
  21. {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/LICENSE +0 -0
  22. {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/WHEEL +0 -0
  23. {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,8 @@
1
1
 
2
2
 
3
- from typing import Any, Optional, Annotated, Sequence, List, Dict, Tuple
3
+ from typing import Any, Optional, Annotated, Sequence, Dict
4
4
  import operator
5
5
  import pandas as pd
6
- import os
7
- from io import StringIO, BytesIO
8
- import base64
9
- import matplotlib.pyplot as plt
10
6
 
11
7
  from IPython.display import Markdown
12
8
 
@@ -14,22 +10,26 @@ from langchain_core.messages import BaseMessage, AIMessage
14
10
  from langgraph.prebuilt import create_react_agent, ToolNode
15
11
  from langgraph.prebuilt.chat_agent_executor import AgentState
16
12
  from langgraph.graph import START, END, StateGraph
13
+ from langgraph.types import Checkpointer
17
14
 
18
15
  from ai_data_science_team.templates import BaseAgent
19
16
  from ai_data_science_team.utils.regex import format_agent_name
20
17
 
21
18
  from ai_data_science_team.tools.eda import (
19
+ explain_data,
22
20
  describe_dataset,
23
21
  visualize_missing,
24
22
  correlation_funnel,
25
23
  generate_sweetviz_report,
26
24
  )
25
+ from ai_data_science_team.utils.messages import get_tool_call_names
27
26
 
28
27
 
29
28
  AGENT_NAME = "exploratory_data_analyst_agent"
30
29
 
31
30
  # Updated tool list for EDA
32
31
  EDA_TOOLS = [
32
+ explain_data,
33
33
  describe_dataset,
34
34
  visualize_missing,
35
35
  correlation_funnel,
@@ -49,6 +49,8 @@ class EDAToolsAgent(BaseAgent):
49
49
  Additional kwargs for create_react_agent.
50
50
  invoke_react_agent_kwargs : dict
51
51
  Additional kwargs for agent invocation.
52
+ checkpointer : Checkpointer, optional
53
+ The checkpointer for the agent.
52
54
  """
53
55
 
54
56
  def __init__(
@@ -56,11 +58,13 @@ class EDAToolsAgent(BaseAgent):
56
58
  model: Any,
57
59
  create_react_agent_kwargs: Optional[Dict] = {},
58
60
  invoke_react_agent_kwargs: Optional[Dict] = {},
61
+ checkpointer: Optional[Checkpointer] = None,
59
62
  ):
60
63
  self._params = {
61
64
  "model": model,
62
65
  "create_react_agent_kwargs": create_react_agent_kwargs,
63
66
  "invoke_react_agent_kwargs": invoke_react_agent_kwargs,
67
+ "checkpointer": checkpointer
64
68
  }
65
69
  self._compiled_graph = self._make_compiled_graph()
66
70
  self.response = None
@@ -162,11 +166,18 @@ class EDAToolsAgent(BaseAgent):
162
166
  return Markdown(self.response["messages"][0].content)
163
167
  else:
164
168
  return self.response["messages"][0].content
169
+
170
+ def get_tool_calls(self):
171
+ """
172
+ Returns the tool calls made by the agent.
173
+ """
174
+ return self.response["tool_calls"]
165
175
 
166
176
  def make_eda_tools_agent(
167
177
  model: Any,
168
178
  create_react_agent_kwargs: Optional[Dict] = {},
169
179
  invoke_react_agent_kwargs: Optional[Dict] = {},
180
+ checkpointer: Optional[Checkpointer] = None,
170
181
  ):
171
182
  """
172
183
  Creates an Exploratory Data Analyst Agent that can interact with EDA tools.
@@ -179,6 +190,8 @@ def make_eda_tools_agent(
179
190
  Additional kwargs for create_react_agent.
180
191
  invoke_react_agent_kwargs : dict
181
192
  Additional kwargs for agent invocation.
193
+ checkpointer : Checkpointer, optional
194
+ The checkpointer for the agent.
182
195
 
183
196
  Returns:
184
197
  -------
@@ -191,6 +204,7 @@ def make_eda_tools_agent(
191
204
  user_instructions: str
192
205
  data_raw: dict
193
206
  eda_artifacts: dict
207
+ tool_calls: list
194
208
 
195
209
  def exploratory_agent(state):
196
210
  print(format_agent_name(AGENT_NAME))
@@ -205,6 +219,7 @@ def make_eda_tools_agent(
205
219
  tools=tool_node,
206
220
  state_schema=GraphState,
207
221
  **create_react_agent_kwargs,
222
+ checkpointer=checkpointer,
208
223
  )
209
224
 
210
225
  response = eda_agent.invoke(
@@ -229,11 +244,14 @@ def make_eda_tools_agent(
229
244
  last_tool_artifact = last_message.artifact
230
245
  elif isinstance(last_message, dict) and "artifact" in last_message:
231
246
  last_tool_artifact = last_message["artifact"]
247
+
248
+ tool_calls = get_tool_call_names(internal_messages)
232
249
 
233
250
  return {
234
251
  "messages": [last_ai_message],
235
252
  "internal_messages": internal_messages,
236
253
  "eda_artifacts": last_tool_artifact,
254
+ "tool_calls": tool_calls,
237
255
  }
238
256
 
239
257
  workflow = StateGraph(GraphState)
@@ -241,5 +259,9 @@ def make_eda_tools_agent(
241
259
  workflow.add_edge(START, "exploratory_agent")
242
260
  workflow.add_edge("exploratory_agent", END)
243
261
 
244
- app = workflow.compile()
262
+ app = workflow.compile(
263
+ checkpointer=checkpointer,
264
+ name=AGENT_NAME,
265
+ )
266
+
245
267
  return app
@@ -5,7 +5,7 @@
5
5
 
6
6
  import os
7
7
  import json
8
- from typing import TypedDict, Annotated, Sequence, Literal
8
+ from typing import TypedDict, Annotated, Sequence, Literal, Optional
9
9
  import operator
10
10
 
11
11
  import pandas as pd
@@ -14,7 +14,7 @@ from IPython.display import Markdown
14
14
  from langchain.prompts import PromptTemplate
15
15
  from langchain_core.messages import BaseMessage
16
16
 
17
- from langgraph.types import Command
17
+ from langgraph.types import Command, Checkpointer
18
18
  from langgraph.checkpoint.memory import MemorySaver
19
19
 
20
20
  from ai_data_science_team.templates import(
@@ -79,6 +79,8 @@ class H2OMLAgent(BaseAgent):
79
79
  Name of the MLflow experiment (created if doesn't exist).
80
80
  mlflow_run_name : str, default None
81
81
  A custom name for the MLflow run.
82
+ checkpointer : langgraph.checkpoint.memory.MemorySaver, optional
83
+ A checkpointer object for saving the agent's state. Defaults to None.
82
84
 
83
85
 
84
86
  Methods
@@ -176,6 +178,7 @@ class H2OMLAgent(BaseAgent):
176
178
  mlflow_tracking_uri=None,
177
179
  mlflow_experiment_name="H2O AutoML",
178
180
  mlflow_run_name=None,
181
+ checkpointer: Optional[Checkpointer]=None,
179
182
  ):
180
183
  self._params = {
181
184
  "model": model,
@@ -193,6 +196,7 @@ class H2OMLAgent(BaseAgent):
193
196
  "mlflow_tracking_uri": mlflow_tracking_uri,
194
197
  "mlflow_experiment_name": mlflow_experiment_name,
195
198
  "mlflow_run_name": mlflow_run_name,
199
+ "checkpointer": checkpointer,
196
200
  }
197
201
  self._compiled_graph = self._make_compiled_graph()
198
202
  self.response = None
@@ -350,6 +354,7 @@ def make_h2o_ml_agent(
350
354
  mlflow_tracking_uri=None,
351
355
  mlflow_experiment_name="H2O AutoML",
352
356
  mlflow_run_name=None,
357
+ checkpointer=None,
353
358
  ):
354
359
  """
355
360
  Creates a machine learning agent that uses H2O for AutoML.
@@ -384,6 +389,12 @@ def make_h2o_ml_agent(
384
389
  " pip install h2o\n\n"
385
390
  "Visit https://docs.h2o.ai/h2o/latest-stable/h2o-docs/downloading.html for details."
386
391
  ) from e
392
+
393
+ if human_in_the_loop:
394
+ if checkpointer is None:
395
+ print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
396
+ checkpointer = MemorySaver()
397
+
387
398
 
388
399
  # Define GraphState
389
400
  class GraphState(TypedDict):
@@ -844,9 +855,10 @@ def make_h2o_ml_agent(
844
855
  retry_count_key="retry_count",
845
856
  human_in_the_loop=human_in_the_loop,
846
857
  human_review_node_name="human_review",
847
- checkpointer=MemorySaver(),
858
+ checkpointer=checkpointer,
848
859
  bypass_recommended_steps=bypass_recommended_steps,
849
860
  bypass_explain_code=bypass_explain_code,
861
+ agent_name=AGENT_NAME,
850
862
  )
851
863
 
852
864
  return app
@@ -10,6 +10,7 @@ from langchain_core.messages import BaseMessage, AIMessage
10
10
 
11
11
  from langgraph.prebuilt import create_react_agent, ToolNode
12
12
  from langgraph.prebuilt.chat_agent_executor import AgentState
13
+ from langgraph.types import Checkpointer
13
14
  from langgraph.graph import START, END, StateGraph
14
15
 
15
16
  from ai_data_science_team.templates import BaseAgent
@@ -27,6 +28,7 @@ from ai_data_science_team.tools.mlflow import (
27
28
  mlflow_search_registered_models,
28
29
  mlflow_get_model_version_details,
29
30
  )
31
+ from ai_data_science_team.utils.messages import get_tool_call_names
30
32
 
31
33
  AGENT_NAME = "mlflow_tools_agent"
32
34
 
@@ -67,6 +69,8 @@ class MLflowToolsAgent(BaseAgent):
67
69
  Additional keyword arguments to pass to the create_react_agent function.
68
70
  invoke_react_agent_kwargs : dict
69
71
  Additional keyword arguments to pass to the invoke method of the react agent.
72
+ checkpointer : langchain.checkpointing.Checkpointer, optional
73
+ A checkpointer to use for saving and loading the agent's state. Defaults to None.
70
74
 
71
75
  Methods:
72
76
  --------
@@ -118,6 +122,7 @@ class MLflowToolsAgent(BaseAgent):
118
122
  mlflow_registry_uri: Optional[str]=None,
119
123
  create_react_agent_kwargs: Optional[Dict]={},
120
124
  invoke_react_agent_kwargs: Optional[Dict]={},
125
+ checkpointer: Optional[Checkpointer]=None,
121
126
  ):
122
127
  self._params = {
123
128
  "model": model,
@@ -125,6 +130,7 @@ class MLflowToolsAgent(BaseAgent):
125
130
  "mlflow_registry_uri": mlflow_registry_uri,
126
131
  "create_react_agent_kwargs": create_react_agent_kwargs,
127
132
  "invoke_react_agent_kwargs": invoke_react_agent_kwargs,
133
+ "checkpointer": checkpointer,
128
134
  }
129
135
  self._compiled_graph = self._make_compiled_graph()
130
136
  self.response = None
@@ -228,6 +234,12 @@ class MLflowToolsAgent(BaseAgent):
228
234
  return Markdown(self.response["messages"][0].content)
229
235
  else:
230
236
  return self.response["messages"][0].content
237
+
238
+ def get_tool_calls(self):
239
+ """
240
+ Returns the tool calls made by the agent.
241
+ """
242
+ return self.response["tool_calls"]
231
243
 
232
244
 
233
245
 
@@ -238,6 +250,7 @@ def make_mlflow_tools_agent(
238
250
  mlflow_registry_uri: str=None,
239
251
  create_react_agent_kwargs: Optional[Dict]={},
240
252
  invoke_react_agent_kwargs: Optional[Dict]={},
253
+ checkpointer: Optional[Checkpointer]=None,
241
254
  ):
242
255
  """
243
256
  MLflow Tool Calling Agent
@@ -254,6 +267,8 @@ def make_mlflow_tools_agent(
254
267
  Additional keyword arguments to pass to the agent's create_react_agent method.
255
268
  invoke_react_agent_kwargs : dict, optional
256
269
  Additional keyword arguments to pass to the agent's invoke method.
270
+ checkpointer : langchain.checkpointing.Checkpointer, optional
271
+ A checkpointer to use for saving and loading the agent's state. Defaults to None.
257
272
 
258
273
  Returns
259
274
  -------
@@ -296,6 +311,7 @@ def make_mlflow_tools_agent(
296
311
  model,
297
312
  tools=tool_node,
298
313
  state_schema=GraphState,
314
+ checkpointer=checkpointer,
299
315
  **create_react_agent_kwargs,
300
316
  )
301
317
 
@@ -330,10 +346,13 @@ def make_mlflow_tools_agent(
330
346
  elif isinstance(last_message, dict) and "artifact" in last_message:
331
347
  last_tool_artifact = last_message["artifact"]
332
348
 
349
+ tool_calls = get_tool_call_names(internal_messages)
350
+
333
351
  return {
334
352
  "messages": [last_ai_message],
335
353
  "internal_messages": internal_messages,
336
354
  "mlflow_artifacts": last_tool_artifact,
355
+ "tool_calls": tool_calls,
337
356
  }
338
357
 
339
358
 
@@ -344,7 +363,10 @@ def make_mlflow_tools_agent(
344
363
  workflow.add_edge(START, "mlflow_tools_agent")
345
364
  workflow.add_edge("mlflow_tools_agent", END)
346
365
 
347
- app = workflow.compile()
366
+ app = workflow.compile(
367
+ checkpointer=checkpointer,
368
+ name=AGENT_NAME,
369
+ )
348
370
 
349
371
  return app
350
372
 
@@ -1 +1,2 @@
1
- from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
1
+ from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
2
+ from ai_data_science_team.multiagents.pandas_data_analyst import PandasDataAnalyst, make_pandas_data_analyst
@@ -0,0 +1,305 @@
1
+ from langchain_core.messages import BaseMessage
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain_core.output_parsers import JsonOutputParser
4
+ from langgraph.types import Checkpointer
5
+ from langgraph.graph import START, END, StateGraph
6
+ from langgraph.graph.state import CompiledStateGraph
7
+
8
+ from typing import TypedDict, Annotated, Sequence, Union
9
+ import operator
10
+
11
+ import pandas as pd
12
+ import json
13
+ from IPython.display import Markdown
14
+
15
+ from ai_data_science_team.templates import BaseAgent
16
+ from ai_data_science_team.agents import DataWranglingAgent, DataVisualizationAgent
17
+ from ai_data_science_team.utils.plotly import plotly_from_dict
18
+ from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
19
+
20
+ AGENT_NAME = "pandas_data_analyst"
21
+
22
+ class PandasDataAnalyst(BaseAgent):
23
+ """
24
+ PandasDataAnalyst is a multi-agent class that combines data wrangling and visualization capabilities.
25
+
26
+ Parameters:
27
+ -----------
28
+ model:
29
+ The language model to be used for the agents.
30
+ data_wrangling_agent: DataWranglingAgent
31
+ The Data Wrangling Agent for transforming raw data.
32
+ data_visualization_agent: DataVisualizationAgent
33
+ The Data Visualization Agent for generating plots.
34
+ checkpointer: Checkpointer (optional)
35
+ The checkpointer to save the state of the multi-agent system.
36
+
37
+ Methods:
38
+ --------
39
+ ainvoke_agent(user_instructions, data_raw, **kwargs)
40
+ Asynchronously invokes the multi-agent with user instructions and raw data.
41
+ invoke_agent(user_instructions, data_raw, **kwargs)
42
+ Synchronously invokes the multi-agent with user instructions and raw data.
43
+ get_data_wrangled()
44
+ Returns the wrangled data as a Pandas DataFrame.
45
+ get_plotly_graph()
46
+ Returns the Plotly graph as a Plotly object.
47
+ get_data_wrangler_function(markdown=False)
48
+ Returns the data wrangling function as a string, optionally in Markdown.
49
+ get_data_visualization_function(markdown=False)
50
+ Returns the data visualization function as a string, optionally in Markdown.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ model,
56
+ data_wrangling_agent: DataWranglingAgent,
57
+ data_visualization_agent: DataVisualizationAgent,
58
+ checkpointer: Checkpointer = None,
59
+ ):
60
+ self._params = {
61
+ "model": model,
62
+ "data_wrangling_agent": data_wrangling_agent,
63
+ "data_visualization_agent": data_visualization_agent,
64
+ "checkpointer": checkpointer,
65
+ }
66
+ self._compiled_graph = self._make_compiled_graph()
67
+ self.response = None
68
+
69
+ def _make_compiled_graph(self):
70
+ """Create or rebuild the compiled graph. Resets response to None."""
71
+ self.response = None
72
+ return make_pandas_data_analyst(
73
+ model=self._params["model"],
74
+ data_wrangling_agent=self._params["data_wrangling_agent"]._compiled_graph,
75
+ data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
76
+ checkpointer=self._params["checkpointer"],
77
+ )
78
+
79
+ def update_params(self, **kwargs):
80
+ """Updates parameters and rebuilds the compiled graph."""
81
+ for k, v in kwargs.items():
82
+ self._params[k] = v
83
+ self._compiled_graph = self._make_compiled_graph()
84
+
85
+ async def ainvoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
86
+ """Asynchronously invokes the multi-agent."""
87
+ response = await self._compiled_graph.ainvoke({
88
+ "user_instructions": user_instructions,
89
+ "data_raw": self._convert_data_input(data_raw),
90
+ "max_retries": max_retries,
91
+ "retry_count": retry_count,
92
+ }, **kwargs)
93
+ if response.get("messages"):
94
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
95
+ self.response = response
96
+
97
+ def invoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
98
+ """Synchronously invokes the multi-agent."""
99
+ response = self._compiled_graph.invoke({
100
+ "user_instructions": user_instructions,
101
+ "data_raw": self._convert_data_input(data_raw),
102
+ "max_retries": max_retries,
103
+ "retry_count": retry_count,
104
+ }, **kwargs)
105
+ if response.get("messages"):
106
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
107
+ self.response = response
108
+
109
+ def get_data_wrangled(self):
110
+ """Returns the wrangled data as a Pandas DataFrame."""
111
+ if self.response and self.response.get("data_wrangled"):
112
+ return pd.DataFrame(self.response.get("data_wrangled"))
113
+
114
+ def get_plotly_graph(self):
115
+ """Returns the Plotly graph as a Plotly object."""
116
+ if self.response and self.response.get("plotly_graph"):
117
+ return plotly_from_dict(self.response.get("plotly_graph"))
118
+
119
+ def get_data_wrangler_function(self, markdown=False):
120
+ """Returns the data wrangling function as a string."""
121
+ if self.response and self.response.get("data_wrangler_function"):
122
+ code = self.response.get("data_wrangler_function")
123
+ return Markdown(f"```python\n{code}\n```") if markdown else code
124
+
125
+ def get_data_visualization_function(self, markdown=False):
126
+ """Returns the data visualization function as a string."""
127
+ if self.response and self.response.get("data_visualization_function"):
128
+ code = self.response.get("data_visualization_function")
129
+ return Markdown(f"```python\n{code}\n```") if markdown else code
130
+
131
+ def get_workflow_summary(self, markdown=False):
132
+ """Returns a summary of the workflow."""
133
+ if self.response and self.response.get("messages"):
134
+ agents = [msg.role for msg in self.response["messages"]]
135
+ agent_labels = [f"- **Agent {i+1}:** {role}" for i, role in enumerate(agents)]
136
+ header = f"# Pandas Data Analyst Workflow Summary\n\nThis workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
137
+ reports = [get_generic_summary(json.loads(msg.content)) for msg in self.response["messages"]]
138
+ summary = "\n" +header + "\n\n".join(reports)
139
+ return Markdown(summary) if markdown else summary
140
+
141
+ @staticmethod
142
+ def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
143
+ """Converts input data to the expected format (dict or list of dicts)."""
144
+ if isinstance(data_raw, pd.DataFrame):
145
+ return data_raw.to_dict()
146
+ if isinstance(data_raw, dict):
147
+ return data_raw
148
+ if isinstance(data_raw, list):
149
+ return [item.to_dict() if isinstance(item, pd.DataFrame) else item for item in data_raw]
150
+ raise ValueError("data_raw must be a DataFrame, dict, or list of DataFrames/dicts")
151
+
152
+ def make_pandas_data_analyst(
153
+ model,
154
+ data_wrangling_agent: CompiledStateGraph,
155
+ data_visualization_agent: CompiledStateGraph,
156
+ checkpointer: Checkpointer = None
157
+ ):
158
+ """
159
+ Creates a multi-agent system that wrangles data and optionally visualizes it.
160
+
161
+ Parameters:
162
+ -----------
163
+ model: The language model to be used.
164
+ data_wrangling_agent: CompiledStateGraph
165
+ The Data Wrangling Agent.
166
+ data_visualization_agent: CompiledStateGraph
167
+ The Data Visualization Agent.
168
+ checkpointer: Checkpointer (optional)
169
+ The checkpointer to save the state.
170
+
171
+ Returns:
172
+ --------
173
+ CompiledStateGraph: The compiled multi-agent system.
174
+ """
175
+
176
+ llm = model
177
+
178
+ routing_preprocessor_prompt = PromptTemplate(
179
+ template="""
180
+ You are an expert in routing decisions for a Pandas Data Manipulation Wrangling Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to:
181
+
182
+ 1. Determine what the correct format for a Users Question should be for use with a Pandas Data Wrangling Agent based on the incoming user question. Anything related to data wrangling and manipulation should be passed along.
183
+ 2. Determine whether or not a chart should be generated or a table should be returned based on the users question.
184
+ 3. If a chart is requested, determine the correct format of a Users Question should be used with a Data Visualization Agent. Anything related to plotting and visualization should be passed along.
185
+
186
+ Use the following criteria on how to route the the initial user question:
187
+
188
+ From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the SQL generator agent. This will be the 'user_instructions_data_wrangling'. If 'None' is found, return the original user question.
189
+
190
+ Next, determine if the user would like a data visualization ('chart') or a 'table' returned with the results of the Data Wrangling Agent. If unknown, not specified or 'None' is found, then select 'table'.
191
+
192
+ If a 'chart' is requested, return the 'user_instructions_data_visualization'. If 'None' is found, return None.
193
+
194
+ Return JSON with 'user_instructions_data_wrangling', 'user_instructions_data_visualization' and 'routing_preprocessor_decision'.
195
+
196
+ INITIAL_USER_QUESTION: {user_instructions}
197
+ """,
198
+ input_variables=["user_instructions"]
199
+ )
200
+
201
+ routing_preprocessor = routing_preprocessor_prompt | llm | JsonOutputParser()
202
+
203
+ class PrimaryState(TypedDict):
204
+ messages: Annotated[Sequence[BaseMessage], operator.add]
205
+ user_instructions: str
206
+ user_instructions_data_wrangling: str
207
+ user_instructions_data_visualization: str
208
+ routing_preprocessor_decision: str
209
+ data_raw: Union[dict, list]
210
+ data_wrangled: dict
211
+ data_wrangler_function: str
212
+ data_visualization_function: str
213
+ plotly_graph: dict
214
+ plotly_error: str
215
+ max_retries: int
216
+ retry_count: int
217
+
218
+
219
+ def preprocess_routing(state: PrimaryState):
220
+ print("---PANDAS DATA ANALYST---")
221
+ print("*************************")
222
+ print("---PREPROCESS ROUTER---")
223
+ question = state.get("user_instructions")
224
+
225
+ # Chart Routing and SQL Prep
226
+ response = routing_preprocessor.invoke({"user_instructions": question})
227
+
228
+ return {
229
+ "user_instructions_data_wrangling": response.get('user_instructions_data_wrangling'),
230
+ "user_instructions_data_visualization": response.get('user_instructions_data_visualization'),
231
+ "routing_preprocessor_decision": response.get('routing_preprocessor_decision'),
232
+ }
233
+
234
+ def router_chart_or_table(state: PrimaryState):
235
+ print("---ROUTER: CHART OR TABLE---")
236
+ return "chart" if state.get('routing_preprocessor_decision') == "chart" else "table"
237
+
238
+
239
+ def invoke_data_wrangling_agent(state: PrimaryState):
240
+
241
+ response = data_wrangling_agent.invoke({
242
+ "user_instructions": state.get("user_instructions_data_wrangling"),
243
+ "data_raw": state.get("data_raw"),
244
+ "max_retries": state.get("max_retries"),
245
+ "retry_count": state.get("retry_count"),
246
+ })
247
+
248
+ return {
249
+ "messages": response.get("messages"),
250
+ "data_wrangled": response.get("data_wrangled"),
251
+ "data_wrangler_function": response.get("data_wrangler_function"),
252
+ "plotly_error": response.get("data_visualization_error"),
253
+
254
+ }
255
+
256
+ def invoke_data_visualization_agent(state: PrimaryState):
257
+
258
+ response = data_visualization_agent.invoke({
259
+ "user_instructions": state.get("user_instructions_data_visualization"),
260
+ "data_raw": state.get("data_wrangled") if state.get("data_wrangled") else state.get("data_raw"),
261
+ "max_retries": state.get("max_retries"),
262
+ "retry_count": state.get("retry_count"),
263
+ })
264
+
265
+ return {
266
+ "messages": response.get("messages"),
267
+ "data_visualization_function": response.get("data_visualization_function"),
268
+ "plotly_graph": response.get("plotly_graph"),
269
+ "plotly_error": response.get("data_visualization_error"),
270
+ }
271
+
272
+ def route_printer(state: PrimaryState):
273
+ print("---ROUTE PRINTER---")
274
+ print(f" Route: {state.get('routing_preprocessor_decision')}")
275
+ print("---END---")
276
+ return {}
277
+
278
+ workflow = StateGraph(PrimaryState)
279
+
280
+ workflow.add_node("routing_preprocessor", preprocess_routing)
281
+ workflow.add_node("data_wrangling_agent", invoke_data_wrangling_agent)
282
+ workflow.add_node("data_visualization_agent", invoke_data_visualization_agent)
283
+ workflow.add_node("route_printer", route_printer)
284
+
285
+ workflow.add_edge(START, "routing_preprocessor")
286
+ workflow.add_edge("routing_preprocessor", "data_wrangling_agent")
287
+
288
+ workflow.add_conditional_edges(
289
+ "data_wrangling_agent",
290
+ router_chart_or_table,
291
+ {
292
+ "chart": "data_visualization_agent",
293
+ "table": "route_printer"
294
+ }
295
+ )
296
+
297
+ workflow.add_edge("data_visualization_agent", "route_printer")
298
+ workflow.add_edge("route_printer", END)
299
+
300
+ app = workflow.compile(
301
+ checkpointer=checkpointer,
302
+ name=AGENT_NAME
303
+ )
304
+
305
+ return app