ai-data-science-team 0.0.0.9012__py3-none-any.whl → 0.0.0.9014__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (23) hide show
  1. ai_data_science_team/__init__.py +22 -0
  2. ai_data_science_team/_version.py +1 -1
  3. ai_data_science_team/agents/data_cleaning_agent.py +17 -3
  4. ai_data_science_team/agents/data_loader_tools_agent.py +24 -1
  5. ai_data_science_team/agents/data_visualization_agent.py +17 -3
  6. ai_data_science_team/agents/data_wrangling_agent.py +30 -10
  7. ai_data_science_team/agents/feature_engineering_agent.py +17 -4
  8. ai_data_science_team/agents/sql_database_agent.py +15 -2
  9. ai_data_science_team/ds_agents/eda_tools_agent.py +28 -6
  10. ai_data_science_team/ml_agents/h2o_ml_agent.py +15 -3
  11. ai_data_science_team/ml_agents/mlflow_tools_agent.py +23 -1
  12. ai_data_science_team/multiagents/__init__.py +2 -1
  13. ai_data_science_team/multiagents/pandas_data_analyst.py +305 -0
  14. ai_data_science_team/multiagents/sql_data_analyst.py +119 -30
  15. ai_data_science_team/templates/agent_templates.py +41 -5
  16. ai_data_science_team/tools/dataframe.py +6 -1
  17. ai_data_science_team/tools/eda.py +75 -16
  18. ai_data_science_team/utils/messages.py +27 -0
  19. {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/METADATA +7 -3
  20. {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/RECORD +23 -21
  21. {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/LICENSE +0 -0
  22. {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/WHEEL +0 -0
  23. {ai_data_science_team-0.0.0.9012.dist-info → ai_data_science_team-0.0.0.9014.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,8 @@
1
1
 
2
2
 
3
- from typing import Any, Optional, Annotated, Sequence, List, Dict, Tuple
3
+ from typing import Any, Optional, Annotated, Sequence, Dict
4
4
  import operator
5
5
  import pandas as pd
6
- import os
7
- from io import StringIO, BytesIO
8
- import base64
9
- import matplotlib.pyplot as plt
10
6
 
11
7
  from IPython.display import Markdown
12
8
 
@@ -14,22 +10,26 @@ from langchain_core.messages import BaseMessage, AIMessage
14
10
  from langgraph.prebuilt import create_react_agent, ToolNode
15
11
  from langgraph.prebuilt.chat_agent_executor import AgentState
16
12
  from langgraph.graph import START, END, StateGraph
13
+ from langgraph.types import Checkpointer
17
14
 
18
15
  from ai_data_science_team.templates import BaseAgent
19
16
  from ai_data_science_team.utils.regex import format_agent_name
20
17
 
21
18
  from ai_data_science_team.tools.eda import (
19
+ explain_data,
22
20
  describe_dataset,
23
21
  visualize_missing,
24
22
  correlation_funnel,
25
23
  generate_sweetviz_report,
26
24
  )
25
+ from ai_data_science_team.utils.messages import get_tool_call_names
27
26
 
28
27
 
29
28
  AGENT_NAME = "exploratory_data_analyst_agent"
30
29
 
31
30
  # Updated tool list for EDA
32
31
  EDA_TOOLS = [
32
+ explain_data,
33
33
  describe_dataset,
34
34
  visualize_missing,
35
35
  correlation_funnel,
@@ -49,6 +49,8 @@ class EDAToolsAgent(BaseAgent):
49
49
  Additional kwargs for create_react_agent.
50
50
  invoke_react_agent_kwargs : dict
51
51
  Additional kwargs for agent invocation.
52
+ checkpointer : Checkpointer, optional
53
+ The checkpointer for the agent.
52
54
  """
53
55
 
54
56
  def __init__(
@@ -56,11 +58,13 @@ class EDAToolsAgent(BaseAgent):
56
58
  model: Any,
57
59
  create_react_agent_kwargs: Optional[Dict] = {},
58
60
  invoke_react_agent_kwargs: Optional[Dict] = {},
61
+ checkpointer: Optional[Checkpointer] = None,
59
62
  ):
60
63
  self._params = {
61
64
  "model": model,
62
65
  "create_react_agent_kwargs": create_react_agent_kwargs,
63
66
  "invoke_react_agent_kwargs": invoke_react_agent_kwargs,
67
+ "checkpointer": checkpointer
64
68
  }
65
69
  self._compiled_graph = self._make_compiled_graph()
66
70
  self.response = None
@@ -162,11 +166,18 @@ class EDAToolsAgent(BaseAgent):
162
166
  return Markdown(self.response["messages"][0].content)
163
167
  else:
164
168
  return self.response["messages"][0].content
169
+
170
+ def get_tool_calls(self):
171
+ """
172
+ Returns the tool calls made by the agent.
173
+ """
174
+ return self.response["tool_calls"]
165
175
 
166
176
  def make_eda_tools_agent(
167
177
  model: Any,
168
178
  create_react_agent_kwargs: Optional[Dict] = {},
169
179
  invoke_react_agent_kwargs: Optional[Dict] = {},
180
+ checkpointer: Optional[Checkpointer] = None,
170
181
  ):
171
182
  """
172
183
  Creates an Exploratory Data Analyst Agent that can interact with EDA tools.
@@ -179,6 +190,8 @@ def make_eda_tools_agent(
179
190
  Additional kwargs for create_react_agent.
180
191
  invoke_react_agent_kwargs : dict
181
192
  Additional kwargs for agent invocation.
193
+ checkpointer : Checkpointer, optional
194
+ The checkpointer for the agent.
182
195
 
183
196
  Returns:
184
197
  -------
@@ -191,6 +204,7 @@ def make_eda_tools_agent(
191
204
  user_instructions: str
192
205
  data_raw: dict
193
206
  eda_artifacts: dict
207
+ tool_calls: list
194
208
 
195
209
  def exploratory_agent(state):
196
210
  print(format_agent_name(AGENT_NAME))
@@ -205,6 +219,7 @@ def make_eda_tools_agent(
205
219
  tools=tool_node,
206
220
  state_schema=GraphState,
207
221
  **create_react_agent_kwargs,
222
+ checkpointer=checkpointer,
208
223
  )
209
224
 
210
225
  response = eda_agent.invoke(
@@ -229,11 +244,14 @@ def make_eda_tools_agent(
229
244
  last_tool_artifact = last_message.artifact
230
245
  elif isinstance(last_message, dict) and "artifact" in last_message:
231
246
  last_tool_artifact = last_message["artifact"]
247
+
248
+ tool_calls = get_tool_call_names(internal_messages)
232
249
 
233
250
  return {
234
251
  "messages": [last_ai_message],
235
252
  "internal_messages": internal_messages,
236
253
  "eda_artifacts": last_tool_artifact,
254
+ "tool_calls": tool_calls,
237
255
  }
238
256
 
239
257
  workflow = StateGraph(GraphState)
@@ -241,5 +259,9 @@ def make_eda_tools_agent(
241
259
  workflow.add_edge(START, "exploratory_agent")
242
260
  workflow.add_edge("exploratory_agent", END)
243
261
 
244
- app = workflow.compile()
262
+ app = workflow.compile(
263
+ checkpointer=checkpointer,
264
+ name=AGENT_NAME,
265
+ )
266
+
245
267
  return app
@@ -5,7 +5,7 @@
5
5
 
6
6
  import os
7
7
  import json
8
- from typing import TypedDict, Annotated, Sequence, Literal
8
+ from typing import TypedDict, Annotated, Sequence, Literal, Optional
9
9
  import operator
10
10
 
11
11
  import pandas as pd
@@ -14,7 +14,7 @@ from IPython.display import Markdown
14
14
  from langchain.prompts import PromptTemplate
15
15
  from langchain_core.messages import BaseMessage
16
16
 
17
- from langgraph.types import Command
17
+ from langgraph.types import Command, Checkpointer
18
18
  from langgraph.checkpoint.memory import MemorySaver
19
19
 
20
20
  from ai_data_science_team.templates import(
@@ -79,6 +79,8 @@ class H2OMLAgent(BaseAgent):
79
79
  Name of the MLflow experiment (created if doesn't exist).
80
80
  mlflow_run_name : str, default None
81
81
  A custom name for the MLflow run.
82
+ checkpointer : langgraph.checkpoint.memory.MemorySaver, optional
83
+ A checkpointer object for saving the agent's state. Defaults to None.
82
84
 
83
85
 
84
86
  Methods
@@ -176,6 +178,7 @@ class H2OMLAgent(BaseAgent):
176
178
  mlflow_tracking_uri=None,
177
179
  mlflow_experiment_name="H2O AutoML",
178
180
  mlflow_run_name=None,
181
+ checkpointer: Optional[Checkpointer]=None,
179
182
  ):
180
183
  self._params = {
181
184
  "model": model,
@@ -193,6 +196,7 @@ class H2OMLAgent(BaseAgent):
193
196
  "mlflow_tracking_uri": mlflow_tracking_uri,
194
197
  "mlflow_experiment_name": mlflow_experiment_name,
195
198
  "mlflow_run_name": mlflow_run_name,
199
+ "checkpointer": checkpointer,
196
200
  }
197
201
  self._compiled_graph = self._make_compiled_graph()
198
202
  self.response = None
@@ -350,6 +354,7 @@ def make_h2o_ml_agent(
350
354
  mlflow_tracking_uri=None,
351
355
  mlflow_experiment_name="H2O AutoML",
352
356
  mlflow_run_name=None,
357
+ checkpointer=None,
353
358
  ):
354
359
  """
355
360
  Creates a machine learning agent that uses H2O for AutoML.
@@ -384,6 +389,12 @@ def make_h2o_ml_agent(
384
389
  " pip install h2o\n\n"
385
390
  "Visit https://docs.h2o.ai/h2o/latest-stable/h2o-docs/downloading.html for details."
386
391
  ) from e
392
+
393
+ if human_in_the_loop:
394
+ if checkpointer is None:
395
+ print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
396
+ checkpointer = MemorySaver()
397
+
387
398
 
388
399
  # Define GraphState
389
400
  class GraphState(TypedDict):
@@ -844,9 +855,10 @@ def make_h2o_ml_agent(
844
855
  retry_count_key="retry_count",
845
856
  human_in_the_loop=human_in_the_loop,
846
857
  human_review_node_name="human_review",
847
- checkpointer=MemorySaver(),
858
+ checkpointer=checkpointer,
848
859
  bypass_recommended_steps=bypass_recommended_steps,
849
860
  bypass_explain_code=bypass_explain_code,
861
+ agent_name=AGENT_NAME,
850
862
  )
851
863
 
852
864
  return app
@@ -10,6 +10,7 @@ from langchain_core.messages import BaseMessage, AIMessage
10
10
 
11
11
  from langgraph.prebuilt import create_react_agent, ToolNode
12
12
  from langgraph.prebuilt.chat_agent_executor import AgentState
13
+ from langgraph.types import Checkpointer
13
14
  from langgraph.graph import START, END, StateGraph
14
15
 
15
16
  from ai_data_science_team.templates import BaseAgent
@@ -27,6 +28,7 @@ from ai_data_science_team.tools.mlflow import (
27
28
  mlflow_search_registered_models,
28
29
  mlflow_get_model_version_details,
29
30
  )
31
+ from ai_data_science_team.utils.messages import get_tool_call_names
30
32
 
31
33
  AGENT_NAME = "mlflow_tools_agent"
32
34
 
@@ -67,6 +69,8 @@ class MLflowToolsAgent(BaseAgent):
67
69
  Additional keyword arguments to pass to the create_react_agent function.
68
70
  invoke_react_agent_kwargs : dict
69
71
  Additional keyword arguments to pass to the invoke method of the react agent.
72
+ checkpointer : langchain.checkpointing.Checkpointer, optional
73
+ A checkpointer to use for saving and loading the agent's state. Defaults to None.
70
74
 
71
75
  Methods:
72
76
  --------
@@ -118,6 +122,7 @@ class MLflowToolsAgent(BaseAgent):
118
122
  mlflow_registry_uri: Optional[str]=None,
119
123
  create_react_agent_kwargs: Optional[Dict]={},
120
124
  invoke_react_agent_kwargs: Optional[Dict]={},
125
+ checkpointer: Optional[Checkpointer]=None,
121
126
  ):
122
127
  self._params = {
123
128
  "model": model,
@@ -125,6 +130,7 @@ class MLflowToolsAgent(BaseAgent):
125
130
  "mlflow_registry_uri": mlflow_registry_uri,
126
131
  "create_react_agent_kwargs": create_react_agent_kwargs,
127
132
  "invoke_react_agent_kwargs": invoke_react_agent_kwargs,
133
+ "checkpointer": checkpointer,
128
134
  }
129
135
  self._compiled_graph = self._make_compiled_graph()
130
136
  self.response = None
@@ -228,6 +234,12 @@ class MLflowToolsAgent(BaseAgent):
228
234
  return Markdown(self.response["messages"][0].content)
229
235
  else:
230
236
  return self.response["messages"][0].content
237
+
238
+ def get_tool_calls(self):
239
+ """
240
+ Returns the tool calls made by the agent.
241
+ """
242
+ return self.response["tool_calls"]
231
243
 
232
244
 
233
245
 
@@ -238,6 +250,7 @@ def make_mlflow_tools_agent(
238
250
  mlflow_registry_uri: str=None,
239
251
  create_react_agent_kwargs: Optional[Dict]={},
240
252
  invoke_react_agent_kwargs: Optional[Dict]={},
253
+ checkpointer: Optional[Checkpointer]=None,
241
254
  ):
242
255
  """
243
256
  MLflow Tool Calling Agent
@@ -254,6 +267,8 @@ def make_mlflow_tools_agent(
254
267
  Additional keyword arguments to pass to the agent's create_react_agent method.
255
268
  invoke_react_agent_kwargs : dict, optional
256
269
  Additional keyword arguments to pass to the agent's invoke method.
270
+ checkpointer : langchain.checkpointing.Checkpointer, optional
271
+ A checkpointer to use for saving and loading the agent's state. Defaults to None.
257
272
 
258
273
  Returns
259
274
  -------
@@ -296,6 +311,7 @@ def make_mlflow_tools_agent(
296
311
  model,
297
312
  tools=tool_node,
298
313
  state_schema=GraphState,
314
+ checkpointer=checkpointer,
299
315
  **create_react_agent_kwargs,
300
316
  )
301
317
 
@@ -330,10 +346,13 @@ def make_mlflow_tools_agent(
330
346
  elif isinstance(last_message, dict) and "artifact" in last_message:
331
347
  last_tool_artifact = last_message["artifact"]
332
348
 
349
+ tool_calls = get_tool_call_names(internal_messages)
350
+
333
351
  return {
334
352
  "messages": [last_ai_message],
335
353
  "internal_messages": internal_messages,
336
354
  "mlflow_artifacts": last_tool_artifact,
355
+ "tool_calls": tool_calls,
337
356
  }
338
357
 
339
358
 
@@ -344,7 +363,10 @@ def make_mlflow_tools_agent(
344
363
  workflow.add_edge(START, "mlflow_tools_agent")
345
364
  workflow.add_edge("mlflow_tools_agent", END)
346
365
 
347
- app = workflow.compile()
366
+ app = workflow.compile(
367
+ checkpointer=checkpointer,
368
+ name=AGENT_NAME,
369
+ )
348
370
 
349
371
  return app
350
372
 
@@ -1 +1,2 @@
1
- from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
1
+ from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
2
+ from ai_data_science_team.multiagents.pandas_data_analyst import PandasDataAnalyst, make_pandas_data_analyst
@@ -0,0 +1,305 @@
1
+ from langchain_core.messages import BaseMessage
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain_core.output_parsers import JsonOutputParser
4
+ from langgraph.types import Checkpointer
5
+ from langgraph.graph import START, END, StateGraph
6
+ from langgraph.graph.state import CompiledStateGraph
7
+
8
+ from typing import TypedDict, Annotated, Sequence, Union
9
+ import operator
10
+
11
+ import pandas as pd
12
+ import json
13
+ from IPython.display import Markdown
14
+
15
+ from ai_data_science_team.templates import BaseAgent
16
+ from ai_data_science_team.agents import DataWranglingAgent, DataVisualizationAgent
17
+ from ai_data_science_team.utils.plotly import plotly_from_dict
18
+ from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
19
+
20
+ AGENT_NAME = "pandas_data_analyst"
21
+
22
+ class PandasDataAnalyst(BaseAgent):
23
+ """
24
+ PandasDataAnalyst is a multi-agent class that combines data wrangling and visualization capabilities.
25
+
26
+ Parameters:
27
+ -----------
28
+ model:
29
+ The language model to be used for the agents.
30
+ data_wrangling_agent: DataWranglingAgent
31
+ The Data Wrangling Agent for transforming raw data.
32
+ data_visualization_agent: DataVisualizationAgent
33
+ The Data Visualization Agent for generating plots.
34
+ checkpointer: Checkpointer (optional)
35
+ The checkpointer to save the state of the multi-agent system.
36
+
37
+ Methods:
38
+ --------
39
+ ainvoke_agent(user_instructions, data_raw, **kwargs)
40
+ Asynchronously invokes the multi-agent with user instructions and raw data.
41
+ invoke_agent(user_instructions, data_raw, **kwargs)
42
+ Synchronously invokes the multi-agent with user instructions and raw data.
43
+ get_data_wrangled()
44
+ Returns the wrangled data as a Pandas DataFrame.
45
+ get_plotly_graph()
46
+ Returns the Plotly graph as a Plotly object.
47
+ get_data_wrangler_function(markdown=False)
48
+ Returns the data wrangling function as a string, optionally in Markdown.
49
+ get_data_visualization_function(markdown=False)
50
+ Returns the data visualization function as a string, optionally in Markdown.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ model,
56
+ data_wrangling_agent: DataWranglingAgent,
57
+ data_visualization_agent: DataVisualizationAgent,
58
+ checkpointer: Checkpointer = None,
59
+ ):
60
+ self._params = {
61
+ "model": model,
62
+ "data_wrangling_agent": data_wrangling_agent,
63
+ "data_visualization_agent": data_visualization_agent,
64
+ "checkpointer": checkpointer,
65
+ }
66
+ self._compiled_graph = self._make_compiled_graph()
67
+ self.response = None
68
+
69
+ def _make_compiled_graph(self):
70
+ """Create or rebuild the compiled graph. Resets response to None."""
71
+ self.response = None
72
+ return make_pandas_data_analyst(
73
+ model=self._params["model"],
74
+ data_wrangling_agent=self._params["data_wrangling_agent"]._compiled_graph,
75
+ data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
76
+ checkpointer=self._params["checkpointer"],
77
+ )
78
+
79
+ def update_params(self, **kwargs):
80
+ """Updates parameters and rebuilds the compiled graph."""
81
+ for k, v in kwargs.items():
82
+ self._params[k] = v
83
+ self._compiled_graph = self._make_compiled_graph()
84
+
85
+ async def ainvoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
86
+ """Asynchronously invokes the multi-agent."""
87
+ response = await self._compiled_graph.ainvoke({
88
+ "user_instructions": user_instructions,
89
+ "data_raw": self._convert_data_input(data_raw),
90
+ "max_retries": max_retries,
91
+ "retry_count": retry_count,
92
+ }, **kwargs)
93
+ if response.get("messages"):
94
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
95
+ self.response = response
96
+
97
+ def invoke_agent(self, user_instructions, data_raw: Union[pd.DataFrame, dict, list], max_retries: int = 3, retry_count: int = 0, **kwargs):
98
+ """Synchronously invokes the multi-agent."""
99
+ response = self._compiled_graph.invoke({
100
+ "user_instructions": user_instructions,
101
+ "data_raw": self._convert_data_input(data_raw),
102
+ "max_retries": max_retries,
103
+ "retry_count": retry_count,
104
+ }, **kwargs)
105
+ if response.get("messages"):
106
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
107
+ self.response = response
108
+
109
+ def get_data_wrangled(self):
110
+ """Returns the wrangled data as a Pandas DataFrame."""
111
+ if self.response and self.response.get("data_wrangled"):
112
+ return pd.DataFrame(self.response.get("data_wrangled"))
113
+
114
+ def get_plotly_graph(self):
115
+ """Returns the Plotly graph as a Plotly object."""
116
+ if self.response and self.response.get("plotly_graph"):
117
+ return plotly_from_dict(self.response.get("plotly_graph"))
118
+
119
+ def get_data_wrangler_function(self, markdown=False):
120
+ """Returns the data wrangling function as a string."""
121
+ if self.response and self.response.get("data_wrangler_function"):
122
+ code = self.response.get("data_wrangler_function")
123
+ return Markdown(f"```python\n{code}\n```") if markdown else code
124
+
125
+ def get_data_visualization_function(self, markdown=False):
126
+ """Returns the data visualization function as a string."""
127
+ if self.response and self.response.get("data_visualization_function"):
128
+ code = self.response.get("data_visualization_function")
129
+ return Markdown(f"```python\n{code}\n```") if markdown else code
130
+
131
+ def get_workflow_summary(self, markdown=False):
132
+ """Returns a summary of the workflow."""
133
+ if self.response and self.response.get("messages"):
134
+ agents = [msg.role for msg in self.response["messages"]]
135
+ agent_labels = [f"- **Agent {i+1}:** {role}" for i, role in enumerate(agents)]
136
+ header = f"# Pandas Data Analyst Workflow Summary\n\nThis workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
137
+ reports = [get_generic_summary(json.loads(msg.content)) for msg in self.response["messages"]]
138
+ summary = "\n" +header + "\n\n".join(reports)
139
+ return Markdown(summary) if markdown else summary
140
+
141
+ @staticmethod
142
+ def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
143
+ """Converts input data to the expected format (dict or list of dicts)."""
144
+ if isinstance(data_raw, pd.DataFrame):
145
+ return data_raw.to_dict()
146
+ if isinstance(data_raw, dict):
147
+ return data_raw
148
+ if isinstance(data_raw, list):
149
+ return [item.to_dict() if isinstance(item, pd.DataFrame) else item for item in data_raw]
150
+ raise ValueError("data_raw must be a DataFrame, dict, or list of DataFrames/dicts")
151
+
152
+ def make_pandas_data_analyst(
153
+ model,
154
+ data_wrangling_agent: CompiledStateGraph,
155
+ data_visualization_agent: CompiledStateGraph,
156
+ checkpointer: Checkpointer = None
157
+ ):
158
+ """
159
+ Creates a multi-agent system that wrangles data and optionally visualizes it.
160
+
161
+ Parameters:
162
+ -----------
163
+ model: The language model to be used.
164
+ data_wrangling_agent: CompiledStateGraph
165
+ The Data Wrangling Agent.
166
+ data_visualization_agent: CompiledStateGraph
167
+ The Data Visualization Agent.
168
+ checkpointer: Checkpointer (optional)
169
+ The checkpointer to save the state.
170
+
171
+ Returns:
172
+ --------
173
+ CompiledStateGraph: The compiled multi-agent system.
174
+ """
175
+
176
+ llm = model
177
+
178
+ routing_preprocessor_prompt = PromptTemplate(
179
+ template="""
180
+ You are an expert in routing decisions for a Pandas Data Manipulation Wrangling Agent, a Charting Visualization Agent, and a Pandas Table Agent. Your job is to:
181
+
182
+ 1. Determine what the correct format for a Users Question should be for use with a Pandas Data Wrangling Agent based on the incoming user question. Anything related to data wrangling and manipulation should be passed along.
183
+ 2. Determine whether or not a chart should be generated or a table should be returned based on the users question.
184
+ 3. If a chart is requested, determine the correct format of a Users Question should be used with a Data Visualization Agent. Anything related to plotting and visualization should be passed along.
185
+
186
+ Use the following criteria on how to route the the initial user question:
187
+
188
+ From the incoming user question, remove any details about the format of the final response as either a Chart or Table and return only the important part of the incoming user question that is relevant for the SQL generator agent. This will be the 'user_instructions_data_wrangling'. If 'None' is found, return the original user question.
189
+
190
+ Next, determine if the user would like a data visualization ('chart') or a 'table' returned with the results of the Data Wrangling Agent. If unknown, not specified or 'None' is found, then select 'table'.
191
+
192
+ If a 'chart' is requested, return the 'user_instructions_data_visualization'. If 'None' is found, return None.
193
+
194
+ Return JSON with 'user_instructions_data_wrangling', 'user_instructions_data_visualization' and 'routing_preprocessor_decision'.
195
+
196
+ INITIAL_USER_QUESTION: {user_instructions}
197
+ """,
198
+ input_variables=["user_instructions"]
199
+ )
200
+
201
+ routing_preprocessor = routing_preprocessor_prompt | llm | JsonOutputParser()
202
+
203
+ class PrimaryState(TypedDict):
204
+ messages: Annotated[Sequence[BaseMessage], operator.add]
205
+ user_instructions: str
206
+ user_instructions_data_wrangling: str
207
+ user_instructions_data_visualization: str
208
+ routing_preprocessor_decision: str
209
+ data_raw: Union[dict, list]
210
+ data_wrangled: dict
211
+ data_wrangler_function: str
212
+ data_visualization_function: str
213
+ plotly_graph: dict
214
+ plotly_error: str
215
+ max_retries: int
216
+ retry_count: int
217
+
218
+
219
+ def preprocess_routing(state: PrimaryState):
220
+ print("---PANDAS DATA ANALYST---")
221
+ print("*************************")
222
+ print("---PREPROCESS ROUTER---")
223
+ question = state.get("user_instructions")
224
+
225
+ # Chart Routing and SQL Prep
226
+ response = routing_preprocessor.invoke({"user_instructions": question})
227
+
228
+ return {
229
+ "user_instructions_data_wrangling": response.get('user_instructions_data_wrangling'),
230
+ "user_instructions_data_visualization": response.get('user_instructions_data_visualization'),
231
+ "routing_preprocessor_decision": response.get('routing_preprocessor_decision'),
232
+ }
233
+
234
+ def router_chart_or_table(state: PrimaryState):
235
+ print("---ROUTER: CHART OR TABLE---")
236
+ return "chart" if state.get('routing_preprocessor_decision') == "chart" else "table"
237
+
238
+
239
+ def invoke_data_wrangling_agent(state: PrimaryState):
240
+
241
+ response = data_wrangling_agent.invoke({
242
+ "user_instructions": state.get("user_instructions_data_wrangling"),
243
+ "data_raw": state.get("data_raw"),
244
+ "max_retries": state.get("max_retries"),
245
+ "retry_count": state.get("retry_count"),
246
+ })
247
+
248
+ return {
249
+ "messages": response.get("messages"),
250
+ "data_wrangled": response.get("data_wrangled"),
251
+ "data_wrangler_function": response.get("data_wrangler_function"),
252
+ "plotly_error": response.get("data_visualization_error"),
253
+
254
+ }
255
+
256
+ def invoke_data_visualization_agent(state: PrimaryState):
257
+
258
+ response = data_visualization_agent.invoke({
259
+ "user_instructions": state.get("user_instructions_data_visualization"),
260
+ "data_raw": state.get("data_wrangled") if state.get("data_wrangled") else state.get("data_raw"),
261
+ "max_retries": state.get("max_retries"),
262
+ "retry_count": state.get("retry_count"),
263
+ })
264
+
265
+ return {
266
+ "messages": response.get("messages"),
267
+ "data_visualization_function": response.get("data_visualization_function"),
268
+ "plotly_graph": response.get("plotly_graph"),
269
+ "plotly_error": response.get("data_visualization_error"),
270
+ }
271
+
272
+ def route_printer(state: PrimaryState):
273
+ print("---ROUTE PRINTER---")
274
+ print(f" Route: {state.get('routing_preprocessor_decision')}")
275
+ print("---END---")
276
+ return {}
277
+
278
+ workflow = StateGraph(PrimaryState)
279
+
280
+ workflow.add_node("routing_preprocessor", preprocess_routing)
281
+ workflow.add_node("data_wrangling_agent", invoke_data_wrangling_agent)
282
+ workflow.add_node("data_visualization_agent", invoke_data_visualization_agent)
283
+ workflow.add_node("route_printer", route_printer)
284
+
285
+ workflow.add_edge(START, "routing_preprocessor")
286
+ workflow.add_edge("routing_preprocessor", "data_wrangling_agent")
287
+
288
+ workflow.add_conditional_edges(
289
+ "data_wrangling_agent",
290
+ router_chart_or_table,
291
+ {
292
+ "chart": "data_visualization_agent",
293
+ "table": "route_printer"
294
+ }
295
+ )
296
+
297
+ workflow.add_edge("data_visualization_agent", "route_printer")
298
+ workflow.add_edge("route_printer", END)
299
+
300
+ app = workflow.compile(
301
+ checkpointer=checkpointer,
302
+ name=AGENT_NAME
303
+ )
304
+
305
+ return app