ai-data-science-team 0.0.0.9005__py3-none-any.whl → 0.0.0.9007__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +3 -1
- ai_data_science_team/agents/data_cleaning_agent.py +213 -20
- ai_data_science_team/agents/data_visualization_agent.py +331 -0
- ai_data_science_team/agents/data_wrangling_agent.py +66 -24
- ai_data_science_team/agents/feature_engineering_agent.py +50 -13
- ai_data_science_team/agents/sql_database_agent.py +397 -0
- ai_data_science_team/templates/__init__.py +8 -0
- ai_data_science_team/templates/agent_templates.py +154 -37
- ai_data_science_team/tools/logging.py +1 -1
- ai_data_science_team/tools/metadata.py +230 -0
- ai_data_science_team/tools/regex.py +7 -1
- {ai_data_science_team-0.0.0.9005.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/METADATA +43 -22
- ai_data_science_team-0.0.0.9007.dist-info/RECORD +21 -0
- {ai_data_science_team-0.0.0.9005.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/WHEEL +1 -1
- ai_data_science_team/tools/data_analysis.py +0 -116
- ai_data_science_team-0.0.0.9005.dist-info/RECORD +0 -19
- {ai_data_science_team-0.0.0.9005.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9005.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
|
|
1
|
+
# BUSINESS SCIENCE UNIVERSITY
|
2
|
+
# AI DATA SCIENCE TEAM
|
3
|
+
# ***
|
4
|
+
# * Agents: Data Visualization Agent
|
5
|
+
|
6
|
+
|
7
|
+
|
8
|
+
# Libraries
|
9
|
+
from typing import TypedDict, Annotated, Sequence, Literal
|
10
|
+
import operator
|
11
|
+
|
12
|
+
from langchain.prompts import PromptTemplate
|
13
|
+
from langchain_core.output_parsers import StrOutputParser
|
14
|
+
from langchain_core.messages import BaseMessage
|
15
|
+
|
16
|
+
from langgraph.types import Command
|
17
|
+
from langgraph.checkpoint.memory import MemorySaver
|
18
|
+
|
19
|
+
import os
|
20
|
+
import io
|
21
|
+
import pandas as pd
|
22
|
+
|
23
|
+
from ai_data_science_team.templates import(
|
24
|
+
node_func_execute_agent_code_on_data,
|
25
|
+
node_func_human_review,
|
26
|
+
node_func_fix_agent_code,
|
27
|
+
node_func_explain_agent_code,
|
28
|
+
create_coding_agent_graph
|
29
|
+
)
|
30
|
+
from ai_data_science_team.tools.parsers import PythonOutputParser
|
31
|
+
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
|
32
|
+
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
33
|
+
from ai_data_science_team.tools.logging import log_ai_function
|
34
|
+
|
35
|
+
# Setup
|
36
|
+
AGENT_NAME = "data_visualization_agent"
|
37
|
+
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
38
|
+
|
39
|
+
# Agent
|
40
|
+
|
41
|
+
def make_data_visualization_agent(
|
42
|
+
model,
|
43
|
+
n_samples=30,
|
44
|
+
log=False,
|
45
|
+
log_path=None,
|
46
|
+
file_name="data_visualization.py",
|
47
|
+
overwrite = True,
|
48
|
+
human_in_the_loop=False,
|
49
|
+
bypass_recommended_steps=False,
|
50
|
+
bypass_explain_code=False
|
51
|
+
):
|
52
|
+
|
53
|
+
llm = model
|
54
|
+
|
55
|
+
# Setup Log Directory
|
56
|
+
if log:
|
57
|
+
if log_path is None:
|
58
|
+
log_path = LOG_PATH
|
59
|
+
if not os.path.exists(log_path):
|
60
|
+
os.makedirs(log_path)
|
61
|
+
|
62
|
+
# Define GraphState for the router
|
63
|
+
class GraphState(TypedDict):
|
64
|
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
65
|
+
user_instructions: str
|
66
|
+
user_instructions_processed: str
|
67
|
+
recommended_steps: str
|
68
|
+
data_raw: dict
|
69
|
+
plotly_graph: dict
|
70
|
+
all_datasets_summary: str
|
71
|
+
data_visualization_function: str
|
72
|
+
data_visualization_function_path: str
|
73
|
+
data_visualization_function_name: str
|
74
|
+
data_visualization_error: str
|
75
|
+
max_retries: int
|
76
|
+
retry_count: int
|
77
|
+
|
78
|
+
def chart_instructor(state: GraphState):
|
79
|
+
|
80
|
+
print(format_agent_name(AGENT_NAME))
|
81
|
+
print(" * CREATE CHART GENERATOR INSTRUCTIONS")
|
82
|
+
|
83
|
+
recommend_steps_prompt = PromptTemplate(
|
84
|
+
template="""
|
85
|
+
You are a supervisor that is an expert in providing instructions to a chart generator agent for plotting.
|
86
|
+
|
87
|
+
You will take a question that a user has and the data that was generated to answer the question, and create instructions to create a chart from the data that will be passed to a chart generator agent.
|
88
|
+
|
89
|
+
USER QUESTION / INSTRUCTIONS:
|
90
|
+
{user_instructions}
|
91
|
+
|
92
|
+
Previously Recommended Instructions (if any):
|
93
|
+
{recommended_steps}
|
94
|
+
|
95
|
+
DATA:
|
96
|
+
{all_datasets_summary}
|
97
|
+
|
98
|
+
Formulate chart generator instructions by informing the chart generator of what type of plotly plot to use (e.g. bar, line, scatter, etc) to best represent the data.
|
99
|
+
|
100
|
+
Come up with an informative title from the user's question and data provided. Also provide X and Y axis titles.
|
101
|
+
|
102
|
+
Instruct the chart generator to use the following theme colors, sizes, etc:
|
103
|
+
|
104
|
+
- Start with the "plotly_white" template
|
105
|
+
- Use a white background
|
106
|
+
- Use this color for bars and lines:
|
107
|
+
'blue': '#3381ff',
|
108
|
+
- Base Font Size: 8.8 (Used for x and y axes tickfont, any annotations, hovertips)
|
109
|
+
- Title Font Size: 13.2
|
110
|
+
- Line Size: 0.65 (specify these within the xaxis and yaxis dictionaries)
|
111
|
+
- Add smoothers or trendlines to scatter plots unless not desired by the user
|
112
|
+
- Do not use color_discrete_map (this will result in an error)
|
113
|
+
- Hover tip size: 8.8
|
114
|
+
|
115
|
+
Return your instructions in the following format:
|
116
|
+
CHART GENERATOR INSTRUCTIONS:
|
117
|
+
FILL IN THE INSTRUCTIONS HERE
|
118
|
+
|
119
|
+
Avoid these:
|
120
|
+
1. Do not include steps to save files.
|
121
|
+
2. Do not include unrelated user instructions that are not related to the chart generation.
|
122
|
+
""",
|
123
|
+
input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
|
124
|
+
|
125
|
+
)
|
126
|
+
|
127
|
+
data_raw = state.get("data_raw")
|
128
|
+
df = pd.DataFrame.from_dict(data_raw)
|
129
|
+
|
130
|
+
all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
|
131
|
+
|
132
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
133
|
+
|
134
|
+
chart_instructor = recommend_steps_prompt | llm
|
135
|
+
|
136
|
+
recommended_steps = chart_instructor.invoke({
|
137
|
+
"user_instructions": state.get("user_instructions"),
|
138
|
+
"recommended_steps": state.get("recommended_steps"),
|
139
|
+
"all_datasets_summary": all_datasets_summary_str
|
140
|
+
})
|
141
|
+
|
142
|
+
return {
|
143
|
+
"recommended_steps": "\n\n# Recommended Data Cleaning Steps:\n" + recommended_steps.content.strip(),
|
144
|
+
"all_datasets_summary": all_datasets_summary_str
|
145
|
+
}
|
146
|
+
|
147
|
+
def chart_generator(state: GraphState):
|
148
|
+
|
149
|
+
print(" * CREATE DATA VISUALIZATION CODE")
|
150
|
+
|
151
|
+
|
152
|
+
if bypass_recommended_steps:
|
153
|
+
print(format_agent_name(AGENT_NAME))
|
154
|
+
|
155
|
+
data_raw = state.get("data_raw")
|
156
|
+
df = pd.DataFrame.from_dict(data_raw)
|
157
|
+
|
158
|
+
all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
|
159
|
+
|
160
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
161
|
+
|
162
|
+
chart_generator_instructions = state.get("user_instructions")
|
163
|
+
|
164
|
+
else:
|
165
|
+
all_datasets_summary_str = state.get("all_datasets_summary")
|
166
|
+
chart_generator_instructions = state.get("recommended_steps")
|
167
|
+
|
168
|
+
prompt_template = PromptTemplate(
|
169
|
+
template="""
|
170
|
+
You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
|
171
|
+
|
172
|
+
Your job is to produce python code to generate visualizations.
|
173
|
+
|
174
|
+
You will take instructions from a Chart Instructor and generate a plotly chart from the data provided.
|
175
|
+
|
176
|
+
CHART INSTRUCTIONS:
|
177
|
+
{chart_generator_instructions}
|
178
|
+
|
179
|
+
DATA:
|
180
|
+
{all_datasets_summary}
|
181
|
+
|
182
|
+
RETURN:
|
183
|
+
|
184
|
+
Return Python code in ```python ``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
|
185
|
+
|
186
|
+
Return the plotly chart as a dictionary.
|
187
|
+
|
188
|
+
Return code to provide the data visualization function:
|
189
|
+
|
190
|
+
def data_visualization(data_raw):
|
191
|
+
import pandas as pd
|
192
|
+
import numpy as np
|
193
|
+
import json
|
194
|
+
import plotly.graph_objects as go
|
195
|
+
import plotly.io as pio
|
196
|
+
|
197
|
+
...
|
198
|
+
|
199
|
+
fig_json = pio.to_json(fig)
|
200
|
+
fig_dict = json.loads(fig_json)
|
201
|
+
|
202
|
+
return fig_dict
|
203
|
+
|
204
|
+
Avoid these:
|
205
|
+
1. Do not include steps to save files.
|
206
|
+
2. Do not include unrelated user instructions that are not related to the chart generation.
|
207
|
+
|
208
|
+
""",
|
209
|
+
input_variables=["chart_generator_instructions", "all_datasets_summary"]
|
210
|
+
)
|
211
|
+
|
212
|
+
data_visualization_agent = prompt_template | llm | PythonOutputParser()
|
213
|
+
|
214
|
+
response = data_visualization_agent.invoke({
|
215
|
+
"chart_generator_instructions": chart_generator_instructions,
|
216
|
+
"all_datasets_summary": all_datasets_summary_str
|
217
|
+
})
|
218
|
+
|
219
|
+
response = relocate_imports_inside_function(response)
|
220
|
+
response = add_comments_to_top(response, agent_name=AGENT_NAME)
|
221
|
+
|
222
|
+
# For logging: store the code generated:
|
223
|
+
file_path, file_name_2 = log_ai_function(
|
224
|
+
response=response,
|
225
|
+
file_name=file_name,
|
226
|
+
log=log,
|
227
|
+
log_path=log_path,
|
228
|
+
overwrite=overwrite
|
229
|
+
)
|
230
|
+
|
231
|
+
return {
|
232
|
+
"data_visualization_function": response,
|
233
|
+
"data_visualization_function_path": file_path,
|
234
|
+
"data_visualization_function_name": file_name_2,
|
235
|
+
"all_datasets_summary": all_datasets_summary_str
|
236
|
+
}
|
237
|
+
|
238
|
+
def human_review(state: GraphState) -> Command[Literal["chart_instructor", "chart_generator"]]:
|
239
|
+
return node_func_human_review(
|
240
|
+
state=state,
|
241
|
+
prompt_text="Is the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}",
|
242
|
+
yes_goto="chart_generator",
|
243
|
+
no_goto="chart_instructor",
|
244
|
+
user_instructions_key="user_instructions",
|
245
|
+
recommended_steps_key="recommended_steps"
|
246
|
+
)
|
247
|
+
|
248
|
+
|
249
|
+
def execute_data_visualization_code(state):
|
250
|
+
return node_func_execute_agent_code_on_data(
|
251
|
+
state=state,
|
252
|
+
data_key="data_raw",
|
253
|
+
result_key="plotly_graph",
|
254
|
+
error_key="data_visualization_error",
|
255
|
+
code_snippet_key="data_visualization_function",
|
256
|
+
agent_function_name="data_visualization",
|
257
|
+
pre_processing=lambda data: pd.DataFrame.from_dict(data),
|
258
|
+
# post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
259
|
+
error_message_prefix="An error occurred during data visualization: "
|
260
|
+
)
|
261
|
+
|
262
|
+
def fix_data_visualization_code(state: GraphState):
|
263
|
+
prompt = """
|
264
|
+
You are a Data Visualization Agent. Your job is to create a data_visualization() function that can be run on the data provided. The function is currently broken and needs to be fixed.
|
265
|
+
|
266
|
+
Make sure to only return the function definition for data_visualization().
|
267
|
+
|
268
|
+
Return Python code in ```python``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
|
269
|
+
|
270
|
+
This is the broken code (please fix):
|
271
|
+
{code_snippet}
|
272
|
+
|
273
|
+
Last Known Error:
|
274
|
+
{error}
|
275
|
+
"""
|
276
|
+
|
277
|
+
return node_func_fix_agent_code(
|
278
|
+
state=state,
|
279
|
+
code_snippet_key="data_visualization_function",
|
280
|
+
error_key="data_visualization_error",
|
281
|
+
llm=llm,
|
282
|
+
prompt_template=prompt,
|
283
|
+
agent_name=AGENT_NAME,
|
284
|
+
log=log,
|
285
|
+
file_path=state.get("data_visualization_function_path"),
|
286
|
+
)
|
287
|
+
|
288
|
+
def explain_data_visualization_code(state: GraphState):
|
289
|
+
return node_func_explain_agent_code(
|
290
|
+
state=state,
|
291
|
+
code_snippet_key="data_visualization_function",
|
292
|
+
result_key="messages",
|
293
|
+
error_key="data_visualization_error",
|
294
|
+
llm=llm,
|
295
|
+
role=AGENT_NAME,
|
296
|
+
explanation_prompt_template="""
|
297
|
+
Explain the data visualization steps that the data visualization agent performed in this function.
|
298
|
+
Keep the summary succinct and to the point.\n\n# Data Visualization Agent:\n\n{code}
|
299
|
+
""",
|
300
|
+
success_prefix="# Data Visualization Agent:\n\n ",
|
301
|
+
error_message="The Data Visualization Agent encountered an error during data visualization. No explanation could be provided."
|
302
|
+
)
|
303
|
+
|
304
|
+
# Define the graph
|
305
|
+
node_functions = {
|
306
|
+
"chart_instructor": chart_instructor,
|
307
|
+
"human_review": human_review,
|
308
|
+
"chart_generator": chart_generator,
|
309
|
+
"execute_data_visualization_code": execute_data_visualization_code,
|
310
|
+
"fix_data_visualization_code": fix_data_visualization_code,
|
311
|
+
"explain_data_visualization_code": explain_data_visualization_code
|
312
|
+
}
|
313
|
+
|
314
|
+
app = create_coding_agent_graph(
|
315
|
+
GraphState=GraphState,
|
316
|
+
node_functions=node_functions,
|
317
|
+
recommended_steps_node_name="chart_instructor",
|
318
|
+
create_code_node_name="chart_generator",
|
319
|
+
execute_code_node_name="execute_data_visualization_code",
|
320
|
+
fix_code_node_name="fix_data_visualization_code",
|
321
|
+
explain_code_node_name="explain_data_visualization_code",
|
322
|
+
error_key="data_visualization_error",
|
323
|
+
human_in_the_loop=human_in_the_loop, # or False
|
324
|
+
human_review_node_name="human_review",
|
325
|
+
checkpointer=MemorySaver() if human_in_the_loop else None,
|
326
|
+
bypass_recommended_steps=bypass_recommended_steps,
|
327
|
+
bypass_explain_code=bypass_explain_code,
|
328
|
+
)
|
329
|
+
|
330
|
+
return app
|
331
|
+
|
@@ -15,7 +15,7 @@ from langchain_core.messages import BaseMessage
|
|
15
15
|
from langgraph.types import Command
|
16
16
|
from langgraph.checkpoint.memory import MemorySaver
|
17
17
|
|
18
|
-
from ai_data_science_team.templates
|
18
|
+
from ai_data_science_team.templates import(
|
19
19
|
node_func_execute_agent_code_on_data,
|
20
20
|
node_func_human_review,
|
21
21
|
node_func_fix_agent_code,
|
@@ -23,15 +23,25 @@ from ai_data_science_team.templates.agent_templates import(
|
|
23
23
|
create_coding_agent_graph
|
24
24
|
)
|
25
25
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
26
|
-
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
|
27
|
-
from ai_data_science_team.tools.
|
26
|
+
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
|
27
|
+
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
28
28
|
from ai_data_science_team.tools.logging import log_ai_function
|
29
29
|
|
30
30
|
# Setup Logging Path
|
31
31
|
AGENT_NAME = "data_wrangling_agent"
|
32
32
|
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
33
33
|
|
34
|
-
def make_data_wrangling_agent(
|
34
|
+
def make_data_wrangling_agent(
|
35
|
+
model,
|
36
|
+
n_samples=30,
|
37
|
+
log=False,
|
38
|
+
log_path=None,
|
39
|
+
file_name="data_wrangler.py",
|
40
|
+
overwrite = True,
|
41
|
+
human_in_the_loop=False,
|
42
|
+
bypass_recommended_steps=False,
|
43
|
+
bypass_explain_code=False
|
44
|
+
):
|
35
45
|
"""
|
36
46
|
Creates a data wrangling agent that can be run on one or more datasets. The agent can be
|
37
47
|
instructed to perform common data wrangling steps such as:
|
@@ -52,17 +62,27 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
52
62
|
----------
|
53
63
|
model : langchain.llms.base.LLM
|
54
64
|
The language model to use to generate code.
|
65
|
+
n_samples : int, optional
|
66
|
+
The number of samples to show in the data summary. Defaults to 30.
|
67
|
+
If you get an error due to maximum tokens, try reducing this number.
|
68
|
+
> "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
|
55
69
|
log : bool, optional
|
56
70
|
Whether or not to log the code generated and any errors that occur.
|
57
71
|
Defaults to False.
|
58
72
|
log_path : str, optional
|
59
73
|
The path to the directory where the log files should be stored. Defaults to "logs/".
|
74
|
+
file_name : str, optional
|
75
|
+
The name of the file to save the response to. Defaults to "data_wrangler.py".
|
60
76
|
overwrite : bool, optional
|
61
77
|
Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
|
62
78
|
Defaults to True.
|
63
79
|
human_in_the_loop : bool, optional
|
64
80
|
Whether or not to use human in the loop. If True, adds an interrupt and human-in-the-loop
|
65
81
|
step that asks the user to review the data wrangling instructions. Defaults to False.
|
82
|
+
bypass_recommended_steps : bool, optional
|
83
|
+
Bypass the recommendation step, by default False
|
84
|
+
bypass_explain_code : bool, optional
|
85
|
+
Bypass the code explanation step, by default False.
|
66
86
|
|
67
87
|
Example
|
68
88
|
-------
|
@@ -90,7 +110,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
90
110
|
|
91
111
|
Returns
|
92
112
|
-------
|
93
|
-
app : langchain.graphs.
|
113
|
+
app : langchain.graphs.CompiledStateGraph
|
94
114
|
The data wrangling agent as a state graph.
|
95
115
|
"""
|
96
116
|
llm = model
|
@@ -118,7 +138,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
118
138
|
retry_count: int
|
119
139
|
|
120
140
|
def recommend_wrangling_steps(state: GraphState):
|
121
|
-
print(
|
141
|
+
print(format_agent_name(AGENT_NAME))
|
122
142
|
print(" * RECOMMEND WRANGLING STEPS")
|
123
143
|
|
124
144
|
data_raw = state.get("data_raw")
|
@@ -139,7 +159,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
139
159
|
|
140
160
|
# Create a summary for all datasets
|
141
161
|
# We'll include a short sample and info for each dataset
|
142
|
-
all_datasets_summary =
|
162
|
+
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
|
143
163
|
|
144
164
|
# Join all datasets summaries into one big text block
|
145
165
|
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
@@ -172,6 +192,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
172
192
|
|
173
193
|
Avoid these:
|
174
194
|
1. Do not include steps to save files.
|
195
|
+
2. Do not include unrelated user instructions that are not related to the data wrangling.
|
175
196
|
""",
|
176
197
|
input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
|
177
198
|
)
|
@@ -190,6 +211,35 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
190
211
|
|
191
212
|
|
192
213
|
def create_data_wrangler_code(state: GraphState):
|
214
|
+
if bypass_recommended_steps:
|
215
|
+
print(format_agent_name(AGENT_NAME))
|
216
|
+
|
217
|
+
data_raw = state.get("data_raw")
|
218
|
+
|
219
|
+
if isinstance(data_raw, dict):
|
220
|
+
# Single dataset scenario
|
221
|
+
primary_dataset_name = "main"
|
222
|
+
datasets = {primary_dataset_name: data_raw}
|
223
|
+
elif isinstance(data_raw, list) and all(isinstance(item, dict) for item in data_raw):
|
224
|
+
# Multiple datasets scenario
|
225
|
+
datasets = {f"dataset_{i}": d for i, d in enumerate(data_raw, start=1)}
|
226
|
+
primary_dataset_name = "dataset_1"
|
227
|
+
else:
|
228
|
+
raise ValueError("data_raw must be a dict or a list of dicts.")
|
229
|
+
|
230
|
+
# Convert all datasets to DataFrames for inspection
|
231
|
+
dataframes = {name: pd.DataFrame.from_dict(d) for name, d in datasets.items()}
|
232
|
+
|
233
|
+
# Create a summary for all datasets
|
234
|
+
# We'll include a short sample and info for each dataset
|
235
|
+
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
|
236
|
+
|
237
|
+
# Join all datasets summaries into one big text block
|
238
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
239
|
+
|
240
|
+
else:
|
241
|
+
all_datasets_summary_str = state.get("all_datasets_summary")
|
242
|
+
|
193
243
|
print(" * CREATE DATA WRANGLER CODE")
|
194
244
|
|
195
245
|
data_wrangling_prompt = PromptTemplate(
|
@@ -236,16 +286,16 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
236
286
|
|
237
287
|
response = data_wrangling_agent.invoke({
|
238
288
|
"recommended_steps": state.get("recommended_steps"),
|
239
|
-
"all_datasets_summary":
|
289
|
+
"all_datasets_summary": all_datasets_summary_str
|
240
290
|
})
|
241
291
|
|
242
292
|
response = relocate_imports_inside_function(response)
|
243
293
|
response = add_comments_to_top(response, agent_name=AGENT_NAME)
|
244
294
|
|
245
295
|
# For logging: store the code generated
|
246
|
-
file_path,
|
296
|
+
file_path, file_name_2 = log_ai_function(
|
247
297
|
response=response,
|
248
|
-
file_name=
|
298
|
+
file_name=file_name,
|
249
299
|
log=log,
|
250
300
|
log_path=log_path,
|
251
301
|
overwrite=overwrite
|
@@ -254,7 +304,8 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
254
304
|
return {
|
255
305
|
"data_wrangler_function" : response,
|
256
306
|
"data_wrangler_function_path": file_path,
|
257
|
-
"data_wrangler_function_name":
|
307
|
+
"data_wrangler_function_name": file_name_2,
|
308
|
+
"all_datasets_summary": all_datasets_summary_str
|
258
309
|
}
|
259
310
|
|
260
311
|
|
@@ -269,17 +320,6 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
269
320
|
)
|
270
321
|
|
271
322
|
def execute_data_wrangler_code(state: GraphState):
|
272
|
-
|
273
|
-
# Handle multiple datasets as lists
|
274
|
-
# def pre_processing(data):
|
275
|
-
# df = []
|
276
|
-
# for i in range(len(data)):
|
277
|
-
# df[i] = pd.DataFrame.from_dict(data[i])
|
278
|
-
# return df
|
279
|
-
|
280
|
-
# def post_processing(df):
|
281
|
-
# return df.to_dict()
|
282
|
-
|
283
323
|
return node_func_execute_agent_code_on_data(
|
284
324
|
state=state,
|
285
325
|
data_key="data_raw",
|
@@ -288,7 +328,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
288
328
|
code_snippet_key="data_wrangler_function",
|
289
329
|
agent_function_name="data_wrangler",
|
290
330
|
# pre_processing=pre_processing,
|
291
|
-
|
331
|
+
post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
292
332
|
error_message_prefix="An error occurred during data wrangling: "
|
293
333
|
)
|
294
334
|
|
@@ -355,7 +395,9 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
355
395
|
error_key="data_wrangler_error",
|
356
396
|
human_in_the_loop=human_in_the_loop,
|
357
397
|
human_review_node_name="human_review",
|
358
|
-
checkpointer=MemorySaver() if human_in_the_loop else None
|
398
|
+
checkpointer=MemorySaver() if human_in_the_loop else None,
|
399
|
+
bypass_recommended_steps=bypass_recommended_steps,
|
400
|
+
bypass_explain_code=bypass_explain_code,
|
359
401
|
)
|
360
402
|
|
361
403
|
return app
|