ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +4 -5
  3. ai_data_science_team/agents/data_cleaning_agent.py +268 -116
  4. ai_data_science_team/agents/data_visualization_agent.py +470 -41
  5. ai_data_science_team/agents/data_wrangling_agent.py +471 -31
  6. ai_data_science_team/agents/feature_engineering_agent.py +426 -41
  7. ai_data_science_team/agents/sql_database_agent.py +458 -58
  8. ai_data_science_team/ml_agents/__init__.py +1 -0
  9. ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
  10. ai_data_science_team/multiagents/__init__.py +1 -0
  11. ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
  12. ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
  13. ai_data_science_team/templates/__init__.py +3 -1
  14. ai_data_science_team/templates/agent_templates.py +319 -43
  15. ai_data_science_team/tools/metadata.py +94 -62
  16. ai_data_science_team/tools/regex.py +86 -1
  17. ai_data_science_team/utils/__init__.py +0 -0
  18. ai_data_science_team/utils/plotly.py +24 -0
  19. ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
  20. ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
  21. ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
  22. ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
  23. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
  24. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
  25. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -17,25 +17,363 @@ from langgraph.types import Command
17
17
  from langgraph.checkpoint.memory import MemorySaver
18
18
 
19
19
  import os
20
- import io
20
+ import json
21
21
  import pandas as pd
22
22
 
23
+ from IPython.display import Markdown
24
+
23
25
  from ai_data_science_team.templates import(
24
26
  node_func_execute_agent_code_on_data,
25
27
  node_func_human_review,
26
28
  node_func_fix_agent_code,
27
- node_func_explain_agent_code,
28
- create_coding_agent_graph
29
+ node_func_report_agent_outputs,
30
+ create_coding_agent_graph,
31
+ BaseAgent,
29
32
  )
30
33
  from ai_data_science_team.tools.parsers import PythonOutputParser
31
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
34
+ from ai_data_science_team.tools.regex import (
35
+ relocate_imports_inside_function,
36
+ add_comments_to_top,
37
+ format_agent_name,
38
+ format_recommended_steps,
39
+ get_generic_summary,
40
+ )
32
41
  from ai_data_science_team.tools.metadata import get_dataframe_summary
33
42
  from ai_data_science_team.tools.logging import log_ai_function
43
+ from ai_data_science_team.utils.plotly import plotly_from_dict
34
44
 
35
45
  # Setup
36
46
  AGENT_NAME = "data_visualization_agent"
37
47
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
38
48
 
49
+ # Class
50
+
51
+ class DataVisualizationAgent(BaseAgent):
52
+ """
53
+ Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
54
+ default visualization steps (if any). The agent generates a Python function to produce the visualization,
55
+ executes it, and logs the process, including code and errors. It is designed to facilitate reproducible
56
+ and customizable data visualization workflows.
57
+
58
+ The agent may use default instructions for creating charts unless instructed otherwise, such as:
59
+ - Generating a recommended chart type (bar, scatter, line, etc.)
60
+ - Creating user-friendly titles and axis labels
61
+ - Applying consistent styling (template, font sizes, color themes)
62
+ - Handling theme details (white background, base font size, line size, etc.)
63
+
64
+ User instructions can modify, add, or remove any of these steps to tailor the visualization process.
65
+
66
+ Parameters
67
+ ----------
68
+ model : langchain.llms.base.LLM
69
+ The language model used to generate the data visualization function.
70
+ n_samples : int, optional
71
+ Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
72
+ Reducing this number can help avoid exceeding the model's token limits.
73
+ log : bool, optional
74
+ Whether to log the generated code and errors. Defaults to False.
75
+ log_path : str, optional
76
+ Directory path for storing log files. Defaults to None.
77
+ file_name : str, optional
78
+ Name of the file for saving the generated response. Defaults to "data_visualization.py".
79
+ function_name : str, optional
80
+ Name of the function for data visualization. Defaults to "data_visualization".
81
+ overwrite : bool, optional
82
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
83
+ human_in_the_loop : bool, optional
84
+ Enables user review of data visualization instructions. Defaults to False.
85
+ bypass_recommended_steps : bool, optional
86
+ If True, skips the default recommended visualization steps. Defaults to False.
87
+ bypass_explain_code : bool, optional
88
+ If True, skips the step that provides code explanations. Defaults to False.
89
+
90
+ Methods
91
+ -------
92
+ update_params(**kwargs)
93
+ Updates the agent's parameters and rebuilds the compiled state graph.
94
+ ainvoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
95
+ Asynchronously generates a visualization based on user instructions.
96
+ invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
97
+ Synchronously generates a visualization based on user instructions.
98
+ get_workflow_summary()
99
+ Retrieves a summary of the agent's workflow.
100
+ get_log_summary()
101
+ Retrieves a summary of logged operations if logging is enabled.
102
+ get_plotly_graph()
103
+ Retrieves the Plotly graph (as a dictionary) produced by the agent.
104
+ get_data_raw()
105
+ Retrieves the raw dataset as a pandas DataFrame (based on the last response).
106
+ get_data_visualization_function()
107
+ Retrieves the generated Python function used for data visualization.
108
+ get_recommended_visualization_steps()
109
+ Retrieves the agent's recommended visualization steps.
110
+ get_response()
111
+ Returns the response from the agent as a dictionary.
112
+ show()
113
+ Displays the agent's mermaid diagram.
114
+
115
+ Examples
116
+ --------
117
+ ```python
118
+ import pandas as pd
119
+ from langchain_openai import ChatOpenAI
120
+ from ai_data_science_team.agents import DataVisualizationAgent
121
+
122
+ llm = ChatOpenAI(model="gpt-4o-mini")
123
+
124
+ data_visualization_agent = DataVisualizationAgent(
125
+ model=llm,
126
+ n_samples=30,
127
+ log=True,
128
+ log_path="logs",
129
+ human_in_the_loop=True
130
+ )
131
+
132
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
133
+
134
+ data_visualization_agent.invoke_agent(
135
+ user_instructions="Generate a scatter plot of age vs. total charges with a trend line.",
136
+ data_raw=df,
137
+ max_retries=3,
138
+ retry_count=0
139
+ )
140
+
141
+ plotly_graph_dict = data_visualization_agent.get_plotly_graph()
142
+ # You can render plotly_graph_dict with plotly.io.from_json or
143
+ # something similar in a Jupyter Notebook.
144
+
145
+ response = data_visualization_agent.get_response()
146
+ ```
147
+
148
+ Returns
149
+ --------
150
+ DataVisualizationAgent : langchain.graphs.CompiledStateGraph
151
+ A data visualization agent implemented as a compiled state graph.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ model,
157
+ n_samples=30,
158
+ log=False,
159
+ log_path=None,
160
+ file_name="data_visualization.py",
161
+ function_name="data_visualization",
162
+ overwrite=True,
163
+ human_in_the_loop=False,
164
+ bypass_recommended_steps=False,
165
+ bypass_explain_code=False
166
+ ):
167
+ self._params = {
168
+ "model": model,
169
+ "n_samples": n_samples,
170
+ "log": log,
171
+ "log_path": log_path,
172
+ "file_name": file_name,
173
+ "function_name": function_name,
174
+ "overwrite": overwrite,
175
+ "human_in_the_loop": human_in_the_loop,
176
+ "bypass_recommended_steps": bypass_recommended_steps,
177
+ "bypass_explain_code": bypass_explain_code,
178
+ }
179
+ self._compiled_graph = self._make_compiled_graph()
180
+ self.response = None
181
+
182
+ def _make_compiled_graph(self):
183
+ """
184
+ Create the compiled graph for the data visualization agent.
185
+ Running this method will reset the response to None.
186
+ """
187
+ self.response = None
188
+ return make_data_visualization_agent(**self._params)
189
+
190
+ def update_params(self, **kwargs):
191
+ """
192
+ Updates the agent's parameters and rebuilds the compiled graph.
193
+ """
194
+ # Update parameters
195
+ for k, v in kwargs.items():
196
+ self._params[k] = v
197
+ # Rebuild the compiled graph
198
+ self._compiled_graph = self._make_compiled_graph()
199
+
200
+ def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
201
+ """
202
+ Asynchronously invokes the agent to generate a visualization.
203
+ The response is stored in the 'response' attribute.
204
+
205
+ Parameters
206
+ ----------
207
+ data_raw : pd.DataFrame
208
+ The raw dataset to be visualized.
209
+ user_instructions : str
210
+ Instructions for data visualization.
211
+ max_retries : int
212
+ Maximum retry attempts.
213
+ retry_count : int
214
+ Current retry attempt count.
215
+ **kwargs : dict
216
+ Additional keyword arguments passed to ainvoke().
217
+
218
+ Returns
219
+ -------
220
+ None
221
+ """
222
+ response = self._compiled_graph.ainvoke({
223
+ "user_instructions": user_instructions,
224
+ "data_raw": data_raw.to_dict(),
225
+ "max_retries": max_retries,
226
+ "retry_count": retry_count,
227
+ }, **kwargs)
228
+ self.response = response
229
+ return None
230
+
231
+ def invoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
232
+ """
233
+ Synchronously invokes the agent to generate a visualization.
234
+ The response is stored in the 'response' attribute.
235
+
236
+ Parameters
237
+ ----------
238
+ data_raw : pd.DataFrame
239
+ The raw dataset to be visualized.
240
+ user_instructions : str
241
+ Instructions for data visualization agent.
242
+ max_retries : int
243
+ Maximum retry attempts.
244
+ retry_count : int
245
+ Current retry attempt count.
246
+ **kwargs : dict
247
+ Additional keyword arguments passed to invoke().
248
+
249
+ Returns
250
+ -------
251
+ None
252
+ """
253
+ response = self._compiled_graph.invoke({
254
+ "user_instructions": user_instructions,
255
+ "data_raw": data_raw.to_dict(),
256
+ "max_retries": max_retries,
257
+ "retry_count": retry_count,
258
+ }, **kwargs)
259
+ self.response = response
260
+ return None
261
+
262
+ def get_workflow_summary(self, markdown=False):
263
+ """
264
+ Retrieves the agent's workflow summary, if logging is enabled.
265
+ """
266
+ if self.response and self.response.get("messages"):
267
+ summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
268
+ if markdown:
269
+ return Markdown(summary)
270
+ else:
271
+ return summary
272
+
273
+ def get_log_summary(self, markdown=False):
274
+ """
275
+ Logs a summary of the agent's operations, if logging is enabled.
276
+ """
277
+ if self.response:
278
+ if self.response.get('data_visualization_function_path'):
279
+ log_details = f"""
280
+ ## Data Visualization Agent Log Summary:
281
+
282
+ Function Path: {self.response.get('data_visualization_function_path')}
283
+
284
+ Function Name: {self.response.get('data_visualization_function_name')}
285
+ """
286
+ if markdown:
287
+ return Markdown(log_details)
288
+ else:
289
+ return log_details
290
+
291
+ def get_plotly_graph(self):
292
+ """
293
+ Retrieves the Plotly graph (in dictionary form) produced by the agent.
294
+
295
+ Returns
296
+ -------
297
+ dict or None
298
+ The Plotly graph dictionary if available, otherwise None.
299
+ """
300
+ if self.response:
301
+ return plotly_from_dict(self.response.get("plotly_graph", None))
302
+ return None
303
+
304
+ def get_data_raw(self):
305
+ """
306
+ Retrieves the raw dataset used in the last invocation.
307
+
308
+ Returns
309
+ -------
310
+ pd.DataFrame or None
311
+ The raw dataset as a DataFrame if available, otherwise None.
312
+ """
313
+ if self.response and self.response.get("data_raw"):
314
+ return pd.DataFrame(self.response.get("data_raw"))
315
+ return None
316
+
317
+ def get_data_visualization_function(self, markdown=False):
318
+ """
319
+ Retrieves the generated Python function used for data visualization.
320
+
321
+ Parameters
322
+ ----------
323
+ markdown : bool, optional
324
+ If True, returns the function in Markdown code block format.
325
+
326
+ Returns
327
+ -------
328
+ str or None
329
+ The Python function code as a string if available, otherwise None.
330
+ """
331
+ if self.response:
332
+ func_code = self.response.get("data_visualization_function", "")
333
+ if markdown:
334
+ return Markdown(f"```python\n{func_code}\n```")
335
+ return func_code
336
+ return None
337
+
338
+ def get_recommended_visualization_steps(self, markdown=False):
339
+ """
340
+ Retrieves the agent's recommended visualization steps.
341
+
342
+ Parameters
343
+ ----------
344
+ markdown : bool, optional
345
+ If True, returns the steps in Markdown format.
346
+
347
+ Returns
348
+ -------
349
+ str or None
350
+ The recommended steps if available, otherwise None.
351
+ """
352
+ if self.response:
353
+ steps = self.response.get("recommended_steps", "")
354
+ if markdown:
355
+ return Markdown(steps)
356
+ return steps
357
+ return None
358
+
359
+ def get_response(self):
360
+ """
361
+ Returns the agent's full response dictionary.
362
+
363
+ Returns
364
+ -------
365
+ dict or None
366
+ The response dictionary if available, otherwise None.
367
+ """
368
+ return self.response
369
+
370
+ def show(self):
371
+ """
372
+ Displays the agent's mermaid diagram for visual inspection of the compiled graph.
373
+ """
374
+ return self._compiled_graph.show()
375
+
376
+
39
377
  # Agent
40
378
 
41
379
  def make_data_visualization_agent(
@@ -44,14 +382,85 @@ def make_data_visualization_agent(
44
382
  log=False,
45
383
  log_path=None,
46
384
  file_name="data_visualization.py",
47
- overwrite = True,
385
+ function_name="data_visualization",
386
+ overwrite=True,
48
387
  human_in_the_loop=False,
49
388
  bypass_recommended_steps=False,
50
389
  bypass_explain_code=False
51
390
  ):
391
+ """
392
+ Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
393
+ default visualization steps. The agent generates a Python function to produce the visualization, executes it,
394
+ and logs the process, including code and errors. It is designed to facilitate reproducible and customizable
395
+ data visualization workflows.
396
+
397
+ The agent can perform the following default visualization steps unless instructed otherwise:
398
+ - Generating a recommended chart type (bar, scatter, line, etc.)
399
+ - Creating user-friendly titles and axis labels
400
+ - Applying consistent styling (template, font sizes, color themes)
401
+ - Handling theme details (white background, base font size, line size, etc.)
402
+
403
+ User instructions can modify, add, or remove any of these steps to tailor the visualization process.
404
+
405
+ Parameters
406
+ ----------
407
+ model : langchain.llms.base.LLM
408
+ The language model used to generate the data visualization function.
409
+ n_samples : int, optional
410
+ Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
411
+ log : bool, optional
412
+ Whether to log the generated code and errors. Defaults to False.
413
+ log_path : str, optional
414
+ Directory path for storing log files. Defaults to None.
415
+ file_name : str, optional
416
+ Name of the file for saving the generated response. Defaults to "data_visualization.py".
417
+ function_name : str, optional
418
+ Name of the function for data visualization. Defaults to "data_visualization".
419
+ overwrite : bool, optional
420
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
421
+ human_in_the_loop : bool, optional
422
+ Enables user review of data visualization instructions. Defaults to False.
423
+ bypass_recommended_steps : bool, optional
424
+ If True, skips the default recommended visualization steps. Defaults to False.
425
+ bypass_explain_code : bool, optional
426
+ If True, skips the step that provides code explanations. Defaults to False.
427
+
428
+ Examples
429
+ --------
430
+ ``` python
431
+ import pandas as pd
432
+ from langchain_openai import ChatOpenAI
433
+ from ai_data_science_team.agents import data_visualization_agent
434
+
435
+ llm = ChatOpenAI(model="gpt-4o-mini")
436
+
437
+ data_visualization_agent = make_data_visualization_agent(llm)
438
+
439
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
440
+
441
+ response = data_visualization_agent.invoke({
442
+ "user_instructions": "Generate a scatter plot of tenure vs. total charges with a trend line.",
443
+ "data_raw": df.to_dict(),
444
+ "max_retries": 3,
445
+ "retry_count": 0
446
+ })
447
+
448
+ pd.DataFrame(response['plotly_graph'])
449
+ ```
450
+
451
+ Returns
452
+ -------
453
+ app : langchain.graphs.CompiledStateGraph
454
+ The data visualization agent as a state graph.
455
+ """
52
456
 
53
457
  llm = model
54
458
 
459
+ # Human in th loop requires recommended steps
460
+ if bypass_recommended_steps and human_in_the_loop:
461
+ bypass_recommended_steps = False
462
+ print("Bypass recommended steps set to False to enable human in the loop.")
463
+
55
464
  # Setup Log Directory
56
465
  if log:
57
466
  if log_path is None:
@@ -70,6 +479,7 @@ def make_data_visualization_agent(
70
479
  all_datasets_summary: str
71
480
  data_visualization_function: str
72
481
  data_visualization_function_path: str
482
+ data_visualization_function_file_name: str
73
483
  data_visualization_function_name: str
74
484
  data_visualization_error: str
75
485
  max_retries: int
@@ -140,7 +550,7 @@ def make_data_visualization_agent(
140
550
  })
141
551
 
142
552
  return {
143
- "recommended_steps": "\n\n# Recommended Data Cleaning Steps:\n" + recommended_steps.content.strip(),
553
+ "recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Cleaning Steps:"),
144
554
  "all_datasets_summary": all_datasets_summary_str
145
555
  }
146
556
 
@@ -169,7 +579,7 @@ def make_data_visualization_agent(
169
579
  template="""
170
580
  You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
171
581
 
172
- Your job is to produce python code to generate visualizations.
582
+ Your job is to produce python code to generate visualizations with a function named {function_name}.
173
583
 
174
584
  You will take instructions from a Chart Instructor and generate a plotly chart from the data provided.
175
585
 
@@ -181,13 +591,13 @@ def make_data_visualization_agent(
181
591
 
182
592
  RETURN:
183
593
 
184
- Return Python code in ```python ``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
594
+ Return Python code in ```python ``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
185
595
 
186
596
  Return the plotly chart as a dictionary.
187
597
 
188
598
  Return code to provide the data visualization function:
189
599
 
190
- def data_visualization(data_raw):
600
+ def {function_name}(data_raw):
191
601
  import pandas as pd
192
602
  import numpy as np
193
603
  import json
@@ -206,14 +616,15 @@ def make_data_visualization_agent(
206
616
  2. Do not include unrelated user instructions that are not related to the chart generation.
207
617
 
208
618
  """,
209
- input_variables=["chart_generator_instructions", "all_datasets_summary"]
619
+ input_variables=["chart_generator_instructions", "all_datasets_summary", "function_name"]
210
620
  )
211
-
621
+
212
622
  data_visualization_agent = prompt_template | llm | PythonOutputParser()
213
623
 
214
624
  response = data_visualization_agent.invoke({
215
625
  "chart_generator_instructions": chart_generator_instructions,
216
- "all_datasets_summary": all_datasets_summary_str
626
+ "all_datasets_summary": all_datasets_summary_str,
627
+ "function_name": function_name
217
628
  })
218
629
 
219
630
  response = relocate_imports_inside_function(response)
@@ -231,19 +642,37 @@ def make_data_visualization_agent(
231
642
  return {
232
643
  "data_visualization_function": response,
233
644
  "data_visualization_function_path": file_path,
234
- "data_visualization_function_name": file_name_2,
645
+ "data_visualization_function_file_name": file_name_2,
646
+ "data_visualization_function_name": function_name,
235
647
  "all_datasets_summary": all_datasets_summary_str
236
648
  }
237
-
238
- def human_review(state: GraphState) -> Command[Literal["chart_instructor", "chart_generator"]]:
239
- return node_func_human_review(
240
- state=state,
241
- prompt_text="Is the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}",
242
- yes_goto="chart_generator",
243
- no_goto="chart_instructor",
244
- user_instructions_key="user_instructions",
245
- recommended_steps_key="recommended_steps"
246
- )
649
+
650
+ # Human Review
651
+
652
+ prompt_text_human_review = "Are the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
653
+
654
+ if not bypass_explain_code:
655
+ def human_review(state: GraphState) -> Command[Literal["chart_instructor", "explain_data_visualization_code"]]:
656
+ return node_func_human_review(
657
+ state=state,
658
+ prompt_text=prompt_text_human_review,
659
+ yes_goto= 'explain_data_visualization_code',
660
+ no_goto="chart_instructor",
661
+ user_instructions_key="user_instructions",
662
+ recommended_steps_key="recommended_steps",
663
+ code_snippet_key="data_visualization_function",
664
+ )
665
+ else:
666
+ def human_review(state: GraphState) -> Command[Literal["chart_instructor", "__end__"]]:
667
+ return node_func_human_review(
668
+ state=state,
669
+ prompt_text=prompt_text_human_review,
670
+ yes_goto= '__end__',
671
+ no_goto="chart_instructor",
672
+ user_instructions_key="user_instructions",
673
+ recommended_steps_key="recommended_steps",
674
+ code_snippet_key="data_visualization_function",
675
+ )
247
676
 
248
677
 
249
678
  def execute_data_visualization_code(state):
@@ -253,7 +682,7 @@ def make_data_visualization_agent(
253
682
  result_key="plotly_graph",
254
683
  error_key="data_visualization_error",
255
684
  code_snippet_key="data_visualization_function",
256
- agent_function_name="data_visualization",
685
+ agent_function_name=state.get("data_visualization_function_name"),
257
686
  pre_processing=lambda data: pd.DataFrame.from_dict(data),
258
687
  # post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
259
688
  error_message_prefix="An error occurred during data visualization: "
@@ -261,11 +690,11 @@ def make_data_visualization_agent(
261
690
 
262
691
  def fix_data_visualization_code(state: GraphState):
263
692
  prompt = """
264
- You are a Data Visualization Agent. Your job is to create a data_visualization() function that can be run on the data provided. The function is currently broken and needs to be fixed.
693
+ You are a Data Visualization Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
265
694
 
266
- Make sure to only return the function definition for data_visualization().
695
+ Make sure to only return the function definition for {function_name}().
267
696
 
268
- Return Python code in ```python``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
697
+ Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
269
698
 
270
699
  This is the broken code (please fix):
271
700
  {code_snippet}
@@ -283,22 +712,23 @@ def make_data_visualization_agent(
283
712
  agent_name=AGENT_NAME,
284
713
  log=log,
285
714
  file_path=state.get("data_visualization_function_path"),
715
+ function_name=state.get("data_visualization_function_name"),
286
716
  )
287
717
 
288
- def explain_data_visualization_code(state: GraphState):
289
- return node_func_explain_agent_code(
718
+ # Final reporting node
719
+ def report_agent_outputs(state: GraphState):
720
+ return node_func_report_agent_outputs(
290
721
  state=state,
291
- code_snippet_key="data_visualization_function",
722
+ keys_to_include=[
723
+ "recommended_steps",
724
+ "data_visualization_function",
725
+ "data_visualization_function_path",
726
+ "data_visualization_function_name",
727
+ "data_visualization_error",
728
+ ],
292
729
  result_key="messages",
293
- error_key="data_visualization_error",
294
- llm=llm,
295
730
  role=AGENT_NAME,
296
- explanation_prompt_template="""
297
- Explain the data visualization steps that the data visualization agent performed in this function.
298
- Keep the summary succinct and to the point.\n\n# Data Visualization Agent:\n\n{code}
299
- """,
300
- success_prefix="# Data Visualization Agent:\n\n ",
301
- error_message="The Data Visualization Agent encountered an error during data visualization. No explanation could be provided."
731
+ custom_title="Data Visualization Agent Outputs"
302
732
  )
303
733
 
304
734
  # Define the graph
@@ -308,7 +738,7 @@ def make_data_visualization_agent(
308
738
  "chart_generator": chart_generator,
309
739
  "execute_data_visualization_code": execute_data_visualization_code,
310
740
  "fix_data_visualization_code": fix_data_visualization_code,
311
- "explain_data_visualization_code": explain_data_visualization_code
741
+ "report_agent_outputs": report_agent_outputs,
312
742
  }
313
743
 
314
744
  app = create_coding_agent_graph(
@@ -318,7 +748,7 @@ def make_data_visualization_agent(
318
748
  create_code_node_name="chart_generator",
319
749
  execute_code_node_name="execute_data_visualization_code",
320
750
  fix_code_node_name="fix_data_visualization_code",
321
- explain_code_node_name="explain_data_visualization_code",
751
+ explain_code_node_name="report_agent_outputs",
322
752
  error_key="data_visualization_error",
323
753
  human_in_the_loop=human_in_the_loop, # or False
324
754
  human_review_node_name="human_review",
@@ -328,4 +758,3 @@ def make_data_visualization_agent(
328
758
  )
329
759
 
330
760
  return app
331
-