ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,764 @@
1
+ # BUSINESS SCIENCE UNIVERSITY
2
+ # AI DATA SCIENCE TEAM
3
+ # ***
4
+ # * Agents: Data Visualization Agent
5
+
6
+
7
+
8
+ # Libraries
9
+ from typing import TypedDict, Annotated, Sequence, Literal
10
+ import operator
11
+
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain_core.output_parsers import StrOutputParser
14
+ from langchain_core.messages import BaseMessage
15
+
16
+ from langgraph.types import Command
17
+ from langgraph.checkpoint.memory import MemorySaver
18
+
19
+ import os
20
+ import pandas as pd
21
+
22
+ from IPython.display import Markdown
23
+
24
+ from ai_data_science_team.templates import(
25
+ node_func_execute_agent_code_on_data,
26
+ node_func_human_review,
27
+ node_func_fix_agent_code,
28
+ node_func_explain_agent_code,
29
+ create_coding_agent_graph,
30
+ BaseAgent,
31
+ )
32
+ from ai_data_science_team.tools.parsers import PythonOutputParser
33
+ from ai_data_science_team.tools.regex import (
34
+ relocate_imports_inside_function,
35
+ add_comments_to_top,
36
+ format_agent_name,
37
+ format_recommended_steps
38
+ )
39
+ from ai_data_science_team.tools.metadata import get_dataframe_summary
40
+ from ai_data_science_team.tools.logging import log_ai_function
41
+ from ai_data_science_team.utils.plotly import plotly_from_dict
42
+
43
+ # Setup
44
+ AGENT_NAME = "data_visualization_agent"
45
+ LOG_PATH = os.path.join(os.getcwd(), "logs/")
46
+
47
+ # Class
48
+
49
+ class DataVisualizationAgent(BaseAgent):
50
+ """
51
+ Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
52
+ default visualization steps (if any). The agent generates a Python function to produce the visualization,
53
+ executes it, and logs the process, including code and errors. It is designed to facilitate reproducible
54
+ and customizable data visualization workflows.
55
+
56
+ The agent may use default instructions for creating charts unless instructed otherwise, such as:
57
+ - Generating a recommended chart type (bar, scatter, line, etc.)
58
+ - Creating user-friendly titles and axis labels
59
+ - Applying consistent styling (template, font sizes, color themes)
60
+ - Handling theme details (white background, base font size, line size, etc.)
61
+
62
+ User instructions can modify, add, or remove any of these steps to tailor the visualization process.
63
+
64
+ Parameters
65
+ ----------
66
+ model : langchain.llms.base.LLM
67
+ The language model used to generate the data visualization function.
68
+ n_samples : int, optional
69
+ Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
70
+ Reducing this number can help avoid exceeding the model's token limits.
71
+ log : bool, optional
72
+ Whether to log the generated code and errors. Defaults to False.
73
+ log_path : str, optional
74
+ Directory path for storing log files. Defaults to None.
75
+ file_name : str, optional
76
+ Name of the file for saving the generated response. Defaults to "data_visualization.py".
77
+ function_name : str, optional
78
+ Name of the function for data visualization. Defaults to "data_visualization".
79
+ overwrite : bool, optional
80
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
81
+ human_in_the_loop : bool, optional
82
+ Enables user review of data visualization instructions. Defaults to False.
83
+ bypass_recommended_steps : bool, optional
84
+ If True, skips the default recommended visualization steps. Defaults to False.
85
+ bypass_explain_code : bool, optional
86
+ If True, skips the step that provides code explanations. Defaults to False.
87
+
88
+ Methods
89
+ -------
90
+ update_params(**kwargs)
91
+ Updates the agent's parameters and rebuilds the compiled state graph.
92
+ ainvoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
93
+ Asynchronously generates a visualization based on user instructions.
94
+ invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
95
+ Synchronously generates a visualization based on user instructions.
96
+ explain_visualization_steps()
97
+ Returns an explanation of the visualization steps performed by the agent.
98
+ get_log_summary()
99
+ Retrieves a summary of logged operations if logging is enabled.
100
+ get_plotly_graph()
101
+ Retrieves the Plotly graph (as a dictionary) produced by the agent.
102
+ get_data_raw()
103
+ Retrieves the raw dataset as a pandas DataFrame (based on the last response).
104
+ get_data_visualization_function()
105
+ Retrieves the generated Python function used for data visualization.
106
+ get_recommended_visualization_steps()
107
+ Retrieves the agent's recommended visualization steps.
108
+ get_response()
109
+ Returns the response from the agent as a dictionary.
110
+ show()
111
+ Displays the agent's mermaid diagram.
112
+
113
+ Examples
114
+ --------
115
+ ```python
116
+ import pandas as pd
117
+ from langchain_openai import ChatOpenAI
118
+ from ai_data_science_team.agents import DataVisualizationAgent
119
+
120
+ llm = ChatOpenAI(model="gpt-4o-mini")
121
+
122
+ data_visualization_agent = DataVisualizationAgent(
123
+ model=llm,
124
+ n_samples=30,
125
+ log=True,
126
+ log_path="logs",
127
+ human_in_the_loop=True
128
+ )
129
+
130
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
131
+
132
+ data_visualization_agent.invoke_agent(
133
+ user_instructions="Generate a scatter plot of age vs. total charges with a trend line.",
134
+ data_raw=df,
135
+ max_retries=3,
136
+ retry_count=0
137
+ )
138
+
139
+ plotly_graph_dict = data_visualization_agent.get_plotly_graph()
140
+ # You can render plotly_graph_dict with plotly.io.from_json or
141
+ # something similar in a Jupyter Notebook.
142
+
143
+ response = data_visualization_agent.get_response()
144
+ ```
145
+
146
+ Returns
147
+ --------
148
+ DataVisualizationAgent : langchain.graphs.CompiledStateGraph
149
+ A data visualization agent implemented as a compiled state graph.
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ model,
155
+ n_samples=30,
156
+ log=False,
157
+ log_path=None,
158
+ file_name="data_visualization.py",
159
+ function_name="data_visualization",
160
+ overwrite=True,
161
+ human_in_the_loop=False,
162
+ bypass_recommended_steps=False,
163
+ bypass_explain_code=False
164
+ ):
165
+ self._params = {
166
+ "model": model,
167
+ "n_samples": n_samples,
168
+ "log": log,
169
+ "log_path": log_path,
170
+ "file_name": file_name,
171
+ "function_name": function_name,
172
+ "overwrite": overwrite,
173
+ "human_in_the_loop": human_in_the_loop,
174
+ "bypass_recommended_steps": bypass_recommended_steps,
175
+ "bypass_explain_code": bypass_explain_code,
176
+ }
177
+ self._compiled_graph = self._make_compiled_graph()
178
+ self.response = None
179
+
180
+ def _make_compiled_graph(self):
181
+ """
182
+ Create the compiled graph for the data visualization agent.
183
+ Running this method will reset the response to None.
184
+ """
185
+ self.response = None
186
+ return make_data_visualization_agent(**self._params)
187
+
188
+ def update_params(self, **kwargs):
189
+ """
190
+ Updates the agent's parameters and rebuilds the compiled graph.
191
+ """
192
+ # Update parameters
193
+ for k, v in kwargs.items():
194
+ self._params[k] = v
195
+ # Rebuild the compiled graph
196
+ self._compiled_graph = self._make_compiled_graph()
197
+
198
+ def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
199
+ """
200
+ Asynchronously invokes the agent to generate a visualization.
201
+ The response is stored in the 'response' attribute.
202
+
203
+ Parameters
204
+ ----------
205
+ data_raw : pd.DataFrame
206
+ The raw dataset to be visualized.
207
+ user_instructions : str
208
+ Instructions for data visualization.
209
+ max_retries : int
210
+ Maximum retry attempts.
211
+ retry_count : int
212
+ Current retry attempt count.
213
+ **kwargs : dict
214
+ Additional keyword arguments passed to ainvoke().
215
+
216
+ Returns
217
+ -------
218
+ None
219
+ """
220
+ response = self._compiled_graph.ainvoke({
221
+ "user_instructions": user_instructions,
222
+ "data_raw": data_raw.to_dict(),
223
+ "max_retries": max_retries,
224
+ "retry_count": retry_count,
225
+ }, **kwargs)
226
+ self.response = response
227
+ return None
228
+
229
+ def invoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
230
+ """
231
+ Synchronously invokes the agent to generate a visualization.
232
+ The response is stored in the 'response' attribute.
233
+
234
+ Parameters
235
+ ----------
236
+ data_raw : pd.DataFrame
237
+ The raw dataset to be visualized.
238
+ user_instructions : str
239
+ Instructions for data visualization agent.
240
+ max_retries : int
241
+ Maximum retry attempts.
242
+ retry_count : int
243
+ Current retry attempt count.
244
+ **kwargs : dict
245
+ Additional keyword arguments passed to invoke().
246
+
247
+ Returns
248
+ -------
249
+ None
250
+ """
251
+ response = self._compiled_graph.invoke({
252
+ "user_instructions": user_instructions,
253
+ "data_raw": data_raw.to_dict(),
254
+ "max_retries": max_retries,
255
+ "retry_count": retry_count,
256
+ }, **kwargs)
257
+ self.response = response
258
+ return None
259
+
260
+ def explain_visualization_steps(self):
261
+ """
262
+ Provides an explanation of the visualization steps performed by the agent.
263
+
264
+ Returns
265
+ -------
266
+ str
267
+ Explanation of the visualization steps, if any are available.
268
+ """
269
+ if self.response:
270
+ return self.response.get("messages", [])
271
+ return []
272
+
273
+ def get_log_summary(self, markdown=False):
274
+ """
275
+ Logs a summary of the agent's operations, if logging is enabled.
276
+
277
+ Parameters
278
+ ----------
279
+ markdown : bool, optional
280
+ If True, returns Markdown-formatted output.
281
+
282
+ Returns
283
+ -------
284
+ str or None
285
+ Summary of logs or None if no logs are available.
286
+ """
287
+ if self.response and self.response.get('data_visualization_function_path'):
288
+ log_details = f"Log Path: {self.response.get('data_visualization_function_path')}"
289
+ if markdown:
290
+ return Markdown(log_details)
291
+ else:
292
+ return log_details
293
+ return None
294
+
295
+ def get_plotly_graph(self):
296
+ """
297
+ Retrieves the Plotly graph (in dictionary form) produced by the agent.
298
+
299
+ Returns
300
+ -------
301
+ dict or None
302
+ The Plotly graph dictionary if available, otherwise None.
303
+ """
304
+ if self.response:
305
+ return plotly_from_dict(self.response.get("plotly_graph", None))
306
+ return None
307
+
308
+ def get_data_raw(self):
309
+ """
310
+ Retrieves the raw dataset used in the last invocation.
311
+
312
+ Returns
313
+ -------
314
+ pd.DataFrame or None
315
+ The raw dataset as a DataFrame if available, otherwise None.
316
+ """
317
+ if self.response and self.response.get("data_raw"):
318
+ return pd.DataFrame(self.response.get("data_raw"))
319
+ return None
320
+
321
+ def get_data_visualization_function(self, markdown=False):
322
+ """
323
+ Retrieves the generated Python function used for data visualization.
324
+
325
+ Parameters
326
+ ----------
327
+ markdown : bool, optional
328
+ If True, returns the function in Markdown code block format.
329
+
330
+ Returns
331
+ -------
332
+ str or None
333
+ The Python function code as a string if available, otherwise None.
334
+ """
335
+ if self.response:
336
+ func_code = self.response.get("data_visualization_function", "")
337
+ if markdown:
338
+ return Markdown(f"```python\n{func_code}\n```")
339
+ return func_code
340
+ return None
341
+
342
+ def get_recommended_visualization_steps(self, markdown=False):
343
+ """
344
+ Retrieves the agent's recommended visualization steps.
345
+
346
+ Parameters
347
+ ----------
348
+ markdown : bool, optional
349
+ If True, returns the steps in Markdown format.
350
+
351
+ Returns
352
+ -------
353
+ str or None
354
+ The recommended steps if available, otherwise None.
355
+ """
356
+ if self.response:
357
+ steps = self.response.get("recommended_steps", "")
358
+ if markdown:
359
+ return Markdown(steps)
360
+ return steps
361
+ return None
362
+
363
+ def get_response(self):
364
+ """
365
+ Returns the agent's full response dictionary.
366
+
367
+ Returns
368
+ -------
369
+ dict or None
370
+ The response dictionary if available, otherwise None.
371
+ """
372
+ return self.response
373
+
374
+ def show(self):
375
+ """
376
+ Displays the agent's mermaid diagram for visual inspection of the compiled graph.
377
+ """
378
+ return self._compiled_graph.show()
379
+
380
+
381
+ # Agent
382
+
383
+ def make_data_visualization_agent(
384
+ model,
385
+ n_samples=30,
386
+ log=False,
387
+ log_path=None,
388
+ file_name="data_visualization.py",
389
+ function_name="data_visualization",
390
+ overwrite=True,
391
+ human_in_the_loop=False,
392
+ bypass_recommended_steps=False,
393
+ bypass_explain_code=False
394
+ ):
395
+ """
396
+ Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
397
+ default visualization steps. The agent generates a Python function to produce the visualization, executes it,
398
+ and logs the process, including code and errors. It is designed to facilitate reproducible and customizable
399
+ data visualization workflows.
400
+
401
+ The agent can perform the following default visualization steps unless instructed otherwise:
402
+ - Generating a recommended chart type (bar, scatter, line, etc.)
403
+ - Creating user-friendly titles and axis labels
404
+ - Applying consistent styling (template, font sizes, color themes)
405
+ - Handling theme details (white background, base font size, line size, etc.)
406
+
407
+ User instructions can modify, add, or remove any of these steps to tailor the visualization process.
408
+
409
+ Parameters
410
+ ----------
411
+ model : langchain.llms.base.LLM
412
+ The language model used to generate the data visualization function.
413
+ n_samples : int, optional
414
+ Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
415
+ log : bool, optional
416
+ Whether to log the generated code and errors. Defaults to False.
417
+ log_path : str, optional
418
+ Directory path for storing log files. Defaults to None.
419
+ file_name : str, optional
420
+ Name of the file for saving the generated response. Defaults to "data_visualization.py".
421
+ function_name : str, optional
422
+ Name of the function for data visualization. Defaults to "data_visualization".
423
+ overwrite : bool, optional
424
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
425
+ human_in_the_loop : bool, optional
426
+ Enables user review of data visualization instructions. Defaults to False.
427
+ bypass_recommended_steps : bool, optional
428
+ If True, skips the default recommended visualization steps. Defaults to False.
429
+ bypass_explain_code : bool, optional
430
+ If True, skips the step that provides code explanations. Defaults to False.
431
+
432
+ Examples
433
+ --------
434
+ ``` python
435
+ import pandas as pd
436
+ from langchain_openai import ChatOpenAI
437
+ from ai_data_science_team.agents import data_visualization_agent
438
+
439
+ llm = ChatOpenAI(model="gpt-4o-mini")
440
+
441
+ data_visualization_agent = make_data_visualization_agent(llm)
442
+
443
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
444
+
445
+ response = data_visualization_agent.invoke({
446
+ "user_instructions": "Generate a scatter plot of tenure vs. total charges with a trend line.",
447
+ "data_raw": df.to_dict(),
448
+ "max_retries": 3,
449
+ "retry_count": 0
450
+ })
451
+
452
+ pd.DataFrame(response['plotly_graph'])
453
+ ```
454
+
455
+ Returns
456
+ -------
457
+ app : langchain.graphs.CompiledStateGraph
458
+ The data visualization agent as a state graph.
459
+ """
460
+
461
+ llm = model
462
+
463
+ # Human in th loop requires recommended steps
464
+ if bypass_recommended_steps and human_in_the_loop:
465
+ bypass_recommended_steps = False
466
+ print("Bypass recommended steps set to False to enable human in the loop.")
467
+
468
+ # Setup Log Directory
469
+ if log:
470
+ if log_path is None:
471
+ log_path = LOG_PATH
472
+ if not os.path.exists(log_path):
473
+ os.makedirs(log_path)
474
+
475
+ # Define GraphState for the router
476
+ class GraphState(TypedDict):
477
+ messages: Annotated[Sequence[BaseMessage], operator.add]
478
+ user_instructions: str
479
+ user_instructions_processed: str
480
+ recommended_steps: str
481
+ data_raw: dict
482
+ plotly_graph: dict
483
+ all_datasets_summary: str
484
+ data_visualization_function: str
485
+ data_visualization_function_path: str
486
+ data_visualization_function_file_name: str
487
+ data_visualization_function_name: str
488
+ data_visualization_error: str
489
+ max_retries: int
490
+ retry_count: int
491
+
492
+ def chart_instructor(state: GraphState):
493
+
494
+ print(format_agent_name(AGENT_NAME))
495
+ print(" * CREATE CHART GENERATOR INSTRUCTIONS")
496
+
497
+ recommend_steps_prompt = PromptTemplate(
498
+ template="""
499
+ You are a supervisor that is an expert in providing instructions to a chart generator agent for plotting.
500
+
501
+ You will take a question that a user has and the data that was generated to answer the question, and create instructions to create a chart from the data that will be passed to a chart generator agent.
502
+
503
+ USER QUESTION / INSTRUCTIONS:
504
+ {user_instructions}
505
+
506
+ Previously Recommended Instructions (if any):
507
+ {recommended_steps}
508
+
509
+ DATA:
510
+ {all_datasets_summary}
511
+
512
+ Formulate chart generator instructions by informing the chart generator of what type of plotly plot to use (e.g. bar, line, scatter, etc) to best represent the data.
513
+
514
+ Come up with an informative title from the user's question and data provided. Also provide X and Y axis titles.
515
+
516
+ Instruct the chart generator to use the following theme colors, sizes, etc:
517
+
518
+ - Start with the "plotly_white" template
519
+ - Use a white background
520
+ - Use this color for bars and lines:
521
+ 'blue': '#3381ff',
522
+ - Base Font Size: 8.8 (Used for x and y axes tickfont, any annotations, hovertips)
523
+ - Title Font Size: 13.2
524
+ - Line Size: 0.65 (specify these within the xaxis and yaxis dictionaries)
525
+ - Add smoothers or trendlines to scatter plots unless not desired by the user
526
+ - Do not use color_discrete_map (this will result in an error)
527
+ - Hover tip size: 8.8
528
+
529
+ Return your instructions in the following format:
530
+ CHART GENERATOR INSTRUCTIONS:
531
+ FILL IN THE INSTRUCTIONS HERE
532
+
533
+ Avoid these:
534
+ 1. Do not include steps to save files.
535
+ 2. Do not include unrelated user instructions that are not related to the chart generation.
536
+ """,
537
+ input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
538
+
539
+ )
540
+
541
+ data_raw = state.get("data_raw")
542
+ df = pd.DataFrame.from_dict(data_raw)
543
+
544
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
545
+
546
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
547
+
548
+ chart_instructor = recommend_steps_prompt | llm
549
+
550
+ recommended_steps = chart_instructor.invoke({
551
+ "user_instructions": state.get("user_instructions"),
552
+ "recommended_steps": state.get("recommended_steps"),
553
+ "all_datasets_summary": all_datasets_summary_str
554
+ })
555
+
556
+ return {
557
+ "recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Cleaning Steps:"),
558
+ "all_datasets_summary": all_datasets_summary_str
559
+ }
560
+
561
+ def chart_generator(state: GraphState):
562
+
563
+ print(" * CREATE DATA VISUALIZATION CODE")
564
+
565
+
566
+ if bypass_recommended_steps:
567
+ print(format_agent_name(AGENT_NAME))
568
+
569
+ data_raw = state.get("data_raw")
570
+ df = pd.DataFrame.from_dict(data_raw)
571
+
572
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
573
+
574
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
575
+
576
+ chart_generator_instructions = state.get("user_instructions")
577
+
578
+ else:
579
+ all_datasets_summary_str = state.get("all_datasets_summary")
580
+ chart_generator_instructions = state.get("recommended_steps")
581
+
582
+ prompt_template = PromptTemplate(
583
+ template="""
584
+ You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
585
+
586
+ Your job is to produce python code to generate visualizations with a function named {function_name}.
587
+
588
+ You will take instructions from a Chart Instructor and generate a plotly chart from the data provided.
589
+
590
+ CHART INSTRUCTIONS:
591
+ {chart_generator_instructions}
592
+
593
+ DATA:
594
+ {all_datasets_summary}
595
+
596
+ RETURN:
597
+
598
+ Return Python code in ```python ``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
599
+
600
+ Return the plotly chart as a dictionary.
601
+
602
+ Return code to provide the data visualization function:
603
+
604
+ def {function_name}(data_raw):
605
+ import pandas as pd
606
+ import numpy as np
607
+ import json
608
+ import plotly.graph_objects as go
609
+ import plotly.io as pio
610
+
611
+ ...
612
+
613
+ fig_json = pio.to_json(fig)
614
+ fig_dict = json.loads(fig_json)
615
+
616
+ return fig_dict
617
+
618
+ Avoid these:
619
+ 1. Do not include steps to save files.
620
+ 2. Do not include unrelated user instructions that are not related to the chart generation.
621
+
622
+ """,
623
+ input_variables=["chart_generator_instructions", "all_datasets_summary", "function_name"]
624
+ )
625
+
626
+ data_visualization_agent = prompt_template | llm | PythonOutputParser()
627
+
628
+ response = data_visualization_agent.invoke({
629
+ "chart_generator_instructions": chart_generator_instructions,
630
+ "all_datasets_summary": all_datasets_summary_str,
631
+ "function_name": function_name
632
+ })
633
+
634
+ response = relocate_imports_inside_function(response)
635
+ response = add_comments_to_top(response, agent_name=AGENT_NAME)
636
+
637
+ # For logging: store the code generated:
638
+ file_path, file_name_2 = log_ai_function(
639
+ response=response,
640
+ file_name=file_name,
641
+ log=log,
642
+ log_path=log_path,
643
+ overwrite=overwrite
644
+ )
645
+
646
+ return {
647
+ "data_visualization_function": response,
648
+ "data_visualization_function_path": file_path,
649
+ "data_visualization_function_file_name": file_name_2,
650
+ "data_visualization_function_name": function_name,
651
+ "all_datasets_summary": all_datasets_summary_str
652
+ }
653
+
654
+ # Human Review
655
+
656
+ prompt_text_human_review = "Are the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
657
+
658
+ if not bypass_explain_code:
659
+ def human_review(state: GraphState) -> Command[Literal["chart_instructor", "explain_data_visualization_code"]]:
660
+ return node_func_human_review(
661
+ state=state,
662
+ prompt_text=prompt_text_human_review,
663
+ yes_goto= 'explain_data_visualization_code',
664
+ no_goto="chart_instructor",
665
+ user_instructions_key="user_instructions",
666
+ recommended_steps_key="recommended_steps",
667
+ code_snippet_key="data_visualization_function",
668
+ )
669
+ else:
670
+ def human_review(state: GraphState) -> Command[Literal["chart_instructor", "__end__"]]:
671
+ return node_func_human_review(
672
+ state=state,
673
+ prompt_text=prompt_text_human_review,
674
+ yes_goto= '__end__',
675
+ no_goto="chart_instructor",
676
+ user_instructions_key="user_instructions",
677
+ recommended_steps_key="recommended_steps",
678
+ code_snippet_key="data_visualization_function",
679
+ )
680
+
681
+
682
+ def execute_data_visualization_code(state):
683
+ return node_func_execute_agent_code_on_data(
684
+ state=state,
685
+ data_key="data_raw",
686
+ result_key="plotly_graph",
687
+ error_key="data_visualization_error",
688
+ code_snippet_key="data_visualization_function",
689
+ agent_function_name=state.get("data_visualization_function_name"),
690
+ pre_processing=lambda data: pd.DataFrame.from_dict(data),
691
+ # post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
692
+ error_message_prefix="An error occurred during data visualization: "
693
+ )
694
+
695
+ def fix_data_visualization_code(state: GraphState):
696
+ prompt = """
697
+ You are a Data Visualization Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
698
+
699
+ Make sure to only return the function definition for {function_name}().
700
+
701
+ Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
702
+
703
+ This is the broken code (please fix):
704
+ {code_snippet}
705
+
706
+ Last Known Error:
707
+ {error}
708
+ """
709
+
710
+ return node_func_fix_agent_code(
711
+ state=state,
712
+ code_snippet_key="data_visualization_function",
713
+ error_key="data_visualization_error",
714
+ llm=llm,
715
+ prompt_template=prompt,
716
+ agent_name=AGENT_NAME,
717
+ log=log,
718
+ file_path=state.get("data_visualization_function_path"),
719
+ function_name=state.get("data_visualization_function_name"),
720
+ )
721
+
722
+ def explain_data_visualization_code(state: GraphState):
723
+ return node_func_explain_agent_code(
724
+ state=state,
725
+ code_snippet_key="data_visualization_function",
726
+ result_key="messages",
727
+ error_key="data_visualization_error",
728
+ llm=llm,
729
+ role=AGENT_NAME,
730
+ explanation_prompt_template="""
731
+ Explain the data visualization steps that the data visualization agent performed in this function.
732
+ Keep the summary succinct and to the point.\n\n# Data Visualization Agent:\n\n{code}
733
+ """,
734
+ success_prefix="# Data Visualization Agent:\n\n ",
735
+ error_message="The Data Visualization Agent encountered an error during data visualization. No explanation could be provided."
736
+ )
737
+
738
+ # Define the graph
739
+ node_functions = {
740
+ "chart_instructor": chart_instructor,
741
+ "human_review": human_review,
742
+ "chart_generator": chart_generator,
743
+ "execute_data_visualization_code": execute_data_visualization_code,
744
+ "fix_data_visualization_code": fix_data_visualization_code,
745
+ "explain_data_visualization_code": explain_data_visualization_code
746
+ }
747
+
748
+ app = create_coding_agent_graph(
749
+ GraphState=GraphState,
750
+ node_functions=node_functions,
751
+ recommended_steps_node_name="chart_instructor",
752
+ create_code_node_name="chart_generator",
753
+ execute_code_node_name="execute_data_visualization_code",
754
+ fix_code_node_name="fix_data_visualization_code",
755
+ explain_code_node_name="explain_data_visualization_code",
756
+ error_key="data_visualization_error",
757
+ human_in_the_loop=human_in_the_loop, # or False
758
+ human_review_node_name="human_review",
759
+ checkpointer=MemorySaver() if human_in_the_loop else None,
760
+ bypass_recommended_steps=bypass_recommended_steps,
761
+ bypass_explain_code=bypass_explain_code,
762
+ )
763
+
764
+ return app