ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,764 @@
1
+ # BUSINESS SCIENCE UNIVERSITY
2
+ # AI DATA SCIENCE TEAM
3
+ # ***
4
+ # * Agents: Data Visualization Agent
5
+
6
+
7
+
8
+ # Libraries
9
+ from typing import TypedDict, Annotated, Sequence, Literal
10
+ import operator
11
+
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain_core.output_parsers import StrOutputParser
14
+ from langchain_core.messages import BaseMessage
15
+
16
+ from langgraph.types import Command
17
+ from langgraph.checkpoint.memory import MemorySaver
18
+
19
+ import os
20
+ import pandas as pd
21
+
22
+ from IPython.display import Markdown
23
+
24
+ from ai_data_science_team.templates import(
25
+ node_func_execute_agent_code_on_data,
26
+ node_func_human_review,
27
+ node_func_fix_agent_code,
28
+ node_func_explain_agent_code,
29
+ create_coding_agent_graph,
30
+ BaseAgent,
31
+ )
32
+ from ai_data_science_team.tools.parsers import PythonOutputParser
33
+ from ai_data_science_team.tools.regex import (
34
+ relocate_imports_inside_function,
35
+ add_comments_to_top,
36
+ format_agent_name,
37
+ format_recommended_steps
38
+ )
39
+ from ai_data_science_team.tools.metadata import get_dataframe_summary
40
+ from ai_data_science_team.tools.logging import log_ai_function
41
+ from ai_data_science_team.utils.plotly import plotly_from_dict
42
+
43
+ # Setup
44
+ AGENT_NAME = "data_visualization_agent"
45
+ LOG_PATH = os.path.join(os.getcwd(), "logs/")
46
+
47
+ # Class
48
+
49
+ class DataVisualizationAgent(BaseAgent):
50
+ """
51
+ Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
52
+ default visualization steps (if any). The agent generates a Python function to produce the visualization,
53
+ executes it, and logs the process, including code and errors. It is designed to facilitate reproducible
54
+ and customizable data visualization workflows.
55
+
56
+ The agent may use default instructions for creating charts unless instructed otherwise, such as:
57
+ - Generating a recommended chart type (bar, scatter, line, etc.)
58
+ - Creating user-friendly titles and axis labels
59
+ - Applying consistent styling (template, font sizes, color themes)
60
+ - Handling theme details (white background, base font size, line size, etc.)
61
+
62
+ User instructions can modify, add, or remove any of these steps to tailor the visualization process.
63
+
64
+ Parameters
65
+ ----------
66
+ model : langchain.llms.base.LLM
67
+ The language model used to generate the data visualization function.
68
+ n_samples : int, optional
69
+ Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
70
+ Reducing this number can help avoid exceeding the model's token limits.
71
+ log : bool, optional
72
+ Whether to log the generated code and errors. Defaults to False.
73
+ log_path : str, optional
74
+ Directory path for storing log files. Defaults to None.
75
+ file_name : str, optional
76
+ Name of the file for saving the generated response. Defaults to "data_visualization.py".
77
+ function_name : str, optional
78
+ Name of the function for data visualization. Defaults to "data_visualization".
79
+ overwrite : bool, optional
80
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
81
+ human_in_the_loop : bool, optional
82
+ Enables user review of data visualization instructions. Defaults to False.
83
+ bypass_recommended_steps : bool, optional
84
+ If True, skips the default recommended visualization steps. Defaults to False.
85
+ bypass_explain_code : bool, optional
86
+ If True, skips the step that provides code explanations. Defaults to False.
87
+
88
+ Methods
89
+ -------
90
+ update_params(**kwargs)
91
+ Updates the agent's parameters and rebuilds the compiled state graph.
92
+ ainvoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
93
+ Asynchronously generates a visualization based on user instructions.
94
+ invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
95
+ Synchronously generates a visualization based on user instructions.
96
+ explain_visualization_steps()
97
+ Returns an explanation of the visualization steps performed by the agent.
98
+ get_log_summary()
99
+ Retrieves a summary of logged operations if logging is enabled.
100
+ get_plotly_graph()
101
+ Retrieves the Plotly graph (as a dictionary) produced by the agent.
102
+ get_data_raw()
103
+ Retrieves the raw dataset as a pandas DataFrame (based on the last response).
104
+ get_data_visualization_function()
105
+ Retrieves the generated Python function used for data visualization.
106
+ get_recommended_visualization_steps()
107
+ Retrieves the agent's recommended visualization steps.
108
+ get_response()
109
+ Returns the response from the agent as a dictionary.
110
+ show()
111
+ Displays the agent's mermaid diagram.
112
+
113
+ Examples
114
+ --------
115
+ ```python
116
+ import pandas as pd
117
+ from langchain_openai import ChatOpenAI
118
+ from ai_data_science_team.agents import DataVisualizationAgent
119
+
120
+ llm = ChatOpenAI(model="gpt-4o-mini")
121
+
122
+ data_visualization_agent = DataVisualizationAgent(
123
+ model=llm,
124
+ n_samples=30,
125
+ log=True,
126
+ log_path="logs",
127
+ human_in_the_loop=True
128
+ )
129
+
130
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
131
+
132
+ data_visualization_agent.invoke_agent(
133
+ user_instructions="Generate a scatter plot of age vs. total charges with a trend line.",
134
+ data_raw=df,
135
+ max_retries=3,
136
+ retry_count=0
137
+ )
138
+
139
+ plotly_graph_dict = data_visualization_agent.get_plotly_graph()
140
+ # You can render plotly_graph_dict with plotly.io.from_json or
141
+ # something similar in a Jupyter Notebook.
142
+
143
+ response = data_visualization_agent.get_response()
144
+ ```
145
+
146
+ Returns
147
+ --------
148
+ DataVisualizationAgent : langchain.graphs.CompiledStateGraph
149
+ A data visualization agent implemented as a compiled state graph.
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ model,
155
+ n_samples=30,
156
+ log=False,
157
+ log_path=None,
158
+ file_name="data_visualization.py",
159
+ function_name="data_visualization",
160
+ overwrite=True,
161
+ human_in_the_loop=False,
162
+ bypass_recommended_steps=False,
163
+ bypass_explain_code=False
164
+ ):
165
+ self._params = {
166
+ "model": model,
167
+ "n_samples": n_samples,
168
+ "log": log,
169
+ "log_path": log_path,
170
+ "file_name": file_name,
171
+ "function_name": function_name,
172
+ "overwrite": overwrite,
173
+ "human_in_the_loop": human_in_the_loop,
174
+ "bypass_recommended_steps": bypass_recommended_steps,
175
+ "bypass_explain_code": bypass_explain_code,
176
+ }
177
+ self._compiled_graph = self._make_compiled_graph()
178
+ self.response = None
179
+
180
+ def _make_compiled_graph(self):
181
+ """
182
+ Create the compiled graph for the data visualization agent.
183
+ Running this method will reset the response to None.
184
+ """
185
+ self.response = None
186
+ return make_data_visualization_agent(**self._params)
187
+
188
+ def update_params(self, **kwargs):
189
+ """
190
+ Updates the agent's parameters and rebuilds the compiled graph.
191
+ """
192
+ # Update parameters
193
+ for k, v in kwargs.items():
194
+ self._params[k] = v
195
+ # Rebuild the compiled graph
196
+ self._compiled_graph = self._make_compiled_graph()
197
+
198
+ def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
199
+ """
200
+ Asynchronously invokes the agent to generate a visualization.
201
+ The response is stored in the 'response' attribute.
202
+
203
+ Parameters
204
+ ----------
205
+ data_raw : pd.DataFrame
206
+ The raw dataset to be visualized.
207
+ user_instructions : str
208
+ Instructions for data visualization.
209
+ max_retries : int
210
+ Maximum retry attempts.
211
+ retry_count : int
212
+ Current retry attempt count.
213
+ **kwargs : dict
214
+ Additional keyword arguments passed to ainvoke().
215
+
216
+ Returns
217
+ -------
218
+ None
219
+ """
220
+ response = self._compiled_graph.ainvoke({
221
+ "user_instructions": user_instructions,
222
+ "data_raw": data_raw.to_dict(),
223
+ "max_retries": max_retries,
224
+ "retry_count": retry_count,
225
+ }, **kwargs)
226
+ self.response = response
227
+ return None
228
+
229
+ def invoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
230
+ """
231
+ Synchronously invokes the agent to generate a visualization.
232
+ The response is stored in the 'response' attribute.
233
+
234
+ Parameters
235
+ ----------
236
+ data_raw : pd.DataFrame
237
+ The raw dataset to be visualized.
238
+ user_instructions : str
239
+ Instructions for data visualization agent.
240
+ max_retries : int
241
+ Maximum retry attempts.
242
+ retry_count : int
243
+ Current retry attempt count.
244
+ **kwargs : dict
245
+ Additional keyword arguments passed to invoke().
246
+
247
+ Returns
248
+ -------
249
+ None
250
+ """
251
+ response = self._compiled_graph.invoke({
252
+ "user_instructions": user_instructions,
253
+ "data_raw": data_raw.to_dict(),
254
+ "max_retries": max_retries,
255
+ "retry_count": retry_count,
256
+ }, **kwargs)
257
+ self.response = response
258
+ return None
259
+
260
+ def explain_visualization_steps(self):
261
+ """
262
+ Provides an explanation of the visualization steps performed by the agent.
263
+
264
+ Returns
265
+ -------
266
+ str
267
+ Explanation of the visualization steps, if any are available.
268
+ """
269
+ if self.response:
270
+ return self.response.get("messages", [])
271
+ return []
272
+
273
+ def get_log_summary(self, markdown=False):
274
+ """
275
+ Logs a summary of the agent's operations, if logging is enabled.
276
+
277
+ Parameters
278
+ ----------
279
+ markdown : bool, optional
280
+ If True, returns Markdown-formatted output.
281
+
282
+ Returns
283
+ -------
284
+ str or None
285
+ Summary of logs or None if no logs are available.
286
+ """
287
+ if self.response and self.response.get('data_visualization_function_path'):
288
+ log_details = f"Log Path: {self.response.get('data_visualization_function_path')}"
289
+ if markdown:
290
+ return Markdown(log_details)
291
+ else:
292
+ return log_details
293
+ return None
294
+
295
+ def get_plotly_graph(self):
296
+ """
297
+ Retrieves the Plotly graph (in dictionary form) produced by the agent.
298
+
299
+ Returns
300
+ -------
301
+ dict or None
302
+ The Plotly graph dictionary if available, otherwise None.
303
+ """
304
+ if self.response:
305
+ return plotly_from_dict(self.response.get("plotly_graph", None))
306
+ return None
307
+
308
+ def get_data_raw(self):
309
+ """
310
+ Retrieves the raw dataset used in the last invocation.
311
+
312
+ Returns
313
+ -------
314
+ pd.DataFrame or None
315
+ The raw dataset as a DataFrame if available, otherwise None.
316
+ """
317
+ if self.response and self.response.get("data_raw"):
318
+ return pd.DataFrame(self.response.get("data_raw"))
319
+ return None
320
+
321
+ def get_data_visualization_function(self, markdown=False):
322
+ """
323
+ Retrieves the generated Python function used for data visualization.
324
+
325
+ Parameters
326
+ ----------
327
+ markdown : bool, optional
328
+ If True, returns the function in Markdown code block format.
329
+
330
+ Returns
331
+ -------
332
+ str or None
333
+ The Python function code as a string if available, otherwise None.
334
+ """
335
+ if self.response:
336
+ func_code = self.response.get("data_visualization_function", "")
337
+ if markdown:
338
+ return Markdown(f"```python\n{func_code}\n```")
339
+ return func_code
340
+ return None
341
+
342
+ def get_recommended_visualization_steps(self, markdown=False):
343
+ """
344
+ Retrieves the agent's recommended visualization steps.
345
+
346
+ Parameters
347
+ ----------
348
+ markdown : bool, optional
349
+ If True, returns the steps in Markdown format.
350
+
351
+ Returns
352
+ -------
353
+ str or None
354
+ The recommended steps if available, otherwise None.
355
+ """
356
+ if self.response:
357
+ steps = self.response.get("recommended_steps", "")
358
+ if markdown:
359
+ return Markdown(steps)
360
+ return steps
361
+ return None
362
+
363
+ def get_response(self):
364
+ """
365
+ Returns the agent's full response dictionary.
366
+
367
+ Returns
368
+ -------
369
+ dict or None
370
+ The response dictionary if available, otherwise None.
371
+ """
372
+ return self.response
373
+
374
+ def show(self):
375
+ """
376
+ Displays the agent's mermaid diagram for visual inspection of the compiled graph.
377
+ """
378
+ return self._compiled_graph.show()
379
+
380
+
381
+ # Agent
382
+
383
+ def make_data_visualization_agent(
384
+ model,
385
+ n_samples=30,
386
+ log=False,
387
+ log_path=None,
388
+ file_name="data_visualization.py",
389
+ function_name="data_visualization",
390
+ overwrite=True,
391
+ human_in_the_loop=False,
392
+ bypass_recommended_steps=False,
393
+ bypass_explain_code=False
394
+ ):
395
+ """
396
+ Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
397
+ default visualization steps. The agent generates a Python function to produce the visualization, executes it,
398
+ and logs the process, including code and errors. It is designed to facilitate reproducible and customizable
399
+ data visualization workflows.
400
+
401
+ The agent can perform the following default visualization steps unless instructed otherwise:
402
+ - Generating a recommended chart type (bar, scatter, line, etc.)
403
+ - Creating user-friendly titles and axis labels
404
+ - Applying consistent styling (template, font sizes, color themes)
405
+ - Handling theme details (white background, base font size, line size, etc.)
406
+
407
+ User instructions can modify, add, or remove any of these steps to tailor the visualization process.
408
+
409
+ Parameters
410
+ ----------
411
+ model : langchain.llms.base.LLM
412
+ The language model used to generate the data visualization function.
413
+ n_samples : int, optional
414
+ Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
415
+ log : bool, optional
416
+ Whether to log the generated code and errors. Defaults to False.
417
+ log_path : str, optional
418
+ Directory path for storing log files. Defaults to None.
419
+ file_name : str, optional
420
+ Name of the file for saving the generated response. Defaults to "data_visualization.py".
421
+ function_name : str, optional
422
+ Name of the function for data visualization. Defaults to "data_visualization".
423
+ overwrite : bool, optional
424
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
425
+ human_in_the_loop : bool, optional
426
+ Enables user review of data visualization instructions. Defaults to False.
427
+ bypass_recommended_steps : bool, optional
428
+ If True, skips the default recommended visualization steps. Defaults to False.
429
+ bypass_explain_code : bool, optional
430
+ If True, skips the step that provides code explanations. Defaults to False.
431
+
432
+ Examples
433
+ --------
434
+ ``` python
435
+ import pandas as pd
436
+ from langchain_openai import ChatOpenAI
437
+ from ai_data_science_team.agents import data_visualization_agent
438
+
439
+ llm = ChatOpenAI(model="gpt-4o-mini")
440
+
441
+ data_visualization_agent = make_data_visualization_agent(llm)
442
+
443
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
444
+
445
+ response = data_visualization_agent.invoke({
446
+ "user_instructions": "Generate a scatter plot of tenure vs. total charges with a trend line.",
447
+ "data_raw": df.to_dict(),
448
+ "max_retries": 3,
449
+ "retry_count": 0
450
+ })
451
+
452
+ pd.DataFrame(response['plotly_graph'])
453
+ ```
454
+
455
+ Returns
456
+ -------
457
+ app : langchain.graphs.CompiledStateGraph
458
+ The data visualization agent as a state graph.
459
+ """
460
+
461
+ llm = model
462
+
463
+ # Human in th loop requires recommended steps
464
+ if bypass_recommended_steps and human_in_the_loop:
465
+ bypass_recommended_steps = False
466
+ print("Bypass recommended steps set to False to enable human in the loop.")
467
+
468
+ # Setup Log Directory
469
+ if log:
470
+ if log_path is None:
471
+ log_path = LOG_PATH
472
+ if not os.path.exists(log_path):
473
+ os.makedirs(log_path)
474
+
475
+ # Define GraphState for the router
476
+ class GraphState(TypedDict):
477
+ messages: Annotated[Sequence[BaseMessage], operator.add]
478
+ user_instructions: str
479
+ user_instructions_processed: str
480
+ recommended_steps: str
481
+ data_raw: dict
482
+ plotly_graph: dict
483
+ all_datasets_summary: str
484
+ data_visualization_function: str
485
+ data_visualization_function_path: str
486
+ data_visualization_function_file_name: str
487
+ data_visualization_function_name: str
488
+ data_visualization_error: str
489
+ max_retries: int
490
+ retry_count: int
491
+
492
+ def chart_instructor(state: GraphState):
493
+
494
+ print(format_agent_name(AGENT_NAME))
495
+ print(" * CREATE CHART GENERATOR INSTRUCTIONS")
496
+
497
+ recommend_steps_prompt = PromptTemplate(
498
+ template="""
499
+ You are a supervisor that is an expert in providing instructions to a chart generator agent for plotting.
500
+
501
+ You will take a question that a user has and the data that was generated to answer the question, and create instructions to create a chart from the data that will be passed to a chart generator agent.
502
+
503
+ USER QUESTION / INSTRUCTIONS:
504
+ {user_instructions}
505
+
506
+ Previously Recommended Instructions (if any):
507
+ {recommended_steps}
508
+
509
+ DATA:
510
+ {all_datasets_summary}
511
+
512
+ Formulate chart generator instructions by informing the chart generator of what type of plotly plot to use (e.g. bar, line, scatter, etc) to best represent the data.
513
+
514
+ Come up with an informative title from the user's question and data provided. Also provide X and Y axis titles.
515
+
516
+ Instruct the chart generator to use the following theme colors, sizes, etc:
517
+
518
+ - Start with the "plotly_white" template
519
+ - Use a white background
520
+ - Use this color for bars and lines:
521
+ 'blue': '#3381ff',
522
+ - Base Font Size: 8.8 (Used for x and y axes tickfont, any annotations, hovertips)
523
+ - Title Font Size: 13.2
524
+ - Line Size: 0.65 (specify these within the xaxis and yaxis dictionaries)
525
+ - Add smoothers or trendlines to scatter plots unless not desired by the user
526
+ - Do not use color_discrete_map (this will result in an error)
527
+ - Hover tip size: 8.8
528
+
529
+ Return your instructions in the following format:
530
+ CHART GENERATOR INSTRUCTIONS:
531
+ FILL IN THE INSTRUCTIONS HERE
532
+
533
+ Avoid these:
534
+ 1. Do not include steps to save files.
535
+ 2. Do not include unrelated user instructions that are not related to the chart generation.
536
+ """,
537
+ input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
538
+
539
+ )
540
+
541
+ data_raw = state.get("data_raw")
542
+ df = pd.DataFrame.from_dict(data_raw)
543
+
544
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
545
+
546
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
547
+
548
+ chart_instructor = recommend_steps_prompt | llm
549
+
550
+ recommended_steps = chart_instructor.invoke({
551
+ "user_instructions": state.get("user_instructions"),
552
+ "recommended_steps": state.get("recommended_steps"),
553
+ "all_datasets_summary": all_datasets_summary_str
554
+ })
555
+
556
+ return {
557
+ "recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Cleaning Steps:"),
558
+ "all_datasets_summary": all_datasets_summary_str
559
+ }
560
+
561
+ def chart_generator(state: GraphState):
562
+
563
+ print(" * CREATE DATA VISUALIZATION CODE")
564
+
565
+
566
+ if bypass_recommended_steps:
567
+ print(format_agent_name(AGENT_NAME))
568
+
569
+ data_raw = state.get("data_raw")
570
+ df = pd.DataFrame.from_dict(data_raw)
571
+
572
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
573
+
574
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
575
+
576
+ chart_generator_instructions = state.get("user_instructions")
577
+
578
+ else:
579
+ all_datasets_summary_str = state.get("all_datasets_summary")
580
+ chart_generator_instructions = state.get("recommended_steps")
581
+
582
+ prompt_template = PromptTemplate(
583
+ template="""
584
+ You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
585
+
586
+ Your job is to produce python code to generate visualizations with a function named {function_name}.
587
+
588
+ You will take instructions from a Chart Instructor and generate a plotly chart from the data provided.
589
+
590
+ CHART INSTRUCTIONS:
591
+ {chart_generator_instructions}
592
+
593
+ DATA:
594
+ {all_datasets_summary}
595
+
596
+ RETURN:
597
+
598
+ Return Python code in ```python ``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
599
+
600
+ Return the plotly chart as a dictionary.
601
+
602
+ Return code to provide the data visualization function:
603
+
604
+ def {function_name}(data_raw):
605
+ import pandas as pd
606
+ import numpy as np
607
+ import json
608
+ import plotly.graph_objects as go
609
+ import plotly.io as pio
610
+
611
+ ...
612
+
613
+ fig_json = pio.to_json(fig)
614
+ fig_dict = json.loads(fig_json)
615
+
616
+ return fig_dict
617
+
618
+ Avoid these:
619
+ 1. Do not include steps to save files.
620
+ 2. Do not include unrelated user instructions that are not related to the chart generation.
621
+
622
+ """,
623
+ input_variables=["chart_generator_instructions", "all_datasets_summary", "function_name"]
624
+ )
625
+
626
+ data_visualization_agent = prompt_template | llm | PythonOutputParser()
627
+
628
+ response = data_visualization_agent.invoke({
629
+ "chart_generator_instructions": chart_generator_instructions,
630
+ "all_datasets_summary": all_datasets_summary_str,
631
+ "function_name": function_name
632
+ })
633
+
634
+ response = relocate_imports_inside_function(response)
635
+ response = add_comments_to_top(response, agent_name=AGENT_NAME)
636
+
637
+ # For logging: store the code generated:
638
+ file_path, file_name_2 = log_ai_function(
639
+ response=response,
640
+ file_name=file_name,
641
+ log=log,
642
+ log_path=log_path,
643
+ overwrite=overwrite
644
+ )
645
+
646
+ return {
647
+ "data_visualization_function": response,
648
+ "data_visualization_function_path": file_path,
649
+ "data_visualization_function_file_name": file_name_2,
650
+ "data_visualization_function_name": function_name,
651
+ "all_datasets_summary": all_datasets_summary_str
652
+ }
653
+
654
+ # Human Review
655
+
656
+ prompt_text_human_review = "Are the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
657
+
658
+ if not bypass_explain_code:
659
+ def human_review(state: GraphState) -> Command[Literal["chart_instructor", "explain_data_visualization_code"]]:
660
+ return node_func_human_review(
661
+ state=state,
662
+ prompt_text=prompt_text_human_review,
663
+ yes_goto= 'explain_data_visualization_code',
664
+ no_goto="chart_instructor",
665
+ user_instructions_key="user_instructions",
666
+ recommended_steps_key="recommended_steps",
667
+ code_snippet_key="data_visualization_function",
668
+ )
669
+ else:
670
+ def human_review(state: GraphState) -> Command[Literal["chart_instructor", "__end__"]]:
671
+ return node_func_human_review(
672
+ state=state,
673
+ prompt_text=prompt_text_human_review,
674
+ yes_goto= '__end__',
675
+ no_goto="chart_instructor",
676
+ user_instructions_key="user_instructions",
677
+ recommended_steps_key="recommended_steps",
678
+ code_snippet_key="data_visualization_function",
679
+ )
680
+
681
+
682
+ def execute_data_visualization_code(state):
683
+ return node_func_execute_agent_code_on_data(
684
+ state=state,
685
+ data_key="data_raw",
686
+ result_key="plotly_graph",
687
+ error_key="data_visualization_error",
688
+ code_snippet_key="data_visualization_function",
689
+ agent_function_name=state.get("data_visualization_function_name"),
690
+ pre_processing=lambda data: pd.DataFrame.from_dict(data),
691
+ # post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
692
+ error_message_prefix="An error occurred during data visualization: "
693
+ )
694
+
695
+ def fix_data_visualization_code(state: GraphState):
696
+ prompt = """
697
+ You are a Data Visualization Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
698
+
699
+ Make sure to only return the function definition for {function_name}().
700
+
701
+ Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
702
+
703
+ This is the broken code (please fix):
704
+ {code_snippet}
705
+
706
+ Last Known Error:
707
+ {error}
708
+ """
709
+
710
+ return node_func_fix_agent_code(
711
+ state=state,
712
+ code_snippet_key="data_visualization_function",
713
+ error_key="data_visualization_error",
714
+ llm=llm,
715
+ prompt_template=prompt,
716
+ agent_name=AGENT_NAME,
717
+ log=log,
718
+ file_path=state.get("data_visualization_function_path"),
719
+ function_name=state.get("data_visualization_function_name"),
720
+ )
721
+
722
+ def explain_data_visualization_code(state: GraphState):
723
+ return node_func_explain_agent_code(
724
+ state=state,
725
+ code_snippet_key="data_visualization_function",
726
+ result_key="messages",
727
+ error_key="data_visualization_error",
728
+ llm=llm,
729
+ role=AGENT_NAME,
730
+ explanation_prompt_template="""
731
+ Explain the data visualization steps that the data visualization agent performed in this function.
732
+ Keep the summary succinct and to the point.\n\n# Data Visualization Agent:\n\n{code}
733
+ """,
734
+ success_prefix="# Data Visualization Agent:\n\n ",
735
+ error_message="The Data Visualization Agent encountered an error during data visualization. No explanation could be provided."
736
+ )
737
+
738
+ # Define the graph
739
+ node_functions = {
740
+ "chart_instructor": chart_instructor,
741
+ "human_review": human_review,
742
+ "chart_generator": chart_generator,
743
+ "execute_data_visualization_code": execute_data_visualization_code,
744
+ "fix_data_visualization_code": fix_data_visualization_code,
745
+ "explain_data_visualization_code": explain_data_visualization_code
746
+ }
747
+
748
+ app = create_coding_agent_graph(
749
+ GraphState=GraphState,
750
+ node_functions=node_functions,
751
+ recommended_steps_node_name="chart_instructor",
752
+ create_code_node_name="chart_generator",
753
+ execute_code_node_name="execute_data_visualization_code",
754
+ fix_code_node_name="fix_data_visualization_code",
755
+ explain_code_node_name="explain_data_visualization_code",
756
+ error_key="data_visualization_error",
757
+ human_in_the_loop=human_in_the_loop, # or False
758
+ human_review_node_name="human_review",
759
+ checkpointer=MemorySaver() if human_in_the_loop else None,
760
+ bypass_recommended_steps=bypass_recommended_steps,
761
+ bypass_explain_code=bypass_explain_code,
762
+ )
763
+
764
+ return app