ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +5 -4
- ai_data_science_team/agents/data_cleaning_agent.py +371 -45
- ai_data_science_team/agents/data_visualization_agent.py +764 -0
- ai_data_science_team/agents/data_wrangling_agent.py +507 -23
- ai_data_science_team/agents/feature_engineering_agent.py +467 -34
- ai_data_science_team/agents/sql_database_agent.py +394 -30
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +286 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +9 -0
- ai_data_science_team/templates/agent_templates.py +247 -42
- ai_data_science_team/tools/metadata.py +110 -47
- ai_data_science_team/tools/regex.py +33 -0
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- ai_data_science_team-0.0.0.9008.dist-info/METADATA +231 -0
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +26 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/WHEEL +1 -1
- ai_data_science_team-0.0.0.9006.dist-info/METADATA +0 -165
- ai_data_science_team-0.0.0.9006.dist-info/RECORD +0 -20
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,764 @@
|
|
1
|
+
# BUSINESS SCIENCE UNIVERSITY
|
2
|
+
# AI DATA SCIENCE TEAM
|
3
|
+
# ***
|
4
|
+
# * Agents: Data Visualization Agent
|
5
|
+
|
6
|
+
|
7
|
+
|
8
|
+
# Libraries
|
9
|
+
from typing import TypedDict, Annotated, Sequence, Literal
|
10
|
+
import operator
|
11
|
+
|
12
|
+
from langchain.prompts import PromptTemplate
|
13
|
+
from langchain_core.output_parsers import StrOutputParser
|
14
|
+
from langchain_core.messages import BaseMessage
|
15
|
+
|
16
|
+
from langgraph.types import Command
|
17
|
+
from langgraph.checkpoint.memory import MemorySaver
|
18
|
+
|
19
|
+
import os
|
20
|
+
import pandas as pd
|
21
|
+
|
22
|
+
from IPython.display import Markdown
|
23
|
+
|
24
|
+
from ai_data_science_team.templates import(
|
25
|
+
node_func_execute_agent_code_on_data,
|
26
|
+
node_func_human_review,
|
27
|
+
node_func_fix_agent_code,
|
28
|
+
node_func_explain_agent_code,
|
29
|
+
create_coding_agent_graph,
|
30
|
+
BaseAgent,
|
31
|
+
)
|
32
|
+
from ai_data_science_team.tools.parsers import PythonOutputParser
|
33
|
+
from ai_data_science_team.tools.regex import (
|
34
|
+
relocate_imports_inside_function,
|
35
|
+
add_comments_to_top,
|
36
|
+
format_agent_name,
|
37
|
+
format_recommended_steps
|
38
|
+
)
|
39
|
+
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
40
|
+
from ai_data_science_team.tools.logging import log_ai_function
|
41
|
+
from ai_data_science_team.utils.plotly import plotly_from_dict
|
42
|
+
|
43
|
+
# Setup
|
44
|
+
AGENT_NAME = "data_visualization_agent"
|
45
|
+
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
46
|
+
|
47
|
+
# Class
|
48
|
+
|
49
|
+
class DataVisualizationAgent(BaseAgent):
|
50
|
+
"""
|
51
|
+
Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
|
52
|
+
default visualization steps (if any). The agent generates a Python function to produce the visualization,
|
53
|
+
executes it, and logs the process, including code and errors. It is designed to facilitate reproducible
|
54
|
+
and customizable data visualization workflows.
|
55
|
+
|
56
|
+
The agent may use default instructions for creating charts unless instructed otherwise, such as:
|
57
|
+
- Generating a recommended chart type (bar, scatter, line, etc.)
|
58
|
+
- Creating user-friendly titles and axis labels
|
59
|
+
- Applying consistent styling (template, font sizes, color themes)
|
60
|
+
- Handling theme details (white background, base font size, line size, etc.)
|
61
|
+
|
62
|
+
User instructions can modify, add, or remove any of these steps to tailor the visualization process.
|
63
|
+
|
64
|
+
Parameters
|
65
|
+
----------
|
66
|
+
model : langchain.llms.base.LLM
|
67
|
+
The language model used to generate the data visualization function.
|
68
|
+
n_samples : int, optional
|
69
|
+
Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
|
70
|
+
Reducing this number can help avoid exceeding the model's token limits.
|
71
|
+
log : bool, optional
|
72
|
+
Whether to log the generated code and errors. Defaults to False.
|
73
|
+
log_path : str, optional
|
74
|
+
Directory path for storing log files. Defaults to None.
|
75
|
+
file_name : str, optional
|
76
|
+
Name of the file for saving the generated response. Defaults to "data_visualization.py".
|
77
|
+
function_name : str, optional
|
78
|
+
Name of the function for data visualization. Defaults to "data_visualization".
|
79
|
+
overwrite : bool, optional
|
80
|
+
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
|
81
|
+
human_in_the_loop : bool, optional
|
82
|
+
Enables user review of data visualization instructions. Defaults to False.
|
83
|
+
bypass_recommended_steps : bool, optional
|
84
|
+
If True, skips the default recommended visualization steps. Defaults to False.
|
85
|
+
bypass_explain_code : bool, optional
|
86
|
+
If True, skips the step that provides code explanations. Defaults to False.
|
87
|
+
|
88
|
+
Methods
|
89
|
+
-------
|
90
|
+
update_params(**kwargs)
|
91
|
+
Updates the agent's parameters and rebuilds the compiled state graph.
|
92
|
+
ainvoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
|
93
|
+
Asynchronously generates a visualization based on user instructions.
|
94
|
+
invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
|
95
|
+
Synchronously generates a visualization based on user instructions.
|
96
|
+
explain_visualization_steps()
|
97
|
+
Returns an explanation of the visualization steps performed by the agent.
|
98
|
+
get_log_summary()
|
99
|
+
Retrieves a summary of logged operations if logging is enabled.
|
100
|
+
get_plotly_graph()
|
101
|
+
Retrieves the Plotly graph (as a dictionary) produced by the agent.
|
102
|
+
get_data_raw()
|
103
|
+
Retrieves the raw dataset as a pandas DataFrame (based on the last response).
|
104
|
+
get_data_visualization_function()
|
105
|
+
Retrieves the generated Python function used for data visualization.
|
106
|
+
get_recommended_visualization_steps()
|
107
|
+
Retrieves the agent's recommended visualization steps.
|
108
|
+
get_response()
|
109
|
+
Returns the response from the agent as a dictionary.
|
110
|
+
show()
|
111
|
+
Displays the agent's mermaid diagram.
|
112
|
+
|
113
|
+
Examples
|
114
|
+
--------
|
115
|
+
```python
|
116
|
+
import pandas as pd
|
117
|
+
from langchain_openai import ChatOpenAI
|
118
|
+
from ai_data_science_team.agents import DataVisualizationAgent
|
119
|
+
|
120
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
121
|
+
|
122
|
+
data_visualization_agent = DataVisualizationAgent(
|
123
|
+
model=llm,
|
124
|
+
n_samples=30,
|
125
|
+
log=True,
|
126
|
+
log_path="logs",
|
127
|
+
human_in_the_loop=True
|
128
|
+
)
|
129
|
+
|
130
|
+
df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
|
131
|
+
|
132
|
+
data_visualization_agent.invoke_agent(
|
133
|
+
user_instructions="Generate a scatter plot of age vs. total charges with a trend line.",
|
134
|
+
data_raw=df,
|
135
|
+
max_retries=3,
|
136
|
+
retry_count=0
|
137
|
+
)
|
138
|
+
|
139
|
+
plotly_graph_dict = data_visualization_agent.get_plotly_graph()
|
140
|
+
# You can render plotly_graph_dict with plotly.io.from_json or
|
141
|
+
# something similar in a Jupyter Notebook.
|
142
|
+
|
143
|
+
response = data_visualization_agent.get_response()
|
144
|
+
```
|
145
|
+
|
146
|
+
Returns
|
147
|
+
--------
|
148
|
+
DataVisualizationAgent : langchain.graphs.CompiledStateGraph
|
149
|
+
A data visualization agent implemented as a compiled state graph.
|
150
|
+
"""
|
151
|
+
|
152
|
+
def __init__(
|
153
|
+
self,
|
154
|
+
model,
|
155
|
+
n_samples=30,
|
156
|
+
log=False,
|
157
|
+
log_path=None,
|
158
|
+
file_name="data_visualization.py",
|
159
|
+
function_name="data_visualization",
|
160
|
+
overwrite=True,
|
161
|
+
human_in_the_loop=False,
|
162
|
+
bypass_recommended_steps=False,
|
163
|
+
bypass_explain_code=False
|
164
|
+
):
|
165
|
+
self._params = {
|
166
|
+
"model": model,
|
167
|
+
"n_samples": n_samples,
|
168
|
+
"log": log,
|
169
|
+
"log_path": log_path,
|
170
|
+
"file_name": file_name,
|
171
|
+
"function_name": function_name,
|
172
|
+
"overwrite": overwrite,
|
173
|
+
"human_in_the_loop": human_in_the_loop,
|
174
|
+
"bypass_recommended_steps": bypass_recommended_steps,
|
175
|
+
"bypass_explain_code": bypass_explain_code,
|
176
|
+
}
|
177
|
+
self._compiled_graph = self._make_compiled_graph()
|
178
|
+
self.response = None
|
179
|
+
|
180
|
+
def _make_compiled_graph(self):
|
181
|
+
"""
|
182
|
+
Create the compiled graph for the data visualization agent.
|
183
|
+
Running this method will reset the response to None.
|
184
|
+
"""
|
185
|
+
self.response = None
|
186
|
+
return make_data_visualization_agent(**self._params)
|
187
|
+
|
188
|
+
def update_params(self, **kwargs):
|
189
|
+
"""
|
190
|
+
Updates the agent's parameters and rebuilds the compiled graph.
|
191
|
+
"""
|
192
|
+
# Update parameters
|
193
|
+
for k, v in kwargs.items():
|
194
|
+
self._params[k] = v
|
195
|
+
# Rebuild the compiled graph
|
196
|
+
self._compiled_graph = self._make_compiled_graph()
|
197
|
+
|
198
|
+
def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
199
|
+
"""
|
200
|
+
Asynchronously invokes the agent to generate a visualization.
|
201
|
+
The response is stored in the 'response' attribute.
|
202
|
+
|
203
|
+
Parameters
|
204
|
+
----------
|
205
|
+
data_raw : pd.DataFrame
|
206
|
+
The raw dataset to be visualized.
|
207
|
+
user_instructions : str
|
208
|
+
Instructions for data visualization.
|
209
|
+
max_retries : int
|
210
|
+
Maximum retry attempts.
|
211
|
+
retry_count : int
|
212
|
+
Current retry attempt count.
|
213
|
+
**kwargs : dict
|
214
|
+
Additional keyword arguments passed to ainvoke().
|
215
|
+
|
216
|
+
Returns
|
217
|
+
-------
|
218
|
+
None
|
219
|
+
"""
|
220
|
+
response = self._compiled_graph.ainvoke({
|
221
|
+
"user_instructions": user_instructions,
|
222
|
+
"data_raw": data_raw.to_dict(),
|
223
|
+
"max_retries": max_retries,
|
224
|
+
"retry_count": retry_count,
|
225
|
+
}, **kwargs)
|
226
|
+
self.response = response
|
227
|
+
return None
|
228
|
+
|
229
|
+
def invoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
230
|
+
"""
|
231
|
+
Synchronously invokes the agent to generate a visualization.
|
232
|
+
The response is stored in the 'response' attribute.
|
233
|
+
|
234
|
+
Parameters
|
235
|
+
----------
|
236
|
+
data_raw : pd.DataFrame
|
237
|
+
The raw dataset to be visualized.
|
238
|
+
user_instructions : str
|
239
|
+
Instructions for data visualization agent.
|
240
|
+
max_retries : int
|
241
|
+
Maximum retry attempts.
|
242
|
+
retry_count : int
|
243
|
+
Current retry attempt count.
|
244
|
+
**kwargs : dict
|
245
|
+
Additional keyword arguments passed to invoke().
|
246
|
+
|
247
|
+
Returns
|
248
|
+
-------
|
249
|
+
None
|
250
|
+
"""
|
251
|
+
response = self._compiled_graph.invoke({
|
252
|
+
"user_instructions": user_instructions,
|
253
|
+
"data_raw": data_raw.to_dict(),
|
254
|
+
"max_retries": max_retries,
|
255
|
+
"retry_count": retry_count,
|
256
|
+
}, **kwargs)
|
257
|
+
self.response = response
|
258
|
+
return None
|
259
|
+
|
260
|
+
def explain_visualization_steps(self):
|
261
|
+
"""
|
262
|
+
Provides an explanation of the visualization steps performed by the agent.
|
263
|
+
|
264
|
+
Returns
|
265
|
+
-------
|
266
|
+
str
|
267
|
+
Explanation of the visualization steps, if any are available.
|
268
|
+
"""
|
269
|
+
if self.response:
|
270
|
+
return self.response.get("messages", [])
|
271
|
+
return []
|
272
|
+
|
273
|
+
def get_log_summary(self, markdown=False):
|
274
|
+
"""
|
275
|
+
Logs a summary of the agent's operations, if logging is enabled.
|
276
|
+
|
277
|
+
Parameters
|
278
|
+
----------
|
279
|
+
markdown : bool, optional
|
280
|
+
If True, returns Markdown-formatted output.
|
281
|
+
|
282
|
+
Returns
|
283
|
+
-------
|
284
|
+
str or None
|
285
|
+
Summary of logs or None if no logs are available.
|
286
|
+
"""
|
287
|
+
if self.response and self.response.get('data_visualization_function_path'):
|
288
|
+
log_details = f"Log Path: {self.response.get('data_visualization_function_path')}"
|
289
|
+
if markdown:
|
290
|
+
return Markdown(log_details)
|
291
|
+
else:
|
292
|
+
return log_details
|
293
|
+
return None
|
294
|
+
|
295
|
+
def get_plotly_graph(self):
|
296
|
+
"""
|
297
|
+
Retrieves the Plotly graph (in dictionary form) produced by the agent.
|
298
|
+
|
299
|
+
Returns
|
300
|
+
-------
|
301
|
+
dict or None
|
302
|
+
The Plotly graph dictionary if available, otherwise None.
|
303
|
+
"""
|
304
|
+
if self.response:
|
305
|
+
return plotly_from_dict(self.response.get("plotly_graph", None))
|
306
|
+
return None
|
307
|
+
|
308
|
+
def get_data_raw(self):
|
309
|
+
"""
|
310
|
+
Retrieves the raw dataset used in the last invocation.
|
311
|
+
|
312
|
+
Returns
|
313
|
+
-------
|
314
|
+
pd.DataFrame or None
|
315
|
+
The raw dataset as a DataFrame if available, otherwise None.
|
316
|
+
"""
|
317
|
+
if self.response and self.response.get("data_raw"):
|
318
|
+
return pd.DataFrame(self.response.get("data_raw"))
|
319
|
+
return None
|
320
|
+
|
321
|
+
def get_data_visualization_function(self, markdown=False):
|
322
|
+
"""
|
323
|
+
Retrieves the generated Python function used for data visualization.
|
324
|
+
|
325
|
+
Parameters
|
326
|
+
----------
|
327
|
+
markdown : bool, optional
|
328
|
+
If True, returns the function in Markdown code block format.
|
329
|
+
|
330
|
+
Returns
|
331
|
+
-------
|
332
|
+
str or None
|
333
|
+
The Python function code as a string if available, otherwise None.
|
334
|
+
"""
|
335
|
+
if self.response:
|
336
|
+
func_code = self.response.get("data_visualization_function", "")
|
337
|
+
if markdown:
|
338
|
+
return Markdown(f"```python\n{func_code}\n```")
|
339
|
+
return func_code
|
340
|
+
return None
|
341
|
+
|
342
|
+
def get_recommended_visualization_steps(self, markdown=False):
|
343
|
+
"""
|
344
|
+
Retrieves the agent's recommended visualization steps.
|
345
|
+
|
346
|
+
Parameters
|
347
|
+
----------
|
348
|
+
markdown : bool, optional
|
349
|
+
If True, returns the steps in Markdown format.
|
350
|
+
|
351
|
+
Returns
|
352
|
+
-------
|
353
|
+
str or None
|
354
|
+
The recommended steps if available, otherwise None.
|
355
|
+
"""
|
356
|
+
if self.response:
|
357
|
+
steps = self.response.get("recommended_steps", "")
|
358
|
+
if markdown:
|
359
|
+
return Markdown(steps)
|
360
|
+
return steps
|
361
|
+
return None
|
362
|
+
|
363
|
+
def get_response(self):
|
364
|
+
"""
|
365
|
+
Returns the agent's full response dictionary.
|
366
|
+
|
367
|
+
Returns
|
368
|
+
-------
|
369
|
+
dict or None
|
370
|
+
The response dictionary if available, otherwise None.
|
371
|
+
"""
|
372
|
+
return self.response
|
373
|
+
|
374
|
+
def show(self):
|
375
|
+
"""
|
376
|
+
Displays the agent's mermaid diagram for visual inspection of the compiled graph.
|
377
|
+
"""
|
378
|
+
return self._compiled_graph.show()
|
379
|
+
|
380
|
+
|
381
|
+
# Agent
|
382
|
+
|
383
|
+
def make_data_visualization_agent(
|
384
|
+
model,
|
385
|
+
n_samples=30,
|
386
|
+
log=False,
|
387
|
+
log_path=None,
|
388
|
+
file_name="data_visualization.py",
|
389
|
+
function_name="data_visualization",
|
390
|
+
overwrite=True,
|
391
|
+
human_in_the_loop=False,
|
392
|
+
bypass_recommended_steps=False,
|
393
|
+
bypass_explain_code=False
|
394
|
+
):
|
395
|
+
"""
|
396
|
+
Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
|
397
|
+
default visualization steps. The agent generates a Python function to produce the visualization, executes it,
|
398
|
+
and logs the process, including code and errors. It is designed to facilitate reproducible and customizable
|
399
|
+
data visualization workflows.
|
400
|
+
|
401
|
+
The agent can perform the following default visualization steps unless instructed otherwise:
|
402
|
+
- Generating a recommended chart type (bar, scatter, line, etc.)
|
403
|
+
- Creating user-friendly titles and axis labels
|
404
|
+
- Applying consistent styling (template, font sizes, color themes)
|
405
|
+
- Handling theme details (white background, base font size, line size, etc.)
|
406
|
+
|
407
|
+
User instructions can modify, add, or remove any of these steps to tailor the visualization process.
|
408
|
+
|
409
|
+
Parameters
|
410
|
+
----------
|
411
|
+
model : langchain.llms.base.LLM
|
412
|
+
The language model used to generate the data visualization function.
|
413
|
+
n_samples : int, optional
|
414
|
+
Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
|
415
|
+
log : bool, optional
|
416
|
+
Whether to log the generated code and errors. Defaults to False.
|
417
|
+
log_path : str, optional
|
418
|
+
Directory path for storing log files. Defaults to None.
|
419
|
+
file_name : str, optional
|
420
|
+
Name of the file for saving the generated response. Defaults to "data_visualization.py".
|
421
|
+
function_name : str, optional
|
422
|
+
Name of the function for data visualization. Defaults to "data_visualization".
|
423
|
+
overwrite : bool, optional
|
424
|
+
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
|
425
|
+
human_in_the_loop : bool, optional
|
426
|
+
Enables user review of data visualization instructions. Defaults to False.
|
427
|
+
bypass_recommended_steps : bool, optional
|
428
|
+
If True, skips the default recommended visualization steps. Defaults to False.
|
429
|
+
bypass_explain_code : bool, optional
|
430
|
+
If True, skips the step that provides code explanations. Defaults to False.
|
431
|
+
|
432
|
+
Examples
|
433
|
+
--------
|
434
|
+
``` python
|
435
|
+
import pandas as pd
|
436
|
+
from langchain_openai import ChatOpenAI
|
437
|
+
from ai_data_science_team.agents import data_visualization_agent
|
438
|
+
|
439
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
440
|
+
|
441
|
+
data_visualization_agent = make_data_visualization_agent(llm)
|
442
|
+
|
443
|
+
df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
|
444
|
+
|
445
|
+
response = data_visualization_agent.invoke({
|
446
|
+
"user_instructions": "Generate a scatter plot of tenure vs. total charges with a trend line.",
|
447
|
+
"data_raw": df.to_dict(),
|
448
|
+
"max_retries": 3,
|
449
|
+
"retry_count": 0
|
450
|
+
})
|
451
|
+
|
452
|
+
pd.DataFrame(response['plotly_graph'])
|
453
|
+
```
|
454
|
+
|
455
|
+
Returns
|
456
|
+
-------
|
457
|
+
app : langchain.graphs.CompiledStateGraph
|
458
|
+
The data visualization agent as a state graph.
|
459
|
+
"""
|
460
|
+
|
461
|
+
llm = model
|
462
|
+
|
463
|
+
# Human in th loop requires recommended steps
|
464
|
+
if bypass_recommended_steps and human_in_the_loop:
|
465
|
+
bypass_recommended_steps = False
|
466
|
+
print("Bypass recommended steps set to False to enable human in the loop.")
|
467
|
+
|
468
|
+
# Setup Log Directory
|
469
|
+
if log:
|
470
|
+
if log_path is None:
|
471
|
+
log_path = LOG_PATH
|
472
|
+
if not os.path.exists(log_path):
|
473
|
+
os.makedirs(log_path)
|
474
|
+
|
475
|
+
# Define GraphState for the router
|
476
|
+
class GraphState(TypedDict):
|
477
|
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
478
|
+
user_instructions: str
|
479
|
+
user_instructions_processed: str
|
480
|
+
recommended_steps: str
|
481
|
+
data_raw: dict
|
482
|
+
plotly_graph: dict
|
483
|
+
all_datasets_summary: str
|
484
|
+
data_visualization_function: str
|
485
|
+
data_visualization_function_path: str
|
486
|
+
data_visualization_function_file_name: str
|
487
|
+
data_visualization_function_name: str
|
488
|
+
data_visualization_error: str
|
489
|
+
max_retries: int
|
490
|
+
retry_count: int
|
491
|
+
|
492
|
+
def chart_instructor(state: GraphState):
|
493
|
+
|
494
|
+
print(format_agent_name(AGENT_NAME))
|
495
|
+
print(" * CREATE CHART GENERATOR INSTRUCTIONS")
|
496
|
+
|
497
|
+
recommend_steps_prompt = PromptTemplate(
|
498
|
+
template="""
|
499
|
+
You are a supervisor that is an expert in providing instructions to a chart generator agent for plotting.
|
500
|
+
|
501
|
+
You will take a question that a user has and the data that was generated to answer the question, and create instructions to create a chart from the data that will be passed to a chart generator agent.
|
502
|
+
|
503
|
+
USER QUESTION / INSTRUCTIONS:
|
504
|
+
{user_instructions}
|
505
|
+
|
506
|
+
Previously Recommended Instructions (if any):
|
507
|
+
{recommended_steps}
|
508
|
+
|
509
|
+
DATA:
|
510
|
+
{all_datasets_summary}
|
511
|
+
|
512
|
+
Formulate chart generator instructions by informing the chart generator of what type of plotly plot to use (e.g. bar, line, scatter, etc) to best represent the data.
|
513
|
+
|
514
|
+
Come up with an informative title from the user's question and data provided. Also provide X and Y axis titles.
|
515
|
+
|
516
|
+
Instruct the chart generator to use the following theme colors, sizes, etc:
|
517
|
+
|
518
|
+
- Start with the "plotly_white" template
|
519
|
+
- Use a white background
|
520
|
+
- Use this color for bars and lines:
|
521
|
+
'blue': '#3381ff',
|
522
|
+
- Base Font Size: 8.8 (Used for x and y axes tickfont, any annotations, hovertips)
|
523
|
+
- Title Font Size: 13.2
|
524
|
+
- Line Size: 0.65 (specify these within the xaxis and yaxis dictionaries)
|
525
|
+
- Add smoothers or trendlines to scatter plots unless not desired by the user
|
526
|
+
- Do not use color_discrete_map (this will result in an error)
|
527
|
+
- Hover tip size: 8.8
|
528
|
+
|
529
|
+
Return your instructions in the following format:
|
530
|
+
CHART GENERATOR INSTRUCTIONS:
|
531
|
+
FILL IN THE INSTRUCTIONS HERE
|
532
|
+
|
533
|
+
Avoid these:
|
534
|
+
1. Do not include steps to save files.
|
535
|
+
2. Do not include unrelated user instructions that are not related to the chart generation.
|
536
|
+
""",
|
537
|
+
input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
|
538
|
+
|
539
|
+
)
|
540
|
+
|
541
|
+
data_raw = state.get("data_raw")
|
542
|
+
df = pd.DataFrame.from_dict(data_raw)
|
543
|
+
|
544
|
+
all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
|
545
|
+
|
546
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
547
|
+
|
548
|
+
chart_instructor = recommend_steps_prompt | llm
|
549
|
+
|
550
|
+
recommended_steps = chart_instructor.invoke({
|
551
|
+
"user_instructions": state.get("user_instructions"),
|
552
|
+
"recommended_steps": state.get("recommended_steps"),
|
553
|
+
"all_datasets_summary": all_datasets_summary_str
|
554
|
+
})
|
555
|
+
|
556
|
+
return {
|
557
|
+
"recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Cleaning Steps:"),
|
558
|
+
"all_datasets_summary": all_datasets_summary_str
|
559
|
+
}
|
560
|
+
|
561
|
+
def chart_generator(state: GraphState):
|
562
|
+
|
563
|
+
print(" * CREATE DATA VISUALIZATION CODE")
|
564
|
+
|
565
|
+
|
566
|
+
if bypass_recommended_steps:
|
567
|
+
print(format_agent_name(AGENT_NAME))
|
568
|
+
|
569
|
+
data_raw = state.get("data_raw")
|
570
|
+
df = pd.DataFrame.from_dict(data_raw)
|
571
|
+
|
572
|
+
all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
|
573
|
+
|
574
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
575
|
+
|
576
|
+
chart_generator_instructions = state.get("user_instructions")
|
577
|
+
|
578
|
+
else:
|
579
|
+
all_datasets_summary_str = state.get("all_datasets_summary")
|
580
|
+
chart_generator_instructions = state.get("recommended_steps")
|
581
|
+
|
582
|
+
prompt_template = PromptTemplate(
|
583
|
+
template="""
|
584
|
+
You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
|
585
|
+
|
586
|
+
Your job is to produce python code to generate visualizations with a function named {function_name}.
|
587
|
+
|
588
|
+
You will take instructions from a Chart Instructor and generate a plotly chart from the data provided.
|
589
|
+
|
590
|
+
CHART INSTRUCTIONS:
|
591
|
+
{chart_generator_instructions}
|
592
|
+
|
593
|
+
DATA:
|
594
|
+
{all_datasets_summary}
|
595
|
+
|
596
|
+
RETURN:
|
597
|
+
|
598
|
+
Return Python code in ```python ``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
|
599
|
+
|
600
|
+
Return the plotly chart as a dictionary.
|
601
|
+
|
602
|
+
Return code to provide the data visualization function:
|
603
|
+
|
604
|
+
def {function_name}(data_raw):
|
605
|
+
import pandas as pd
|
606
|
+
import numpy as np
|
607
|
+
import json
|
608
|
+
import plotly.graph_objects as go
|
609
|
+
import plotly.io as pio
|
610
|
+
|
611
|
+
...
|
612
|
+
|
613
|
+
fig_json = pio.to_json(fig)
|
614
|
+
fig_dict = json.loads(fig_json)
|
615
|
+
|
616
|
+
return fig_dict
|
617
|
+
|
618
|
+
Avoid these:
|
619
|
+
1. Do not include steps to save files.
|
620
|
+
2. Do not include unrelated user instructions that are not related to the chart generation.
|
621
|
+
|
622
|
+
""",
|
623
|
+
input_variables=["chart_generator_instructions", "all_datasets_summary", "function_name"]
|
624
|
+
)
|
625
|
+
|
626
|
+
data_visualization_agent = prompt_template | llm | PythonOutputParser()
|
627
|
+
|
628
|
+
response = data_visualization_agent.invoke({
|
629
|
+
"chart_generator_instructions": chart_generator_instructions,
|
630
|
+
"all_datasets_summary": all_datasets_summary_str,
|
631
|
+
"function_name": function_name
|
632
|
+
})
|
633
|
+
|
634
|
+
response = relocate_imports_inside_function(response)
|
635
|
+
response = add_comments_to_top(response, agent_name=AGENT_NAME)
|
636
|
+
|
637
|
+
# For logging: store the code generated:
|
638
|
+
file_path, file_name_2 = log_ai_function(
|
639
|
+
response=response,
|
640
|
+
file_name=file_name,
|
641
|
+
log=log,
|
642
|
+
log_path=log_path,
|
643
|
+
overwrite=overwrite
|
644
|
+
)
|
645
|
+
|
646
|
+
return {
|
647
|
+
"data_visualization_function": response,
|
648
|
+
"data_visualization_function_path": file_path,
|
649
|
+
"data_visualization_function_file_name": file_name_2,
|
650
|
+
"data_visualization_function_name": function_name,
|
651
|
+
"all_datasets_summary": all_datasets_summary_str
|
652
|
+
}
|
653
|
+
|
654
|
+
# Human Review
|
655
|
+
|
656
|
+
prompt_text_human_review = "Are the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
|
657
|
+
|
658
|
+
if not bypass_explain_code:
|
659
|
+
def human_review(state: GraphState) -> Command[Literal["chart_instructor", "explain_data_visualization_code"]]:
|
660
|
+
return node_func_human_review(
|
661
|
+
state=state,
|
662
|
+
prompt_text=prompt_text_human_review,
|
663
|
+
yes_goto= 'explain_data_visualization_code',
|
664
|
+
no_goto="chart_instructor",
|
665
|
+
user_instructions_key="user_instructions",
|
666
|
+
recommended_steps_key="recommended_steps",
|
667
|
+
code_snippet_key="data_visualization_function",
|
668
|
+
)
|
669
|
+
else:
|
670
|
+
def human_review(state: GraphState) -> Command[Literal["chart_instructor", "__end__"]]:
|
671
|
+
return node_func_human_review(
|
672
|
+
state=state,
|
673
|
+
prompt_text=prompt_text_human_review,
|
674
|
+
yes_goto= '__end__',
|
675
|
+
no_goto="chart_instructor",
|
676
|
+
user_instructions_key="user_instructions",
|
677
|
+
recommended_steps_key="recommended_steps",
|
678
|
+
code_snippet_key="data_visualization_function",
|
679
|
+
)
|
680
|
+
|
681
|
+
|
682
|
+
def execute_data_visualization_code(state):
|
683
|
+
return node_func_execute_agent_code_on_data(
|
684
|
+
state=state,
|
685
|
+
data_key="data_raw",
|
686
|
+
result_key="plotly_graph",
|
687
|
+
error_key="data_visualization_error",
|
688
|
+
code_snippet_key="data_visualization_function",
|
689
|
+
agent_function_name=state.get("data_visualization_function_name"),
|
690
|
+
pre_processing=lambda data: pd.DataFrame.from_dict(data),
|
691
|
+
# post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
692
|
+
error_message_prefix="An error occurred during data visualization: "
|
693
|
+
)
|
694
|
+
|
695
|
+
def fix_data_visualization_code(state: GraphState):
|
696
|
+
prompt = """
|
697
|
+
You are a Data Visualization Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
|
698
|
+
|
699
|
+
Make sure to only return the function definition for {function_name}().
|
700
|
+
|
701
|
+
Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
|
702
|
+
|
703
|
+
This is the broken code (please fix):
|
704
|
+
{code_snippet}
|
705
|
+
|
706
|
+
Last Known Error:
|
707
|
+
{error}
|
708
|
+
"""
|
709
|
+
|
710
|
+
return node_func_fix_agent_code(
|
711
|
+
state=state,
|
712
|
+
code_snippet_key="data_visualization_function",
|
713
|
+
error_key="data_visualization_error",
|
714
|
+
llm=llm,
|
715
|
+
prompt_template=prompt,
|
716
|
+
agent_name=AGENT_NAME,
|
717
|
+
log=log,
|
718
|
+
file_path=state.get("data_visualization_function_path"),
|
719
|
+
function_name=state.get("data_visualization_function_name"),
|
720
|
+
)
|
721
|
+
|
722
|
+
def explain_data_visualization_code(state: GraphState):
|
723
|
+
return node_func_explain_agent_code(
|
724
|
+
state=state,
|
725
|
+
code_snippet_key="data_visualization_function",
|
726
|
+
result_key="messages",
|
727
|
+
error_key="data_visualization_error",
|
728
|
+
llm=llm,
|
729
|
+
role=AGENT_NAME,
|
730
|
+
explanation_prompt_template="""
|
731
|
+
Explain the data visualization steps that the data visualization agent performed in this function.
|
732
|
+
Keep the summary succinct and to the point.\n\n# Data Visualization Agent:\n\n{code}
|
733
|
+
""",
|
734
|
+
success_prefix="# Data Visualization Agent:\n\n ",
|
735
|
+
error_message="The Data Visualization Agent encountered an error during data visualization. No explanation could be provided."
|
736
|
+
)
|
737
|
+
|
738
|
+
# Define the graph
|
739
|
+
node_functions = {
|
740
|
+
"chart_instructor": chart_instructor,
|
741
|
+
"human_review": human_review,
|
742
|
+
"chart_generator": chart_generator,
|
743
|
+
"execute_data_visualization_code": execute_data_visualization_code,
|
744
|
+
"fix_data_visualization_code": fix_data_visualization_code,
|
745
|
+
"explain_data_visualization_code": explain_data_visualization_code
|
746
|
+
}
|
747
|
+
|
748
|
+
app = create_coding_agent_graph(
|
749
|
+
GraphState=GraphState,
|
750
|
+
node_functions=node_functions,
|
751
|
+
recommended_steps_node_name="chart_instructor",
|
752
|
+
create_code_node_name="chart_generator",
|
753
|
+
execute_code_node_name="execute_data_visualization_code",
|
754
|
+
fix_code_node_name="fix_data_visualization_code",
|
755
|
+
explain_code_node_name="explain_data_visualization_code",
|
756
|
+
error_key="data_visualization_error",
|
757
|
+
human_in_the_loop=human_in_the_loop, # or False
|
758
|
+
human_review_node_name="human_review",
|
759
|
+
checkpointer=MemorySaver() if human_in_the_loop else None,
|
760
|
+
bypass_recommended_steps=bypass_recommended_steps,
|
761
|
+
bypass_explain_code=bypass_explain_code,
|
762
|
+
)
|
763
|
+
|
764
|
+
return app
|