ai-data-science-team 0.0.0.9013__py3-none-any.whl → 0.0.0.9015__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/__init__.py +22 -0
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/data_cleaning_agent.py +17 -3
- ai_data_science_team/agents/data_loader_tools_agent.py +13 -1
- ai_data_science_team/agents/data_visualization_agent.py +187 -130
- ai_data_science_team/agents/data_wrangling_agent.py +31 -10
- ai_data_science_team/agents/feature_engineering_agent.py +17 -4
- ai_data_science_team/agents/sql_database_agent.py +15 -2
- ai_data_science_team/ds_agents/eda_tools_agent.py +15 -6
- ai_data_science_team/ml_agents/h2o_ml_agent.py +15 -3
- ai_data_science_team/ml_agents/mlflow_tools_agent.py +13 -1
- ai_data_science_team/multiagents/__init__.py +2 -1
- ai_data_science_team/multiagents/pandas_data_analyst.py +305 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +126 -48
- ai_data_science_team/templates/agent_templates.py +41 -5
- ai_data_science_team/tools/eda.py +2 -0
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9015.dist-info}/METADATA +6 -5
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9015.dist-info}/RECORD +21 -20
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9015.dist-info}/WHEEL +1 -1
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9015.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9015.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,6 @@
|
|
4
4
|
# * Agents: Data Visualization Agent
|
5
5
|
|
6
6
|
|
7
|
-
|
8
7
|
# Libraries
|
9
8
|
from typing import TypedDict, Annotated, Sequence, Literal
|
10
9
|
import operator
|
@@ -14,27 +13,28 @@ from langchain_core.messages import BaseMessage
|
|
14
13
|
|
15
14
|
from langgraph.types import Command
|
16
15
|
from langgraph.checkpoint.memory import MemorySaver
|
16
|
+
from langgraph.types import Checkpointer
|
17
17
|
|
18
18
|
import os
|
19
|
-
import json
|
19
|
+
import json
|
20
20
|
import pandas as pd
|
21
21
|
|
22
22
|
from IPython.display import Markdown
|
23
23
|
|
24
|
-
from ai_data_science_team.templates import(
|
25
|
-
node_func_execute_agent_code_on_data,
|
24
|
+
from ai_data_science_team.templates import (
|
25
|
+
node_func_execute_agent_code_on_data,
|
26
26
|
node_func_human_review,
|
27
|
-
node_func_fix_agent_code,
|
27
|
+
node_func_fix_agent_code,
|
28
28
|
node_func_report_agent_outputs,
|
29
29
|
create_coding_agent_graph,
|
30
30
|
BaseAgent,
|
31
31
|
)
|
32
32
|
from ai_data_science_team.parsers.parsers import PythonOutputParser
|
33
33
|
from ai_data_science_team.utils.regex import (
|
34
|
-
relocate_imports_inside_function,
|
35
|
-
add_comments_to_top,
|
36
|
-
format_agent_name,
|
37
|
-
format_recommended_steps,
|
34
|
+
relocate_imports_inside_function,
|
35
|
+
add_comments_to_top,
|
36
|
+
format_agent_name,
|
37
|
+
format_recommended_steps,
|
38
38
|
get_generic_summary,
|
39
39
|
)
|
40
40
|
from ai_data_science_team.tools.dataframe import get_dataframe_summary
|
@@ -47,11 +47,12 @@ LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
|
47
47
|
|
48
48
|
# Class
|
49
49
|
|
50
|
+
|
50
51
|
class DataVisualizationAgent(BaseAgent):
|
51
52
|
"""
|
52
53
|
Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
|
53
|
-
default visualization steps (if any). The agent generates a Python function to produce the visualization,
|
54
|
-
executes it, and logs the process, including code and errors. It is designed to facilitate reproducible
|
54
|
+
default visualization steps (if any). The agent generates a Python function to produce the visualization,
|
55
|
+
executes it, and logs the process, including code and errors. It is designed to facilitate reproducible
|
55
56
|
and customizable data visualization workflows.
|
56
57
|
|
57
58
|
The agent may use default instructions for creating charts unless instructed otherwise, such as:
|
@@ -85,6 +86,8 @@ class DataVisualizationAgent(BaseAgent):
|
|
85
86
|
If True, skips the default recommended visualization steps. Defaults to False.
|
86
87
|
bypass_explain_code : bool, optional
|
87
88
|
If True, skips the step that provides code explanations. Defaults to False.
|
89
|
+
checkpointer : langgraph.types.Checkpointer
|
90
|
+
A checkpointer to use for saving and loading the agent
|
88
91
|
|
89
92
|
Methods
|
90
93
|
-------
|
@@ -121,10 +124,10 @@ class DataVisualizationAgent(BaseAgent):
|
|
121
124
|
llm = ChatOpenAI(model="gpt-4o-mini")
|
122
125
|
|
123
126
|
data_visualization_agent = DataVisualizationAgent(
|
124
|
-
model=llm,
|
127
|
+
model=llm,
|
125
128
|
n_samples=30,
|
126
|
-
log=True,
|
127
|
-
log_path="logs",
|
129
|
+
log=True,
|
130
|
+
log_path="logs",
|
128
131
|
human_in_the_loop=True
|
129
132
|
)
|
130
133
|
|
@@ -138,7 +141,7 @@ class DataVisualizationAgent(BaseAgent):
|
|
138
141
|
)
|
139
142
|
|
140
143
|
plotly_graph_dict = data_visualization_agent.get_plotly_graph()
|
141
|
-
# You can render plotly_graph_dict with plotly.io.from_json or
|
144
|
+
# You can render plotly_graph_dict with plotly.io.from_json or
|
142
145
|
# something similar in a Jupyter Notebook.
|
143
146
|
|
144
147
|
response = data_visualization_agent.get_response()
|
@@ -146,22 +149,23 @@ class DataVisualizationAgent(BaseAgent):
|
|
146
149
|
|
147
150
|
Returns
|
148
151
|
--------
|
149
|
-
DataVisualizationAgent : langchain.graphs.CompiledStateGraph
|
150
|
-
A data visualization agent implemented as a compiled state graph.
|
152
|
+
DataVisualizationAgent : langchain.graphs.CompiledStateGraph
|
153
|
+
A data visualization agent implemented as a compiled state graph.
|
151
154
|
"""
|
152
155
|
|
153
156
|
def __init__(
|
154
|
-
self,
|
155
|
-
model,
|
156
|
-
n_samples=30,
|
157
|
-
log=False,
|
158
|
-
log_path=None,
|
159
|
-
file_name="data_visualization.py",
|
157
|
+
self,
|
158
|
+
model,
|
159
|
+
n_samples=30,
|
160
|
+
log=False,
|
161
|
+
log_path=None,
|
162
|
+
file_name="data_visualization.py",
|
160
163
|
function_name="data_visualization",
|
161
|
-
overwrite=True,
|
162
|
-
human_in_the_loop=False,
|
163
|
-
bypass_recommended_steps=False,
|
164
|
-
bypass_explain_code=False
|
164
|
+
overwrite=True,
|
165
|
+
human_in_the_loop=False,
|
166
|
+
bypass_recommended_steps=False,
|
167
|
+
bypass_explain_code=False,
|
168
|
+
checkpointer=None,
|
165
169
|
):
|
166
170
|
self._params = {
|
167
171
|
"model": model,
|
@@ -174,13 +178,14 @@ class DataVisualizationAgent(BaseAgent):
|
|
174
178
|
"human_in_the_loop": human_in_the_loop,
|
175
179
|
"bypass_recommended_steps": bypass_recommended_steps,
|
176
180
|
"bypass_explain_code": bypass_explain_code,
|
181
|
+
"checkpointer": checkpointer,
|
177
182
|
}
|
178
183
|
self._compiled_graph = self._make_compiled_graph()
|
179
184
|
self.response = None
|
180
185
|
|
181
186
|
def _make_compiled_graph(self):
|
182
187
|
"""
|
183
|
-
Create the compiled graph for the data visualization agent.
|
188
|
+
Create the compiled graph for the data visualization agent.
|
184
189
|
Running this method will reset the response to None.
|
185
190
|
"""
|
186
191
|
self.response = None
|
@@ -196,9 +201,16 @@ class DataVisualizationAgent(BaseAgent):
|
|
196
201
|
# Rebuild the compiled graph
|
197
202
|
self._compiled_graph = self._make_compiled_graph()
|
198
203
|
|
199
|
-
async def ainvoke_agent(
|
204
|
+
async def ainvoke_agent(
|
205
|
+
self,
|
206
|
+
data_raw: pd.DataFrame,
|
207
|
+
user_instructions: str = None,
|
208
|
+
max_retries: int = 3,
|
209
|
+
retry_count: int = 0,
|
210
|
+
**kwargs,
|
211
|
+
):
|
200
212
|
"""
|
201
|
-
Asynchronously invokes the agent to generate a visualization.
|
213
|
+
Asynchronously invokes the agent to generate a visualization.
|
202
214
|
The response is stored in the 'response' attribute.
|
203
215
|
|
204
216
|
Parameters
|
@@ -218,18 +230,28 @@ class DataVisualizationAgent(BaseAgent):
|
|
218
230
|
-------
|
219
231
|
None
|
220
232
|
"""
|
221
|
-
response = await self._compiled_graph.ainvoke(
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
233
|
+
response = await self._compiled_graph.ainvoke(
|
234
|
+
{
|
235
|
+
"user_instructions": user_instructions,
|
236
|
+
"data_raw": data_raw.to_dict(),
|
237
|
+
"max_retries": max_retries,
|
238
|
+
"retry_count": retry_count,
|
239
|
+
},
|
240
|
+
**kwargs,
|
241
|
+
)
|
227
242
|
self.response = response
|
228
243
|
return None
|
229
244
|
|
230
|
-
def invoke_agent(
|
245
|
+
def invoke_agent(
|
246
|
+
self,
|
247
|
+
data_raw: pd.DataFrame,
|
248
|
+
user_instructions: str = None,
|
249
|
+
max_retries: int = 3,
|
250
|
+
retry_count: int = 0,
|
251
|
+
**kwargs,
|
252
|
+
):
|
231
253
|
"""
|
232
|
-
Synchronously invokes the agent to generate a visualization.
|
254
|
+
Synchronously invokes the agent to generate a visualization.
|
233
255
|
The response is stored in the 'response' attribute.
|
234
256
|
|
235
257
|
Parameters
|
@@ -249,12 +271,15 @@ class DataVisualizationAgent(BaseAgent):
|
|
249
271
|
-------
|
250
272
|
None
|
251
273
|
"""
|
252
|
-
response = self._compiled_graph.invoke(
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
274
|
+
response = self._compiled_graph.invoke(
|
275
|
+
{
|
276
|
+
"user_instructions": user_instructions,
|
277
|
+
"data_raw": data_raw.to_dict(),
|
278
|
+
"max_retries": max_retries,
|
279
|
+
"retry_count": retry_count,
|
280
|
+
},
|
281
|
+
**kwargs,
|
282
|
+
)
|
258
283
|
self.response = response
|
259
284
|
return None
|
260
285
|
|
@@ -263,7 +288,9 @@ class DataVisualizationAgent(BaseAgent):
|
|
263
288
|
Retrieves the agent's workflow summary, if logging is enabled.
|
264
289
|
"""
|
265
290
|
if self.response and self.response.get("messages"):
|
266
|
-
summary = get_generic_summary(
|
291
|
+
summary = get_generic_summary(
|
292
|
+
json.loads(self.response.get("messages")[-1].content)
|
293
|
+
)
|
267
294
|
if markdown:
|
268
295
|
return Markdown(summary)
|
269
296
|
else:
|
@@ -274,7 +301,7 @@ class DataVisualizationAgent(BaseAgent):
|
|
274
301
|
Logs a summary of the agent's operations, if logging is enabled.
|
275
302
|
"""
|
276
303
|
if self.response:
|
277
|
-
if self.response.get(
|
304
|
+
if self.response.get("data_visualization_function_path"):
|
278
305
|
log_details = f"""
|
279
306
|
## Data Visualization Agent Log Summary:
|
280
307
|
|
@@ -283,7 +310,7 @@ Function Path: {self.response.get('data_visualization_function_path')}
|
|
283
310
|
Function Name: {self.response.get('data_visualization_function_name')}
|
284
311
|
"""
|
285
312
|
if markdown:
|
286
|
-
return Markdown(log_details)
|
313
|
+
return Markdown(log_details)
|
287
314
|
else:
|
288
315
|
return log_details
|
289
316
|
|
@@ -375,17 +402,19 @@ Function Name: {self.response.get('data_visualization_function_name')}
|
|
375
402
|
|
376
403
|
# Agent
|
377
404
|
|
405
|
+
|
378
406
|
def make_data_visualization_agent(
|
379
|
-
model,
|
407
|
+
model,
|
380
408
|
n_samples=30,
|
381
|
-
log=False,
|
382
|
-
log_path=None,
|
409
|
+
log=False,
|
410
|
+
log_path=None,
|
383
411
|
file_name="data_visualization.py",
|
384
412
|
function_name="data_visualization",
|
385
|
-
overwrite=True,
|
386
|
-
human_in_the_loop=False,
|
387
|
-
bypass_recommended_steps=False,
|
388
|
-
bypass_explain_code=False
|
413
|
+
overwrite=True,
|
414
|
+
human_in_the_loop=False,
|
415
|
+
bypass_recommended_steps=False,
|
416
|
+
bypass_explain_code=False,
|
417
|
+
checkpointer=None,
|
389
418
|
):
|
390
419
|
"""
|
391
420
|
Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
|
@@ -423,6 +452,8 @@ def make_data_visualization_agent(
|
|
423
452
|
If True, skips the default recommended visualization steps. Defaults to False.
|
424
453
|
bypass_explain_code : bool, optional
|
425
454
|
If True, skips the step that provides code explanations. Defaults to False.
|
455
|
+
checkpointer : langgraph.types.Checkpointer
|
456
|
+
A checkpointer to use for saving and loading the agent
|
426
457
|
|
427
458
|
Examples
|
428
459
|
--------
|
@@ -452,20 +483,27 @@ def make_data_visualization_agent(
|
|
452
483
|
app : langchain.graphs.CompiledStateGraph
|
453
484
|
The data visualization agent as a state graph.
|
454
485
|
"""
|
455
|
-
|
486
|
+
|
456
487
|
llm = model
|
457
|
-
|
488
|
+
|
489
|
+
if human_in_the_loop:
|
490
|
+
if checkpointer is None:
|
491
|
+
print(
|
492
|
+
"Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver()."
|
493
|
+
)
|
494
|
+
checkpointer = MemorySaver()
|
495
|
+
|
458
496
|
# Human in th loop requires recommended steps
|
459
497
|
if bypass_recommended_steps and human_in_the_loop:
|
460
498
|
bypass_recommended_steps = False
|
461
499
|
print("Bypass recommended steps set to False to enable human in the loop.")
|
462
|
-
|
500
|
+
|
463
501
|
# Setup Log Directory
|
464
502
|
if log:
|
465
503
|
if log_path is None:
|
466
504
|
log_path = LOG_PATH
|
467
505
|
if not os.path.exists(log_path):
|
468
|
-
os.makedirs(log_path)
|
506
|
+
os.makedirs(log_path)
|
469
507
|
|
470
508
|
# Define GraphState for the router
|
471
509
|
class GraphState(TypedDict):
|
@@ -483,12 +521,11 @@ def make_data_visualization_agent(
|
|
483
521
|
data_visualization_error: str
|
484
522
|
max_retries: int
|
485
523
|
retry_count: int
|
486
|
-
|
524
|
+
|
487
525
|
def chart_instructor(state: GraphState):
|
488
|
-
|
489
526
|
print(format_agent_name(AGENT_NAME))
|
490
527
|
print(" * CREATE CHART GENERATOR INSTRUCTIONS")
|
491
|
-
|
528
|
+
|
492
529
|
recommend_steps_prompt = PromptTemplate(
|
493
530
|
template="""
|
494
531
|
You are a supervisor that is an expert in providing instructions to a chart generator agent for plotting.
|
@@ -501,25 +538,23 @@ def make_data_visualization_agent(
|
|
501
538
|
Previously Recommended Instructions (if any):
|
502
539
|
{recommended_steps}
|
503
540
|
|
504
|
-
DATA:
|
541
|
+
DATA SUMMARY:
|
505
542
|
{all_datasets_summary}
|
506
543
|
|
507
|
-
|
544
|
+
IMPORTANT:
|
545
|
+
|
546
|
+
- Formulate chart generator instructions by informing the chart generator of what type of plotly plot to use (e.g. bar, line, scatter, etc) to best represent the data.
|
547
|
+
- Think about how best to convey the information in the data to the user.
|
548
|
+
- If the user does not specify a type of plot, select the appropriate chart type based on the data summary provided and the user's question and how best to show the results.
|
549
|
+
- Come up with an informative title from the user's question and data provided. Also provide X and Y axis titles.
|
550
|
+
|
551
|
+
CHART TYPE SELECTION TIPS:
|
508
552
|
|
509
|
-
|
553
|
+
- If a numeric column has less than 10 unique values, consider this column to be treated as a categorical column. Pick a chart that is appropriate for categorical data.
|
554
|
+
- If a numeric column has more than 10 unique values, consider this column to be treated as a continuous column. Pick a chart that is appropriate for continuous data.
|
510
555
|
|
511
|
-
Instruct the chart generator to use the following theme colors, sizes, etc:
|
512
556
|
|
513
|
-
|
514
|
-
- Use a white background
|
515
|
-
- Use this color for bars and lines:
|
516
|
-
'blue': '#3381ff',
|
517
|
-
- Base Font Size: 8.8 (Used for x and y axes tickfont, any annotations, hovertips)
|
518
|
-
- Title Font Size: 13.2
|
519
|
-
- Line Size: 0.65 (specify these within the xaxis and yaxis dictionaries)
|
520
|
-
- Add smoothers or trendlines to scatter plots unless not desired by the user
|
521
|
-
- Do not use color_discrete_map (this will result in an error)
|
522
|
-
- Hover tip size: 8.8
|
557
|
+
RETURN FORMAT:
|
523
558
|
|
524
559
|
Return your instructions in the following format:
|
525
560
|
CHART GENERATOR INSTRUCTIONS:
|
@@ -529,51 +564,61 @@ def make_data_visualization_agent(
|
|
529
564
|
1. Do not include steps to save files.
|
530
565
|
2. Do not include unrelated user instructions that are not related to the chart generation.
|
531
566
|
""",
|
532
|
-
input_variables=[
|
533
|
-
|
567
|
+
input_variables=[
|
568
|
+
"user_instructions",
|
569
|
+
"recommended_steps",
|
570
|
+
"all_datasets_summary",
|
571
|
+
],
|
534
572
|
)
|
535
|
-
|
573
|
+
|
536
574
|
data_raw = state.get("data_raw")
|
537
575
|
df = pd.DataFrame.from_dict(data_raw)
|
538
576
|
|
539
|
-
all_datasets_summary = get_dataframe_summary(
|
540
|
-
|
577
|
+
all_datasets_summary = get_dataframe_summary(
|
578
|
+
[df], n_sample=n_samples, skip_stats=False
|
579
|
+
)
|
580
|
+
|
541
581
|
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
542
582
|
|
543
|
-
chart_instructor = recommend_steps_prompt | llm
|
544
|
-
|
545
|
-
recommended_steps = chart_instructor.invoke(
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
583
|
+
chart_instructor = recommend_steps_prompt | llm
|
584
|
+
|
585
|
+
recommended_steps = chart_instructor.invoke(
|
586
|
+
{
|
587
|
+
"user_instructions": state.get("user_instructions"),
|
588
|
+
"recommended_steps": state.get("recommended_steps"),
|
589
|
+
"all_datasets_summary": all_datasets_summary_str,
|
590
|
+
}
|
591
|
+
)
|
592
|
+
|
551
593
|
return {
|
552
|
-
"recommended_steps": format_recommended_steps(
|
553
|
-
|
594
|
+
"recommended_steps": format_recommended_steps(
|
595
|
+
recommended_steps.content.strip(),
|
596
|
+
heading="# Recommended Data Cleaning Steps:",
|
597
|
+
),
|
598
|
+
"all_datasets_summary": all_datasets_summary_str,
|
554
599
|
}
|
555
|
-
|
600
|
+
|
556
601
|
def chart_generator(state: GraphState):
|
557
|
-
|
558
602
|
print(" * CREATE DATA VISUALIZATION CODE")
|
559
603
|
|
560
|
-
|
561
604
|
if bypass_recommended_steps:
|
562
605
|
print(format_agent_name(AGENT_NAME))
|
563
|
-
|
606
|
+
|
564
607
|
data_raw = state.get("data_raw")
|
565
608
|
df = pd.DataFrame.from_dict(data_raw)
|
566
609
|
|
567
|
-
all_datasets_summary = get_dataframe_summary(
|
568
|
-
|
610
|
+
all_datasets_summary = get_dataframe_summary(
|
611
|
+
[df], n_sample=n_samples, skip_stats=False
|
612
|
+
)
|
613
|
+
|
569
614
|
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
570
|
-
|
615
|
+
|
571
616
|
chart_generator_instructions = state.get("user_instructions")
|
572
|
-
|
617
|
+
|
573
618
|
else:
|
574
619
|
all_datasets_summary_str = state.get("all_datasets_summary")
|
575
620
|
chart_generator_instructions = state.get("recommended_steps")
|
576
|
-
|
621
|
+
|
577
622
|
prompt_template = PromptTemplate(
|
578
623
|
template="""
|
579
624
|
You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
|
@@ -615,65 +660,76 @@ def make_data_visualization_agent(
|
|
615
660
|
2. Do not include unrelated user instructions that are not related to the chart generation.
|
616
661
|
|
617
662
|
""",
|
618
|
-
input_variables=[
|
663
|
+
input_variables=[
|
664
|
+
"chart_generator_instructions",
|
665
|
+
"all_datasets_summary",
|
666
|
+
"function_name",
|
667
|
+
],
|
619
668
|
)
|
620
669
|
|
621
670
|
data_visualization_agent = prompt_template | llm | PythonOutputParser()
|
622
|
-
|
623
|
-
response = data_visualization_agent.invoke(
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
671
|
+
|
672
|
+
response = data_visualization_agent.invoke(
|
673
|
+
{
|
674
|
+
"chart_generator_instructions": chart_generator_instructions,
|
675
|
+
"all_datasets_summary": all_datasets_summary_str,
|
676
|
+
"function_name": function_name,
|
677
|
+
}
|
678
|
+
)
|
679
|
+
|
629
680
|
response = relocate_imports_inside_function(response)
|
630
681
|
response = add_comments_to_top(response, agent_name=AGENT_NAME)
|
631
|
-
|
682
|
+
|
632
683
|
# For logging: store the code generated:
|
633
684
|
file_path, file_name_2 = log_ai_function(
|
634
685
|
response=response,
|
635
686
|
file_name=file_name,
|
636
687
|
log=log,
|
637
688
|
log_path=log_path,
|
638
|
-
overwrite=overwrite
|
689
|
+
overwrite=overwrite,
|
639
690
|
)
|
640
|
-
|
691
|
+
|
641
692
|
return {
|
642
693
|
"data_visualization_function": response,
|
643
694
|
"data_visualization_function_path": file_path,
|
644
695
|
"data_visualization_function_file_name": file_name_2,
|
645
696
|
"data_visualization_function_name": function_name,
|
646
|
-
"all_datasets_summary": all_datasets_summary_str
|
697
|
+
"all_datasets_summary": all_datasets_summary_str,
|
647
698
|
}
|
648
|
-
|
699
|
+
|
649
700
|
# Human Review
|
650
|
-
|
701
|
+
|
651
702
|
prompt_text_human_review = "Are the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
|
652
|
-
|
703
|
+
|
653
704
|
if not bypass_explain_code:
|
654
|
-
|
705
|
+
|
706
|
+
def human_review(
|
707
|
+
state: GraphState,
|
708
|
+
) -> Command[Literal["chart_instructor", "explain_data_visualization_code"]]:
|
655
709
|
return node_func_human_review(
|
656
710
|
state=state,
|
657
711
|
prompt_text=prompt_text_human_review,
|
658
|
-
yes_goto=
|
712
|
+
yes_goto="explain_data_visualization_code",
|
659
713
|
no_goto="chart_instructor",
|
660
714
|
user_instructions_key="user_instructions",
|
661
715
|
recommended_steps_key="recommended_steps",
|
662
716
|
code_snippet_key="data_visualization_function",
|
663
717
|
)
|
664
718
|
else:
|
665
|
-
|
719
|
+
|
720
|
+
def human_review(
|
721
|
+
state: GraphState,
|
722
|
+
) -> Command[Literal["chart_instructor", "__end__"]]:
|
666
723
|
return node_func_human_review(
|
667
724
|
state=state,
|
668
725
|
prompt_text=prompt_text_human_review,
|
669
|
-
yes_goto=
|
726
|
+
yes_goto="__end__",
|
670
727
|
no_goto="chart_instructor",
|
671
728
|
user_instructions_key="user_instructions",
|
672
729
|
recommended_steps_key="recommended_steps",
|
673
|
-
code_snippet_key="data_visualization_function",
|
730
|
+
code_snippet_key="data_visualization_function",
|
674
731
|
)
|
675
|
-
|
676
|
-
|
732
|
+
|
677
733
|
def execute_data_visualization_code(state):
|
678
734
|
return node_func_execute_agent_code_on_data(
|
679
735
|
state=state,
|
@@ -684,9 +740,9 @@ def make_data_visualization_agent(
|
|
684
740
|
agent_function_name=state.get("data_visualization_function_name"),
|
685
741
|
pre_processing=lambda data: pd.DataFrame.from_dict(data),
|
686
742
|
# post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
687
|
-
error_message_prefix="An error occurred during data visualization: "
|
743
|
+
error_message_prefix="An error occurred during data visualization: ",
|
688
744
|
)
|
689
|
-
|
745
|
+
|
690
746
|
def fix_data_visualization_code(state: GraphState):
|
691
747
|
prompt = """
|
692
748
|
You are a Data Visualization Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
|
@@ -706,14 +762,14 @@ def make_data_visualization_agent(
|
|
706
762
|
state=state,
|
707
763
|
code_snippet_key="data_visualization_function",
|
708
764
|
error_key="data_visualization_error",
|
709
|
-
llm=llm,
|
765
|
+
llm=llm,
|
710
766
|
prompt_template=prompt,
|
711
767
|
agent_name=AGENT_NAME,
|
712
768
|
log=log,
|
713
769
|
file_path=state.get("data_visualization_function_path"),
|
714
770
|
function_name=state.get("data_visualization_function_name"),
|
715
771
|
)
|
716
|
-
|
772
|
+
|
717
773
|
# Final reporting node
|
718
774
|
def report_agent_outputs(state: GraphState):
|
719
775
|
return node_func_report_agent_outputs(
|
@@ -727,9 +783,9 @@ def make_data_visualization_agent(
|
|
727
783
|
],
|
728
784
|
result_key="messages",
|
729
785
|
role=AGENT_NAME,
|
730
|
-
custom_title="Data Visualization Agent Outputs"
|
786
|
+
custom_title="Data Visualization Agent Outputs",
|
731
787
|
)
|
732
|
-
|
788
|
+
|
733
789
|
# Define the graph
|
734
790
|
node_functions = {
|
735
791
|
"chart_instructor": chart_instructor,
|
@@ -739,7 +795,7 @@ def make_data_visualization_agent(
|
|
739
795
|
"fix_data_visualization_code": fix_data_visualization_code,
|
740
796
|
"report_agent_outputs": report_agent_outputs,
|
741
797
|
}
|
742
|
-
|
798
|
+
|
743
799
|
app = create_coding_agent_graph(
|
744
800
|
GraphState=GraphState,
|
745
801
|
node_functions=node_functions,
|
@@ -751,9 +807,10 @@ def make_data_visualization_agent(
|
|
751
807
|
error_key="data_visualization_error",
|
752
808
|
human_in_the_loop=human_in_the_loop, # or False
|
753
809
|
human_review_node_name="human_review",
|
754
|
-
checkpointer=
|
810
|
+
checkpointer=checkpointer,
|
755
811
|
bypass_recommended_steps=bypass_recommended_steps,
|
756
812
|
bypass_explain_code=bypass_explain_code,
|
813
|
+
agent_name=AGENT_NAME,
|
757
814
|
)
|
758
|
-
|
815
|
+
|
759
816
|
return app
|