ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +4 -4
- ai_data_science_team/agents/data_cleaning_agent.py +225 -84
- ai_data_science_team/agents/data_visualization_agent.py +460 -27
- ai_data_science_team/agents/data_wrangling_agent.py +455 -16
- ai_data_science_team/agents/feature_engineering_agent.py +429 -25
- ai_data_science_team/agents/sql_database_agent.py +367 -21
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +286 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +2 -1
- ai_data_science_team/templates/agent_templates.py +247 -42
- ai_data_science_team/tools/regex.py +28 -1
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/METADATA +76 -28
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +26 -0
- ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/top_level.txt +0 -0
@@ -17,25 +17,367 @@ from langgraph.types import Command
|
|
17
17
|
from langgraph.checkpoint.memory import MemorySaver
|
18
18
|
|
19
19
|
import os
|
20
|
-
import io
|
21
20
|
import pandas as pd
|
22
21
|
|
22
|
+
from IPython.display import Markdown
|
23
|
+
|
23
24
|
from ai_data_science_team.templates import(
|
24
25
|
node_func_execute_agent_code_on_data,
|
25
26
|
node_func_human_review,
|
26
27
|
node_func_fix_agent_code,
|
27
28
|
node_func_explain_agent_code,
|
28
|
-
create_coding_agent_graph
|
29
|
+
create_coding_agent_graph,
|
30
|
+
BaseAgent,
|
29
31
|
)
|
30
32
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
31
|
-
from ai_data_science_team.tools.regex import
|
33
|
+
from ai_data_science_team.tools.regex import (
|
34
|
+
relocate_imports_inside_function,
|
35
|
+
add_comments_to_top,
|
36
|
+
format_agent_name,
|
37
|
+
format_recommended_steps
|
38
|
+
)
|
32
39
|
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
33
40
|
from ai_data_science_team.tools.logging import log_ai_function
|
41
|
+
from ai_data_science_team.utils.plotly import plotly_from_dict
|
34
42
|
|
35
43
|
# Setup
|
36
44
|
AGENT_NAME = "data_visualization_agent"
|
37
45
|
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
38
46
|
|
47
|
+
# Class
|
48
|
+
|
49
|
+
class DataVisualizationAgent(BaseAgent):
|
50
|
+
"""
|
51
|
+
Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
|
52
|
+
default visualization steps (if any). The agent generates a Python function to produce the visualization,
|
53
|
+
executes it, and logs the process, including code and errors. It is designed to facilitate reproducible
|
54
|
+
and customizable data visualization workflows.
|
55
|
+
|
56
|
+
The agent may use default instructions for creating charts unless instructed otherwise, such as:
|
57
|
+
- Generating a recommended chart type (bar, scatter, line, etc.)
|
58
|
+
- Creating user-friendly titles and axis labels
|
59
|
+
- Applying consistent styling (template, font sizes, color themes)
|
60
|
+
- Handling theme details (white background, base font size, line size, etc.)
|
61
|
+
|
62
|
+
User instructions can modify, add, or remove any of these steps to tailor the visualization process.
|
63
|
+
|
64
|
+
Parameters
|
65
|
+
----------
|
66
|
+
model : langchain.llms.base.LLM
|
67
|
+
The language model used to generate the data visualization function.
|
68
|
+
n_samples : int, optional
|
69
|
+
Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
|
70
|
+
Reducing this number can help avoid exceeding the model's token limits.
|
71
|
+
log : bool, optional
|
72
|
+
Whether to log the generated code and errors. Defaults to False.
|
73
|
+
log_path : str, optional
|
74
|
+
Directory path for storing log files. Defaults to None.
|
75
|
+
file_name : str, optional
|
76
|
+
Name of the file for saving the generated response. Defaults to "data_visualization.py".
|
77
|
+
function_name : str, optional
|
78
|
+
Name of the function for data visualization. Defaults to "data_visualization".
|
79
|
+
overwrite : bool, optional
|
80
|
+
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
|
81
|
+
human_in_the_loop : bool, optional
|
82
|
+
Enables user review of data visualization instructions. Defaults to False.
|
83
|
+
bypass_recommended_steps : bool, optional
|
84
|
+
If True, skips the default recommended visualization steps. Defaults to False.
|
85
|
+
bypass_explain_code : bool, optional
|
86
|
+
If True, skips the step that provides code explanations. Defaults to False.
|
87
|
+
|
88
|
+
Methods
|
89
|
+
-------
|
90
|
+
update_params(**kwargs)
|
91
|
+
Updates the agent's parameters and rebuilds the compiled state graph.
|
92
|
+
ainvoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
|
93
|
+
Asynchronously generates a visualization based on user instructions.
|
94
|
+
invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
|
95
|
+
Synchronously generates a visualization based on user instructions.
|
96
|
+
explain_visualization_steps()
|
97
|
+
Returns an explanation of the visualization steps performed by the agent.
|
98
|
+
get_log_summary()
|
99
|
+
Retrieves a summary of logged operations if logging is enabled.
|
100
|
+
get_plotly_graph()
|
101
|
+
Retrieves the Plotly graph (as a dictionary) produced by the agent.
|
102
|
+
get_data_raw()
|
103
|
+
Retrieves the raw dataset as a pandas DataFrame (based on the last response).
|
104
|
+
get_data_visualization_function()
|
105
|
+
Retrieves the generated Python function used for data visualization.
|
106
|
+
get_recommended_visualization_steps()
|
107
|
+
Retrieves the agent's recommended visualization steps.
|
108
|
+
get_response()
|
109
|
+
Returns the response from the agent as a dictionary.
|
110
|
+
show()
|
111
|
+
Displays the agent's mermaid diagram.
|
112
|
+
|
113
|
+
Examples
|
114
|
+
--------
|
115
|
+
```python
|
116
|
+
import pandas as pd
|
117
|
+
from langchain_openai import ChatOpenAI
|
118
|
+
from ai_data_science_team.agents import DataVisualizationAgent
|
119
|
+
|
120
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
121
|
+
|
122
|
+
data_visualization_agent = DataVisualizationAgent(
|
123
|
+
model=llm,
|
124
|
+
n_samples=30,
|
125
|
+
log=True,
|
126
|
+
log_path="logs",
|
127
|
+
human_in_the_loop=True
|
128
|
+
)
|
129
|
+
|
130
|
+
df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
|
131
|
+
|
132
|
+
data_visualization_agent.invoke_agent(
|
133
|
+
user_instructions="Generate a scatter plot of age vs. total charges with a trend line.",
|
134
|
+
data_raw=df,
|
135
|
+
max_retries=3,
|
136
|
+
retry_count=0
|
137
|
+
)
|
138
|
+
|
139
|
+
plotly_graph_dict = data_visualization_agent.get_plotly_graph()
|
140
|
+
# You can render plotly_graph_dict with plotly.io.from_json or
|
141
|
+
# something similar in a Jupyter Notebook.
|
142
|
+
|
143
|
+
response = data_visualization_agent.get_response()
|
144
|
+
```
|
145
|
+
|
146
|
+
Returns
|
147
|
+
--------
|
148
|
+
DataVisualizationAgent : langchain.graphs.CompiledStateGraph
|
149
|
+
A data visualization agent implemented as a compiled state graph.
|
150
|
+
"""
|
151
|
+
|
152
|
+
def __init__(
|
153
|
+
self,
|
154
|
+
model,
|
155
|
+
n_samples=30,
|
156
|
+
log=False,
|
157
|
+
log_path=None,
|
158
|
+
file_name="data_visualization.py",
|
159
|
+
function_name="data_visualization",
|
160
|
+
overwrite=True,
|
161
|
+
human_in_the_loop=False,
|
162
|
+
bypass_recommended_steps=False,
|
163
|
+
bypass_explain_code=False
|
164
|
+
):
|
165
|
+
self._params = {
|
166
|
+
"model": model,
|
167
|
+
"n_samples": n_samples,
|
168
|
+
"log": log,
|
169
|
+
"log_path": log_path,
|
170
|
+
"file_name": file_name,
|
171
|
+
"function_name": function_name,
|
172
|
+
"overwrite": overwrite,
|
173
|
+
"human_in_the_loop": human_in_the_loop,
|
174
|
+
"bypass_recommended_steps": bypass_recommended_steps,
|
175
|
+
"bypass_explain_code": bypass_explain_code,
|
176
|
+
}
|
177
|
+
self._compiled_graph = self._make_compiled_graph()
|
178
|
+
self.response = None
|
179
|
+
|
180
|
+
def _make_compiled_graph(self):
|
181
|
+
"""
|
182
|
+
Create the compiled graph for the data visualization agent.
|
183
|
+
Running this method will reset the response to None.
|
184
|
+
"""
|
185
|
+
self.response = None
|
186
|
+
return make_data_visualization_agent(**self._params)
|
187
|
+
|
188
|
+
def update_params(self, **kwargs):
|
189
|
+
"""
|
190
|
+
Updates the agent's parameters and rebuilds the compiled graph.
|
191
|
+
"""
|
192
|
+
# Update parameters
|
193
|
+
for k, v in kwargs.items():
|
194
|
+
self._params[k] = v
|
195
|
+
# Rebuild the compiled graph
|
196
|
+
self._compiled_graph = self._make_compiled_graph()
|
197
|
+
|
198
|
+
def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
199
|
+
"""
|
200
|
+
Asynchronously invokes the agent to generate a visualization.
|
201
|
+
The response is stored in the 'response' attribute.
|
202
|
+
|
203
|
+
Parameters
|
204
|
+
----------
|
205
|
+
data_raw : pd.DataFrame
|
206
|
+
The raw dataset to be visualized.
|
207
|
+
user_instructions : str
|
208
|
+
Instructions for data visualization.
|
209
|
+
max_retries : int
|
210
|
+
Maximum retry attempts.
|
211
|
+
retry_count : int
|
212
|
+
Current retry attempt count.
|
213
|
+
**kwargs : dict
|
214
|
+
Additional keyword arguments passed to ainvoke().
|
215
|
+
|
216
|
+
Returns
|
217
|
+
-------
|
218
|
+
None
|
219
|
+
"""
|
220
|
+
response = self._compiled_graph.ainvoke({
|
221
|
+
"user_instructions": user_instructions,
|
222
|
+
"data_raw": data_raw.to_dict(),
|
223
|
+
"max_retries": max_retries,
|
224
|
+
"retry_count": retry_count,
|
225
|
+
}, **kwargs)
|
226
|
+
self.response = response
|
227
|
+
return None
|
228
|
+
|
229
|
+
def invoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
230
|
+
"""
|
231
|
+
Synchronously invokes the agent to generate a visualization.
|
232
|
+
The response is stored in the 'response' attribute.
|
233
|
+
|
234
|
+
Parameters
|
235
|
+
----------
|
236
|
+
data_raw : pd.DataFrame
|
237
|
+
The raw dataset to be visualized.
|
238
|
+
user_instructions : str
|
239
|
+
Instructions for data visualization agent.
|
240
|
+
max_retries : int
|
241
|
+
Maximum retry attempts.
|
242
|
+
retry_count : int
|
243
|
+
Current retry attempt count.
|
244
|
+
**kwargs : dict
|
245
|
+
Additional keyword arguments passed to invoke().
|
246
|
+
|
247
|
+
Returns
|
248
|
+
-------
|
249
|
+
None
|
250
|
+
"""
|
251
|
+
response = self._compiled_graph.invoke({
|
252
|
+
"user_instructions": user_instructions,
|
253
|
+
"data_raw": data_raw.to_dict(),
|
254
|
+
"max_retries": max_retries,
|
255
|
+
"retry_count": retry_count,
|
256
|
+
}, **kwargs)
|
257
|
+
self.response = response
|
258
|
+
return None
|
259
|
+
|
260
|
+
def explain_visualization_steps(self):
|
261
|
+
"""
|
262
|
+
Provides an explanation of the visualization steps performed by the agent.
|
263
|
+
|
264
|
+
Returns
|
265
|
+
-------
|
266
|
+
str
|
267
|
+
Explanation of the visualization steps, if any are available.
|
268
|
+
"""
|
269
|
+
if self.response:
|
270
|
+
return self.response.get("messages", [])
|
271
|
+
return []
|
272
|
+
|
273
|
+
def get_log_summary(self, markdown=False):
|
274
|
+
"""
|
275
|
+
Logs a summary of the agent's operations, if logging is enabled.
|
276
|
+
|
277
|
+
Parameters
|
278
|
+
----------
|
279
|
+
markdown : bool, optional
|
280
|
+
If True, returns Markdown-formatted output.
|
281
|
+
|
282
|
+
Returns
|
283
|
+
-------
|
284
|
+
str or None
|
285
|
+
Summary of logs or None if no logs are available.
|
286
|
+
"""
|
287
|
+
if self.response and self.response.get('data_visualization_function_path'):
|
288
|
+
log_details = f"Log Path: {self.response.get('data_visualization_function_path')}"
|
289
|
+
if markdown:
|
290
|
+
return Markdown(log_details)
|
291
|
+
else:
|
292
|
+
return log_details
|
293
|
+
return None
|
294
|
+
|
295
|
+
def get_plotly_graph(self):
|
296
|
+
"""
|
297
|
+
Retrieves the Plotly graph (in dictionary form) produced by the agent.
|
298
|
+
|
299
|
+
Returns
|
300
|
+
-------
|
301
|
+
dict or None
|
302
|
+
The Plotly graph dictionary if available, otherwise None.
|
303
|
+
"""
|
304
|
+
if self.response:
|
305
|
+
return plotly_from_dict(self.response.get("plotly_graph", None))
|
306
|
+
return None
|
307
|
+
|
308
|
+
def get_data_raw(self):
|
309
|
+
"""
|
310
|
+
Retrieves the raw dataset used in the last invocation.
|
311
|
+
|
312
|
+
Returns
|
313
|
+
-------
|
314
|
+
pd.DataFrame or None
|
315
|
+
The raw dataset as a DataFrame if available, otherwise None.
|
316
|
+
"""
|
317
|
+
if self.response and self.response.get("data_raw"):
|
318
|
+
return pd.DataFrame(self.response.get("data_raw"))
|
319
|
+
return None
|
320
|
+
|
321
|
+
def get_data_visualization_function(self, markdown=False):
|
322
|
+
"""
|
323
|
+
Retrieves the generated Python function used for data visualization.
|
324
|
+
|
325
|
+
Parameters
|
326
|
+
----------
|
327
|
+
markdown : bool, optional
|
328
|
+
If True, returns the function in Markdown code block format.
|
329
|
+
|
330
|
+
Returns
|
331
|
+
-------
|
332
|
+
str or None
|
333
|
+
The Python function code as a string if available, otherwise None.
|
334
|
+
"""
|
335
|
+
if self.response:
|
336
|
+
func_code = self.response.get("data_visualization_function", "")
|
337
|
+
if markdown:
|
338
|
+
return Markdown(f"```python\n{func_code}\n```")
|
339
|
+
return func_code
|
340
|
+
return None
|
341
|
+
|
342
|
+
def get_recommended_visualization_steps(self, markdown=False):
|
343
|
+
"""
|
344
|
+
Retrieves the agent's recommended visualization steps.
|
345
|
+
|
346
|
+
Parameters
|
347
|
+
----------
|
348
|
+
markdown : bool, optional
|
349
|
+
If True, returns the steps in Markdown format.
|
350
|
+
|
351
|
+
Returns
|
352
|
+
-------
|
353
|
+
str or None
|
354
|
+
The recommended steps if available, otherwise None.
|
355
|
+
"""
|
356
|
+
if self.response:
|
357
|
+
steps = self.response.get("recommended_steps", "")
|
358
|
+
if markdown:
|
359
|
+
return Markdown(steps)
|
360
|
+
return steps
|
361
|
+
return None
|
362
|
+
|
363
|
+
def get_response(self):
|
364
|
+
"""
|
365
|
+
Returns the agent's full response dictionary.
|
366
|
+
|
367
|
+
Returns
|
368
|
+
-------
|
369
|
+
dict or None
|
370
|
+
The response dictionary if available, otherwise None.
|
371
|
+
"""
|
372
|
+
return self.response
|
373
|
+
|
374
|
+
def show(self):
|
375
|
+
"""
|
376
|
+
Displays the agent's mermaid diagram for visual inspection of the compiled graph.
|
377
|
+
"""
|
378
|
+
return self._compiled_graph.show()
|
379
|
+
|
380
|
+
|
39
381
|
# Agent
|
40
382
|
|
41
383
|
def make_data_visualization_agent(
|
@@ -44,14 +386,85 @@ def make_data_visualization_agent(
|
|
44
386
|
log=False,
|
45
387
|
log_path=None,
|
46
388
|
file_name="data_visualization.py",
|
47
|
-
|
389
|
+
function_name="data_visualization",
|
390
|
+
overwrite=True,
|
48
391
|
human_in_the_loop=False,
|
49
392
|
bypass_recommended_steps=False,
|
50
393
|
bypass_explain_code=False
|
51
394
|
):
|
395
|
+
"""
|
396
|
+
Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
|
397
|
+
default visualization steps. The agent generates a Python function to produce the visualization, executes it,
|
398
|
+
and logs the process, including code and errors. It is designed to facilitate reproducible and customizable
|
399
|
+
data visualization workflows.
|
400
|
+
|
401
|
+
The agent can perform the following default visualization steps unless instructed otherwise:
|
402
|
+
- Generating a recommended chart type (bar, scatter, line, etc.)
|
403
|
+
- Creating user-friendly titles and axis labels
|
404
|
+
- Applying consistent styling (template, font sizes, color themes)
|
405
|
+
- Handling theme details (white background, base font size, line size, etc.)
|
406
|
+
|
407
|
+
User instructions can modify, add, or remove any of these steps to tailor the visualization process.
|
408
|
+
|
409
|
+
Parameters
|
410
|
+
----------
|
411
|
+
model : langchain.llms.base.LLM
|
412
|
+
The language model used to generate the data visualization function.
|
413
|
+
n_samples : int, optional
|
414
|
+
Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
|
415
|
+
log : bool, optional
|
416
|
+
Whether to log the generated code and errors. Defaults to False.
|
417
|
+
log_path : str, optional
|
418
|
+
Directory path for storing log files. Defaults to None.
|
419
|
+
file_name : str, optional
|
420
|
+
Name of the file for saving the generated response. Defaults to "data_visualization.py".
|
421
|
+
function_name : str, optional
|
422
|
+
Name of the function for data visualization. Defaults to "data_visualization".
|
423
|
+
overwrite : bool, optional
|
424
|
+
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
|
425
|
+
human_in_the_loop : bool, optional
|
426
|
+
Enables user review of data visualization instructions. Defaults to False.
|
427
|
+
bypass_recommended_steps : bool, optional
|
428
|
+
If True, skips the default recommended visualization steps. Defaults to False.
|
429
|
+
bypass_explain_code : bool, optional
|
430
|
+
If True, skips the step that provides code explanations. Defaults to False.
|
431
|
+
|
432
|
+
Examples
|
433
|
+
--------
|
434
|
+
``` python
|
435
|
+
import pandas as pd
|
436
|
+
from langchain_openai import ChatOpenAI
|
437
|
+
from ai_data_science_team.agents import data_visualization_agent
|
438
|
+
|
439
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
440
|
+
|
441
|
+
data_visualization_agent = make_data_visualization_agent(llm)
|
442
|
+
|
443
|
+
df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
|
444
|
+
|
445
|
+
response = data_visualization_agent.invoke({
|
446
|
+
"user_instructions": "Generate a scatter plot of tenure vs. total charges with a trend line.",
|
447
|
+
"data_raw": df.to_dict(),
|
448
|
+
"max_retries": 3,
|
449
|
+
"retry_count": 0
|
450
|
+
})
|
451
|
+
|
452
|
+
pd.DataFrame(response['plotly_graph'])
|
453
|
+
```
|
454
|
+
|
455
|
+
Returns
|
456
|
+
-------
|
457
|
+
app : langchain.graphs.CompiledStateGraph
|
458
|
+
The data visualization agent as a state graph.
|
459
|
+
"""
|
52
460
|
|
53
461
|
llm = model
|
54
462
|
|
463
|
+
# Human in th loop requires recommended steps
|
464
|
+
if bypass_recommended_steps and human_in_the_loop:
|
465
|
+
bypass_recommended_steps = False
|
466
|
+
print("Bypass recommended steps set to False to enable human in the loop.")
|
467
|
+
|
55
468
|
# Setup Log Directory
|
56
469
|
if log:
|
57
470
|
if log_path is None:
|
@@ -70,6 +483,7 @@ def make_data_visualization_agent(
|
|
70
483
|
all_datasets_summary: str
|
71
484
|
data_visualization_function: str
|
72
485
|
data_visualization_function_path: str
|
486
|
+
data_visualization_function_file_name: str
|
73
487
|
data_visualization_function_name: str
|
74
488
|
data_visualization_error: str
|
75
489
|
max_retries: int
|
@@ -140,7 +554,7 @@ def make_data_visualization_agent(
|
|
140
554
|
})
|
141
555
|
|
142
556
|
return {
|
143
|
-
"recommended_steps": "
|
557
|
+
"recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Cleaning Steps:"),
|
144
558
|
"all_datasets_summary": all_datasets_summary_str
|
145
559
|
}
|
146
560
|
|
@@ -169,7 +583,7 @@ def make_data_visualization_agent(
|
|
169
583
|
template="""
|
170
584
|
You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
|
171
585
|
|
172
|
-
Your job is to produce python code to generate visualizations.
|
586
|
+
Your job is to produce python code to generate visualizations with a function named {function_name}.
|
173
587
|
|
174
588
|
You will take instructions from a Chart Instructor and generate a plotly chart from the data provided.
|
175
589
|
|
@@ -181,13 +595,13 @@ def make_data_visualization_agent(
|
|
181
595
|
|
182
596
|
RETURN:
|
183
597
|
|
184
|
-
Return Python code in ```python ``` format with a single function definition,
|
598
|
+
Return Python code in ```python ``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
|
185
599
|
|
186
600
|
Return the plotly chart as a dictionary.
|
187
601
|
|
188
602
|
Return code to provide the data visualization function:
|
189
603
|
|
190
|
-
def
|
604
|
+
def {function_name}(data_raw):
|
191
605
|
import pandas as pd
|
192
606
|
import numpy as np
|
193
607
|
import json
|
@@ -206,14 +620,15 @@ def make_data_visualization_agent(
|
|
206
620
|
2. Do not include unrelated user instructions that are not related to the chart generation.
|
207
621
|
|
208
622
|
""",
|
209
|
-
input_variables=["chart_generator_instructions", "all_datasets_summary"]
|
623
|
+
input_variables=["chart_generator_instructions", "all_datasets_summary", "function_name"]
|
210
624
|
)
|
211
|
-
|
625
|
+
|
212
626
|
data_visualization_agent = prompt_template | llm | PythonOutputParser()
|
213
627
|
|
214
628
|
response = data_visualization_agent.invoke({
|
215
629
|
"chart_generator_instructions": chart_generator_instructions,
|
216
|
-
"all_datasets_summary": all_datasets_summary_str
|
630
|
+
"all_datasets_summary": all_datasets_summary_str,
|
631
|
+
"function_name": function_name
|
217
632
|
})
|
218
633
|
|
219
634
|
response = relocate_imports_inside_function(response)
|
@@ -231,19 +646,37 @@ def make_data_visualization_agent(
|
|
231
646
|
return {
|
232
647
|
"data_visualization_function": response,
|
233
648
|
"data_visualization_function_path": file_path,
|
234
|
-
"
|
649
|
+
"data_visualization_function_file_name": file_name_2,
|
650
|
+
"data_visualization_function_name": function_name,
|
235
651
|
"all_datasets_summary": all_datasets_summary_str
|
236
652
|
}
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
653
|
+
|
654
|
+
# Human Review
|
655
|
+
|
656
|
+
prompt_text_human_review = "Are the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
|
657
|
+
|
658
|
+
if not bypass_explain_code:
|
659
|
+
def human_review(state: GraphState) -> Command[Literal["chart_instructor", "explain_data_visualization_code"]]:
|
660
|
+
return node_func_human_review(
|
661
|
+
state=state,
|
662
|
+
prompt_text=prompt_text_human_review,
|
663
|
+
yes_goto= 'explain_data_visualization_code',
|
664
|
+
no_goto="chart_instructor",
|
665
|
+
user_instructions_key="user_instructions",
|
666
|
+
recommended_steps_key="recommended_steps",
|
667
|
+
code_snippet_key="data_visualization_function",
|
668
|
+
)
|
669
|
+
else:
|
670
|
+
def human_review(state: GraphState) -> Command[Literal["chart_instructor", "__end__"]]:
|
671
|
+
return node_func_human_review(
|
672
|
+
state=state,
|
673
|
+
prompt_text=prompt_text_human_review,
|
674
|
+
yes_goto= '__end__',
|
675
|
+
no_goto="chart_instructor",
|
676
|
+
user_instructions_key="user_instructions",
|
677
|
+
recommended_steps_key="recommended_steps",
|
678
|
+
code_snippet_key="data_visualization_function",
|
679
|
+
)
|
247
680
|
|
248
681
|
|
249
682
|
def execute_data_visualization_code(state):
|
@@ -253,7 +686,7 @@ def make_data_visualization_agent(
|
|
253
686
|
result_key="plotly_graph",
|
254
687
|
error_key="data_visualization_error",
|
255
688
|
code_snippet_key="data_visualization_function",
|
256
|
-
agent_function_name="
|
689
|
+
agent_function_name=state.get("data_visualization_function_name"),
|
257
690
|
pre_processing=lambda data: pd.DataFrame.from_dict(data),
|
258
691
|
# post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
259
692
|
error_message_prefix="An error occurred during data visualization: "
|
@@ -261,11 +694,11 @@ def make_data_visualization_agent(
|
|
261
694
|
|
262
695
|
def fix_data_visualization_code(state: GraphState):
|
263
696
|
prompt = """
|
264
|
-
You are a Data Visualization Agent. Your job is to create a
|
697
|
+
You are a Data Visualization Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
|
265
698
|
|
266
|
-
Make sure to only return the function definition for
|
699
|
+
Make sure to only return the function definition for {function_name}().
|
267
700
|
|
268
|
-
Return Python code in ```python``` format with a single function definition,
|
701
|
+
Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
|
269
702
|
|
270
703
|
This is the broken code (please fix):
|
271
704
|
{code_snippet}
|
@@ -283,6 +716,7 @@ def make_data_visualization_agent(
|
|
283
716
|
agent_name=AGENT_NAME,
|
284
717
|
log=log,
|
285
718
|
file_path=state.get("data_visualization_function_path"),
|
719
|
+
function_name=state.get("data_visualization_function_name"),
|
286
720
|
)
|
287
721
|
|
288
722
|
def explain_data_visualization_code(state: GraphState):
|
@@ -328,4 +762,3 @@ def make_data_visualization_agent(
|
|
328
762
|
)
|
329
763
|
|
330
764
|
return app
|
331
|
-
|