ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +4 -5
- ai_data_science_team/agents/data_cleaning_agent.py +268 -116
- ai_data_science_team/agents/data_visualization_agent.py +470 -41
- ai_data_science_team/agents/data_wrangling_agent.py +471 -31
- ai_data_science_team/agents/feature_engineering_agent.py +426 -41
- ai_data_science_team/agents/sql_database_agent.py +458 -58
- ai_data_science_team/ml_agents/__init__.py +1 -0
- ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +3 -1
- ai_data_science_team/templates/agent_templates.py +319 -43
- ai_data_science_team/tools/metadata.py +94 -62
- ai_data_science_team/tools/regex.py +86 -1
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
- ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
- ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
- ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -17,25 +17,363 @@ from langgraph.types import Command
|
|
17
17
|
from langgraph.checkpoint.memory import MemorySaver
|
18
18
|
|
19
19
|
import os
|
20
|
-
import
|
20
|
+
import json
|
21
21
|
import pandas as pd
|
22
22
|
|
23
|
+
from IPython.display import Markdown
|
24
|
+
|
23
25
|
from ai_data_science_team.templates import(
|
24
26
|
node_func_execute_agent_code_on_data,
|
25
27
|
node_func_human_review,
|
26
28
|
node_func_fix_agent_code,
|
27
|
-
|
28
|
-
create_coding_agent_graph
|
29
|
+
node_func_report_agent_outputs,
|
30
|
+
create_coding_agent_graph,
|
31
|
+
BaseAgent,
|
29
32
|
)
|
30
33
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
31
|
-
from ai_data_science_team.tools.regex import
|
34
|
+
from ai_data_science_team.tools.regex import (
|
35
|
+
relocate_imports_inside_function,
|
36
|
+
add_comments_to_top,
|
37
|
+
format_agent_name,
|
38
|
+
format_recommended_steps,
|
39
|
+
get_generic_summary,
|
40
|
+
)
|
32
41
|
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
33
42
|
from ai_data_science_team.tools.logging import log_ai_function
|
43
|
+
from ai_data_science_team.utils.plotly import plotly_from_dict
|
34
44
|
|
35
45
|
# Setup
|
36
46
|
AGENT_NAME = "data_visualization_agent"
|
37
47
|
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
38
48
|
|
49
|
+
# Class
|
50
|
+
|
51
|
+
class DataVisualizationAgent(BaseAgent):
|
52
|
+
"""
|
53
|
+
Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
|
54
|
+
default visualization steps (if any). The agent generates a Python function to produce the visualization,
|
55
|
+
executes it, and logs the process, including code and errors. It is designed to facilitate reproducible
|
56
|
+
and customizable data visualization workflows.
|
57
|
+
|
58
|
+
The agent may use default instructions for creating charts unless instructed otherwise, such as:
|
59
|
+
- Generating a recommended chart type (bar, scatter, line, etc.)
|
60
|
+
- Creating user-friendly titles and axis labels
|
61
|
+
- Applying consistent styling (template, font sizes, color themes)
|
62
|
+
- Handling theme details (white background, base font size, line size, etc.)
|
63
|
+
|
64
|
+
User instructions can modify, add, or remove any of these steps to tailor the visualization process.
|
65
|
+
|
66
|
+
Parameters
|
67
|
+
----------
|
68
|
+
model : langchain.llms.base.LLM
|
69
|
+
The language model used to generate the data visualization function.
|
70
|
+
n_samples : int, optional
|
71
|
+
Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
|
72
|
+
Reducing this number can help avoid exceeding the model's token limits.
|
73
|
+
log : bool, optional
|
74
|
+
Whether to log the generated code and errors. Defaults to False.
|
75
|
+
log_path : str, optional
|
76
|
+
Directory path for storing log files. Defaults to None.
|
77
|
+
file_name : str, optional
|
78
|
+
Name of the file for saving the generated response. Defaults to "data_visualization.py".
|
79
|
+
function_name : str, optional
|
80
|
+
Name of the function for data visualization. Defaults to "data_visualization".
|
81
|
+
overwrite : bool, optional
|
82
|
+
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
|
83
|
+
human_in_the_loop : bool, optional
|
84
|
+
Enables user review of data visualization instructions. Defaults to False.
|
85
|
+
bypass_recommended_steps : bool, optional
|
86
|
+
If True, skips the default recommended visualization steps. Defaults to False.
|
87
|
+
bypass_explain_code : bool, optional
|
88
|
+
If True, skips the step that provides code explanations. Defaults to False.
|
89
|
+
|
90
|
+
Methods
|
91
|
+
-------
|
92
|
+
update_params(**kwargs)
|
93
|
+
Updates the agent's parameters and rebuilds the compiled state graph.
|
94
|
+
ainvoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
|
95
|
+
Asynchronously generates a visualization based on user instructions.
|
96
|
+
invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
|
97
|
+
Synchronously generates a visualization based on user instructions.
|
98
|
+
get_workflow_summary()
|
99
|
+
Retrieves a summary of the agent's workflow.
|
100
|
+
get_log_summary()
|
101
|
+
Retrieves a summary of logged operations if logging is enabled.
|
102
|
+
get_plotly_graph()
|
103
|
+
Retrieves the Plotly graph (as a dictionary) produced by the agent.
|
104
|
+
get_data_raw()
|
105
|
+
Retrieves the raw dataset as a pandas DataFrame (based on the last response).
|
106
|
+
get_data_visualization_function()
|
107
|
+
Retrieves the generated Python function used for data visualization.
|
108
|
+
get_recommended_visualization_steps()
|
109
|
+
Retrieves the agent's recommended visualization steps.
|
110
|
+
get_response()
|
111
|
+
Returns the response from the agent as a dictionary.
|
112
|
+
show()
|
113
|
+
Displays the agent's mermaid diagram.
|
114
|
+
|
115
|
+
Examples
|
116
|
+
--------
|
117
|
+
```python
|
118
|
+
import pandas as pd
|
119
|
+
from langchain_openai import ChatOpenAI
|
120
|
+
from ai_data_science_team.agents import DataVisualizationAgent
|
121
|
+
|
122
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
123
|
+
|
124
|
+
data_visualization_agent = DataVisualizationAgent(
|
125
|
+
model=llm,
|
126
|
+
n_samples=30,
|
127
|
+
log=True,
|
128
|
+
log_path="logs",
|
129
|
+
human_in_the_loop=True
|
130
|
+
)
|
131
|
+
|
132
|
+
df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
|
133
|
+
|
134
|
+
data_visualization_agent.invoke_agent(
|
135
|
+
user_instructions="Generate a scatter plot of age vs. total charges with a trend line.",
|
136
|
+
data_raw=df,
|
137
|
+
max_retries=3,
|
138
|
+
retry_count=0
|
139
|
+
)
|
140
|
+
|
141
|
+
plotly_graph_dict = data_visualization_agent.get_plotly_graph()
|
142
|
+
# You can render plotly_graph_dict with plotly.io.from_json or
|
143
|
+
# something similar in a Jupyter Notebook.
|
144
|
+
|
145
|
+
response = data_visualization_agent.get_response()
|
146
|
+
```
|
147
|
+
|
148
|
+
Returns
|
149
|
+
--------
|
150
|
+
DataVisualizationAgent : langchain.graphs.CompiledStateGraph
|
151
|
+
A data visualization agent implemented as a compiled state graph.
|
152
|
+
"""
|
153
|
+
|
154
|
+
def __init__(
|
155
|
+
self,
|
156
|
+
model,
|
157
|
+
n_samples=30,
|
158
|
+
log=False,
|
159
|
+
log_path=None,
|
160
|
+
file_name="data_visualization.py",
|
161
|
+
function_name="data_visualization",
|
162
|
+
overwrite=True,
|
163
|
+
human_in_the_loop=False,
|
164
|
+
bypass_recommended_steps=False,
|
165
|
+
bypass_explain_code=False
|
166
|
+
):
|
167
|
+
self._params = {
|
168
|
+
"model": model,
|
169
|
+
"n_samples": n_samples,
|
170
|
+
"log": log,
|
171
|
+
"log_path": log_path,
|
172
|
+
"file_name": file_name,
|
173
|
+
"function_name": function_name,
|
174
|
+
"overwrite": overwrite,
|
175
|
+
"human_in_the_loop": human_in_the_loop,
|
176
|
+
"bypass_recommended_steps": bypass_recommended_steps,
|
177
|
+
"bypass_explain_code": bypass_explain_code,
|
178
|
+
}
|
179
|
+
self._compiled_graph = self._make_compiled_graph()
|
180
|
+
self.response = None
|
181
|
+
|
182
|
+
def _make_compiled_graph(self):
|
183
|
+
"""
|
184
|
+
Create the compiled graph for the data visualization agent.
|
185
|
+
Running this method will reset the response to None.
|
186
|
+
"""
|
187
|
+
self.response = None
|
188
|
+
return make_data_visualization_agent(**self._params)
|
189
|
+
|
190
|
+
def update_params(self, **kwargs):
|
191
|
+
"""
|
192
|
+
Updates the agent's parameters and rebuilds the compiled graph.
|
193
|
+
"""
|
194
|
+
# Update parameters
|
195
|
+
for k, v in kwargs.items():
|
196
|
+
self._params[k] = v
|
197
|
+
# Rebuild the compiled graph
|
198
|
+
self._compiled_graph = self._make_compiled_graph()
|
199
|
+
|
200
|
+
def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
201
|
+
"""
|
202
|
+
Asynchronously invokes the agent to generate a visualization.
|
203
|
+
The response is stored in the 'response' attribute.
|
204
|
+
|
205
|
+
Parameters
|
206
|
+
----------
|
207
|
+
data_raw : pd.DataFrame
|
208
|
+
The raw dataset to be visualized.
|
209
|
+
user_instructions : str
|
210
|
+
Instructions for data visualization.
|
211
|
+
max_retries : int
|
212
|
+
Maximum retry attempts.
|
213
|
+
retry_count : int
|
214
|
+
Current retry attempt count.
|
215
|
+
**kwargs : dict
|
216
|
+
Additional keyword arguments passed to ainvoke().
|
217
|
+
|
218
|
+
Returns
|
219
|
+
-------
|
220
|
+
None
|
221
|
+
"""
|
222
|
+
response = self._compiled_graph.ainvoke({
|
223
|
+
"user_instructions": user_instructions,
|
224
|
+
"data_raw": data_raw.to_dict(),
|
225
|
+
"max_retries": max_retries,
|
226
|
+
"retry_count": retry_count,
|
227
|
+
}, **kwargs)
|
228
|
+
self.response = response
|
229
|
+
return None
|
230
|
+
|
231
|
+
def invoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
232
|
+
"""
|
233
|
+
Synchronously invokes the agent to generate a visualization.
|
234
|
+
The response is stored in the 'response' attribute.
|
235
|
+
|
236
|
+
Parameters
|
237
|
+
----------
|
238
|
+
data_raw : pd.DataFrame
|
239
|
+
The raw dataset to be visualized.
|
240
|
+
user_instructions : str
|
241
|
+
Instructions for data visualization agent.
|
242
|
+
max_retries : int
|
243
|
+
Maximum retry attempts.
|
244
|
+
retry_count : int
|
245
|
+
Current retry attempt count.
|
246
|
+
**kwargs : dict
|
247
|
+
Additional keyword arguments passed to invoke().
|
248
|
+
|
249
|
+
Returns
|
250
|
+
-------
|
251
|
+
None
|
252
|
+
"""
|
253
|
+
response = self._compiled_graph.invoke({
|
254
|
+
"user_instructions": user_instructions,
|
255
|
+
"data_raw": data_raw.to_dict(),
|
256
|
+
"max_retries": max_retries,
|
257
|
+
"retry_count": retry_count,
|
258
|
+
}, **kwargs)
|
259
|
+
self.response = response
|
260
|
+
return None
|
261
|
+
|
262
|
+
def get_workflow_summary(self, markdown=False):
|
263
|
+
"""
|
264
|
+
Retrieves the agent's workflow summary, if logging is enabled.
|
265
|
+
"""
|
266
|
+
if self.response and self.response.get("messages"):
|
267
|
+
summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
|
268
|
+
if markdown:
|
269
|
+
return Markdown(summary)
|
270
|
+
else:
|
271
|
+
return summary
|
272
|
+
|
273
|
+
def get_log_summary(self, markdown=False):
|
274
|
+
"""
|
275
|
+
Logs a summary of the agent's operations, if logging is enabled.
|
276
|
+
"""
|
277
|
+
if self.response:
|
278
|
+
if self.response.get('data_visualization_function_path'):
|
279
|
+
log_details = f"""
|
280
|
+
## Data Visualization Agent Log Summary:
|
281
|
+
|
282
|
+
Function Path: {self.response.get('data_visualization_function_path')}
|
283
|
+
|
284
|
+
Function Name: {self.response.get('data_visualization_function_name')}
|
285
|
+
"""
|
286
|
+
if markdown:
|
287
|
+
return Markdown(log_details)
|
288
|
+
else:
|
289
|
+
return log_details
|
290
|
+
|
291
|
+
def get_plotly_graph(self):
|
292
|
+
"""
|
293
|
+
Retrieves the Plotly graph (in dictionary form) produced by the agent.
|
294
|
+
|
295
|
+
Returns
|
296
|
+
-------
|
297
|
+
dict or None
|
298
|
+
The Plotly graph dictionary if available, otherwise None.
|
299
|
+
"""
|
300
|
+
if self.response:
|
301
|
+
return plotly_from_dict(self.response.get("plotly_graph", None))
|
302
|
+
return None
|
303
|
+
|
304
|
+
def get_data_raw(self):
|
305
|
+
"""
|
306
|
+
Retrieves the raw dataset used in the last invocation.
|
307
|
+
|
308
|
+
Returns
|
309
|
+
-------
|
310
|
+
pd.DataFrame or None
|
311
|
+
The raw dataset as a DataFrame if available, otherwise None.
|
312
|
+
"""
|
313
|
+
if self.response and self.response.get("data_raw"):
|
314
|
+
return pd.DataFrame(self.response.get("data_raw"))
|
315
|
+
return None
|
316
|
+
|
317
|
+
def get_data_visualization_function(self, markdown=False):
|
318
|
+
"""
|
319
|
+
Retrieves the generated Python function used for data visualization.
|
320
|
+
|
321
|
+
Parameters
|
322
|
+
----------
|
323
|
+
markdown : bool, optional
|
324
|
+
If True, returns the function in Markdown code block format.
|
325
|
+
|
326
|
+
Returns
|
327
|
+
-------
|
328
|
+
str or None
|
329
|
+
The Python function code as a string if available, otherwise None.
|
330
|
+
"""
|
331
|
+
if self.response:
|
332
|
+
func_code = self.response.get("data_visualization_function", "")
|
333
|
+
if markdown:
|
334
|
+
return Markdown(f"```python\n{func_code}\n```")
|
335
|
+
return func_code
|
336
|
+
return None
|
337
|
+
|
338
|
+
def get_recommended_visualization_steps(self, markdown=False):
|
339
|
+
"""
|
340
|
+
Retrieves the agent's recommended visualization steps.
|
341
|
+
|
342
|
+
Parameters
|
343
|
+
----------
|
344
|
+
markdown : bool, optional
|
345
|
+
If True, returns the steps in Markdown format.
|
346
|
+
|
347
|
+
Returns
|
348
|
+
-------
|
349
|
+
str or None
|
350
|
+
The recommended steps if available, otherwise None.
|
351
|
+
"""
|
352
|
+
if self.response:
|
353
|
+
steps = self.response.get("recommended_steps", "")
|
354
|
+
if markdown:
|
355
|
+
return Markdown(steps)
|
356
|
+
return steps
|
357
|
+
return None
|
358
|
+
|
359
|
+
def get_response(self):
|
360
|
+
"""
|
361
|
+
Returns the agent's full response dictionary.
|
362
|
+
|
363
|
+
Returns
|
364
|
+
-------
|
365
|
+
dict or None
|
366
|
+
The response dictionary if available, otherwise None.
|
367
|
+
"""
|
368
|
+
return self.response
|
369
|
+
|
370
|
+
def show(self):
|
371
|
+
"""
|
372
|
+
Displays the agent's mermaid diagram for visual inspection of the compiled graph.
|
373
|
+
"""
|
374
|
+
return self._compiled_graph.show()
|
375
|
+
|
376
|
+
|
39
377
|
# Agent
|
40
378
|
|
41
379
|
def make_data_visualization_agent(
|
@@ -44,14 +382,85 @@ def make_data_visualization_agent(
|
|
44
382
|
log=False,
|
45
383
|
log_path=None,
|
46
384
|
file_name="data_visualization.py",
|
47
|
-
|
385
|
+
function_name="data_visualization",
|
386
|
+
overwrite=True,
|
48
387
|
human_in_the_loop=False,
|
49
388
|
bypass_recommended_steps=False,
|
50
389
|
bypass_explain_code=False
|
51
390
|
):
|
391
|
+
"""
|
392
|
+
Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
|
393
|
+
default visualization steps. The agent generates a Python function to produce the visualization, executes it,
|
394
|
+
and logs the process, including code and errors. It is designed to facilitate reproducible and customizable
|
395
|
+
data visualization workflows.
|
396
|
+
|
397
|
+
The agent can perform the following default visualization steps unless instructed otherwise:
|
398
|
+
- Generating a recommended chart type (bar, scatter, line, etc.)
|
399
|
+
- Creating user-friendly titles and axis labels
|
400
|
+
- Applying consistent styling (template, font sizes, color themes)
|
401
|
+
- Handling theme details (white background, base font size, line size, etc.)
|
402
|
+
|
403
|
+
User instructions can modify, add, or remove any of these steps to tailor the visualization process.
|
404
|
+
|
405
|
+
Parameters
|
406
|
+
----------
|
407
|
+
model : langchain.llms.base.LLM
|
408
|
+
The language model used to generate the data visualization function.
|
409
|
+
n_samples : int, optional
|
410
|
+
Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
|
411
|
+
log : bool, optional
|
412
|
+
Whether to log the generated code and errors. Defaults to False.
|
413
|
+
log_path : str, optional
|
414
|
+
Directory path for storing log files. Defaults to None.
|
415
|
+
file_name : str, optional
|
416
|
+
Name of the file for saving the generated response. Defaults to "data_visualization.py".
|
417
|
+
function_name : str, optional
|
418
|
+
Name of the function for data visualization. Defaults to "data_visualization".
|
419
|
+
overwrite : bool, optional
|
420
|
+
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
|
421
|
+
human_in_the_loop : bool, optional
|
422
|
+
Enables user review of data visualization instructions. Defaults to False.
|
423
|
+
bypass_recommended_steps : bool, optional
|
424
|
+
If True, skips the default recommended visualization steps. Defaults to False.
|
425
|
+
bypass_explain_code : bool, optional
|
426
|
+
If True, skips the step that provides code explanations. Defaults to False.
|
427
|
+
|
428
|
+
Examples
|
429
|
+
--------
|
430
|
+
``` python
|
431
|
+
import pandas as pd
|
432
|
+
from langchain_openai import ChatOpenAI
|
433
|
+
from ai_data_science_team.agents import data_visualization_agent
|
434
|
+
|
435
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
436
|
+
|
437
|
+
data_visualization_agent = make_data_visualization_agent(llm)
|
438
|
+
|
439
|
+
df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
|
440
|
+
|
441
|
+
response = data_visualization_agent.invoke({
|
442
|
+
"user_instructions": "Generate a scatter plot of tenure vs. total charges with a trend line.",
|
443
|
+
"data_raw": df.to_dict(),
|
444
|
+
"max_retries": 3,
|
445
|
+
"retry_count": 0
|
446
|
+
})
|
447
|
+
|
448
|
+
pd.DataFrame(response['plotly_graph'])
|
449
|
+
```
|
450
|
+
|
451
|
+
Returns
|
452
|
+
-------
|
453
|
+
app : langchain.graphs.CompiledStateGraph
|
454
|
+
The data visualization agent as a state graph.
|
455
|
+
"""
|
52
456
|
|
53
457
|
llm = model
|
54
458
|
|
459
|
+
# Human in th loop requires recommended steps
|
460
|
+
if bypass_recommended_steps and human_in_the_loop:
|
461
|
+
bypass_recommended_steps = False
|
462
|
+
print("Bypass recommended steps set to False to enable human in the loop.")
|
463
|
+
|
55
464
|
# Setup Log Directory
|
56
465
|
if log:
|
57
466
|
if log_path is None:
|
@@ -70,6 +479,7 @@ def make_data_visualization_agent(
|
|
70
479
|
all_datasets_summary: str
|
71
480
|
data_visualization_function: str
|
72
481
|
data_visualization_function_path: str
|
482
|
+
data_visualization_function_file_name: str
|
73
483
|
data_visualization_function_name: str
|
74
484
|
data_visualization_error: str
|
75
485
|
max_retries: int
|
@@ -140,7 +550,7 @@ def make_data_visualization_agent(
|
|
140
550
|
})
|
141
551
|
|
142
552
|
return {
|
143
|
-
"recommended_steps": "
|
553
|
+
"recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Cleaning Steps:"),
|
144
554
|
"all_datasets_summary": all_datasets_summary_str
|
145
555
|
}
|
146
556
|
|
@@ -169,7 +579,7 @@ def make_data_visualization_agent(
|
|
169
579
|
template="""
|
170
580
|
You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
|
171
581
|
|
172
|
-
Your job is to produce python code to generate visualizations.
|
582
|
+
Your job is to produce python code to generate visualizations with a function named {function_name}.
|
173
583
|
|
174
584
|
You will take instructions from a Chart Instructor and generate a plotly chart from the data provided.
|
175
585
|
|
@@ -181,13 +591,13 @@ def make_data_visualization_agent(
|
|
181
591
|
|
182
592
|
RETURN:
|
183
593
|
|
184
|
-
Return Python code in ```python ``` format with a single function definition,
|
594
|
+
Return Python code in ```python ``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
|
185
595
|
|
186
596
|
Return the plotly chart as a dictionary.
|
187
597
|
|
188
598
|
Return code to provide the data visualization function:
|
189
599
|
|
190
|
-
def
|
600
|
+
def {function_name}(data_raw):
|
191
601
|
import pandas as pd
|
192
602
|
import numpy as np
|
193
603
|
import json
|
@@ -206,14 +616,15 @@ def make_data_visualization_agent(
|
|
206
616
|
2. Do not include unrelated user instructions that are not related to the chart generation.
|
207
617
|
|
208
618
|
""",
|
209
|
-
input_variables=["chart_generator_instructions", "all_datasets_summary"]
|
619
|
+
input_variables=["chart_generator_instructions", "all_datasets_summary", "function_name"]
|
210
620
|
)
|
211
|
-
|
621
|
+
|
212
622
|
data_visualization_agent = prompt_template | llm | PythonOutputParser()
|
213
623
|
|
214
624
|
response = data_visualization_agent.invoke({
|
215
625
|
"chart_generator_instructions": chart_generator_instructions,
|
216
|
-
"all_datasets_summary": all_datasets_summary_str
|
626
|
+
"all_datasets_summary": all_datasets_summary_str,
|
627
|
+
"function_name": function_name
|
217
628
|
})
|
218
629
|
|
219
630
|
response = relocate_imports_inside_function(response)
|
@@ -231,19 +642,37 @@ def make_data_visualization_agent(
|
|
231
642
|
return {
|
232
643
|
"data_visualization_function": response,
|
233
644
|
"data_visualization_function_path": file_path,
|
234
|
-
"
|
645
|
+
"data_visualization_function_file_name": file_name_2,
|
646
|
+
"data_visualization_function_name": function_name,
|
235
647
|
"all_datasets_summary": all_datasets_summary_str
|
236
648
|
}
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
649
|
+
|
650
|
+
# Human Review
|
651
|
+
|
652
|
+
prompt_text_human_review = "Are the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
|
653
|
+
|
654
|
+
if not bypass_explain_code:
|
655
|
+
def human_review(state: GraphState) -> Command[Literal["chart_instructor", "explain_data_visualization_code"]]:
|
656
|
+
return node_func_human_review(
|
657
|
+
state=state,
|
658
|
+
prompt_text=prompt_text_human_review,
|
659
|
+
yes_goto= 'explain_data_visualization_code',
|
660
|
+
no_goto="chart_instructor",
|
661
|
+
user_instructions_key="user_instructions",
|
662
|
+
recommended_steps_key="recommended_steps",
|
663
|
+
code_snippet_key="data_visualization_function",
|
664
|
+
)
|
665
|
+
else:
|
666
|
+
def human_review(state: GraphState) -> Command[Literal["chart_instructor", "__end__"]]:
|
667
|
+
return node_func_human_review(
|
668
|
+
state=state,
|
669
|
+
prompt_text=prompt_text_human_review,
|
670
|
+
yes_goto= '__end__',
|
671
|
+
no_goto="chart_instructor",
|
672
|
+
user_instructions_key="user_instructions",
|
673
|
+
recommended_steps_key="recommended_steps",
|
674
|
+
code_snippet_key="data_visualization_function",
|
675
|
+
)
|
247
676
|
|
248
677
|
|
249
678
|
def execute_data_visualization_code(state):
|
@@ -253,7 +682,7 @@ def make_data_visualization_agent(
|
|
253
682
|
result_key="plotly_graph",
|
254
683
|
error_key="data_visualization_error",
|
255
684
|
code_snippet_key="data_visualization_function",
|
256
|
-
agent_function_name="
|
685
|
+
agent_function_name=state.get("data_visualization_function_name"),
|
257
686
|
pre_processing=lambda data: pd.DataFrame.from_dict(data),
|
258
687
|
# post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
259
688
|
error_message_prefix="An error occurred during data visualization: "
|
@@ -261,11 +690,11 @@ def make_data_visualization_agent(
|
|
261
690
|
|
262
691
|
def fix_data_visualization_code(state: GraphState):
|
263
692
|
prompt = """
|
264
|
-
You are a Data Visualization Agent. Your job is to create a
|
693
|
+
You are a Data Visualization Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
|
265
694
|
|
266
|
-
Make sure to only return the function definition for
|
695
|
+
Make sure to only return the function definition for {function_name}().
|
267
696
|
|
268
|
-
Return Python code in ```python``` format with a single function definition,
|
697
|
+
Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
|
269
698
|
|
270
699
|
This is the broken code (please fix):
|
271
700
|
{code_snippet}
|
@@ -283,22 +712,23 @@ def make_data_visualization_agent(
|
|
283
712
|
agent_name=AGENT_NAME,
|
284
713
|
log=log,
|
285
714
|
file_path=state.get("data_visualization_function_path"),
|
715
|
+
function_name=state.get("data_visualization_function_name"),
|
286
716
|
)
|
287
717
|
|
288
|
-
|
289
|
-
|
718
|
+
# Final reporting node
|
719
|
+
def report_agent_outputs(state: GraphState):
|
720
|
+
return node_func_report_agent_outputs(
|
290
721
|
state=state,
|
291
|
-
|
722
|
+
keys_to_include=[
|
723
|
+
"recommended_steps",
|
724
|
+
"data_visualization_function",
|
725
|
+
"data_visualization_function_path",
|
726
|
+
"data_visualization_function_name",
|
727
|
+
"data_visualization_error",
|
728
|
+
],
|
292
729
|
result_key="messages",
|
293
|
-
error_key="data_visualization_error",
|
294
|
-
llm=llm,
|
295
730
|
role=AGENT_NAME,
|
296
|
-
|
297
|
-
Explain the data visualization steps that the data visualization agent performed in this function.
|
298
|
-
Keep the summary succinct and to the point.\n\n# Data Visualization Agent:\n\n{code}
|
299
|
-
""",
|
300
|
-
success_prefix="# Data Visualization Agent:\n\n ",
|
301
|
-
error_message="The Data Visualization Agent encountered an error during data visualization. No explanation could be provided."
|
731
|
+
custom_title="Data Visualization Agent Outputs"
|
302
732
|
)
|
303
733
|
|
304
734
|
# Define the graph
|
@@ -308,7 +738,7 @@ def make_data_visualization_agent(
|
|
308
738
|
"chart_generator": chart_generator,
|
309
739
|
"execute_data_visualization_code": execute_data_visualization_code,
|
310
740
|
"fix_data_visualization_code": fix_data_visualization_code,
|
311
|
-
"
|
741
|
+
"report_agent_outputs": report_agent_outputs,
|
312
742
|
}
|
313
743
|
|
314
744
|
app = create_coding_agent_graph(
|
@@ -318,7 +748,7 @@ def make_data_visualization_agent(
|
|
318
748
|
create_code_node_name="chart_generator",
|
319
749
|
execute_code_node_name="execute_data_visualization_code",
|
320
750
|
fix_code_node_name="fix_data_visualization_code",
|
321
|
-
explain_code_node_name="
|
751
|
+
explain_code_node_name="report_agent_outputs",
|
322
752
|
error_key="data_visualization_error",
|
323
753
|
human_in_the_loop=human_in_the_loop, # or False
|
324
754
|
human_review_node_name="human_review",
|
@@ -328,4 +758,3 @@ def make_data_visualization_agent(
|
|
328
758
|
)
|
329
759
|
|
330
760
|
return app
|
331
|
-
|