ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +4 -4
- ai_data_science_team/agents/data_cleaning_agent.py +225 -84
- ai_data_science_team/agents/data_visualization_agent.py +460 -27
- ai_data_science_team/agents/data_wrangling_agent.py +455 -16
- ai_data_science_team/agents/feature_engineering_agent.py +429 -25
- ai_data_science_team/agents/sql_database_agent.py +367 -21
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +286 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +2 -1
- ai_data_science_team/templates/agent_templates.py +247 -42
- ai_data_science_team/tools/regex.py +28 -1
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/METADATA +76 -28
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +26 -0
- ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,11 @@
|
|
4
4
|
# * Agents: Data Wrangling Agent
|
5
5
|
|
6
6
|
# Libraries
|
7
|
-
from typing import TypedDict, Annotated, Sequence, Literal, Union
|
7
|
+
from typing import TypedDict, Annotated, Sequence, Literal, Union, Optional
|
8
8
|
import operator
|
9
9
|
import os
|
10
|
-
import io
|
11
10
|
import pandas as pd
|
11
|
+
from IPython.display import Markdown
|
12
12
|
|
13
13
|
from langchain.prompts import PromptTemplate
|
14
14
|
from langchain_core.messages import BaseMessage
|
@@ -20,10 +20,11 @@ from ai_data_science_team.templates import(
|
|
20
20
|
node_func_human_review,
|
21
21
|
node_func_fix_agent_code,
|
22
22
|
node_func_explain_agent_code,
|
23
|
-
create_coding_agent_graph
|
23
|
+
create_coding_agent_graph,
|
24
|
+
BaseAgent,
|
24
25
|
)
|
25
26
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
26
|
-
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
|
27
|
+
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name, format_recommended_steps
|
27
28
|
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
28
29
|
from ai_data_science_team.tools.logging import log_ai_function
|
29
30
|
|
@@ -31,13 +32,414 @@ from ai_data_science_team.tools.logging import log_ai_function
|
|
31
32
|
AGENT_NAME = "data_wrangling_agent"
|
32
33
|
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
33
34
|
|
35
|
+
# Class
|
36
|
+
|
37
|
+
class DataWranglingAgent(BaseAgent):
|
38
|
+
"""
|
39
|
+
Creates a data wrangling agent that can work with one or more datasets, performing operations such as
|
40
|
+
joining/merging multiple datasets, reshaping, aggregating, encoding, creating computed features,
|
41
|
+
and ensuring consistent data types. The agent generates a Python function to wrangle the data,
|
42
|
+
executes the function, and logs the process (if enabled).
|
43
|
+
|
44
|
+
The agent can handle:
|
45
|
+
- A single dataset (provided as a dictionary of {column: list_of_values})
|
46
|
+
- Multiple datasets (provided as a list of such dictionaries)
|
47
|
+
|
48
|
+
Key wrangling steps can include:
|
49
|
+
- Merging or joining datasets
|
50
|
+
- Pivoting/melting data for reshaping
|
51
|
+
- GroupBy aggregations (sums, means, counts, etc.)
|
52
|
+
- Encoding categorical variables
|
53
|
+
- Computing new columns from existing ones
|
54
|
+
- Dropping or rearranging columns
|
55
|
+
- Any additional user instructions
|
56
|
+
|
57
|
+
Parameters
|
58
|
+
----------
|
59
|
+
model : langchain.llms.base.LLM
|
60
|
+
The language model used to generate the data wrangling function.
|
61
|
+
n_samples : int, optional
|
62
|
+
Number of samples to show in the data summary for wrangling. Defaults to 30.
|
63
|
+
log : bool, optional
|
64
|
+
Whether to log the generated code and errors. Defaults to False.
|
65
|
+
log_path : str, optional
|
66
|
+
Directory path for storing log files. Defaults to None.
|
67
|
+
file_name : str, optional
|
68
|
+
Name of the file for saving the generated response. Defaults to "data_wrangler.py".
|
69
|
+
function_name : str, optional
|
70
|
+
Name of the function to be generated. Defaults to "data_wrangler".
|
71
|
+
overwrite : bool, optional
|
72
|
+
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
|
73
|
+
human_in_the_loop : bool, optional
|
74
|
+
Enables user review of data wrangling instructions. Defaults to False.
|
75
|
+
bypass_recommended_steps : bool, optional
|
76
|
+
If True, skips the step that generates recommended data wrangling steps. Defaults to False.
|
77
|
+
bypass_explain_code : bool, optional
|
78
|
+
If True, skips the step that provides code explanations. Defaults to False.
|
79
|
+
|
80
|
+
Methods
|
81
|
+
-------
|
82
|
+
update_params(**kwargs)
|
83
|
+
Updates the agent's parameters and rebuilds the compiled state graph.
|
84
|
+
|
85
|
+
ainvoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
|
86
|
+
Asynchronously wrangles the provided dataset(s) based on user instructions.
|
87
|
+
|
88
|
+
invoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
|
89
|
+
Synchronously wrangles the provided dataset(s) based on user instructions.
|
90
|
+
|
91
|
+
explain_wrangling_steps()
|
92
|
+
Returns an explanation of the wrangling steps performed by the agent.
|
93
|
+
|
94
|
+
get_log_summary()
|
95
|
+
Retrieves a summary of logged operations if logging is enabled.
|
96
|
+
|
97
|
+
get_data_wrangled()
|
98
|
+
Retrieves the final wrangled dataset (as a dictionary of {column: list_of_values}).
|
99
|
+
|
100
|
+
get_data_raw()
|
101
|
+
Retrieves the raw dataset(s).
|
102
|
+
|
103
|
+
get_data_wrangler_function()
|
104
|
+
Retrieves the generated Python function used for data wrangling.
|
105
|
+
|
106
|
+
get_recommended_wrangling_steps()
|
107
|
+
Retrieves the agent's recommended wrangling steps.
|
108
|
+
|
109
|
+
get_response()
|
110
|
+
Returns the full response dictionary from the agent.
|
111
|
+
|
112
|
+
show()
|
113
|
+
Displays the agent's mermaid diagram for visual inspection of the compiled graph.
|
114
|
+
|
115
|
+
Examples
|
116
|
+
--------
|
117
|
+
```python
|
118
|
+
import pandas as pd
|
119
|
+
from langchain_openai import ChatOpenAI
|
120
|
+
from ai_data_science_team.agents import DataWranglingAgent
|
121
|
+
|
122
|
+
# Single dataset example
|
123
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
124
|
+
|
125
|
+
data_wrangling_agent = DataWranglingAgent(
|
126
|
+
model=llm,
|
127
|
+
n_samples=30,
|
128
|
+
log=True,
|
129
|
+
log_path="logs",
|
130
|
+
human_in_the_loop=True
|
131
|
+
)
|
132
|
+
|
133
|
+
df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
|
134
|
+
|
135
|
+
data_wrangling_agent.invoke_agent(
|
136
|
+
user_instructions="Group by 'gender' and compute mean of 'tenure'.",
|
137
|
+
data_raw=df, # data_raw can be df.to_dict() or just a DataFrame
|
138
|
+
max_retries=3,
|
139
|
+
retry_count=0
|
140
|
+
)
|
141
|
+
|
142
|
+
data_wrangled = data_wrangling_agent.get_data_wrangled()
|
143
|
+
response = data_wrangling_agent.get_response()
|
144
|
+
|
145
|
+
# Multiple dataset example (list of dicts)
|
146
|
+
df1 = pd.DataFrame({'id': [1,2,3], 'val1': [10,20,30]})
|
147
|
+
df2 = pd.DataFrame({'id': [1,2,3], 'val2': [40,50,60]})
|
148
|
+
|
149
|
+
data_wrangling_agent.invoke_agent(
|
150
|
+
user_instructions="Merge these two datasets on 'id' and compute a new column 'val_sum' = val1+val2",
|
151
|
+
data_raw=[df1, df2], # multiple datasets
|
152
|
+
max_retries=3,
|
153
|
+
retry_count=0
|
154
|
+
)
|
155
|
+
|
156
|
+
data_wrangled = data_wrangling_agent.get_data_wrangled()
|
157
|
+
```
|
158
|
+
|
159
|
+
Returns
|
160
|
+
-------
|
161
|
+
DataWranglingAgent : langchain.graphs.CompiledStateGraph
|
162
|
+
A data wrangling agent implemented as a compiled state graph.
|
163
|
+
"""
|
164
|
+
|
165
|
+
def __init__(
|
166
|
+
self,
|
167
|
+
model,
|
168
|
+
n_samples=30,
|
169
|
+
log=False,
|
170
|
+
log_path=None,
|
171
|
+
file_name="data_wrangler.py",
|
172
|
+
function_name="data_wrangler",
|
173
|
+
overwrite=True,
|
174
|
+
human_in_the_loop=False,
|
175
|
+
bypass_recommended_steps=False,
|
176
|
+
bypass_explain_code=False
|
177
|
+
):
|
178
|
+
self._params = {
|
179
|
+
"model": model,
|
180
|
+
"n_samples": n_samples,
|
181
|
+
"log": log,
|
182
|
+
"log_path": log_path,
|
183
|
+
"file_name": file_name,
|
184
|
+
"function_name": function_name,
|
185
|
+
"overwrite": overwrite,
|
186
|
+
"human_in_the_loop": human_in_the_loop,
|
187
|
+
"bypass_recommended_steps": bypass_recommended_steps,
|
188
|
+
"bypass_explain_code": bypass_explain_code
|
189
|
+
}
|
190
|
+
self._compiled_graph = self._make_compiled_graph()
|
191
|
+
self.response = None
|
192
|
+
|
193
|
+
def _make_compiled_graph(self):
|
194
|
+
"""
|
195
|
+
Create the compiled graph for the data wrangling agent.
|
196
|
+
Running this method will reset the response to None.
|
197
|
+
"""
|
198
|
+
self.response = None
|
199
|
+
return make_data_wrangling_agent(**self._params)
|
200
|
+
|
201
|
+
def update_params(self, **kwargs):
|
202
|
+
"""
|
203
|
+
Updates the agent's parameters and rebuilds the compiled graph.
|
204
|
+
"""
|
205
|
+
for k, v in kwargs.items():
|
206
|
+
self._params[k] = v
|
207
|
+
self._compiled_graph = self._make_compiled_graph()
|
208
|
+
|
209
|
+
def ainvoke_agent(
|
210
|
+
self,
|
211
|
+
data_raw: Union[pd.DataFrame, dict, list],
|
212
|
+
user_instructions: str=None,
|
213
|
+
max_retries:int=3,
|
214
|
+
retry_count:int=0,
|
215
|
+
**kwargs
|
216
|
+
):
|
217
|
+
"""
|
218
|
+
Asynchronously wrangles the provided dataset(s) based on user instructions.
|
219
|
+
The response is stored in the 'response' attribute.
|
220
|
+
|
221
|
+
Parameters
|
222
|
+
----------
|
223
|
+
data_raw : Union[pd.DataFrame, dict, list]
|
224
|
+
The raw dataset(s) to be wrangled.
|
225
|
+
Can be a single DataFrame, a single dict ({col: list_of_values}),
|
226
|
+
or a list of dicts if multiple datasets are provided.
|
227
|
+
user_instructions : str
|
228
|
+
Instructions for data wrangling.
|
229
|
+
max_retries : int
|
230
|
+
Maximum retry attempts.
|
231
|
+
retry_count : int
|
232
|
+
Current retry attempt count.
|
233
|
+
**kwargs
|
234
|
+
Additional keyword arguments to pass to ainvoke().
|
235
|
+
|
236
|
+
Returns
|
237
|
+
-------
|
238
|
+
None
|
239
|
+
"""
|
240
|
+
data_input = self._convert_data_input(data_raw)
|
241
|
+
response = self._compiled_graph.ainvoke({
|
242
|
+
"user_instructions": user_instructions,
|
243
|
+
"data_raw": data_input,
|
244
|
+
"max_retries": max_retries,
|
245
|
+
"retry_count": retry_count
|
246
|
+
}, **kwargs)
|
247
|
+
self.response = response
|
248
|
+
return None
|
249
|
+
|
250
|
+
def invoke_agent(
|
251
|
+
self,
|
252
|
+
data_raw: Union[pd.DataFrame, dict, list],
|
253
|
+
user_instructions: str=None,
|
254
|
+
max_retries:int=3,
|
255
|
+
retry_count:int=0,
|
256
|
+
**kwargs
|
257
|
+
):
|
258
|
+
"""
|
259
|
+
Synchronously wrangles the provided dataset(s) based on user instructions.
|
260
|
+
The response is stored in the 'response' attribute.
|
261
|
+
|
262
|
+
Parameters
|
263
|
+
----------
|
264
|
+
data_raw : Union[pd.DataFrame, dict, list]
|
265
|
+
The raw dataset(s) to be wrangled.
|
266
|
+
Can be a single DataFrame, a single dict, or a list of dicts.
|
267
|
+
user_instructions : str
|
268
|
+
Instructions for data wrangling agent.
|
269
|
+
max_retries : int
|
270
|
+
Maximum retry attempts.
|
271
|
+
retry_count : int
|
272
|
+
Current retry attempt count.
|
273
|
+
**kwargs
|
274
|
+
Additional keyword arguments to pass to invoke().
|
275
|
+
|
276
|
+
Returns
|
277
|
+
-------
|
278
|
+
None
|
279
|
+
"""
|
280
|
+
data_input = self._convert_data_input(data_raw)
|
281
|
+
response = self._compiled_graph.invoke({
|
282
|
+
"user_instructions": user_instructions,
|
283
|
+
"data_raw": data_input,
|
284
|
+
"max_retries": max_retries,
|
285
|
+
"retry_count": retry_count
|
286
|
+
}, **kwargs)
|
287
|
+
self.response = response
|
288
|
+
return None
|
289
|
+
|
290
|
+
def explain_wrangling_steps(self):
|
291
|
+
"""
|
292
|
+
Provides an explanation of the wrangling steps performed by the agent.
|
293
|
+
|
294
|
+
Returns
|
295
|
+
-------
|
296
|
+
str or list
|
297
|
+
Explanation of the data wrangling steps.
|
298
|
+
"""
|
299
|
+
if self.response:
|
300
|
+
return self.response.get("messages", [])
|
301
|
+
return []
|
302
|
+
|
303
|
+
def get_log_summary(self, markdown=False):
|
304
|
+
"""
|
305
|
+
Logs a summary of the agent's operations, if logging is enabled.
|
306
|
+
|
307
|
+
Parameters
|
308
|
+
----------
|
309
|
+
markdown : bool, optional
|
310
|
+
If True, returns the summary in Markdown.
|
311
|
+
|
312
|
+
Returns
|
313
|
+
-------
|
314
|
+
str or None
|
315
|
+
The log details, or None if not available.
|
316
|
+
"""
|
317
|
+
if self.response and self.response.get("data_wrangler_function_path"):
|
318
|
+
log_details = f"Log Path: {self.response.get('data_wrangler_function_path')}"
|
319
|
+
if markdown:
|
320
|
+
return Markdown(log_details)
|
321
|
+
else:
|
322
|
+
return log_details
|
323
|
+
return None
|
324
|
+
|
325
|
+
def get_data_wrangled(self) -> Optional[pd.DataFrame]:
|
326
|
+
"""
|
327
|
+
Retrieves the wrangled data after running invoke_agent() or ainvoke_agent().
|
328
|
+
|
329
|
+
Returns
|
330
|
+
-------
|
331
|
+
pd.DataFrame or None
|
332
|
+
The wrangled dataset as a pandas DataFrame (if available).
|
333
|
+
"""
|
334
|
+
if self.response and "data_wrangled" in self.response:
|
335
|
+
return pd.DataFrame(self.response["data_wrangled"])
|
336
|
+
return None
|
337
|
+
|
338
|
+
def get_data_raw(self) -> Union[dict, list, None]:
|
339
|
+
"""
|
340
|
+
Retrieves the original raw data from the last invocation.
|
341
|
+
|
342
|
+
Returns
|
343
|
+
-------
|
344
|
+
Union[dict, list, None]
|
345
|
+
The original dataset(s) as a single dict or a list of dicts, or None if not available.
|
346
|
+
"""
|
347
|
+
if self.response and "data_raw" in self.response:
|
348
|
+
return self.response["data_raw"]
|
349
|
+
return None
|
350
|
+
|
351
|
+
def get_data_wrangler_function(self, markdown=False) -> Optional[str]:
|
352
|
+
"""
|
353
|
+
Retrieves the generated data wrangling function code.
|
354
|
+
|
355
|
+
Parameters
|
356
|
+
----------
|
357
|
+
markdown : bool, optional
|
358
|
+
If True, returns the function in Markdown code block format.
|
359
|
+
|
360
|
+
Returns
|
361
|
+
-------
|
362
|
+
str or None
|
363
|
+
The Python function code, or None if not available.
|
364
|
+
"""
|
365
|
+
if self.response and "data_wrangler_function" in self.response:
|
366
|
+
code = self.response["data_wrangler_function"]
|
367
|
+
if markdown:
|
368
|
+
return Markdown(f"```python\n{code}\n```")
|
369
|
+
return code
|
370
|
+
return None
|
371
|
+
|
372
|
+
def get_recommended_wrangling_steps(self, markdown=False) -> Optional[str]:
|
373
|
+
"""
|
374
|
+
Retrieves the agent's recommended data wrangling steps.
|
375
|
+
|
376
|
+
Parameters
|
377
|
+
----------
|
378
|
+
markdown : bool, optional
|
379
|
+
If True, returns the steps in Markdown format.
|
380
|
+
|
381
|
+
Returns
|
382
|
+
-------
|
383
|
+
str or None
|
384
|
+
The recommended steps, or None if not available.
|
385
|
+
"""
|
386
|
+
if self.response and "recommended_steps" in self.response:
|
387
|
+
steps = self.response["recommended_steps"]
|
388
|
+
if markdown:
|
389
|
+
return Markdown(steps)
|
390
|
+
return steps
|
391
|
+
return None
|
392
|
+
|
393
|
+
@staticmethod
|
394
|
+
def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
|
395
|
+
"""
|
396
|
+
Internal utility to convert data_raw (which could be a DataFrame, dict, or list of dicts)
|
397
|
+
into the format expected by the underlying agent (dict or list of dicts).
|
398
|
+
|
399
|
+
Parameters
|
400
|
+
----------
|
401
|
+
data_raw : Union[pd.DataFrame, dict, list]
|
402
|
+
The raw input data to be converted.
|
403
|
+
|
404
|
+
Returns
|
405
|
+
-------
|
406
|
+
Union[dict, list]
|
407
|
+
The data in a dictionary or list-of-dictionaries format.
|
408
|
+
"""
|
409
|
+
# If a single DataFrame, convert to dict
|
410
|
+
if isinstance(data_raw, pd.DataFrame):
|
411
|
+
return data_raw.to_dict()
|
412
|
+
|
413
|
+
# If it's already a dict (single dataset)
|
414
|
+
if isinstance(data_raw, dict):
|
415
|
+
return data_raw
|
416
|
+
|
417
|
+
# If it's already a list, check if it's a list of DataFrames or dicts
|
418
|
+
if isinstance(data_raw, list):
|
419
|
+
# Convert any DataFrame item to dict
|
420
|
+
converted_list = []
|
421
|
+
for item in data_raw:
|
422
|
+
if isinstance(item, pd.DataFrame):
|
423
|
+
converted_list.append(item.to_dict())
|
424
|
+
elif isinstance(item, dict):
|
425
|
+
converted_list.append(item)
|
426
|
+
else:
|
427
|
+
raise ValueError("List must contain only DataFrames or dictionaries.")
|
428
|
+
return converted_list
|
429
|
+
|
430
|
+
raise ValueError("data_raw must be a DataFrame, a dict, or a list of dicts/DataFrames.")
|
431
|
+
|
432
|
+
|
433
|
+
# Function
|
434
|
+
|
34
435
|
def make_data_wrangling_agent(
|
35
436
|
model,
|
36
437
|
n_samples=30,
|
37
438
|
log=False,
|
38
439
|
log_path=None,
|
39
440
|
file_name="data_wrangler.py",
|
40
|
-
|
441
|
+
function_name="data_wrangler",
|
442
|
+
overwrite=True,
|
41
443
|
human_in_the_loop=False,
|
42
444
|
bypass_recommended_steps=False,
|
43
445
|
bypass_explain_code=False
|
@@ -73,6 +475,8 @@ def make_data_wrangling_agent(
|
|
73
475
|
The path to the directory where the log files should be stored. Defaults to "logs/".
|
74
476
|
file_name : str, optional
|
75
477
|
The name of the file to save the response to. Defaults to "data_wrangler.py".
|
478
|
+
function_name : str, optional
|
479
|
+
The name of the function to be generated. Defaults to "data_wrangler".
|
76
480
|
overwrite : bool, optional
|
77
481
|
Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
|
78
482
|
Defaults to True.
|
@@ -115,6 +519,11 @@ def make_data_wrangling_agent(
|
|
115
519
|
"""
|
116
520
|
llm = model
|
117
521
|
|
522
|
+
# Human in th loop requires recommended steps
|
523
|
+
if bypass_recommended_steps and human_in_the_loop:
|
524
|
+
bypass_recommended_steps = False
|
525
|
+
print("Bypass recommended steps set to False to enable human in the loop.")
|
526
|
+
|
118
527
|
# Setup Log Directory
|
119
528
|
if log:
|
120
529
|
if log_path is None:
|
@@ -205,7 +614,7 @@ def make_data_wrangling_agent(
|
|
205
614
|
})
|
206
615
|
|
207
616
|
return {
|
208
|
-
"recommended_steps": "
|
617
|
+
"recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Wrangling Steps:"),
|
209
618
|
"all_datasets_summary": all_datasets_summary_str,
|
210
619
|
}
|
211
620
|
|
@@ -244,7 +653,7 @@ def make_data_wrangling_agent(
|
|
244
653
|
|
245
654
|
data_wrangling_prompt = PromptTemplate(
|
246
655
|
template="""
|
247
|
-
You are a Data Wrangling Coding Agent. Your job is to create a
|
656
|
+
You are a Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data.
|
248
657
|
|
249
658
|
Follow these recommended steps:
|
250
659
|
{recommended_steps}
|
@@ -254,10 +663,10 @@ def make_data_wrangling_agent(
|
|
254
663
|
Below are summaries of all datasets provided. If more than one dataset is provided, you may need to merge or join them.:
|
255
664
|
{all_datasets_summary}
|
256
665
|
|
257
|
-
Return Python code in ```python``` format with a single function definition,
|
666
|
+
Return Python code in ```python``` format with a single function definition, {function_name}(), that includes all imports inside the function. And returns a single pandas data frame.
|
258
667
|
|
259
668
|
```python
|
260
|
-
def
|
669
|
+
def {function_name}(data_list):
|
261
670
|
'''
|
262
671
|
Wrangle the data provided in data.
|
263
672
|
|
@@ -279,14 +688,15 @@ def make_data_wrangling_agent(
|
|
279
688
|
|
280
689
|
|
281
690
|
""",
|
282
|
-
input_variables=["recommended_steps", "all_datasets_summary"]
|
691
|
+
input_variables=["recommended_steps", "all_datasets_summary", "function_name"]
|
283
692
|
)
|
284
693
|
|
285
694
|
data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
|
286
695
|
|
287
696
|
response = data_wrangling_agent.invoke({
|
288
697
|
"recommended_steps": state.get("recommended_steps"),
|
289
|
-
"all_datasets_summary": all_datasets_summary_str
|
698
|
+
"all_datasets_summary": all_datasets_summary_str,
|
699
|
+
"function_name": function_name
|
290
700
|
})
|
291
701
|
|
292
702
|
response = relocate_imports_inside_function(response)
|
@@ -304,7 +714,8 @@ def make_data_wrangling_agent(
|
|
304
714
|
return {
|
305
715
|
"data_wrangler_function" : response,
|
306
716
|
"data_wrangler_function_path": file_path,
|
307
|
-
"
|
717
|
+
"data_wrangler_file_name": file_name_2,
|
718
|
+
"data_wrangler_function_name": function_name,
|
308
719
|
"all_datasets_summary": all_datasets_summary_str
|
309
720
|
}
|
310
721
|
|
@@ -318,6 +729,33 @@ def make_data_wrangling_agent(
|
|
318
729
|
user_instructions_key="user_instructions",
|
319
730
|
recommended_steps_key="recommended_steps"
|
320
731
|
)
|
732
|
+
|
733
|
+
# Human Review
|
734
|
+
|
735
|
+
prompt_text_human_review = "Are the following data wrangling instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
|
736
|
+
|
737
|
+
if not bypass_explain_code:
|
738
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "explain_data_wrangler_code"]]:
|
739
|
+
return node_func_human_review(
|
740
|
+
state=state,
|
741
|
+
prompt_text=prompt_text_human_review,
|
742
|
+
yes_goto= 'explain_data_wrangler_code',
|
743
|
+
no_goto="recommend_wrangling_steps",
|
744
|
+
user_instructions_key="user_instructions",
|
745
|
+
recommended_steps_key="recommended_steps",
|
746
|
+
code_snippet_key="data_wrangler_function",
|
747
|
+
)
|
748
|
+
else:
|
749
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "__end__"]]:
|
750
|
+
return node_func_human_review(
|
751
|
+
state=state,
|
752
|
+
prompt_text=prompt_text_human_review,
|
753
|
+
yes_goto= '__end__',
|
754
|
+
no_goto="recommend_wrangling_steps",
|
755
|
+
user_instructions_key="user_instructions",
|
756
|
+
recommended_steps_key="recommended_steps",
|
757
|
+
code_snippet_key="data_wrangler_function",
|
758
|
+
)
|
321
759
|
|
322
760
|
def execute_data_wrangler_code(state: GraphState):
|
323
761
|
return node_func_execute_agent_code_on_data(
|
@@ -326,7 +764,7 @@ def make_data_wrangling_agent(
|
|
326
764
|
result_key="data_wrangled",
|
327
765
|
error_key="data_wrangler_error",
|
328
766
|
code_snippet_key="data_wrangler_function",
|
329
|
-
agent_function_name="
|
767
|
+
agent_function_name=state.get("data_wrangler_function_name"),
|
330
768
|
# pre_processing=pre_processing,
|
331
769
|
post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
332
770
|
error_message_prefix="An error occurred during data wrangling: "
|
@@ -334,11 +772,11 @@ def make_data_wrangling_agent(
|
|
334
772
|
|
335
773
|
def fix_data_wrangler_code(state: GraphState):
|
336
774
|
data_wrangler_prompt = """
|
337
|
-
You are a Data Wrangling Agent. Your job is to create a
|
775
|
+
You are a Data Wrangling Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
|
338
776
|
|
339
|
-
Make sure to only return the function definition for
|
777
|
+
Make sure to only return the function definition for {function_name}().
|
340
778
|
|
341
|
-
Return Python code in ```python``` format with a single function definition,
|
779
|
+
Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
|
342
780
|
|
343
781
|
This is the broken code (please fix):
|
344
782
|
{code_snippet}
|
@@ -356,6 +794,7 @@ def make_data_wrangling_agent(
|
|
356
794
|
agent_name=AGENT_NAME,
|
357
795
|
log=log,
|
358
796
|
file_path=state.get("data_wrangler_function_path"),
|
797
|
+
function_name=state.get("data_wrangler_function_name"),
|
359
798
|
)
|
360
799
|
|
361
800
|
def explain_data_wrangler_code(state: GraphState):
|