ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +4 -5
- ai_data_science_team/agents/data_cleaning_agent.py +268 -116
- ai_data_science_team/agents/data_visualization_agent.py +470 -41
- ai_data_science_team/agents/data_wrangling_agent.py +471 -31
- ai_data_science_team/agents/feature_engineering_agent.py +426 -41
- ai_data_science_team/agents/sql_database_agent.py +458 -58
- ai_data_science_team/ml_agents/__init__.py +1 -0
- ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +3 -1
- ai_data_science_team/templates/agent_templates.py +319 -43
- ai_data_science_team/tools/metadata.py +94 -62
- ai_data_science_team/tools/regex.py +86 -1
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
- ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
- ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
- ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,12 @@
|
|
4
4
|
# * Agents: Data Wrangling Agent
|
5
5
|
|
6
6
|
# Libraries
|
7
|
-
from typing import TypedDict, Annotated, Sequence, Literal, Union
|
7
|
+
from typing import TypedDict, Annotated, Sequence, Literal, Union, Optional
|
8
8
|
import operator
|
9
9
|
import os
|
10
|
-
import
|
10
|
+
import json
|
11
11
|
import pandas as pd
|
12
|
+
from IPython.display import Markdown
|
12
13
|
|
13
14
|
from langchain.prompts import PromptTemplate
|
14
15
|
from langchain_core.messages import BaseMessage
|
@@ -19,11 +20,18 @@ from ai_data_science_team.templates import(
|
|
19
20
|
node_func_execute_agent_code_on_data,
|
20
21
|
node_func_human_review,
|
21
22
|
node_func_fix_agent_code,
|
22
|
-
|
23
|
-
create_coding_agent_graph
|
23
|
+
node_func_report_agent_outputs,
|
24
|
+
create_coding_agent_graph,
|
25
|
+
BaseAgent,
|
24
26
|
)
|
25
27
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
26
|
-
from ai_data_science_team.tools.regex import
|
28
|
+
from ai_data_science_team.tools.regex import (
|
29
|
+
relocate_imports_inside_function,
|
30
|
+
add_comments_to_top,
|
31
|
+
format_agent_name,
|
32
|
+
format_recommended_steps,
|
33
|
+
get_generic_summary,
|
34
|
+
)
|
27
35
|
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
28
36
|
from ai_data_science_team.tools.logging import log_ai_function
|
29
37
|
|
@@ -31,13 +39,408 @@ from ai_data_science_team.tools.logging import log_ai_function
|
|
31
39
|
AGENT_NAME = "data_wrangling_agent"
|
32
40
|
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
33
41
|
|
42
|
+
# Class
|
43
|
+
|
44
|
+
class DataWranglingAgent(BaseAgent):
|
45
|
+
"""
|
46
|
+
Creates a data wrangling agent that can work with one or more datasets, performing operations such as
|
47
|
+
joining/merging multiple datasets, reshaping, aggregating, encoding, creating computed features,
|
48
|
+
and ensuring consistent data types. The agent generates a Python function to wrangle the data,
|
49
|
+
executes the function, and logs the process (if enabled).
|
50
|
+
|
51
|
+
The agent can handle:
|
52
|
+
- A single dataset (provided as a dictionary of {column: list_of_values})
|
53
|
+
- Multiple datasets (provided as a list of such dictionaries)
|
54
|
+
|
55
|
+
Key wrangling steps can include:
|
56
|
+
- Merging or joining datasets
|
57
|
+
- Pivoting/melting data for reshaping
|
58
|
+
- GroupBy aggregations (sums, means, counts, etc.)
|
59
|
+
- Encoding categorical variables
|
60
|
+
- Computing new columns from existing ones
|
61
|
+
- Dropping or rearranging columns
|
62
|
+
- Any additional user instructions
|
63
|
+
|
64
|
+
Parameters
|
65
|
+
----------
|
66
|
+
model : langchain.llms.base.LLM
|
67
|
+
The language model used to generate the data wrangling function.
|
68
|
+
n_samples : int, optional
|
69
|
+
Number of samples to show in the data summary for wrangling. Defaults to 30.
|
70
|
+
log : bool, optional
|
71
|
+
Whether to log the generated code and errors. Defaults to False.
|
72
|
+
log_path : str, optional
|
73
|
+
Directory path for storing log files. Defaults to None.
|
74
|
+
file_name : str, optional
|
75
|
+
Name of the file for saving the generated response. Defaults to "data_wrangler.py".
|
76
|
+
function_name : str, optional
|
77
|
+
Name of the function to be generated. Defaults to "data_wrangler".
|
78
|
+
overwrite : bool, optional
|
79
|
+
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
|
80
|
+
human_in_the_loop : bool, optional
|
81
|
+
Enables user review of data wrangling instructions. Defaults to False.
|
82
|
+
bypass_recommended_steps : bool, optional
|
83
|
+
If True, skips the step that generates recommended data wrangling steps. Defaults to False.
|
84
|
+
bypass_explain_code : bool, optional
|
85
|
+
If True, skips the step that provides code explanations. Defaults to False.
|
86
|
+
|
87
|
+
Methods
|
88
|
+
-------
|
89
|
+
update_params(**kwargs)
|
90
|
+
Updates the agent's parameters and rebuilds the compiled state graph.
|
91
|
+
|
92
|
+
ainvoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
|
93
|
+
Asynchronously wrangles the provided dataset(s) based on user instructions.
|
94
|
+
|
95
|
+
invoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
|
96
|
+
Synchronously wrangles the provided dataset(s) based on user instructions.
|
97
|
+
|
98
|
+
get_workflow_summary()
|
99
|
+
Retrieves a summary of the agent's workflow.
|
100
|
+
|
101
|
+
get_log_summary()
|
102
|
+
Retrieves a summary of logged operations if logging is enabled.
|
103
|
+
|
104
|
+
get_data_wrangled()
|
105
|
+
Retrieves the final wrangled dataset (as a dictionary of {column: list_of_values}).
|
106
|
+
|
107
|
+
get_data_raw()
|
108
|
+
Retrieves the raw dataset(s).
|
109
|
+
|
110
|
+
get_data_wrangler_function()
|
111
|
+
Retrieves the generated Python function used for data wrangling.
|
112
|
+
|
113
|
+
get_recommended_wrangling_steps()
|
114
|
+
Retrieves the agent's recommended wrangling steps.
|
115
|
+
|
116
|
+
get_response()
|
117
|
+
Returns the full response dictionary from the agent.
|
118
|
+
|
119
|
+
show()
|
120
|
+
Displays the agent's mermaid diagram for visual inspection of the compiled graph.
|
121
|
+
|
122
|
+
Examples
|
123
|
+
--------
|
124
|
+
```python
|
125
|
+
import pandas as pd
|
126
|
+
from langchain_openai import ChatOpenAI
|
127
|
+
from ai_data_science_team.agents import DataWranglingAgent
|
128
|
+
|
129
|
+
# Single dataset example
|
130
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
131
|
+
|
132
|
+
data_wrangling_agent = DataWranglingAgent(
|
133
|
+
model=llm,
|
134
|
+
n_samples=30,
|
135
|
+
log=True,
|
136
|
+
log_path="logs",
|
137
|
+
human_in_the_loop=True
|
138
|
+
)
|
139
|
+
|
140
|
+
df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
|
141
|
+
|
142
|
+
data_wrangling_agent.invoke_agent(
|
143
|
+
user_instructions="Group by 'gender' and compute mean of 'tenure'.",
|
144
|
+
data_raw=df, # data_raw can be df.to_dict() or just a DataFrame
|
145
|
+
max_retries=3,
|
146
|
+
retry_count=0
|
147
|
+
)
|
148
|
+
|
149
|
+
data_wrangled = data_wrangling_agent.get_data_wrangled()
|
150
|
+
response = data_wrangling_agent.get_response()
|
151
|
+
|
152
|
+
# Multiple dataset example (list of dicts)
|
153
|
+
df1 = pd.DataFrame({'id': [1,2,3], 'val1': [10,20,30]})
|
154
|
+
df2 = pd.DataFrame({'id': [1,2,3], 'val2': [40,50,60]})
|
155
|
+
|
156
|
+
data_wrangling_agent.invoke_agent(
|
157
|
+
user_instructions="Merge these two datasets on 'id' and compute a new column 'val_sum' = val1+val2",
|
158
|
+
data_raw=[df1, df2], # multiple datasets
|
159
|
+
max_retries=3,
|
160
|
+
retry_count=0
|
161
|
+
)
|
162
|
+
|
163
|
+
data_wrangled = data_wrangling_agent.get_data_wrangled()
|
164
|
+
```
|
165
|
+
|
166
|
+
Returns
|
167
|
+
-------
|
168
|
+
DataWranglingAgent : langchain.graphs.CompiledStateGraph
|
169
|
+
A data wrangling agent implemented as a compiled state graph.
|
170
|
+
"""
|
171
|
+
|
172
|
+
def __init__(
|
173
|
+
self,
|
174
|
+
model,
|
175
|
+
n_samples=30,
|
176
|
+
log=False,
|
177
|
+
log_path=None,
|
178
|
+
file_name="data_wrangler.py",
|
179
|
+
function_name="data_wrangler",
|
180
|
+
overwrite=True,
|
181
|
+
human_in_the_loop=False,
|
182
|
+
bypass_recommended_steps=False,
|
183
|
+
bypass_explain_code=False
|
184
|
+
):
|
185
|
+
self._params = {
|
186
|
+
"model": model,
|
187
|
+
"n_samples": n_samples,
|
188
|
+
"log": log,
|
189
|
+
"log_path": log_path,
|
190
|
+
"file_name": file_name,
|
191
|
+
"function_name": function_name,
|
192
|
+
"overwrite": overwrite,
|
193
|
+
"human_in_the_loop": human_in_the_loop,
|
194
|
+
"bypass_recommended_steps": bypass_recommended_steps,
|
195
|
+
"bypass_explain_code": bypass_explain_code
|
196
|
+
}
|
197
|
+
self._compiled_graph = self._make_compiled_graph()
|
198
|
+
self.response = None
|
199
|
+
|
200
|
+
def _make_compiled_graph(self):
|
201
|
+
"""
|
202
|
+
Create the compiled graph for the data wrangling agent.
|
203
|
+
Running this method will reset the response to None.
|
204
|
+
"""
|
205
|
+
self.response = None
|
206
|
+
return make_data_wrangling_agent(**self._params)
|
207
|
+
|
208
|
+
def update_params(self, **kwargs):
|
209
|
+
"""
|
210
|
+
Updates the agent's parameters and rebuilds the compiled graph.
|
211
|
+
"""
|
212
|
+
for k, v in kwargs.items():
|
213
|
+
self._params[k] = v
|
214
|
+
self._compiled_graph = self._make_compiled_graph()
|
215
|
+
|
216
|
+
def ainvoke_agent(
|
217
|
+
self,
|
218
|
+
data_raw: Union[pd.DataFrame, dict, list],
|
219
|
+
user_instructions: str=None,
|
220
|
+
max_retries:int=3,
|
221
|
+
retry_count:int=0,
|
222
|
+
**kwargs
|
223
|
+
):
|
224
|
+
"""
|
225
|
+
Asynchronously wrangles the provided dataset(s) based on user instructions.
|
226
|
+
The response is stored in the 'response' attribute.
|
227
|
+
|
228
|
+
Parameters
|
229
|
+
----------
|
230
|
+
data_raw : Union[pd.DataFrame, dict, list]
|
231
|
+
The raw dataset(s) to be wrangled.
|
232
|
+
Can be a single DataFrame, a single dict ({col: list_of_values}),
|
233
|
+
or a list of dicts if multiple datasets are provided.
|
234
|
+
user_instructions : str
|
235
|
+
Instructions for data wrangling.
|
236
|
+
max_retries : int
|
237
|
+
Maximum retry attempts.
|
238
|
+
retry_count : int
|
239
|
+
Current retry attempt count.
|
240
|
+
**kwargs
|
241
|
+
Additional keyword arguments to pass to ainvoke().
|
242
|
+
|
243
|
+
Returns
|
244
|
+
-------
|
245
|
+
None
|
246
|
+
"""
|
247
|
+
data_input = self._convert_data_input(data_raw)
|
248
|
+
response = self._compiled_graph.ainvoke({
|
249
|
+
"user_instructions": user_instructions,
|
250
|
+
"data_raw": data_input,
|
251
|
+
"max_retries": max_retries,
|
252
|
+
"retry_count": retry_count
|
253
|
+
}, **kwargs)
|
254
|
+
self.response = response
|
255
|
+
return None
|
256
|
+
|
257
|
+
def invoke_agent(
|
258
|
+
self,
|
259
|
+
data_raw: Union[pd.DataFrame, dict, list],
|
260
|
+
user_instructions: str=None,
|
261
|
+
max_retries:int=3,
|
262
|
+
retry_count:int=0,
|
263
|
+
**kwargs
|
264
|
+
):
|
265
|
+
"""
|
266
|
+
Synchronously wrangles the provided dataset(s) based on user instructions.
|
267
|
+
The response is stored in the 'response' attribute.
|
268
|
+
|
269
|
+
Parameters
|
270
|
+
----------
|
271
|
+
data_raw : Union[pd.DataFrame, dict, list]
|
272
|
+
The raw dataset(s) to be wrangled.
|
273
|
+
Can be a single DataFrame, a single dict, or a list of dicts.
|
274
|
+
user_instructions : str
|
275
|
+
Instructions for data wrangling agent.
|
276
|
+
max_retries : int
|
277
|
+
Maximum retry attempts.
|
278
|
+
retry_count : int
|
279
|
+
Current retry attempt count.
|
280
|
+
**kwargs
|
281
|
+
Additional keyword arguments to pass to invoke().
|
282
|
+
|
283
|
+
Returns
|
284
|
+
-------
|
285
|
+
None
|
286
|
+
"""
|
287
|
+
data_input = self._convert_data_input(data_raw)
|
288
|
+
response = self._compiled_graph.invoke({
|
289
|
+
"user_instructions": user_instructions,
|
290
|
+
"data_raw": data_input,
|
291
|
+
"max_retries": max_retries,
|
292
|
+
"retry_count": retry_count
|
293
|
+
}, **kwargs)
|
294
|
+
self.response = response
|
295
|
+
return None
|
296
|
+
|
297
|
+
def get_workflow_summary(self, markdown=False):
|
298
|
+
"""
|
299
|
+
Retrieves the agent's workflow summary, if logging is enabled.
|
300
|
+
"""
|
301
|
+
if self.response and self.response.get("messages"):
|
302
|
+
summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
|
303
|
+
if markdown:
|
304
|
+
return Markdown(summary)
|
305
|
+
else:
|
306
|
+
return summary
|
307
|
+
|
308
|
+
def get_log_summary(self, markdown=False):
|
309
|
+
"""
|
310
|
+
Logs a summary of the agent's operations, if logging is enabled.
|
311
|
+
"""
|
312
|
+
if self.response:
|
313
|
+
if self.response.get('data_wrangler_function_path'):
|
314
|
+
log_details = f"""
|
315
|
+
## Data Wrangling Agent Log Summary:
|
316
|
+
|
317
|
+
Function Path: {self.response.get('data_wrangler_function_path')}
|
318
|
+
|
319
|
+
Function Name: {self.response.get('data_wrangler_function_name')}
|
320
|
+
"""
|
321
|
+
if markdown:
|
322
|
+
return Markdown(log_details)
|
323
|
+
else:
|
324
|
+
return log_details
|
325
|
+
|
326
|
+
def get_data_wrangled(self) -> Optional[pd.DataFrame]:
|
327
|
+
"""
|
328
|
+
Retrieves the wrangled data after running invoke_agent() or ainvoke_agent().
|
329
|
+
|
330
|
+
Returns
|
331
|
+
-------
|
332
|
+
pd.DataFrame or None
|
333
|
+
The wrangled dataset as a pandas DataFrame (if available).
|
334
|
+
"""
|
335
|
+
if self.response and "data_wrangled" in self.response:
|
336
|
+
return pd.DataFrame(self.response["data_wrangled"])
|
337
|
+
return None
|
338
|
+
|
339
|
+
def get_data_raw(self) -> Union[dict, list, None]:
|
340
|
+
"""
|
341
|
+
Retrieves the original raw data from the last invocation.
|
342
|
+
|
343
|
+
Returns
|
344
|
+
-------
|
345
|
+
Union[dict, list, None]
|
346
|
+
The original dataset(s) as a single dict or a list of dicts, or None if not available.
|
347
|
+
"""
|
348
|
+
if self.response and "data_raw" in self.response:
|
349
|
+
return self.response["data_raw"]
|
350
|
+
return None
|
351
|
+
|
352
|
+
def get_data_wrangler_function(self, markdown=False) -> Optional[str]:
|
353
|
+
"""
|
354
|
+
Retrieves the generated data wrangling function code.
|
355
|
+
|
356
|
+
Parameters
|
357
|
+
----------
|
358
|
+
markdown : bool, optional
|
359
|
+
If True, returns the function in Markdown code block format.
|
360
|
+
|
361
|
+
Returns
|
362
|
+
-------
|
363
|
+
str or None
|
364
|
+
The Python function code, or None if not available.
|
365
|
+
"""
|
366
|
+
if self.response and "data_wrangler_function" in self.response:
|
367
|
+
code = self.response["data_wrangler_function"]
|
368
|
+
if markdown:
|
369
|
+
return Markdown(f"```python\n{code}\n```")
|
370
|
+
return code
|
371
|
+
return None
|
372
|
+
|
373
|
+
def get_recommended_wrangling_steps(self, markdown=False) -> Optional[str]:
|
374
|
+
"""
|
375
|
+
Retrieves the agent's recommended data wrangling steps.
|
376
|
+
|
377
|
+
Parameters
|
378
|
+
----------
|
379
|
+
markdown : bool, optional
|
380
|
+
If True, returns the steps in Markdown format.
|
381
|
+
|
382
|
+
Returns
|
383
|
+
-------
|
384
|
+
str or None
|
385
|
+
The recommended steps, or None if not available.
|
386
|
+
"""
|
387
|
+
if self.response and "recommended_steps" in self.response:
|
388
|
+
steps = self.response["recommended_steps"]
|
389
|
+
if markdown:
|
390
|
+
return Markdown(steps)
|
391
|
+
return steps
|
392
|
+
return None
|
393
|
+
|
394
|
+
@staticmethod
|
395
|
+
def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
|
396
|
+
"""
|
397
|
+
Internal utility to convert data_raw (which could be a DataFrame, dict, or list of dicts)
|
398
|
+
into the format expected by the underlying agent (dict or list of dicts).
|
399
|
+
|
400
|
+
Parameters
|
401
|
+
----------
|
402
|
+
data_raw : Union[pd.DataFrame, dict, list]
|
403
|
+
The raw input data to be converted.
|
404
|
+
|
405
|
+
Returns
|
406
|
+
-------
|
407
|
+
Union[dict, list]
|
408
|
+
The data in a dictionary or list-of-dictionaries format.
|
409
|
+
"""
|
410
|
+
# If a single DataFrame, convert to dict
|
411
|
+
if isinstance(data_raw, pd.DataFrame):
|
412
|
+
return data_raw.to_dict()
|
413
|
+
|
414
|
+
# If it's already a dict (single dataset)
|
415
|
+
if isinstance(data_raw, dict):
|
416
|
+
return data_raw
|
417
|
+
|
418
|
+
# If it's already a list, check if it's a list of DataFrames or dicts
|
419
|
+
if isinstance(data_raw, list):
|
420
|
+
# Convert any DataFrame item to dict
|
421
|
+
converted_list = []
|
422
|
+
for item in data_raw:
|
423
|
+
if isinstance(item, pd.DataFrame):
|
424
|
+
converted_list.append(item.to_dict())
|
425
|
+
elif isinstance(item, dict):
|
426
|
+
converted_list.append(item)
|
427
|
+
else:
|
428
|
+
raise ValueError("List must contain only DataFrames or dictionaries.")
|
429
|
+
return converted_list
|
430
|
+
|
431
|
+
raise ValueError("data_raw must be a DataFrame, a dict, or a list of dicts/DataFrames.")
|
432
|
+
|
433
|
+
|
434
|
+
# Function
|
435
|
+
|
34
436
|
def make_data_wrangling_agent(
|
35
437
|
model,
|
36
438
|
n_samples=30,
|
37
439
|
log=False,
|
38
440
|
log_path=None,
|
39
441
|
file_name="data_wrangler.py",
|
40
|
-
|
442
|
+
function_name="data_wrangler",
|
443
|
+
overwrite=True,
|
41
444
|
human_in_the_loop=False,
|
42
445
|
bypass_recommended_steps=False,
|
43
446
|
bypass_explain_code=False
|
@@ -73,6 +476,8 @@ def make_data_wrangling_agent(
|
|
73
476
|
The path to the directory where the log files should be stored. Defaults to "logs/".
|
74
477
|
file_name : str, optional
|
75
478
|
The name of the file to save the response to. Defaults to "data_wrangler.py".
|
479
|
+
function_name : str, optional
|
480
|
+
The name of the function to be generated. Defaults to "data_wrangler".
|
76
481
|
overwrite : bool, optional
|
77
482
|
Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
|
78
483
|
Defaults to True.
|
@@ -115,6 +520,11 @@ def make_data_wrangling_agent(
|
|
115
520
|
"""
|
116
521
|
llm = model
|
117
522
|
|
523
|
+
# Human in th loop requires recommended steps
|
524
|
+
if bypass_recommended_steps and human_in_the_loop:
|
525
|
+
bypass_recommended_steps = False
|
526
|
+
print("Bypass recommended steps set to False to enable human in the loop.")
|
527
|
+
|
118
528
|
# Setup Log Directory
|
119
529
|
if log:
|
120
530
|
if log_path is None:
|
@@ -188,7 +598,7 @@ def make_data_wrangling_agent(
|
|
188
598
|
Below are summaries of all datasets provided:
|
189
599
|
{all_datasets_summary}
|
190
600
|
|
191
|
-
Return
|
601
|
+
Return steps as a numbered list. You can return short code snippets to demonstrate actions. But do not return a fully coded solution. The code will be generated separately by a Coding Agent.
|
192
602
|
|
193
603
|
Avoid these:
|
194
604
|
1. Do not include steps to save files.
|
@@ -205,7 +615,7 @@ def make_data_wrangling_agent(
|
|
205
615
|
})
|
206
616
|
|
207
617
|
return {
|
208
|
-
"recommended_steps": "
|
618
|
+
"recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Wrangling Steps:"),
|
209
619
|
"all_datasets_summary": all_datasets_summary_str,
|
210
620
|
}
|
211
621
|
|
@@ -244,7 +654,7 @@ def make_data_wrangling_agent(
|
|
244
654
|
|
245
655
|
data_wrangling_prompt = PromptTemplate(
|
246
656
|
template="""
|
247
|
-
You are a Data Wrangling Coding Agent. Your job is to create a
|
657
|
+
You are a Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data.
|
248
658
|
|
249
659
|
Follow these recommended steps:
|
250
660
|
{recommended_steps}
|
@@ -254,10 +664,10 @@ def make_data_wrangling_agent(
|
|
254
664
|
Below are summaries of all datasets provided. If more than one dataset is provided, you may need to merge or join them.:
|
255
665
|
{all_datasets_summary}
|
256
666
|
|
257
|
-
Return Python code in ```python``` format with a single function definition,
|
667
|
+
Return Python code in ```python``` format with a single function definition, {function_name}(), that includes all imports inside the function. And returns a single pandas data frame.
|
258
668
|
|
259
669
|
```python
|
260
|
-
def
|
670
|
+
def {function_name}(data_list):
|
261
671
|
'''
|
262
672
|
Wrangle the data provided in data.
|
263
673
|
|
@@ -279,14 +689,15 @@ def make_data_wrangling_agent(
|
|
279
689
|
|
280
690
|
|
281
691
|
""",
|
282
|
-
input_variables=["recommended_steps", "all_datasets_summary"]
|
692
|
+
input_variables=["recommended_steps", "all_datasets_summary", "function_name"]
|
283
693
|
)
|
284
694
|
|
285
695
|
data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
|
286
696
|
|
287
697
|
response = data_wrangling_agent.invoke({
|
288
698
|
"recommended_steps": state.get("recommended_steps"),
|
289
|
-
"all_datasets_summary": all_datasets_summary_str
|
699
|
+
"all_datasets_summary": all_datasets_summary_str,
|
700
|
+
"function_name": function_name
|
290
701
|
})
|
291
702
|
|
292
703
|
response = relocate_imports_inside_function(response)
|
@@ -304,7 +715,8 @@ def make_data_wrangling_agent(
|
|
304
715
|
return {
|
305
716
|
"data_wrangler_function" : response,
|
306
717
|
"data_wrangler_function_path": file_path,
|
307
|
-
"
|
718
|
+
"data_wrangler_file_name": file_name_2,
|
719
|
+
"data_wrangler_function_name": function_name,
|
308
720
|
"all_datasets_summary": all_datasets_summary_str
|
309
721
|
}
|
310
722
|
|
@@ -318,6 +730,33 @@ def make_data_wrangling_agent(
|
|
318
730
|
user_instructions_key="user_instructions",
|
319
731
|
recommended_steps_key="recommended_steps"
|
320
732
|
)
|
733
|
+
|
734
|
+
# Human Review
|
735
|
+
|
736
|
+
prompt_text_human_review = "Are the following data wrangling instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
|
737
|
+
|
738
|
+
if not bypass_explain_code:
|
739
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "explain_data_wrangler_code"]]:
|
740
|
+
return node_func_human_review(
|
741
|
+
state=state,
|
742
|
+
prompt_text=prompt_text_human_review,
|
743
|
+
yes_goto= 'explain_data_wrangler_code',
|
744
|
+
no_goto="recommend_wrangling_steps",
|
745
|
+
user_instructions_key="user_instructions",
|
746
|
+
recommended_steps_key="recommended_steps",
|
747
|
+
code_snippet_key="data_wrangler_function",
|
748
|
+
)
|
749
|
+
else:
|
750
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "__end__"]]:
|
751
|
+
return node_func_human_review(
|
752
|
+
state=state,
|
753
|
+
prompt_text=prompt_text_human_review,
|
754
|
+
yes_goto= '__end__',
|
755
|
+
no_goto="recommend_wrangling_steps",
|
756
|
+
user_instructions_key="user_instructions",
|
757
|
+
recommended_steps_key="recommended_steps",
|
758
|
+
code_snippet_key="data_wrangler_function",
|
759
|
+
)
|
321
760
|
|
322
761
|
def execute_data_wrangler_code(state: GraphState):
|
323
762
|
return node_func_execute_agent_code_on_data(
|
@@ -326,7 +765,7 @@ def make_data_wrangling_agent(
|
|
326
765
|
result_key="data_wrangled",
|
327
766
|
error_key="data_wrangler_error",
|
328
767
|
code_snippet_key="data_wrangler_function",
|
329
|
-
agent_function_name="
|
768
|
+
agent_function_name=state.get("data_wrangler_function_name"),
|
330
769
|
# pre_processing=pre_processing,
|
331
770
|
post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
332
771
|
error_message_prefix="An error occurred during data wrangling: "
|
@@ -334,11 +773,11 @@ def make_data_wrangling_agent(
|
|
334
773
|
|
335
774
|
def fix_data_wrangler_code(state: GraphState):
|
336
775
|
data_wrangler_prompt = """
|
337
|
-
You are a Data Wrangling Agent. Your job is to create a
|
776
|
+
You are a Data Wrangling Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
|
338
777
|
|
339
|
-
Make sure to only return the function definition for
|
778
|
+
Make sure to only return the function definition for {function_name}().
|
340
779
|
|
341
|
-
Return Python code in ```python``` format with a single function definition,
|
780
|
+
Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
|
342
781
|
|
343
782
|
This is the broken code (please fix):
|
344
783
|
{code_snippet}
|
@@ -356,22 +795,23 @@ def make_data_wrangling_agent(
|
|
356
795
|
agent_name=AGENT_NAME,
|
357
796
|
log=log,
|
358
797
|
file_path=state.get("data_wrangler_function_path"),
|
798
|
+
function_name=state.get("data_wrangler_function_name"),
|
359
799
|
)
|
360
800
|
|
361
|
-
|
362
|
-
|
801
|
+
# Final reporting node
|
802
|
+
def report_agent_outputs(state: GraphState):
|
803
|
+
return node_func_report_agent_outputs(
|
363
804
|
state=state,
|
364
|
-
|
805
|
+
keys_to_include=[
|
806
|
+
"recommended_steps",
|
807
|
+
"data_wrangler_function",
|
808
|
+
"data_wrangler_function_path",
|
809
|
+
"data_wrangler_function_name",
|
810
|
+
"data_wrangler_error",
|
811
|
+
],
|
365
812
|
result_key="messages",
|
366
|
-
error_key="data_wrangler_error",
|
367
|
-
llm=llm,
|
368
813
|
role=AGENT_NAME,
|
369
|
-
|
370
|
-
Explain the data wrangling steps that the data wrangling agent performed in this function.
|
371
|
-
Keep the summary succinct and to the point.\n\n# Data Wrangling Agent:\n\n{code}
|
372
|
-
""",
|
373
|
-
success_prefix="# Data Wrangling Agent:\n\n ",
|
374
|
-
error_message="The Data Wrangling Agent encountered an error during data wrangling. Data could not be explained."
|
814
|
+
custom_title="Data Wrangling Agent Outputs"
|
375
815
|
)
|
376
816
|
|
377
817
|
# Define the graph
|
@@ -381,7 +821,7 @@ def make_data_wrangling_agent(
|
|
381
821
|
"create_data_wrangler_code": create_data_wrangler_code,
|
382
822
|
"execute_data_wrangler_code": execute_data_wrangler_code,
|
383
823
|
"fix_data_wrangler_code": fix_data_wrangler_code,
|
384
|
-
"
|
824
|
+
"report_agent_outputs": report_agent_outputs,
|
385
825
|
}
|
386
826
|
|
387
827
|
app = create_coding_agent_graph(
|
@@ -391,7 +831,7 @@ def make_data_wrangling_agent(
|
|
391
831
|
create_code_node_name="create_data_wrangler_code",
|
392
832
|
execute_code_node_name="execute_data_wrangler_code",
|
393
833
|
fix_code_node_name="fix_data_wrangler_code",
|
394
|
-
explain_code_node_name="
|
834
|
+
explain_code_node_name="report_agent_outputs",
|
395
835
|
error_key="data_wrangler_error",
|
396
836
|
human_in_the_loop=human_in_the_loop,
|
397
837
|
human_review_node_name="human_review",
|