ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9007__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1 +1 @@
1
- __version__ = "0.0.0.9006"
1
+ __version__ = "0.0.0.9007"
@@ -1,5 +1,6 @@
1
- from ai_data_science_team.agents.data_cleaning_agent import make_data_cleaning_agent
1
+ from ai_data_science_team.agents.data_cleaning_agent import make_data_cleaning_agent, DataCleaningAgent
2
2
  from ai_data_science_team.agents.feature_engineering_agent import make_feature_engineering_agent
3
3
  from ai_data_science_team.agents.data_wrangling_agent import make_data_wrangling_agent
4
4
  from ai_data_science_team.agents.sql_database_agent import make_sql_database_agent
5
+ from ai_data_science_team.agents.data_visualization_agent import make_data_visualization_agent
5
6
 
@@ -13,11 +13,13 @@ from langchain_core.messages import BaseMessage
13
13
  from langgraph.types import Command
14
14
  from langgraph.checkpoint.memory import MemorySaver
15
15
 
16
+ from langgraph.graph.state import CompiledStateGraph
17
+
16
18
  import os
17
19
  import io
18
20
  import pandas as pd
19
21
 
20
- from ai_data_science_team.templates.agent_templates import(
22
+ from ai_data_science_team.templates import(
21
23
  node_func_execute_agent_code_on_data,
22
24
  node_func_human_review,
23
25
  node_func_fix_agent_code,
@@ -25,7 +27,7 @@ from ai_data_science_team.templates.agent_templates import(
25
27
  create_coding_agent_graph
26
28
  )
27
29
  from ai_data_science_team.tools.parsers import PythonOutputParser
28
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
30
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
29
31
  from ai_data_science_team.tools.metadata import get_dataframe_summary
30
32
  from ai_data_science_team.tools.logging import log_ai_function
31
33
 
@@ -33,9 +35,170 @@ from ai_data_science_team.tools.logging import log_ai_function
33
35
  AGENT_NAME = "data_cleaning_agent"
34
36
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
35
37
 
38
+
39
+
40
+ # Class
41
+ class DataCleaningAgent(CompiledStateGraph):
42
+
43
+ def __init__(
44
+ self,
45
+ model,
46
+ n_samples=30,
47
+ log=False,
48
+ log_path=None,
49
+ file_name="data_cleaner.py",
50
+ overwrite=True,
51
+ human_in_the_loop=False,
52
+ bypass_recommended_steps=False,
53
+ bypass_explain_code=False
54
+ ):
55
+ self._params = {
56
+ "model": model,
57
+ "n_samples": n_samples,
58
+ "log": log,
59
+ "log_path": log_path,
60
+ "file_name": file_name,
61
+ "overwrite": overwrite,
62
+ "human_in_the_loop": human_in_the_loop,
63
+ "bypass_recommended_steps": bypass_recommended_steps,
64
+ "bypass_explain_code": bypass_explain_code,
65
+ }
66
+ self._compiled_graph = self._make_compiled_graph()
67
+ self.response = None
68
+
69
+ def _make_compiled_graph(self):
70
+ self.response = None
71
+ return make_data_cleaning_agent(**self._params)
72
+
73
+ def update_params(self, **kwargs):
74
+ """
75
+ Update one or more parameters at once, then rebuild the compiled graph.
76
+ e.g. agent.update_params(model=new_llm, n_samples=100)
77
+ """
78
+ self._params.update(kwargs)
79
+ self._compiled_graph = self._make_compiled_graph()
80
+
81
+ def __getattr__(self, name: str):
82
+ """
83
+ Delegate attribute access to `_compiled_graph` if `name` is not
84
+ found in this instance. This 'inherits' methods from the compiled graph.
85
+ """
86
+ return getattr(self._compiled_graph, name)
87
+
88
+ def ainvoke(self, user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0):
89
+ """
90
+ Cleans the provided dataset based on user instructions.
91
+
92
+ Parameters:
93
+ user_instructions (str): Instructions for data cleaning.
94
+ data_raw (pd.DataFrame): The raw dataset to be cleaned.
95
+ max_retries (int): Maximum retry attempts for cleaning.
96
+ retry_count (int): Current retry attempt.
97
+
98
+ Returns:
99
+ None. The response is stored in the response attribute.
100
+ """
101
+ response = self.ainvoke({
102
+ "user_instructions": user_instructions,
103
+ "data_raw": data_raw.to_dict(),
104
+ "max_retries": max_retries,
105
+ "retry_count": retry_count,
106
+ })
107
+ self.response = response
108
+ return None
109
+
110
+ def invoke(self, user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0):
111
+ """
112
+ Cleans the provided dataset based on user instructions.
113
+
114
+ Parameters:
115
+ user_instructions (str): Instructions for data cleaning.
116
+ data_raw (pd.DataFrame): The raw dataset to be cleaned.
117
+ max_retries (int): Maximum retry attempts for cleaning.
118
+ retry_count (int): Current retry attempt.
119
+
120
+ Returns:
121
+ None. The response is stored in the response attribute.
122
+ """
123
+ response = self.invoke({
124
+ "user_instructions": user_instructions,
125
+ "data_raw": data_raw.to_dict(),
126
+ "max_retries": max_retries,
127
+ "retry_count": retry_count,
128
+ })
129
+ self.response = response
130
+ return None
131
+
132
+ def explain_cleaning_steps(self):
133
+ """
134
+ Provides an explanation of the cleaning steps performed by the agent.
135
+
136
+ Returns:
137
+ str: Explanation of the cleaning steps.
138
+ """
139
+ messages = self.response.get("messages", [])
140
+ return messages
141
+
142
+ def get_log_summary(self):
143
+ """
144
+ Logs a summary of the agent's operations, if logging is enabled.
145
+ """
146
+ if self.response:
147
+ if self.log:
148
+ log_details = f"Log Path: {self.response.get('data_cleaner_function_path')}"
149
+ return log_details
150
+
151
+ def get_state_keys(self):
152
+ """
153
+ Returns a list of keys that the state graph returns in a response.
154
+ """
155
+ return list(self.get_output_jsonschema()['properties'].keys())
156
+
157
+ def get_state_properties(self):
158
+ """
159
+ Returns a list of keys that the state graph returns in a response.
160
+ """
161
+ return self.get_output_jsonschema()['properties']
162
+
163
+ def get_data_cleaned(self):
164
+ """
165
+ Retrieves the cleaned data stored after running invoke or clean_data methods.
166
+ """
167
+ if self.response:
168
+ return pd.DataFrame(self.response.get("data_cleaned"))
169
+
170
+ def get_data_raw(self):
171
+ """
172
+ Retrieves the raw data.
173
+ """
174
+ if self.response:
175
+ return pd.DataFrame(self.response.get("data_raw"))
176
+
177
+ def get_data_cleaner_function(self):
178
+ """
179
+ Retrieves the agent's pipeline function.
180
+ """
181
+ if self.response:
182
+ return self.response.get("data_cleaner_function")
183
+
184
+
185
+
186
+
187
+
188
+
36
189
  # Agent
37
190
 
38
- def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True, human_in_the_loop=False, bypass_recommended_steps=False, bypass_explain_code=False):
191
+ def make_data_cleaning_agent(
192
+ model,
193
+ n_samples = 30,
194
+ log=False,
195
+ log_path=None,
196
+ file_name="data_cleaner.py",
197
+ overwrite = True,
198
+ human_in_the_loop=False,
199
+ bypass_recommended_steps=False,
200
+ bypass_explain_code=False
201
+ ):
39
202
  """
40
203
  Creates a data cleaning agent that can be run on a dataset. The agent can be used to clean a dataset in a variety of
41
204
  ways, such as removing columns with more than 40% missing values, imputing missing
@@ -44,9 +207,9 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
44
207
  The agent takes in a dataset and some user instructions, and outputs a python
45
208
  function that can be used to clean the dataset. The agent also logs the code
46
209
  generated and any errors that occur.
47
-
210
+
48
211
  The agent is instructed to to perform the following data cleaning steps:
49
-
212
+
50
213
  - Removing columns if more than 40 percent of the data is missing
51
214
  - Imputing missing values with the mean of the column if the column is numeric
52
215
  - Imputing missing values with the mode of the column if the column is categorical
@@ -60,12 +223,18 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
60
223
  ----------
61
224
  model : langchain.llms.base.LLM
62
225
  The language model to use to generate code.
226
+ n_samples : int, optional
227
+ The number of samples to use when summarizing the dataset. Defaults to 30.
228
+ If you get an error due to maximum tokens, try reducing this number.
229
+ > "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
63
230
  log : bool, optional
64
231
  Whether or not to log the code generated and any errors that occur.
65
232
  Defaults to False.
66
233
  log_path : str, optional
67
234
  The path to the directory where the log files should be stored. Defaults to
68
235
  "logs/".
236
+ file_name : str, optional
237
+ The name of the file to save the response to. Defaults to "data_cleaner.py".
69
238
  overwrite : bool, optional
70
239
  Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
71
240
  Defaults to True.
@@ -82,26 +251,26 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
82
251
  import pandas as pd
83
252
  from langchain_openai import ChatOpenAI
84
253
  from ai_data_science_team.agents import data_cleaning_agent
85
-
254
+
86
255
  llm = ChatOpenAI(model = "gpt-4o-mini")
87
256
 
88
257
  data_cleaning_agent = make_data_cleaning_agent(llm)
89
-
258
+
90
259
  df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
91
-
260
+
92
261
  response = data_cleaning_agent.invoke({
93
262
  "user_instructions": "Don't remove outliers when cleaning the data.",
94
263
  "data_raw": df.to_dict(),
95
264
  "max_retries":3,
96
265
  "retry_count":0
97
266
  })
98
-
267
+
99
268
  pd.DataFrame(response['data_cleaned'])
100
269
  ```
101
270
 
102
271
  Returns
103
272
  -------
104
- app : langchain.graphs.StateGraph
273
+ app : langchain.graphs.CompiledStateGraph
105
274
  The data cleaning agent as a state graph.
106
275
  """
107
276
  llm = model
@@ -134,7 +303,7 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
134
303
  Recommend a series of data cleaning steps based on the input data.
135
304
  These recommended steps will be appended to the user_instructions.
136
305
  """
137
- print("---DATA CLEANING AGENT----")
306
+ print(format_agent_name(AGENT_NAME))
138
307
  print(" * RECOMMEND CLEANING STEPS")
139
308
 
140
309
  # Prompt to get recommended steps from the LLM
@@ -177,6 +346,7 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
177
346
 
178
347
  Avoid these:
179
348
  1. Do not include steps to save files.
349
+ 2. Do not include unrelated user instructions that are not related to the data cleaning.
180
350
  """,
181
351
  input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
182
352
  )
@@ -184,7 +354,7 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
184
354
  data_raw = state.get("data_raw")
185
355
  df = pd.DataFrame.from_dict(data_raw)
186
356
 
187
- all_datasets_summary = get_dataframe_summary([df])
357
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples)
188
358
 
189
359
  all_datasets_summary_str = "\n\n".join(all_datasets_summary)
190
360
 
@@ -201,10 +371,21 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
201
371
  }
202
372
 
203
373
  def create_data_cleaner_code(state: GraphState):
204
- if bypass_recommended_steps:
205
- print("---DATA CLEANING AGENT----")
374
+
206
375
  print(" * CREATE DATA CLEANER CODE")
207
376
 
377
+ if bypass_recommended_steps:
378
+ print(format_agent_name(AGENT_NAME))
379
+
380
+ data_raw = state.get("data_raw")
381
+ df = pd.DataFrame.from_dict(data_raw)
382
+
383
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples)
384
+
385
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
386
+ else:
387
+ all_datasets_summary_str = state.get("all_datasets_summary")
388
+
208
389
  data_cleaning_prompt = PromptTemplate(
209
390
  template="""
210
391
  You are a Data Cleaning Agent. Your job is to create a data_cleaner() function that can be run on the data provided using the following recommended steps.
@@ -218,7 +399,7 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
218
399
 
219
400
  {all_datasets_summary}
220
401
 
221
- Return Python code in ```python ``` format with a single function definition, data_cleaner(data_raw), that incldues all imports inside the function.
402
+ Return Python code in ```python ``` format with a single function definition, data_cleaner(data_raw), that includes all imports inside the function.
222
403
 
223
404
  Return code to provide the data cleaning function:
224
405
 
@@ -240,16 +421,16 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
240
421
 
241
422
  response = data_cleaning_agent.invoke({
242
423
  "recommended_steps": state.get("recommended_steps"),
243
- "all_datasets_summary": state.get("all_datasets_summary")
424
+ "all_datasets_summary": all_datasets_summary_str
244
425
  })
245
426
 
246
427
  response = relocate_imports_inside_function(response)
247
428
  response = add_comments_to_top(response, agent_name=AGENT_NAME)
248
429
 
249
430
  # For logging: store the code generated:
250
- file_path, file_name = log_ai_function(
431
+ file_path, file_name_2 = log_ai_function(
251
432
  response=response,
252
- file_name="data_cleaner.py",
433
+ file_name=file_name,
253
434
  log=log,
254
435
  log_path=log_path,
255
436
  overwrite=overwrite
@@ -258,7 +439,8 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
258
439
  return {
259
440
  "data_cleaner_function" : response,
260
441
  "data_cleaner_function_path": file_path,
261
- "data_cleaner_function_name": file_name
442
+ "data_cleaner_function_name": file_name_2,
443
+ "all_datasets_summary": all_datasets_summary_str
262
444
  }
263
445
 
264
446
  def human_review(state: GraphState) -> Command[Literal["recommend_cleaning_steps", "create_data_cleaner_code"]]:
@@ -353,3 +535,6 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
353
535
  )
354
536
 
355
537
  return app
538
+
539
+
540
+
@@ -0,0 +1,331 @@
1
+ # BUSINESS SCIENCE UNIVERSITY
2
+ # AI DATA SCIENCE TEAM
3
+ # ***
4
+ # * Agents: Data Visualization Agent
5
+
6
+
7
+
8
+ # Libraries
9
+ from typing import TypedDict, Annotated, Sequence, Literal
10
+ import operator
11
+
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain_core.output_parsers import StrOutputParser
14
+ from langchain_core.messages import BaseMessage
15
+
16
+ from langgraph.types import Command
17
+ from langgraph.checkpoint.memory import MemorySaver
18
+
19
+ import os
20
+ import io
21
+ import pandas as pd
22
+
23
+ from ai_data_science_team.templates import(
24
+ node_func_execute_agent_code_on_data,
25
+ node_func_human_review,
26
+ node_func_fix_agent_code,
27
+ node_func_explain_agent_code,
28
+ create_coding_agent_graph
29
+ )
30
+ from ai_data_science_team.tools.parsers import PythonOutputParser
31
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
32
+ from ai_data_science_team.tools.metadata import get_dataframe_summary
33
+ from ai_data_science_team.tools.logging import log_ai_function
34
+
35
+ # Setup
36
+ AGENT_NAME = "data_visualization_agent"
37
+ LOG_PATH = os.path.join(os.getcwd(), "logs/")
38
+
39
+ # Agent
40
+
41
+ def make_data_visualization_agent(
42
+ model,
43
+ n_samples=30,
44
+ log=False,
45
+ log_path=None,
46
+ file_name="data_visualization.py",
47
+ overwrite = True,
48
+ human_in_the_loop=False,
49
+ bypass_recommended_steps=False,
50
+ bypass_explain_code=False
51
+ ):
52
+
53
+ llm = model
54
+
55
+ # Setup Log Directory
56
+ if log:
57
+ if log_path is None:
58
+ log_path = LOG_PATH
59
+ if not os.path.exists(log_path):
60
+ os.makedirs(log_path)
61
+
62
+ # Define GraphState for the router
63
+ class GraphState(TypedDict):
64
+ messages: Annotated[Sequence[BaseMessage], operator.add]
65
+ user_instructions: str
66
+ user_instructions_processed: str
67
+ recommended_steps: str
68
+ data_raw: dict
69
+ plotly_graph: dict
70
+ all_datasets_summary: str
71
+ data_visualization_function: str
72
+ data_visualization_function_path: str
73
+ data_visualization_function_name: str
74
+ data_visualization_error: str
75
+ max_retries: int
76
+ retry_count: int
77
+
78
+ def chart_instructor(state: GraphState):
79
+
80
+ print(format_agent_name(AGENT_NAME))
81
+ print(" * CREATE CHART GENERATOR INSTRUCTIONS")
82
+
83
+ recommend_steps_prompt = PromptTemplate(
84
+ template="""
85
+ You are a supervisor that is an expert in providing instructions to a chart generator agent for plotting.
86
+
87
+ You will take a question that a user has and the data that was generated to answer the question, and create instructions to create a chart from the data that will be passed to a chart generator agent.
88
+
89
+ USER QUESTION / INSTRUCTIONS:
90
+ {user_instructions}
91
+
92
+ Previously Recommended Instructions (if any):
93
+ {recommended_steps}
94
+
95
+ DATA:
96
+ {all_datasets_summary}
97
+
98
+ Formulate chart generator instructions by informing the chart generator of what type of plotly plot to use (e.g. bar, line, scatter, etc) to best represent the data.
99
+
100
+ Come up with an informative title from the user's question and data provided. Also provide X and Y axis titles.
101
+
102
+ Instruct the chart generator to use the following theme colors, sizes, etc:
103
+
104
+ - Start with the "plotly_white" template
105
+ - Use a white background
106
+ - Use this color for bars and lines:
107
+ 'blue': '#3381ff',
108
+ - Base Font Size: 8.8 (Used for x and y axes tickfont, any annotations, hovertips)
109
+ - Title Font Size: 13.2
110
+ - Line Size: 0.65 (specify these within the xaxis and yaxis dictionaries)
111
+ - Add smoothers or trendlines to scatter plots unless not desired by the user
112
+ - Do not use color_discrete_map (this will result in an error)
113
+ - Hover tip size: 8.8
114
+
115
+ Return your instructions in the following format:
116
+ CHART GENERATOR INSTRUCTIONS:
117
+ FILL IN THE INSTRUCTIONS HERE
118
+
119
+ Avoid these:
120
+ 1. Do not include steps to save files.
121
+ 2. Do not include unrelated user instructions that are not related to the chart generation.
122
+ """,
123
+ input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
124
+
125
+ )
126
+
127
+ data_raw = state.get("data_raw")
128
+ df = pd.DataFrame.from_dict(data_raw)
129
+
130
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
131
+
132
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
133
+
134
+ chart_instructor = recommend_steps_prompt | llm
135
+
136
+ recommended_steps = chart_instructor.invoke({
137
+ "user_instructions": state.get("user_instructions"),
138
+ "recommended_steps": state.get("recommended_steps"),
139
+ "all_datasets_summary": all_datasets_summary_str
140
+ })
141
+
142
+ return {
143
+ "recommended_steps": "\n\n# Recommended Data Cleaning Steps:\n" + recommended_steps.content.strip(),
144
+ "all_datasets_summary": all_datasets_summary_str
145
+ }
146
+
147
+ def chart_generator(state: GraphState):
148
+
149
+ print(" * CREATE DATA VISUALIZATION CODE")
150
+
151
+
152
+ if bypass_recommended_steps:
153
+ print(format_agent_name(AGENT_NAME))
154
+
155
+ data_raw = state.get("data_raw")
156
+ df = pd.DataFrame.from_dict(data_raw)
157
+
158
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
159
+
160
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
161
+
162
+ chart_generator_instructions = state.get("user_instructions")
163
+
164
+ else:
165
+ all_datasets_summary_str = state.get("all_datasets_summary")
166
+ chart_generator_instructions = state.get("recommended_steps")
167
+
168
+ prompt_template = PromptTemplate(
169
+ template="""
170
+ You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
171
+
172
+ Your job is to produce python code to generate visualizations.
173
+
174
+ You will take instructions from a Chart Instructor and generate a plotly chart from the data provided.
175
+
176
+ CHART INSTRUCTIONS:
177
+ {chart_generator_instructions}
178
+
179
+ DATA:
180
+ {all_datasets_summary}
181
+
182
+ RETURN:
183
+
184
+ Return Python code in ```python ``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
185
+
186
+ Return the plotly chart as a dictionary.
187
+
188
+ Return code to provide the data visualization function:
189
+
190
+ def data_visualization(data_raw):
191
+ import pandas as pd
192
+ import numpy as np
193
+ import json
194
+ import plotly.graph_objects as go
195
+ import plotly.io as pio
196
+
197
+ ...
198
+
199
+ fig_json = pio.to_json(fig)
200
+ fig_dict = json.loads(fig_json)
201
+
202
+ return fig_dict
203
+
204
+ Avoid these:
205
+ 1. Do not include steps to save files.
206
+ 2. Do not include unrelated user instructions that are not related to the chart generation.
207
+
208
+ """,
209
+ input_variables=["chart_generator_instructions", "all_datasets_summary"]
210
+ )
211
+
212
+ data_visualization_agent = prompt_template | llm | PythonOutputParser()
213
+
214
+ response = data_visualization_agent.invoke({
215
+ "chart_generator_instructions": chart_generator_instructions,
216
+ "all_datasets_summary": all_datasets_summary_str
217
+ })
218
+
219
+ response = relocate_imports_inside_function(response)
220
+ response = add_comments_to_top(response, agent_name=AGENT_NAME)
221
+
222
+ # For logging: store the code generated:
223
+ file_path, file_name_2 = log_ai_function(
224
+ response=response,
225
+ file_name=file_name,
226
+ log=log,
227
+ log_path=log_path,
228
+ overwrite=overwrite
229
+ )
230
+
231
+ return {
232
+ "data_visualization_function": response,
233
+ "data_visualization_function_path": file_path,
234
+ "data_visualization_function_name": file_name_2,
235
+ "all_datasets_summary": all_datasets_summary_str
236
+ }
237
+
238
+ def human_review(state: GraphState) -> Command[Literal["chart_instructor", "chart_generator"]]:
239
+ return node_func_human_review(
240
+ state=state,
241
+ prompt_text="Is the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}",
242
+ yes_goto="chart_generator",
243
+ no_goto="chart_instructor",
244
+ user_instructions_key="user_instructions",
245
+ recommended_steps_key="recommended_steps"
246
+ )
247
+
248
+
249
+ def execute_data_visualization_code(state):
250
+ return node_func_execute_agent_code_on_data(
251
+ state=state,
252
+ data_key="data_raw",
253
+ result_key="plotly_graph",
254
+ error_key="data_visualization_error",
255
+ code_snippet_key="data_visualization_function",
256
+ agent_function_name="data_visualization",
257
+ pre_processing=lambda data: pd.DataFrame.from_dict(data),
258
+ # post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
259
+ error_message_prefix="An error occurred during data visualization: "
260
+ )
261
+
262
+ def fix_data_visualization_code(state: GraphState):
263
+ prompt = """
264
+ You are a Data Visualization Agent. Your job is to create a data_visualization() function that can be run on the data provided. The function is currently broken and needs to be fixed.
265
+
266
+ Make sure to only return the function definition for data_visualization().
267
+
268
+ Return Python code in ```python``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
269
+
270
+ This is the broken code (please fix):
271
+ {code_snippet}
272
+
273
+ Last Known Error:
274
+ {error}
275
+ """
276
+
277
+ return node_func_fix_agent_code(
278
+ state=state,
279
+ code_snippet_key="data_visualization_function",
280
+ error_key="data_visualization_error",
281
+ llm=llm,
282
+ prompt_template=prompt,
283
+ agent_name=AGENT_NAME,
284
+ log=log,
285
+ file_path=state.get("data_visualization_function_path"),
286
+ )
287
+
288
+ def explain_data_visualization_code(state: GraphState):
289
+ return node_func_explain_agent_code(
290
+ state=state,
291
+ code_snippet_key="data_visualization_function",
292
+ result_key="messages",
293
+ error_key="data_visualization_error",
294
+ llm=llm,
295
+ role=AGENT_NAME,
296
+ explanation_prompt_template="""
297
+ Explain the data visualization steps that the data visualization agent performed in this function.
298
+ Keep the summary succinct and to the point.\n\n# Data Visualization Agent:\n\n{code}
299
+ """,
300
+ success_prefix="# Data Visualization Agent:\n\n ",
301
+ error_message="The Data Visualization Agent encountered an error during data visualization. No explanation could be provided."
302
+ )
303
+
304
+ # Define the graph
305
+ node_functions = {
306
+ "chart_instructor": chart_instructor,
307
+ "human_review": human_review,
308
+ "chart_generator": chart_generator,
309
+ "execute_data_visualization_code": execute_data_visualization_code,
310
+ "fix_data_visualization_code": fix_data_visualization_code,
311
+ "explain_data_visualization_code": explain_data_visualization_code
312
+ }
313
+
314
+ app = create_coding_agent_graph(
315
+ GraphState=GraphState,
316
+ node_functions=node_functions,
317
+ recommended_steps_node_name="chart_instructor",
318
+ create_code_node_name="chart_generator",
319
+ execute_code_node_name="execute_data_visualization_code",
320
+ fix_code_node_name="fix_data_visualization_code",
321
+ explain_code_node_name="explain_data_visualization_code",
322
+ error_key="data_visualization_error",
323
+ human_in_the_loop=human_in_the_loop, # or False
324
+ human_review_node_name="human_review",
325
+ checkpointer=MemorySaver() if human_in_the_loop else None,
326
+ bypass_recommended_steps=bypass_recommended_steps,
327
+ bypass_explain_code=bypass_explain_code,
328
+ )
329
+
330
+ return app
331
+