ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (25) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +4 -5
  3. ai_data_science_team/agents/data_cleaning_agent.py +268 -116
  4. ai_data_science_team/agents/data_visualization_agent.py +470 -41
  5. ai_data_science_team/agents/data_wrangling_agent.py +471 -31
  6. ai_data_science_team/agents/feature_engineering_agent.py +426 -41
  7. ai_data_science_team/agents/sql_database_agent.py +458 -58
  8. ai_data_science_team/ml_agents/__init__.py +1 -0
  9. ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
  10. ai_data_science_team/multiagents/__init__.py +1 -0
  11. ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
  12. ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
  13. ai_data_science_team/templates/__init__.py +3 -1
  14. ai_data_science_team/templates/agent_templates.py +319 -43
  15. ai_data_science_team/tools/metadata.py +94 -62
  16. ai_data_science_team/tools/regex.py +86 -1
  17. ai_data_science_team/utils/__init__.py +0 -0
  18. ai_data_science_team/utils/plotly.py +24 -0
  19. ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
  20. ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
  21. ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
  22. ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
  23. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
  24. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
  25. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -17,25 +17,363 @@ from langgraph.types import Command
17
17
  from langgraph.checkpoint.memory import MemorySaver
18
18
 
19
19
  import os
20
- import io
20
+ import json
21
21
  import pandas as pd
22
22
 
23
+ from IPython.display import Markdown
24
+
23
25
  from ai_data_science_team.templates import(
24
26
  node_func_execute_agent_code_on_data,
25
27
  node_func_human_review,
26
28
  node_func_fix_agent_code,
27
- node_func_explain_agent_code,
28
- create_coding_agent_graph
29
+ node_func_report_agent_outputs,
30
+ create_coding_agent_graph,
31
+ BaseAgent,
29
32
  )
30
33
  from ai_data_science_team.tools.parsers import PythonOutputParser
31
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
34
+ from ai_data_science_team.tools.regex import (
35
+ relocate_imports_inside_function,
36
+ add_comments_to_top,
37
+ format_agent_name,
38
+ format_recommended_steps,
39
+ get_generic_summary,
40
+ )
32
41
  from ai_data_science_team.tools.metadata import get_dataframe_summary
33
42
  from ai_data_science_team.tools.logging import log_ai_function
43
+ from ai_data_science_team.utils.plotly import plotly_from_dict
34
44
 
35
45
  # Setup
36
46
  AGENT_NAME = "data_visualization_agent"
37
47
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
38
48
 
49
+ # Class
50
+
51
+ class DataVisualizationAgent(BaseAgent):
52
+ """
53
+ Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
54
+ default visualization steps (if any). The agent generates a Python function to produce the visualization,
55
+ executes it, and logs the process, including code and errors. It is designed to facilitate reproducible
56
+ and customizable data visualization workflows.
57
+
58
+ The agent may use default instructions for creating charts unless instructed otherwise, such as:
59
+ - Generating a recommended chart type (bar, scatter, line, etc.)
60
+ - Creating user-friendly titles and axis labels
61
+ - Applying consistent styling (template, font sizes, color themes)
62
+ - Handling theme details (white background, base font size, line size, etc.)
63
+
64
+ User instructions can modify, add, or remove any of these steps to tailor the visualization process.
65
+
66
+ Parameters
67
+ ----------
68
+ model : langchain.llms.base.LLM
69
+ The language model used to generate the data visualization function.
70
+ n_samples : int, optional
71
+ Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
72
+ Reducing this number can help avoid exceeding the model's token limits.
73
+ log : bool, optional
74
+ Whether to log the generated code and errors. Defaults to False.
75
+ log_path : str, optional
76
+ Directory path for storing log files. Defaults to None.
77
+ file_name : str, optional
78
+ Name of the file for saving the generated response. Defaults to "data_visualization.py".
79
+ function_name : str, optional
80
+ Name of the function for data visualization. Defaults to "data_visualization".
81
+ overwrite : bool, optional
82
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
83
+ human_in_the_loop : bool, optional
84
+ Enables user review of data visualization instructions. Defaults to False.
85
+ bypass_recommended_steps : bool, optional
86
+ If True, skips the default recommended visualization steps. Defaults to False.
87
+ bypass_explain_code : bool, optional
88
+ If True, skips the step that provides code explanations. Defaults to False.
89
+
90
+ Methods
91
+ -------
92
+ update_params(**kwargs)
93
+ Updates the agent's parameters and rebuilds the compiled state graph.
94
+ ainvoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
95
+ Asynchronously generates a visualization based on user instructions.
96
+ invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
97
+ Synchronously generates a visualization based on user instructions.
98
+ get_workflow_summary()
99
+ Retrieves a summary of the agent's workflow.
100
+ get_log_summary()
101
+ Retrieves a summary of logged operations if logging is enabled.
102
+ get_plotly_graph()
103
+ Retrieves the Plotly graph (as a dictionary) produced by the agent.
104
+ get_data_raw()
105
+ Retrieves the raw dataset as a pandas DataFrame (based on the last response).
106
+ get_data_visualization_function()
107
+ Retrieves the generated Python function used for data visualization.
108
+ get_recommended_visualization_steps()
109
+ Retrieves the agent's recommended visualization steps.
110
+ get_response()
111
+ Returns the response from the agent as a dictionary.
112
+ show()
113
+ Displays the agent's mermaid diagram.
114
+
115
+ Examples
116
+ --------
117
+ ```python
118
+ import pandas as pd
119
+ from langchain_openai import ChatOpenAI
120
+ from ai_data_science_team.agents import DataVisualizationAgent
121
+
122
+ llm = ChatOpenAI(model="gpt-4o-mini")
123
+
124
+ data_visualization_agent = DataVisualizationAgent(
125
+ model=llm,
126
+ n_samples=30,
127
+ log=True,
128
+ log_path="logs",
129
+ human_in_the_loop=True
130
+ )
131
+
132
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
133
+
134
+ data_visualization_agent.invoke_agent(
135
+ user_instructions="Generate a scatter plot of age vs. total charges with a trend line.",
136
+ data_raw=df,
137
+ max_retries=3,
138
+ retry_count=0
139
+ )
140
+
141
+ plotly_graph_dict = data_visualization_agent.get_plotly_graph()
142
+ # You can render plotly_graph_dict with plotly.io.from_json or
143
+ # something similar in a Jupyter Notebook.
144
+
145
+ response = data_visualization_agent.get_response()
146
+ ```
147
+
148
+ Returns
149
+ --------
150
+ DataVisualizationAgent : langchain.graphs.CompiledStateGraph
151
+ A data visualization agent implemented as a compiled state graph.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ model,
157
+ n_samples=30,
158
+ log=False,
159
+ log_path=None,
160
+ file_name="data_visualization.py",
161
+ function_name="data_visualization",
162
+ overwrite=True,
163
+ human_in_the_loop=False,
164
+ bypass_recommended_steps=False,
165
+ bypass_explain_code=False
166
+ ):
167
+ self._params = {
168
+ "model": model,
169
+ "n_samples": n_samples,
170
+ "log": log,
171
+ "log_path": log_path,
172
+ "file_name": file_name,
173
+ "function_name": function_name,
174
+ "overwrite": overwrite,
175
+ "human_in_the_loop": human_in_the_loop,
176
+ "bypass_recommended_steps": bypass_recommended_steps,
177
+ "bypass_explain_code": bypass_explain_code,
178
+ }
179
+ self._compiled_graph = self._make_compiled_graph()
180
+ self.response = None
181
+
182
+ def _make_compiled_graph(self):
183
+ """
184
+ Create the compiled graph for the data visualization agent.
185
+ Running this method will reset the response to None.
186
+ """
187
+ self.response = None
188
+ return make_data_visualization_agent(**self._params)
189
+
190
+ def update_params(self, **kwargs):
191
+ """
192
+ Updates the agent's parameters and rebuilds the compiled graph.
193
+ """
194
+ # Update parameters
195
+ for k, v in kwargs.items():
196
+ self._params[k] = v
197
+ # Rebuild the compiled graph
198
+ self._compiled_graph = self._make_compiled_graph()
199
+
200
+ def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
201
+ """
202
+ Asynchronously invokes the agent to generate a visualization.
203
+ The response is stored in the 'response' attribute.
204
+
205
+ Parameters
206
+ ----------
207
+ data_raw : pd.DataFrame
208
+ The raw dataset to be visualized.
209
+ user_instructions : str
210
+ Instructions for data visualization.
211
+ max_retries : int
212
+ Maximum retry attempts.
213
+ retry_count : int
214
+ Current retry attempt count.
215
+ **kwargs : dict
216
+ Additional keyword arguments passed to ainvoke().
217
+
218
+ Returns
219
+ -------
220
+ None
221
+ """
222
+ response = self._compiled_graph.ainvoke({
223
+ "user_instructions": user_instructions,
224
+ "data_raw": data_raw.to_dict(),
225
+ "max_retries": max_retries,
226
+ "retry_count": retry_count,
227
+ }, **kwargs)
228
+ self.response = response
229
+ return None
230
+
231
+ def invoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
232
+ """
233
+ Synchronously invokes the agent to generate a visualization.
234
+ The response is stored in the 'response' attribute.
235
+
236
+ Parameters
237
+ ----------
238
+ data_raw : pd.DataFrame
239
+ The raw dataset to be visualized.
240
+ user_instructions : str
241
+ Instructions for data visualization agent.
242
+ max_retries : int
243
+ Maximum retry attempts.
244
+ retry_count : int
245
+ Current retry attempt count.
246
+ **kwargs : dict
247
+ Additional keyword arguments passed to invoke().
248
+
249
+ Returns
250
+ -------
251
+ None
252
+ """
253
+ response = self._compiled_graph.invoke({
254
+ "user_instructions": user_instructions,
255
+ "data_raw": data_raw.to_dict(),
256
+ "max_retries": max_retries,
257
+ "retry_count": retry_count,
258
+ }, **kwargs)
259
+ self.response = response
260
+ return None
261
+
262
+ def get_workflow_summary(self, markdown=False):
263
+ """
264
+ Retrieves the agent's workflow summary, if logging is enabled.
265
+ """
266
+ if self.response and self.response.get("messages"):
267
+ summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
268
+ if markdown:
269
+ return Markdown(summary)
270
+ else:
271
+ return summary
272
+
273
+ def get_log_summary(self, markdown=False):
274
+ """
275
+ Logs a summary of the agent's operations, if logging is enabled.
276
+ """
277
+ if self.response:
278
+ if self.response.get('data_visualization_function_path'):
279
+ log_details = f"""
280
+ ## Data Visualization Agent Log Summary:
281
+
282
+ Function Path: {self.response.get('data_visualization_function_path')}
283
+
284
+ Function Name: {self.response.get('data_visualization_function_name')}
285
+ """
286
+ if markdown:
287
+ return Markdown(log_details)
288
+ else:
289
+ return log_details
290
+
291
+ def get_plotly_graph(self):
292
+ """
293
+ Retrieves the Plotly graph (in dictionary form) produced by the agent.
294
+
295
+ Returns
296
+ -------
297
+ dict or None
298
+ The Plotly graph dictionary if available, otherwise None.
299
+ """
300
+ if self.response:
301
+ return plotly_from_dict(self.response.get("plotly_graph", None))
302
+ return None
303
+
304
+ def get_data_raw(self):
305
+ """
306
+ Retrieves the raw dataset used in the last invocation.
307
+
308
+ Returns
309
+ -------
310
+ pd.DataFrame or None
311
+ The raw dataset as a DataFrame if available, otherwise None.
312
+ """
313
+ if self.response and self.response.get("data_raw"):
314
+ return pd.DataFrame(self.response.get("data_raw"))
315
+ return None
316
+
317
+ def get_data_visualization_function(self, markdown=False):
318
+ """
319
+ Retrieves the generated Python function used for data visualization.
320
+
321
+ Parameters
322
+ ----------
323
+ markdown : bool, optional
324
+ If True, returns the function in Markdown code block format.
325
+
326
+ Returns
327
+ -------
328
+ str or None
329
+ The Python function code as a string if available, otherwise None.
330
+ """
331
+ if self.response:
332
+ func_code = self.response.get("data_visualization_function", "")
333
+ if markdown:
334
+ return Markdown(f"```python\n{func_code}\n```")
335
+ return func_code
336
+ return None
337
+
338
+ def get_recommended_visualization_steps(self, markdown=False):
339
+ """
340
+ Retrieves the agent's recommended visualization steps.
341
+
342
+ Parameters
343
+ ----------
344
+ markdown : bool, optional
345
+ If True, returns the steps in Markdown format.
346
+
347
+ Returns
348
+ -------
349
+ str or None
350
+ The recommended steps if available, otherwise None.
351
+ """
352
+ if self.response:
353
+ steps = self.response.get("recommended_steps", "")
354
+ if markdown:
355
+ return Markdown(steps)
356
+ return steps
357
+ return None
358
+
359
+ def get_response(self):
360
+ """
361
+ Returns the agent's full response dictionary.
362
+
363
+ Returns
364
+ -------
365
+ dict or None
366
+ The response dictionary if available, otherwise None.
367
+ """
368
+ return self.response
369
+
370
+ def show(self):
371
+ """
372
+ Displays the agent's mermaid diagram for visual inspection of the compiled graph.
373
+ """
374
+ return self._compiled_graph.show()
375
+
376
+
39
377
  # Agent
40
378
 
41
379
  def make_data_visualization_agent(
@@ -44,14 +382,85 @@ def make_data_visualization_agent(
44
382
  log=False,
45
383
  log_path=None,
46
384
  file_name="data_visualization.py",
47
- overwrite = True,
385
+ function_name="data_visualization",
386
+ overwrite=True,
48
387
  human_in_the_loop=False,
49
388
  bypass_recommended_steps=False,
50
389
  bypass_explain_code=False
51
390
  ):
391
+ """
392
+ Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
393
+ default visualization steps. The agent generates a Python function to produce the visualization, executes it,
394
+ and logs the process, including code and errors. It is designed to facilitate reproducible and customizable
395
+ data visualization workflows.
396
+
397
+ The agent can perform the following default visualization steps unless instructed otherwise:
398
+ - Generating a recommended chart type (bar, scatter, line, etc.)
399
+ - Creating user-friendly titles and axis labels
400
+ - Applying consistent styling (template, font sizes, color themes)
401
+ - Handling theme details (white background, base font size, line size, etc.)
402
+
403
+ User instructions can modify, add, or remove any of these steps to tailor the visualization process.
404
+
405
+ Parameters
406
+ ----------
407
+ model : langchain.llms.base.LLM
408
+ The language model used to generate the data visualization function.
409
+ n_samples : int, optional
410
+ Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
411
+ log : bool, optional
412
+ Whether to log the generated code and errors. Defaults to False.
413
+ log_path : str, optional
414
+ Directory path for storing log files. Defaults to None.
415
+ file_name : str, optional
416
+ Name of the file for saving the generated response. Defaults to "data_visualization.py".
417
+ function_name : str, optional
418
+ Name of the function for data visualization. Defaults to "data_visualization".
419
+ overwrite : bool, optional
420
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
421
+ human_in_the_loop : bool, optional
422
+ Enables user review of data visualization instructions. Defaults to False.
423
+ bypass_recommended_steps : bool, optional
424
+ If True, skips the default recommended visualization steps. Defaults to False.
425
+ bypass_explain_code : bool, optional
426
+ If True, skips the step that provides code explanations. Defaults to False.
427
+
428
+ Examples
429
+ --------
430
+ ``` python
431
+ import pandas as pd
432
+ from langchain_openai import ChatOpenAI
433
+ from ai_data_science_team.agents import data_visualization_agent
434
+
435
+ llm = ChatOpenAI(model="gpt-4o-mini")
436
+
437
+ data_visualization_agent = make_data_visualization_agent(llm)
438
+
439
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
440
+
441
+ response = data_visualization_agent.invoke({
442
+ "user_instructions": "Generate a scatter plot of tenure vs. total charges with a trend line.",
443
+ "data_raw": df.to_dict(),
444
+ "max_retries": 3,
445
+ "retry_count": 0
446
+ })
447
+
448
+ pd.DataFrame(response['plotly_graph'])
449
+ ```
450
+
451
+ Returns
452
+ -------
453
+ app : langchain.graphs.CompiledStateGraph
454
+ The data visualization agent as a state graph.
455
+ """
52
456
 
53
457
  llm = model
54
458
 
459
+ # Human in th loop requires recommended steps
460
+ if bypass_recommended_steps and human_in_the_loop:
461
+ bypass_recommended_steps = False
462
+ print("Bypass recommended steps set to False to enable human in the loop.")
463
+
55
464
  # Setup Log Directory
56
465
  if log:
57
466
  if log_path is None:
@@ -70,6 +479,7 @@ def make_data_visualization_agent(
70
479
  all_datasets_summary: str
71
480
  data_visualization_function: str
72
481
  data_visualization_function_path: str
482
+ data_visualization_function_file_name: str
73
483
  data_visualization_function_name: str
74
484
  data_visualization_error: str
75
485
  max_retries: int
@@ -140,7 +550,7 @@ def make_data_visualization_agent(
140
550
  })
141
551
 
142
552
  return {
143
- "recommended_steps": "\n\n# Recommended Data Cleaning Steps:\n" + recommended_steps.content.strip(),
553
+ "recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Cleaning Steps:"),
144
554
  "all_datasets_summary": all_datasets_summary_str
145
555
  }
146
556
 
@@ -169,7 +579,7 @@ def make_data_visualization_agent(
169
579
  template="""
170
580
  You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
171
581
 
172
- Your job is to produce python code to generate visualizations.
582
+ Your job is to produce python code to generate visualizations with a function named {function_name}.
173
583
 
174
584
  You will take instructions from a Chart Instructor and generate a plotly chart from the data provided.
175
585
 
@@ -181,13 +591,13 @@ def make_data_visualization_agent(
181
591
 
182
592
  RETURN:
183
593
 
184
- Return Python code in ```python ``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
594
+ Return Python code in ```python ``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
185
595
 
186
596
  Return the plotly chart as a dictionary.
187
597
 
188
598
  Return code to provide the data visualization function:
189
599
 
190
- def data_visualization(data_raw):
600
+ def {function_name}(data_raw):
191
601
  import pandas as pd
192
602
  import numpy as np
193
603
  import json
@@ -206,14 +616,15 @@ def make_data_visualization_agent(
206
616
  2. Do not include unrelated user instructions that are not related to the chart generation.
207
617
 
208
618
  """,
209
- input_variables=["chart_generator_instructions", "all_datasets_summary"]
619
+ input_variables=["chart_generator_instructions", "all_datasets_summary", "function_name"]
210
620
  )
211
-
621
+
212
622
  data_visualization_agent = prompt_template | llm | PythonOutputParser()
213
623
 
214
624
  response = data_visualization_agent.invoke({
215
625
  "chart_generator_instructions": chart_generator_instructions,
216
- "all_datasets_summary": all_datasets_summary_str
626
+ "all_datasets_summary": all_datasets_summary_str,
627
+ "function_name": function_name
217
628
  })
218
629
 
219
630
  response = relocate_imports_inside_function(response)
@@ -231,19 +642,37 @@ def make_data_visualization_agent(
231
642
  return {
232
643
  "data_visualization_function": response,
233
644
  "data_visualization_function_path": file_path,
234
- "data_visualization_function_name": file_name_2,
645
+ "data_visualization_function_file_name": file_name_2,
646
+ "data_visualization_function_name": function_name,
235
647
  "all_datasets_summary": all_datasets_summary_str
236
648
  }
237
-
238
- def human_review(state: GraphState) -> Command[Literal["chart_instructor", "chart_generator"]]:
239
- return node_func_human_review(
240
- state=state,
241
- prompt_text="Is the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}",
242
- yes_goto="chart_generator",
243
- no_goto="chart_instructor",
244
- user_instructions_key="user_instructions",
245
- recommended_steps_key="recommended_steps"
246
- )
649
+
650
+ # Human Review
651
+
652
+ prompt_text_human_review = "Are the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
653
+
654
+ if not bypass_explain_code:
655
+ def human_review(state: GraphState) -> Command[Literal["chart_instructor", "explain_data_visualization_code"]]:
656
+ return node_func_human_review(
657
+ state=state,
658
+ prompt_text=prompt_text_human_review,
659
+ yes_goto= 'explain_data_visualization_code',
660
+ no_goto="chart_instructor",
661
+ user_instructions_key="user_instructions",
662
+ recommended_steps_key="recommended_steps",
663
+ code_snippet_key="data_visualization_function",
664
+ )
665
+ else:
666
+ def human_review(state: GraphState) -> Command[Literal["chart_instructor", "__end__"]]:
667
+ return node_func_human_review(
668
+ state=state,
669
+ prompt_text=prompt_text_human_review,
670
+ yes_goto= '__end__',
671
+ no_goto="chart_instructor",
672
+ user_instructions_key="user_instructions",
673
+ recommended_steps_key="recommended_steps",
674
+ code_snippet_key="data_visualization_function",
675
+ )
247
676
 
248
677
 
249
678
  def execute_data_visualization_code(state):
@@ -253,7 +682,7 @@ def make_data_visualization_agent(
253
682
  result_key="plotly_graph",
254
683
  error_key="data_visualization_error",
255
684
  code_snippet_key="data_visualization_function",
256
- agent_function_name="data_visualization",
685
+ agent_function_name=state.get("data_visualization_function_name"),
257
686
  pre_processing=lambda data: pd.DataFrame.from_dict(data),
258
687
  # post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
259
688
  error_message_prefix="An error occurred during data visualization: "
@@ -261,11 +690,11 @@ def make_data_visualization_agent(
261
690
 
262
691
  def fix_data_visualization_code(state: GraphState):
263
692
  prompt = """
264
- You are a Data Visualization Agent. Your job is to create a data_visualization() function that can be run on the data provided. The function is currently broken and needs to be fixed.
693
+ You are a Data Visualization Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
265
694
 
266
- Make sure to only return the function definition for data_visualization().
695
+ Make sure to only return the function definition for {function_name}().
267
696
 
268
- Return Python code in ```python``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
697
+ Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
269
698
 
270
699
  This is the broken code (please fix):
271
700
  {code_snippet}
@@ -283,22 +712,23 @@ def make_data_visualization_agent(
283
712
  agent_name=AGENT_NAME,
284
713
  log=log,
285
714
  file_path=state.get("data_visualization_function_path"),
715
+ function_name=state.get("data_visualization_function_name"),
286
716
  )
287
717
 
288
- def explain_data_visualization_code(state: GraphState):
289
- return node_func_explain_agent_code(
718
+ # Final reporting node
719
+ def report_agent_outputs(state: GraphState):
720
+ return node_func_report_agent_outputs(
290
721
  state=state,
291
- code_snippet_key="data_visualization_function",
722
+ keys_to_include=[
723
+ "recommended_steps",
724
+ "data_visualization_function",
725
+ "data_visualization_function_path",
726
+ "data_visualization_function_name",
727
+ "data_visualization_error",
728
+ ],
292
729
  result_key="messages",
293
- error_key="data_visualization_error",
294
- llm=llm,
295
730
  role=AGENT_NAME,
296
- explanation_prompt_template="""
297
- Explain the data visualization steps that the data visualization agent performed in this function.
298
- Keep the summary succinct and to the point.\n\n# Data Visualization Agent:\n\n{code}
299
- """,
300
- success_prefix="# Data Visualization Agent:\n\n ",
301
- error_message="The Data Visualization Agent encountered an error during data visualization. No explanation could be provided."
731
+ custom_title="Data Visualization Agent Outputs"
302
732
  )
303
733
 
304
734
  # Define the graph
@@ -308,7 +738,7 @@ def make_data_visualization_agent(
308
738
  "chart_generator": chart_generator,
309
739
  "execute_data_visualization_code": execute_data_visualization_code,
310
740
  "fix_data_visualization_code": fix_data_visualization_code,
311
- "explain_data_visualization_code": explain_data_visualization_code
741
+ "report_agent_outputs": report_agent_outputs,
312
742
  }
313
743
 
314
744
  app = create_coding_agent_graph(
@@ -318,7 +748,7 @@ def make_data_visualization_agent(
318
748
  create_code_node_name="chart_generator",
319
749
  execute_code_node_name="execute_data_visualization_code",
320
750
  fix_code_node_name="fix_data_visualization_code",
321
- explain_code_node_name="explain_data_visualization_code",
751
+ explain_code_node_name="report_agent_outputs",
322
752
  error_key="data_visualization_error",
323
753
  human_in_the_loop=human_in_the_loop, # or False
324
754
  human_review_node_name="human_review",
@@ -328,4 +758,3 @@ def make_data_visualization_agent(
328
758
  )
329
759
 
330
760
  return app
331
-