ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +5 -4
- ai_data_science_team/agents/data_cleaning_agent.py +371 -45
- ai_data_science_team/agents/data_visualization_agent.py +764 -0
- ai_data_science_team/agents/data_wrangling_agent.py +507 -23
- ai_data_science_team/agents/feature_engineering_agent.py +467 -34
- ai_data_science_team/agents/sql_database_agent.py +394 -30
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +286 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +9 -0
- ai_data_science_team/templates/agent_templates.py +247 -42
- ai_data_science_team/tools/metadata.py +110 -47
- ai_data_science_team/tools/regex.py +33 -0
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- ai_data_science_team-0.0.0.9008.dist-info/METADATA +231 -0
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +26 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/WHEEL +1 -1
- ai_data_science_team-0.0.0.9006.dist-info/METADATA +0 -165
- ai_data_science_team-0.0.0.9006.dist-info/RECORD +0 -20
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,764 @@
|
|
1
|
+
# BUSINESS SCIENCE UNIVERSITY
|
2
|
+
# AI DATA SCIENCE TEAM
|
3
|
+
# ***
|
4
|
+
# * Agents: Data Visualization Agent
|
5
|
+
|
6
|
+
|
7
|
+
|
8
|
+
# Libraries
|
9
|
+
from typing import TypedDict, Annotated, Sequence, Literal
|
10
|
+
import operator
|
11
|
+
|
12
|
+
from langchain.prompts import PromptTemplate
|
13
|
+
from langchain_core.output_parsers import StrOutputParser
|
14
|
+
from langchain_core.messages import BaseMessage
|
15
|
+
|
16
|
+
from langgraph.types import Command
|
17
|
+
from langgraph.checkpoint.memory import MemorySaver
|
18
|
+
|
19
|
+
import os
|
20
|
+
import pandas as pd
|
21
|
+
|
22
|
+
from IPython.display import Markdown
|
23
|
+
|
24
|
+
from ai_data_science_team.templates import(
|
25
|
+
node_func_execute_agent_code_on_data,
|
26
|
+
node_func_human_review,
|
27
|
+
node_func_fix_agent_code,
|
28
|
+
node_func_explain_agent_code,
|
29
|
+
create_coding_agent_graph,
|
30
|
+
BaseAgent,
|
31
|
+
)
|
32
|
+
from ai_data_science_team.tools.parsers import PythonOutputParser
|
33
|
+
from ai_data_science_team.tools.regex import (
|
34
|
+
relocate_imports_inside_function,
|
35
|
+
add_comments_to_top,
|
36
|
+
format_agent_name,
|
37
|
+
format_recommended_steps
|
38
|
+
)
|
39
|
+
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
40
|
+
from ai_data_science_team.tools.logging import log_ai_function
|
41
|
+
from ai_data_science_team.utils.plotly import plotly_from_dict
|
42
|
+
|
43
|
+
# Setup
|
44
|
+
AGENT_NAME = "data_visualization_agent"
|
45
|
+
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
46
|
+
|
47
|
+
# Class
|
48
|
+
|
49
|
+
class DataVisualizationAgent(BaseAgent):
|
50
|
+
"""
|
51
|
+
Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
|
52
|
+
default visualization steps (if any). The agent generates a Python function to produce the visualization,
|
53
|
+
executes it, and logs the process, including code and errors. It is designed to facilitate reproducible
|
54
|
+
and customizable data visualization workflows.
|
55
|
+
|
56
|
+
The agent may use default instructions for creating charts unless instructed otherwise, such as:
|
57
|
+
- Generating a recommended chart type (bar, scatter, line, etc.)
|
58
|
+
- Creating user-friendly titles and axis labels
|
59
|
+
- Applying consistent styling (template, font sizes, color themes)
|
60
|
+
- Handling theme details (white background, base font size, line size, etc.)
|
61
|
+
|
62
|
+
User instructions can modify, add, or remove any of these steps to tailor the visualization process.
|
63
|
+
|
64
|
+
Parameters
|
65
|
+
----------
|
66
|
+
model : langchain.llms.base.LLM
|
67
|
+
The language model used to generate the data visualization function.
|
68
|
+
n_samples : int, optional
|
69
|
+
Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
|
70
|
+
Reducing this number can help avoid exceeding the model's token limits.
|
71
|
+
log : bool, optional
|
72
|
+
Whether to log the generated code and errors. Defaults to False.
|
73
|
+
log_path : str, optional
|
74
|
+
Directory path for storing log files. Defaults to None.
|
75
|
+
file_name : str, optional
|
76
|
+
Name of the file for saving the generated response. Defaults to "data_visualization.py".
|
77
|
+
function_name : str, optional
|
78
|
+
Name of the function for data visualization. Defaults to "data_visualization".
|
79
|
+
overwrite : bool, optional
|
80
|
+
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
|
81
|
+
human_in_the_loop : bool, optional
|
82
|
+
Enables user review of data visualization instructions. Defaults to False.
|
83
|
+
bypass_recommended_steps : bool, optional
|
84
|
+
If True, skips the default recommended visualization steps. Defaults to False.
|
85
|
+
bypass_explain_code : bool, optional
|
86
|
+
If True, skips the step that provides code explanations. Defaults to False.
|
87
|
+
|
88
|
+
Methods
|
89
|
+
-------
|
90
|
+
update_params(**kwargs)
|
91
|
+
Updates the agent's parameters and rebuilds the compiled state graph.
|
92
|
+
ainvoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
|
93
|
+
Asynchronously generates a visualization based on user instructions.
|
94
|
+
invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
|
95
|
+
Synchronously generates a visualization based on user instructions.
|
96
|
+
explain_visualization_steps()
|
97
|
+
Returns an explanation of the visualization steps performed by the agent.
|
98
|
+
get_log_summary()
|
99
|
+
Retrieves a summary of logged operations if logging is enabled.
|
100
|
+
get_plotly_graph()
|
101
|
+
Retrieves the Plotly graph (as a dictionary) produced by the agent.
|
102
|
+
get_data_raw()
|
103
|
+
Retrieves the raw dataset as a pandas DataFrame (based on the last response).
|
104
|
+
get_data_visualization_function()
|
105
|
+
Retrieves the generated Python function used for data visualization.
|
106
|
+
get_recommended_visualization_steps()
|
107
|
+
Retrieves the agent's recommended visualization steps.
|
108
|
+
get_response()
|
109
|
+
Returns the response from the agent as a dictionary.
|
110
|
+
show()
|
111
|
+
Displays the agent's mermaid diagram.
|
112
|
+
|
113
|
+
Examples
|
114
|
+
--------
|
115
|
+
```python
|
116
|
+
import pandas as pd
|
117
|
+
from langchain_openai import ChatOpenAI
|
118
|
+
from ai_data_science_team.agents import DataVisualizationAgent
|
119
|
+
|
120
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
121
|
+
|
122
|
+
data_visualization_agent = DataVisualizationAgent(
|
123
|
+
model=llm,
|
124
|
+
n_samples=30,
|
125
|
+
log=True,
|
126
|
+
log_path="logs",
|
127
|
+
human_in_the_loop=True
|
128
|
+
)
|
129
|
+
|
130
|
+
df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
|
131
|
+
|
132
|
+
data_visualization_agent.invoke_agent(
|
133
|
+
user_instructions="Generate a scatter plot of age vs. total charges with a trend line.",
|
134
|
+
data_raw=df,
|
135
|
+
max_retries=3,
|
136
|
+
retry_count=0
|
137
|
+
)
|
138
|
+
|
139
|
+
plotly_graph_dict = data_visualization_agent.get_plotly_graph()
|
140
|
+
# You can render plotly_graph_dict with plotly.io.from_json or
|
141
|
+
# something similar in a Jupyter Notebook.
|
142
|
+
|
143
|
+
response = data_visualization_agent.get_response()
|
144
|
+
```
|
145
|
+
|
146
|
+
Returns
|
147
|
+
--------
|
148
|
+
DataVisualizationAgent : langchain.graphs.CompiledStateGraph
|
149
|
+
A data visualization agent implemented as a compiled state graph.
|
150
|
+
"""
|
151
|
+
|
152
|
+
def __init__(
|
153
|
+
self,
|
154
|
+
model,
|
155
|
+
n_samples=30,
|
156
|
+
log=False,
|
157
|
+
log_path=None,
|
158
|
+
file_name="data_visualization.py",
|
159
|
+
function_name="data_visualization",
|
160
|
+
overwrite=True,
|
161
|
+
human_in_the_loop=False,
|
162
|
+
bypass_recommended_steps=False,
|
163
|
+
bypass_explain_code=False
|
164
|
+
):
|
165
|
+
self._params = {
|
166
|
+
"model": model,
|
167
|
+
"n_samples": n_samples,
|
168
|
+
"log": log,
|
169
|
+
"log_path": log_path,
|
170
|
+
"file_name": file_name,
|
171
|
+
"function_name": function_name,
|
172
|
+
"overwrite": overwrite,
|
173
|
+
"human_in_the_loop": human_in_the_loop,
|
174
|
+
"bypass_recommended_steps": bypass_recommended_steps,
|
175
|
+
"bypass_explain_code": bypass_explain_code,
|
176
|
+
}
|
177
|
+
self._compiled_graph = self._make_compiled_graph()
|
178
|
+
self.response = None
|
179
|
+
|
180
|
+
def _make_compiled_graph(self):
|
181
|
+
"""
|
182
|
+
Create the compiled graph for the data visualization agent.
|
183
|
+
Running this method will reset the response to None.
|
184
|
+
"""
|
185
|
+
self.response = None
|
186
|
+
return make_data_visualization_agent(**self._params)
|
187
|
+
|
188
|
+
def update_params(self, **kwargs):
|
189
|
+
"""
|
190
|
+
Updates the agent's parameters and rebuilds the compiled graph.
|
191
|
+
"""
|
192
|
+
# Update parameters
|
193
|
+
for k, v in kwargs.items():
|
194
|
+
self._params[k] = v
|
195
|
+
# Rebuild the compiled graph
|
196
|
+
self._compiled_graph = self._make_compiled_graph()
|
197
|
+
|
198
|
+
def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
199
|
+
"""
|
200
|
+
Asynchronously invokes the agent to generate a visualization.
|
201
|
+
The response is stored in the 'response' attribute.
|
202
|
+
|
203
|
+
Parameters
|
204
|
+
----------
|
205
|
+
data_raw : pd.DataFrame
|
206
|
+
The raw dataset to be visualized.
|
207
|
+
user_instructions : str
|
208
|
+
Instructions for data visualization.
|
209
|
+
max_retries : int
|
210
|
+
Maximum retry attempts.
|
211
|
+
retry_count : int
|
212
|
+
Current retry attempt count.
|
213
|
+
**kwargs : dict
|
214
|
+
Additional keyword arguments passed to ainvoke().
|
215
|
+
|
216
|
+
Returns
|
217
|
+
-------
|
218
|
+
None
|
219
|
+
"""
|
220
|
+
response = self._compiled_graph.ainvoke({
|
221
|
+
"user_instructions": user_instructions,
|
222
|
+
"data_raw": data_raw.to_dict(),
|
223
|
+
"max_retries": max_retries,
|
224
|
+
"retry_count": retry_count,
|
225
|
+
}, **kwargs)
|
226
|
+
self.response = response
|
227
|
+
return None
|
228
|
+
|
229
|
+
def invoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
230
|
+
"""
|
231
|
+
Synchronously invokes the agent to generate a visualization.
|
232
|
+
The response is stored in the 'response' attribute.
|
233
|
+
|
234
|
+
Parameters
|
235
|
+
----------
|
236
|
+
data_raw : pd.DataFrame
|
237
|
+
The raw dataset to be visualized.
|
238
|
+
user_instructions : str
|
239
|
+
Instructions for data visualization agent.
|
240
|
+
max_retries : int
|
241
|
+
Maximum retry attempts.
|
242
|
+
retry_count : int
|
243
|
+
Current retry attempt count.
|
244
|
+
**kwargs : dict
|
245
|
+
Additional keyword arguments passed to invoke().
|
246
|
+
|
247
|
+
Returns
|
248
|
+
-------
|
249
|
+
None
|
250
|
+
"""
|
251
|
+
response = self._compiled_graph.invoke({
|
252
|
+
"user_instructions": user_instructions,
|
253
|
+
"data_raw": data_raw.to_dict(),
|
254
|
+
"max_retries": max_retries,
|
255
|
+
"retry_count": retry_count,
|
256
|
+
}, **kwargs)
|
257
|
+
self.response = response
|
258
|
+
return None
|
259
|
+
|
260
|
+
def explain_visualization_steps(self):
|
261
|
+
"""
|
262
|
+
Provides an explanation of the visualization steps performed by the agent.
|
263
|
+
|
264
|
+
Returns
|
265
|
+
-------
|
266
|
+
str
|
267
|
+
Explanation of the visualization steps, if any are available.
|
268
|
+
"""
|
269
|
+
if self.response:
|
270
|
+
return self.response.get("messages", [])
|
271
|
+
return []
|
272
|
+
|
273
|
+
def get_log_summary(self, markdown=False):
|
274
|
+
"""
|
275
|
+
Logs a summary of the agent's operations, if logging is enabled.
|
276
|
+
|
277
|
+
Parameters
|
278
|
+
----------
|
279
|
+
markdown : bool, optional
|
280
|
+
If True, returns Markdown-formatted output.
|
281
|
+
|
282
|
+
Returns
|
283
|
+
-------
|
284
|
+
str or None
|
285
|
+
Summary of logs or None if no logs are available.
|
286
|
+
"""
|
287
|
+
if self.response and self.response.get('data_visualization_function_path'):
|
288
|
+
log_details = f"Log Path: {self.response.get('data_visualization_function_path')}"
|
289
|
+
if markdown:
|
290
|
+
return Markdown(log_details)
|
291
|
+
else:
|
292
|
+
return log_details
|
293
|
+
return None
|
294
|
+
|
295
|
+
def get_plotly_graph(self):
|
296
|
+
"""
|
297
|
+
Retrieves the Plotly graph (in dictionary form) produced by the agent.
|
298
|
+
|
299
|
+
Returns
|
300
|
+
-------
|
301
|
+
dict or None
|
302
|
+
The Plotly graph dictionary if available, otherwise None.
|
303
|
+
"""
|
304
|
+
if self.response:
|
305
|
+
return plotly_from_dict(self.response.get("plotly_graph", None))
|
306
|
+
return None
|
307
|
+
|
308
|
+
def get_data_raw(self):
|
309
|
+
"""
|
310
|
+
Retrieves the raw dataset used in the last invocation.
|
311
|
+
|
312
|
+
Returns
|
313
|
+
-------
|
314
|
+
pd.DataFrame or None
|
315
|
+
The raw dataset as a DataFrame if available, otherwise None.
|
316
|
+
"""
|
317
|
+
if self.response and self.response.get("data_raw"):
|
318
|
+
return pd.DataFrame(self.response.get("data_raw"))
|
319
|
+
return None
|
320
|
+
|
321
|
+
def get_data_visualization_function(self, markdown=False):
|
322
|
+
"""
|
323
|
+
Retrieves the generated Python function used for data visualization.
|
324
|
+
|
325
|
+
Parameters
|
326
|
+
----------
|
327
|
+
markdown : bool, optional
|
328
|
+
If True, returns the function in Markdown code block format.
|
329
|
+
|
330
|
+
Returns
|
331
|
+
-------
|
332
|
+
str or None
|
333
|
+
The Python function code as a string if available, otherwise None.
|
334
|
+
"""
|
335
|
+
if self.response:
|
336
|
+
func_code = self.response.get("data_visualization_function", "")
|
337
|
+
if markdown:
|
338
|
+
return Markdown(f"```python\n{func_code}\n```")
|
339
|
+
return func_code
|
340
|
+
return None
|
341
|
+
|
342
|
+
def get_recommended_visualization_steps(self, markdown=False):
|
343
|
+
"""
|
344
|
+
Retrieves the agent's recommended visualization steps.
|
345
|
+
|
346
|
+
Parameters
|
347
|
+
----------
|
348
|
+
markdown : bool, optional
|
349
|
+
If True, returns the steps in Markdown format.
|
350
|
+
|
351
|
+
Returns
|
352
|
+
-------
|
353
|
+
str or None
|
354
|
+
The recommended steps if available, otherwise None.
|
355
|
+
"""
|
356
|
+
if self.response:
|
357
|
+
steps = self.response.get("recommended_steps", "")
|
358
|
+
if markdown:
|
359
|
+
return Markdown(steps)
|
360
|
+
return steps
|
361
|
+
return None
|
362
|
+
|
363
|
+
def get_response(self):
|
364
|
+
"""
|
365
|
+
Returns the agent's full response dictionary.
|
366
|
+
|
367
|
+
Returns
|
368
|
+
-------
|
369
|
+
dict or None
|
370
|
+
The response dictionary if available, otherwise None.
|
371
|
+
"""
|
372
|
+
return self.response
|
373
|
+
|
374
|
+
def show(self):
|
375
|
+
"""
|
376
|
+
Displays the agent's mermaid diagram for visual inspection of the compiled graph.
|
377
|
+
"""
|
378
|
+
return self._compiled_graph.show()
|
379
|
+
|
380
|
+
|
381
|
+
# Agent
|
382
|
+
|
383
|
+
def make_data_visualization_agent(
|
384
|
+
model,
|
385
|
+
n_samples=30,
|
386
|
+
log=False,
|
387
|
+
log_path=None,
|
388
|
+
file_name="data_visualization.py",
|
389
|
+
function_name="data_visualization",
|
390
|
+
overwrite=True,
|
391
|
+
human_in_the_loop=False,
|
392
|
+
bypass_recommended_steps=False,
|
393
|
+
bypass_explain_code=False
|
394
|
+
):
|
395
|
+
"""
|
396
|
+
Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
|
397
|
+
default visualization steps. The agent generates a Python function to produce the visualization, executes it,
|
398
|
+
and logs the process, including code and errors. It is designed to facilitate reproducible and customizable
|
399
|
+
data visualization workflows.
|
400
|
+
|
401
|
+
The agent can perform the following default visualization steps unless instructed otherwise:
|
402
|
+
- Generating a recommended chart type (bar, scatter, line, etc.)
|
403
|
+
- Creating user-friendly titles and axis labels
|
404
|
+
- Applying consistent styling (template, font sizes, color themes)
|
405
|
+
- Handling theme details (white background, base font size, line size, etc.)
|
406
|
+
|
407
|
+
User instructions can modify, add, or remove any of these steps to tailor the visualization process.
|
408
|
+
|
409
|
+
Parameters
|
410
|
+
----------
|
411
|
+
model : langchain.llms.base.LLM
|
412
|
+
The language model used to generate the data visualization function.
|
413
|
+
n_samples : int, optional
|
414
|
+
Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
|
415
|
+
log : bool, optional
|
416
|
+
Whether to log the generated code and errors. Defaults to False.
|
417
|
+
log_path : str, optional
|
418
|
+
Directory path for storing log files. Defaults to None.
|
419
|
+
file_name : str, optional
|
420
|
+
Name of the file for saving the generated response. Defaults to "data_visualization.py".
|
421
|
+
function_name : str, optional
|
422
|
+
Name of the function for data visualization. Defaults to "data_visualization".
|
423
|
+
overwrite : bool, optional
|
424
|
+
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
|
425
|
+
human_in_the_loop : bool, optional
|
426
|
+
Enables user review of data visualization instructions. Defaults to False.
|
427
|
+
bypass_recommended_steps : bool, optional
|
428
|
+
If True, skips the default recommended visualization steps. Defaults to False.
|
429
|
+
bypass_explain_code : bool, optional
|
430
|
+
If True, skips the step that provides code explanations. Defaults to False.
|
431
|
+
|
432
|
+
Examples
|
433
|
+
--------
|
434
|
+
``` python
|
435
|
+
import pandas as pd
|
436
|
+
from langchain_openai import ChatOpenAI
|
437
|
+
from ai_data_science_team.agents import data_visualization_agent
|
438
|
+
|
439
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
440
|
+
|
441
|
+
data_visualization_agent = make_data_visualization_agent(llm)
|
442
|
+
|
443
|
+
df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
|
444
|
+
|
445
|
+
response = data_visualization_agent.invoke({
|
446
|
+
"user_instructions": "Generate a scatter plot of tenure vs. total charges with a trend line.",
|
447
|
+
"data_raw": df.to_dict(),
|
448
|
+
"max_retries": 3,
|
449
|
+
"retry_count": 0
|
450
|
+
})
|
451
|
+
|
452
|
+
pd.DataFrame(response['plotly_graph'])
|
453
|
+
```
|
454
|
+
|
455
|
+
Returns
|
456
|
+
-------
|
457
|
+
app : langchain.graphs.CompiledStateGraph
|
458
|
+
The data visualization agent as a state graph.
|
459
|
+
"""
|
460
|
+
|
461
|
+
llm = model
|
462
|
+
|
463
|
+
# Human in th loop requires recommended steps
|
464
|
+
if bypass_recommended_steps and human_in_the_loop:
|
465
|
+
bypass_recommended_steps = False
|
466
|
+
print("Bypass recommended steps set to False to enable human in the loop.")
|
467
|
+
|
468
|
+
# Setup Log Directory
|
469
|
+
if log:
|
470
|
+
if log_path is None:
|
471
|
+
log_path = LOG_PATH
|
472
|
+
if not os.path.exists(log_path):
|
473
|
+
os.makedirs(log_path)
|
474
|
+
|
475
|
+
# Define GraphState for the router
|
476
|
+
class GraphState(TypedDict):
|
477
|
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
478
|
+
user_instructions: str
|
479
|
+
user_instructions_processed: str
|
480
|
+
recommended_steps: str
|
481
|
+
data_raw: dict
|
482
|
+
plotly_graph: dict
|
483
|
+
all_datasets_summary: str
|
484
|
+
data_visualization_function: str
|
485
|
+
data_visualization_function_path: str
|
486
|
+
data_visualization_function_file_name: str
|
487
|
+
data_visualization_function_name: str
|
488
|
+
data_visualization_error: str
|
489
|
+
max_retries: int
|
490
|
+
retry_count: int
|
491
|
+
|
492
|
+
def chart_instructor(state: GraphState):
|
493
|
+
|
494
|
+
print(format_agent_name(AGENT_NAME))
|
495
|
+
print(" * CREATE CHART GENERATOR INSTRUCTIONS")
|
496
|
+
|
497
|
+
recommend_steps_prompt = PromptTemplate(
|
498
|
+
template="""
|
499
|
+
You are a supervisor that is an expert in providing instructions to a chart generator agent for plotting.
|
500
|
+
|
501
|
+
You will take a question that a user has and the data that was generated to answer the question, and create instructions to create a chart from the data that will be passed to a chart generator agent.
|
502
|
+
|
503
|
+
USER QUESTION / INSTRUCTIONS:
|
504
|
+
{user_instructions}
|
505
|
+
|
506
|
+
Previously Recommended Instructions (if any):
|
507
|
+
{recommended_steps}
|
508
|
+
|
509
|
+
DATA:
|
510
|
+
{all_datasets_summary}
|
511
|
+
|
512
|
+
Formulate chart generator instructions by informing the chart generator of what type of plotly plot to use (e.g. bar, line, scatter, etc) to best represent the data.
|
513
|
+
|
514
|
+
Come up with an informative title from the user's question and data provided. Also provide X and Y axis titles.
|
515
|
+
|
516
|
+
Instruct the chart generator to use the following theme colors, sizes, etc:
|
517
|
+
|
518
|
+
- Start with the "plotly_white" template
|
519
|
+
- Use a white background
|
520
|
+
- Use this color for bars and lines:
|
521
|
+
'blue': '#3381ff',
|
522
|
+
- Base Font Size: 8.8 (Used for x and y axes tickfont, any annotations, hovertips)
|
523
|
+
- Title Font Size: 13.2
|
524
|
+
- Line Size: 0.65 (specify these within the xaxis and yaxis dictionaries)
|
525
|
+
- Add smoothers or trendlines to scatter plots unless not desired by the user
|
526
|
+
- Do not use color_discrete_map (this will result in an error)
|
527
|
+
- Hover tip size: 8.8
|
528
|
+
|
529
|
+
Return your instructions in the following format:
|
530
|
+
CHART GENERATOR INSTRUCTIONS:
|
531
|
+
FILL IN THE INSTRUCTIONS HERE
|
532
|
+
|
533
|
+
Avoid these:
|
534
|
+
1. Do not include steps to save files.
|
535
|
+
2. Do not include unrelated user instructions that are not related to the chart generation.
|
536
|
+
""",
|
537
|
+
input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
|
538
|
+
|
539
|
+
)
|
540
|
+
|
541
|
+
data_raw = state.get("data_raw")
|
542
|
+
df = pd.DataFrame.from_dict(data_raw)
|
543
|
+
|
544
|
+
all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
|
545
|
+
|
546
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
547
|
+
|
548
|
+
chart_instructor = recommend_steps_prompt | llm
|
549
|
+
|
550
|
+
recommended_steps = chart_instructor.invoke({
|
551
|
+
"user_instructions": state.get("user_instructions"),
|
552
|
+
"recommended_steps": state.get("recommended_steps"),
|
553
|
+
"all_datasets_summary": all_datasets_summary_str
|
554
|
+
})
|
555
|
+
|
556
|
+
return {
|
557
|
+
"recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Cleaning Steps:"),
|
558
|
+
"all_datasets_summary": all_datasets_summary_str
|
559
|
+
}
|
560
|
+
|
561
|
+
def chart_generator(state: GraphState):
|
562
|
+
|
563
|
+
print(" * CREATE DATA VISUALIZATION CODE")
|
564
|
+
|
565
|
+
|
566
|
+
if bypass_recommended_steps:
|
567
|
+
print(format_agent_name(AGENT_NAME))
|
568
|
+
|
569
|
+
data_raw = state.get("data_raw")
|
570
|
+
df = pd.DataFrame.from_dict(data_raw)
|
571
|
+
|
572
|
+
all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
|
573
|
+
|
574
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
575
|
+
|
576
|
+
chart_generator_instructions = state.get("user_instructions")
|
577
|
+
|
578
|
+
else:
|
579
|
+
all_datasets_summary_str = state.get("all_datasets_summary")
|
580
|
+
chart_generator_instructions = state.get("recommended_steps")
|
581
|
+
|
582
|
+
prompt_template = PromptTemplate(
|
583
|
+
template="""
|
584
|
+
You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
|
585
|
+
|
586
|
+
Your job is to produce python code to generate visualizations with a function named {function_name}.
|
587
|
+
|
588
|
+
You will take instructions from a Chart Instructor and generate a plotly chart from the data provided.
|
589
|
+
|
590
|
+
CHART INSTRUCTIONS:
|
591
|
+
{chart_generator_instructions}
|
592
|
+
|
593
|
+
DATA:
|
594
|
+
{all_datasets_summary}
|
595
|
+
|
596
|
+
RETURN:
|
597
|
+
|
598
|
+
Return Python code in ```python ``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
|
599
|
+
|
600
|
+
Return the plotly chart as a dictionary.
|
601
|
+
|
602
|
+
Return code to provide the data visualization function:
|
603
|
+
|
604
|
+
def {function_name}(data_raw):
|
605
|
+
import pandas as pd
|
606
|
+
import numpy as np
|
607
|
+
import json
|
608
|
+
import plotly.graph_objects as go
|
609
|
+
import plotly.io as pio
|
610
|
+
|
611
|
+
...
|
612
|
+
|
613
|
+
fig_json = pio.to_json(fig)
|
614
|
+
fig_dict = json.loads(fig_json)
|
615
|
+
|
616
|
+
return fig_dict
|
617
|
+
|
618
|
+
Avoid these:
|
619
|
+
1. Do not include steps to save files.
|
620
|
+
2. Do not include unrelated user instructions that are not related to the chart generation.
|
621
|
+
|
622
|
+
""",
|
623
|
+
input_variables=["chart_generator_instructions", "all_datasets_summary", "function_name"]
|
624
|
+
)
|
625
|
+
|
626
|
+
data_visualization_agent = prompt_template | llm | PythonOutputParser()
|
627
|
+
|
628
|
+
response = data_visualization_agent.invoke({
|
629
|
+
"chart_generator_instructions": chart_generator_instructions,
|
630
|
+
"all_datasets_summary": all_datasets_summary_str,
|
631
|
+
"function_name": function_name
|
632
|
+
})
|
633
|
+
|
634
|
+
response = relocate_imports_inside_function(response)
|
635
|
+
response = add_comments_to_top(response, agent_name=AGENT_NAME)
|
636
|
+
|
637
|
+
# For logging: store the code generated:
|
638
|
+
file_path, file_name_2 = log_ai_function(
|
639
|
+
response=response,
|
640
|
+
file_name=file_name,
|
641
|
+
log=log,
|
642
|
+
log_path=log_path,
|
643
|
+
overwrite=overwrite
|
644
|
+
)
|
645
|
+
|
646
|
+
return {
|
647
|
+
"data_visualization_function": response,
|
648
|
+
"data_visualization_function_path": file_path,
|
649
|
+
"data_visualization_function_file_name": file_name_2,
|
650
|
+
"data_visualization_function_name": function_name,
|
651
|
+
"all_datasets_summary": all_datasets_summary_str
|
652
|
+
}
|
653
|
+
|
654
|
+
# Human Review
|
655
|
+
|
656
|
+
prompt_text_human_review = "Are the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
|
657
|
+
|
658
|
+
if not bypass_explain_code:
|
659
|
+
def human_review(state: GraphState) -> Command[Literal["chart_instructor", "explain_data_visualization_code"]]:
|
660
|
+
return node_func_human_review(
|
661
|
+
state=state,
|
662
|
+
prompt_text=prompt_text_human_review,
|
663
|
+
yes_goto= 'explain_data_visualization_code',
|
664
|
+
no_goto="chart_instructor",
|
665
|
+
user_instructions_key="user_instructions",
|
666
|
+
recommended_steps_key="recommended_steps",
|
667
|
+
code_snippet_key="data_visualization_function",
|
668
|
+
)
|
669
|
+
else:
|
670
|
+
def human_review(state: GraphState) -> Command[Literal["chart_instructor", "__end__"]]:
|
671
|
+
return node_func_human_review(
|
672
|
+
state=state,
|
673
|
+
prompt_text=prompt_text_human_review,
|
674
|
+
yes_goto= '__end__',
|
675
|
+
no_goto="chart_instructor",
|
676
|
+
user_instructions_key="user_instructions",
|
677
|
+
recommended_steps_key="recommended_steps",
|
678
|
+
code_snippet_key="data_visualization_function",
|
679
|
+
)
|
680
|
+
|
681
|
+
|
682
|
+
def execute_data_visualization_code(state):
|
683
|
+
return node_func_execute_agent_code_on_data(
|
684
|
+
state=state,
|
685
|
+
data_key="data_raw",
|
686
|
+
result_key="plotly_graph",
|
687
|
+
error_key="data_visualization_error",
|
688
|
+
code_snippet_key="data_visualization_function",
|
689
|
+
agent_function_name=state.get("data_visualization_function_name"),
|
690
|
+
pre_processing=lambda data: pd.DataFrame.from_dict(data),
|
691
|
+
# post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
692
|
+
error_message_prefix="An error occurred during data visualization: "
|
693
|
+
)
|
694
|
+
|
695
|
+
def fix_data_visualization_code(state: GraphState):
|
696
|
+
prompt = """
|
697
|
+
You are a Data Visualization Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
|
698
|
+
|
699
|
+
Make sure to only return the function definition for {function_name}().
|
700
|
+
|
701
|
+
Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
|
702
|
+
|
703
|
+
This is the broken code (please fix):
|
704
|
+
{code_snippet}
|
705
|
+
|
706
|
+
Last Known Error:
|
707
|
+
{error}
|
708
|
+
"""
|
709
|
+
|
710
|
+
return node_func_fix_agent_code(
|
711
|
+
state=state,
|
712
|
+
code_snippet_key="data_visualization_function",
|
713
|
+
error_key="data_visualization_error",
|
714
|
+
llm=llm,
|
715
|
+
prompt_template=prompt,
|
716
|
+
agent_name=AGENT_NAME,
|
717
|
+
log=log,
|
718
|
+
file_path=state.get("data_visualization_function_path"),
|
719
|
+
function_name=state.get("data_visualization_function_name"),
|
720
|
+
)
|
721
|
+
|
722
|
+
def explain_data_visualization_code(state: GraphState):
|
723
|
+
return node_func_explain_agent_code(
|
724
|
+
state=state,
|
725
|
+
code_snippet_key="data_visualization_function",
|
726
|
+
result_key="messages",
|
727
|
+
error_key="data_visualization_error",
|
728
|
+
llm=llm,
|
729
|
+
role=AGENT_NAME,
|
730
|
+
explanation_prompt_template="""
|
731
|
+
Explain the data visualization steps that the data visualization agent performed in this function.
|
732
|
+
Keep the summary succinct and to the point.\n\n# Data Visualization Agent:\n\n{code}
|
733
|
+
""",
|
734
|
+
success_prefix="# Data Visualization Agent:\n\n ",
|
735
|
+
error_message="The Data Visualization Agent encountered an error during data visualization. No explanation could be provided."
|
736
|
+
)
|
737
|
+
|
738
|
+
# Define the graph
|
739
|
+
node_functions = {
|
740
|
+
"chart_instructor": chart_instructor,
|
741
|
+
"human_review": human_review,
|
742
|
+
"chart_generator": chart_generator,
|
743
|
+
"execute_data_visualization_code": execute_data_visualization_code,
|
744
|
+
"fix_data_visualization_code": fix_data_visualization_code,
|
745
|
+
"explain_data_visualization_code": explain_data_visualization_code
|
746
|
+
}
|
747
|
+
|
748
|
+
app = create_coding_agent_graph(
|
749
|
+
GraphState=GraphState,
|
750
|
+
node_functions=node_functions,
|
751
|
+
recommended_steps_node_name="chart_instructor",
|
752
|
+
create_code_node_name="chart_generator",
|
753
|
+
execute_code_node_name="execute_data_visualization_code",
|
754
|
+
fix_code_node_name="fix_data_visualization_code",
|
755
|
+
explain_code_node_name="explain_data_visualization_code",
|
756
|
+
error_key="data_visualization_error",
|
757
|
+
human_in_the_loop=human_in_the_loop, # or False
|
758
|
+
human_review_node_name="human_review",
|
759
|
+
checkpointer=MemorySaver() if human_in_the_loop else None,
|
760
|
+
bypass_recommended_steps=bypass_recommended_steps,
|
761
|
+
bypass_explain_code=bypass_explain_code,
|
762
|
+
)
|
763
|
+
|
764
|
+
return app
|