ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +5 -4
- ai_data_science_team/agents/data_cleaning_agent.py +371 -45
- ai_data_science_team/agents/data_visualization_agent.py +764 -0
- ai_data_science_team/agents/data_wrangling_agent.py +507 -23
- ai_data_science_team/agents/feature_engineering_agent.py +467 -34
- ai_data_science_team/agents/sql_database_agent.py +394 -30
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +286 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +9 -0
- ai_data_science_team/templates/agent_templates.py +247 -42
- ai_data_science_team/tools/metadata.py +110 -47
- ai_data_science_team/tools/regex.py +33 -0
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- ai_data_science_team-0.0.0.9008.dist-info/METADATA +231 -0
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +26 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/WHEEL +1 -1
- ai_data_science_team-0.0.0.9006.dist-info/METADATA +0 -165
- ai_data_science_team-0.0.0.9006.dist-info/RECORD +0 -20
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/top_level.txt +0 -0
@@ -4,26 +4,27 @@
|
|
4
4
|
# * Agents: Data Wrangling Agent
|
5
5
|
|
6
6
|
# Libraries
|
7
|
-
from typing import TypedDict, Annotated, Sequence, Literal, Union
|
7
|
+
from typing import TypedDict, Annotated, Sequence, Literal, Union, Optional
|
8
8
|
import operator
|
9
9
|
import os
|
10
|
-
import io
|
11
10
|
import pandas as pd
|
11
|
+
from IPython.display import Markdown
|
12
12
|
|
13
13
|
from langchain.prompts import PromptTemplate
|
14
14
|
from langchain_core.messages import BaseMessage
|
15
15
|
from langgraph.types import Command
|
16
16
|
from langgraph.checkpoint.memory import MemorySaver
|
17
17
|
|
18
|
-
from ai_data_science_team.templates
|
18
|
+
from ai_data_science_team.templates import(
|
19
19
|
node_func_execute_agent_code_on_data,
|
20
20
|
node_func_human_review,
|
21
21
|
node_func_fix_agent_code,
|
22
22
|
node_func_explain_agent_code,
|
23
|
-
create_coding_agent_graph
|
23
|
+
create_coding_agent_graph,
|
24
|
+
BaseAgent,
|
24
25
|
)
|
25
26
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
26
|
-
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
|
27
|
+
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name, format_recommended_steps
|
27
28
|
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
28
29
|
from ai_data_science_team.tools.logging import log_ai_function
|
29
30
|
|
@@ -31,7 +32,418 @@ from ai_data_science_team.tools.logging import log_ai_function
|
|
31
32
|
AGENT_NAME = "data_wrangling_agent"
|
32
33
|
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
33
34
|
|
34
|
-
|
35
|
+
# Class
|
36
|
+
|
37
|
+
class DataWranglingAgent(BaseAgent):
|
38
|
+
"""
|
39
|
+
Creates a data wrangling agent that can work with one or more datasets, performing operations such as
|
40
|
+
joining/merging multiple datasets, reshaping, aggregating, encoding, creating computed features,
|
41
|
+
and ensuring consistent data types. The agent generates a Python function to wrangle the data,
|
42
|
+
executes the function, and logs the process (if enabled).
|
43
|
+
|
44
|
+
The agent can handle:
|
45
|
+
- A single dataset (provided as a dictionary of {column: list_of_values})
|
46
|
+
- Multiple datasets (provided as a list of such dictionaries)
|
47
|
+
|
48
|
+
Key wrangling steps can include:
|
49
|
+
- Merging or joining datasets
|
50
|
+
- Pivoting/melting data for reshaping
|
51
|
+
- GroupBy aggregations (sums, means, counts, etc.)
|
52
|
+
- Encoding categorical variables
|
53
|
+
- Computing new columns from existing ones
|
54
|
+
- Dropping or rearranging columns
|
55
|
+
- Any additional user instructions
|
56
|
+
|
57
|
+
Parameters
|
58
|
+
----------
|
59
|
+
model : langchain.llms.base.LLM
|
60
|
+
The language model used to generate the data wrangling function.
|
61
|
+
n_samples : int, optional
|
62
|
+
Number of samples to show in the data summary for wrangling. Defaults to 30.
|
63
|
+
log : bool, optional
|
64
|
+
Whether to log the generated code and errors. Defaults to False.
|
65
|
+
log_path : str, optional
|
66
|
+
Directory path for storing log files. Defaults to None.
|
67
|
+
file_name : str, optional
|
68
|
+
Name of the file for saving the generated response. Defaults to "data_wrangler.py".
|
69
|
+
function_name : str, optional
|
70
|
+
Name of the function to be generated. Defaults to "data_wrangler".
|
71
|
+
overwrite : bool, optional
|
72
|
+
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
|
73
|
+
human_in_the_loop : bool, optional
|
74
|
+
Enables user review of data wrangling instructions. Defaults to False.
|
75
|
+
bypass_recommended_steps : bool, optional
|
76
|
+
If True, skips the step that generates recommended data wrangling steps. Defaults to False.
|
77
|
+
bypass_explain_code : bool, optional
|
78
|
+
If True, skips the step that provides code explanations. Defaults to False.
|
79
|
+
|
80
|
+
Methods
|
81
|
+
-------
|
82
|
+
update_params(**kwargs)
|
83
|
+
Updates the agent's parameters and rebuilds the compiled state graph.
|
84
|
+
|
85
|
+
ainvoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
|
86
|
+
Asynchronously wrangles the provided dataset(s) based on user instructions.
|
87
|
+
|
88
|
+
invoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
|
89
|
+
Synchronously wrangles the provided dataset(s) based on user instructions.
|
90
|
+
|
91
|
+
explain_wrangling_steps()
|
92
|
+
Returns an explanation of the wrangling steps performed by the agent.
|
93
|
+
|
94
|
+
get_log_summary()
|
95
|
+
Retrieves a summary of logged operations if logging is enabled.
|
96
|
+
|
97
|
+
get_data_wrangled()
|
98
|
+
Retrieves the final wrangled dataset (as a dictionary of {column: list_of_values}).
|
99
|
+
|
100
|
+
get_data_raw()
|
101
|
+
Retrieves the raw dataset(s).
|
102
|
+
|
103
|
+
get_data_wrangler_function()
|
104
|
+
Retrieves the generated Python function used for data wrangling.
|
105
|
+
|
106
|
+
get_recommended_wrangling_steps()
|
107
|
+
Retrieves the agent's recommended wrangling steps.
|
108
|
+
|
109
|
+
get_response()
|
110
|
+
Returns the full response dictionary from the agent.
|
111
|
+
|
112
|
+
show()
|
113
|
+
Displays the agent's mermaid diagram for visual inspection of the compiled graph.
|
114
|
+
|
115
|
+
Examples
|
116
|
+
--------
|
117
|
+
```python
|
118
|
+
import pandas as pd
|
119
|
+
from langchain_openai import ChatOpenAI
|
120
|
+
from ai_data_science_team.agents import DataWranglingAgent
|
121
|
+
|
122
|
+
# Single dataset example
|
123
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
124
|
+
|
125
|
+
data_wrangling_agent = DataWranglingAgent(
|
126
|
+
model=llm,
|
127
|
+
n_samples=30,
|
128
|
+
log=True,
|
129
|
+
log_path="logs",
|
130
|
+
human_in_the_loop=True
|
131
|
+
)
|
132
|
+
|
133
|
+
df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
|
134
|
+
|
135
|
+
data_wrangling_agent.invoke_agent(
|
136
|
+
user_instructions="Group by 'gender' and compute mean of 'tenure'.",
|
137
|
+
data_raw=df, # data_raw can be df.to_dict() or just a DataFrame
|
138
|
+
max_retries=3,
|
139
|
+
retry_count=0
|
140
|
+
)
|
141
|
+
|
142
|
+
data_wrangled = data_wrangling_agent.get_data_wrangled()
|
143
|
+
response = data_wrangling_agent.get_response()
|
144
|
+
|
145
|
+
# Multiple dataset example (list of dicts)
|
146
|
+
df1 = pd.DataFrame({'id': [1,2,3], 'val1': [10,20,30]})
|
147
|
+
df2 = pd.DataFrame({'id': [1,2,3], 'val2': [40,50,60]})
|
148
|
+
|
149
|
+
data_wrangling_agent.invoke_agent(
|
150
|
+
user_instructions="Merge these two datasets on 'id' and compute a new column 'val_sum' = val1+val2",
|
151
|
+
data_raw=[df1, df2], # multiple datasets
|
152
|
+
max_retries=3,
|
153
|
+
retry_count=0
|
154
|
+
)
|
155
|
+
|
156
|
+
data_wrangled = data_wrangling_agent.get_data_wrangled()
|
157
|
+
```
|
158
|
+
|
159
|
+
Returns
|
160
|
+
-------
|
161
|
+
DataWranglingAgent : langchain.graphs.CompiledStateGraph
|
162
|
+
A data wrangling agent implemented as a compiled state graph.
|
163
|
+
"""
|
164
|
+
|
165
|
+
def __init__(
|
166
|
+
self,
|
167
|
+
model,
|
168
|
+
n_samples=30,
|
169
|
+
log=False,
|
170
|
+
log_path=None,
|
171
|
+
file_name="data_wrangler.py",
|
172
|
+
function_name="data_wrangler",
|
173
|
+
overwrite=True,
|
174
|
+
human_in_the_loop=False,
|
175
|
+
bypass_recommended_steps=False,
|
176
|
+
bypass_explain_code=False
|
177
|
+
):
|
178
|
+
self._params = {
|
179
|
+
"model": model,
|
180
|
+
"n_samples": n_samples,
|
181
|
+
"log": log,
|
182
|
+
"log_path": log_path,
|
183
|
+
"file_name": file_name,
|
184
|
+
"function_name": function_name,
|
185
|
+
"overwrite": overwrite,
|
186
|
+
"human_in_the_loop": human_in_the_loop,
|
187
|
+
"bypass_recommended_steps": bypass_recommended_steps,
|
188
|
+
"bypass_explain_code": bypass_explain_code
|
189
|
+
}
|
190
|
+
self._compiled_graph = self._make_compiled_graph()
|
191
|
+
self.response = None
|
192
|
+
|
193
|
+
def _make_compiled_graph(self):
|
194
|
+
"""
|
195
|
+
Create the compiled graph for the data wrangling agent.
|
196
|
+
Running this method will reset the response to None.
|
197
|
+
"""
|
198
|
+
self.response = None
|
199
|
+
return make_data_wrangling_agent(**self._params)
|
200
|
+
|
201
|
+
def update_params(self, **kwargs):
|
202
|
+
"""
|
203
|
+
Updates the agent's parameters and rebuilds the compiled graph.
|
204
|
+
"""
|
205
|
+
for k, v in kwargs.items():
|
206
|
+
self._params[k] = v
|
207
|
+
self._compiled_graph = self._make_compiled_graph()
|
208
|
+
|
209
|
+
def ainvoke_agent(
|
210
|
+
self,
|
211
|
+
data_raw: Union[pd.DataFrame, dict, list],
|
212
|
+
user_instructions: str=None,
|
213
|
+
max_retries:int=3,
|
214
|
+
retry_count:int=0,
|
215
|
+
**kwargs
|
216
|
+
):
|
217
|
+
"""
|
218
|
+
Asynchronously wrangles the provided dataset(s) based on user instructions.
|
219
|
+
The response is stored in the 'response' attribute.
|
220
|
+
|
221
|
+
Parameters
|
222
|
+
----------
|
223
|
+
data_raw : Union[pd.DataFrame, dict, list]
|
224
|
+
The raw dataset(s) to be wrangled.
|
225
|
+
Can be a single DataFrame, a single dict ({col: list_of_values}),
|
226
|
+
or a list of dicts if multiple datasets are provided.
|
227
|
+
user_instructions : str
|
228
|
+
Instructions for data wrangling.
|
229
|
+
max_retries : int
|
230
|
+
Maximum retry attempts.
|
231
|
+
retry_count : int
|
232
|
+
Current retry attempt count.
|
233
|
+
**kwargs
|
234
|
+
Additional keyword arguments to pass to ainvoke().
|
235
|
+
|
236
|
+
Returns
|
237
|
+
-------
|
238
|
+
None
|
239
|
+
"""
|
240
|
+
data_input = self._convert_data_input(data_raw)
|
241
|
+
response = self._compiled_graph.ainvoke({
|
242
|
+
"user_instructions": user_instructions,
|
243
|
+
"data_raw": data_input,
|
244
|
+
"max_retries": max_retries,
|
245
|
+
"retry_count": retry_count
|
246
|
+
}, **kwargs)
|
247
|
+
self.response = response
|
248
|
+
return None
|
249
|
+
|
250
|
+
def invoke_agent(
|
251
|
+
self,
|
252
|
+
data_raw: Union[pd.DataFrame, dict, list],
|
253
|
+
user_instructions: str=None,
|
254
|
+
max_retries:int=3,
|
255
|
+
retry_count:int=0,
|
256
|
+
**kwargs
|
257
|
+
):
|
258
|
+
"""
|
259
|
+
Synchronously wrangles the provided dataset(s) based on user instructions.
|
260
|
+
The response is stored in the 'response' attribute.
|
261
|
+
|
262
|
+
Parameters
|
263
|
+
----------
|
264
|
+
data_raw : Union[pd.DataFrame, dict, list]
|
265
|
+
The raw dataset(s) to be wrangled.
|
266
|
+
Can be a single DataFrame, a single dict, or a list of dicts.
|
267
|
+
user_instructions : str
|
268
|
+
Instructions for data wrangling agent.
|
269
|
+
max_retries : int
|
270
|
+
Maximum retry attempts.
|
271
|
+
retry_count : int
|
272
|
+
Current retry attempt count.
|
273
|
+
**kwargs
|
274
|
+
Additional keyword arguments to pass to invoke().
|
275
|
+
|
276
|
+
Returns
|
277
|
+
-------
|
278
|
+
None
|
279
|
+
"""
|
280
|
+
data_input = self._convert_data_input(data_raw)
|
281
|
+
response = self._compiled_graph.invoke({
|
282
|
+
"user_instructions": user_instructions,
|
283
|
+
"data_raw": data_input,
|
284
|
+
"max_retries": max_retries,
|
285
|
+
"retry_count": retry_count
|
286
|
+
}, **kwargs)
|
287
|
+
self.response = response
|
288
|
+
return None
|
289
|
+
|
290
|
+
def explain_wrangling_steps(self):
|
291
|
+
"""
|
292
|
+
Provides an explanation of the wrangling steps performed by the agent.
|
293
|
+
|
294
|
+
Returns
|
295
|
+
-------
|
296
|
+
str or list
|
297
|
+
Explanation of the data wrangling steps.
|
298
|
+
"""
|
299
|
+
if self.response:
|
300
|
+
return self.response.get("messages", [])
|
301
|
+
return []
|
302
|
+
|
303
|
+
def get_log_summary(self, markdown=False):
|
304
|
+
"""
|
305
|
+
Logs a summary of the agent's operations, if logging is enabled.
|
306
|
+
|
307
|
+
Parameters
|
308
|
+
----------
|
309
|
+
markdown : bool, optional
|
310
|
+
If True, returns the summary in Markdown.
|
311
|
+
|
312
|
+
Returns
|
313
|
+
-------
|
314
|
+
str or None
|
315
|
+
The log details, or None if not available.
|
316
|
+
"""
|
317
|
+
if self.response and self.response.get("data_wrangler_function_path"):
|
318
|
+
log_details = f"Log Path: {self.response.get('data_wrangler_function_path')}"
|
319
|
+
if markdown:
|
320
|
+
return Markdown(log_details)
|
321
|
+
else:
|
322
|
+
return log_details
|
323
|
+
return None
|
324
|
+
|
325
|
+
def get_data_wrangled(self) -> Optional[pd.DataFrame]:
|
326
|
+
"""
|
327
|
+
Retrieves the wrangled data after running invoke_agent() or ainvoke_agent().
|
328
|
+
|
329
|
+
Returns
|
330
|
+
-------
|
331
|
+
pd.DataFrame or None
|
332
|
+
The wrangled dataset as a pandas DataFrame (if available).
|
333
|
+
"""
|
334
|
+
if self.response and "data_wrangled" in self.response:
|
335
|
+
return pd.DataFrame(self.response["data_wrangled"])
|
336
|
+
return None
|
337
|
+
|
338
|
+
def get_data_raw(self) -> Union[dict, list, None]:
|
339
|
+
"""
|
340
|
+
Retrieves the original raw data from the last invocation.
|
341
|
+
|
342
|
+
Returns
|
343
|
+
-------
|
344
|
+
Union[dict, list, None]
|
345
|
+
The original dataset(s) as a single dict or a list of dicts, or None if not available.
|
346
|
+
"""
|
347
|
+
if self.response and "data_raw" in self.response:
|
348
|
+
return self.response["data_raw"]
|
349
|
+
return None
|
350
|
+
|
351
|
+
def get_data_wrangler_function(self, markdown=False) -> Optional[str]:
|
352
|
+
"""
|
353
|
+
Retrieves the generated data wrangling function code.
|
354
|
+
|
355
|
+
Parameters
|
356
|
+
----------
|
357
|
+
markdown : bool, optional
|
358
|
+
If True, returns the function in Markdown code block format.
|
359
|
+
|
360
|
+
Returns
|
361
|
+
-------
|
362
|
+
str or None
|
363
|
+
The Python function code, or None if not available.
|
364
|
+
"""
|
365
|
+
if self.response and "data_wrangler_function" in self.response:
|
366
|
+
code = self.response["data_wrangler_function"]
|
367
|
+
if markdown:
|
368
|
+
return Markdown(f"```python\n{code}\n```")
|
369
|
+
return code
|
370
|
+
return None
|
371
|
+
|
372
|
+
def get_recommended_wrangling_steps(self, markdown=False) -> Optional[str]:
|
373
|
+
"""
|
374
|
+
Retrieves the agent's recommended data wrangling steps.
|
375
|
+
|
376
|
+
Parameters
|
377
|
+
----------
|
378
|
+
markdown : bool, optional
|
379
|
+
If True, returns the steps in Markdown format.
|
380
|
+
|
381
|
+
Returns
|
382
|
+
-------
|
383
|
+
str or None
|
384
|
+
The recommended steps, or None if not available.
|
385
|
+
"""
|
386
|
+
if self.response and "recommended_steps" in self.response:
|
387
|
+
steps = self.response["recommended_steps"]
|
388
|
+
if markdown:
|
389
|
+
return Markdown(steps)
|
390
|
+
return steps
|
391
|
+
return None
|
392
|
+
|
393
|
+
@staticmethod
|
394
|
+
def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
|
395
|
+
"""
|
396
|
+
Internal utility to convert data_raw (which could be a DataFrame, dict, or list of dicts)
|
397
|
+
into the format expected by the underlying agent (dict or list of dicts).
|
398
|
+
|
399
|
+
Parameters
|
400
|
+
----------
|
401
|
+
data_raw : Union[pd.DataFrame, dict, list]
|
402
|
+
The raw input data to be converted.
|
403
|
+
|
404
|
+
Returns
|
405
|
+
-------
|
406
|
+
Union[dict, list]
|
407
|
+
The data in a dictionary or list-of-dictionaries format.
|
408
|
+
"""
|
409
|
+
# If a single DataFrame, convert to dict
|
410
|
+
if isinstance(data_raw, pd.DataFrame):
|
411
|
+
return data_raw.to_dict()
|
412
|
+
|
413
|
+
# If it's already a dict (single dataset)
|
414
|
+
if isinstance(data_raw, dict):
|
415
|
+
return data_raw
|
416
|
+
|
417
|
+
# If it's already a list, check if it's a list of DataFrames or dicts
|
418
|
+
if isinstance(data_raw, list):
|
419
|
+
# Convert any DataFrame item to dict
|
420
|
+
converted_list = []
|
421
|
+
for item in data_raw:
|
422
|
+
if isinstance(item, pd.DataFrame):
|
423
|
+
converted_list.append(item.to_dict())
|
424
|
+
elif isinstance(item, dict):
|
425
|
+
converted_list.append(item)
|
426
|
+
else:
|
427
|
+
raise ValueError("List must contain only DataFrames or dictionaries.")
|
428
|
+
return converted_list
|
429
|
+
|
430
|
+
raise ValueError("data_raw must be a DataFrame, a dict, or a list of dicts/DataFrames.")
|
431
|
+
|
432
|
+
|
433
|
+
# Function
|
434
|
+
|
435
|
+
def make_data_wrangling_agent(
|
436
|
+
model,
|
437
|
+
n_samples=30,
|
438
|
+
log=False,
|
439
|
+
log_path=None,
|
440
|
+
file_name="data_wrangler.py",
|
441
|
+
function_name="data_wrangler",
|
442
|
+
overwrite=True,
|
443
|
+
human_in_the_loop=False,
|
444
|
+
bypass_recommended_steps=False,
|
445
|
+
bypass_explain_code=False
|
446
|
+
):
|
35
447
|
"""
|
36
448
|
Creates a data wrangling agent that can be run on one or more datasets. The agent can be
|
37
449
|
instructed to perform common data wrangling steps such as:
|
@@ -52,11 +464,19 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
52
464
|
----------
|
53
465
|
model : langchain.llms.base.LLM
|
54
466
|
The language model to use to generate code.
|
467
|
+
n_samples : int, optional
|
468
|
+
The number of samples to show in the data summary. Defaults to 30.
|
469
|
+
If you get an error due to maximum tokens, try reducing this number.
|
470
|
+
> "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
|
55
471
|
log : bool, optional
|
56
472
|
Whether or not to log the code generated and any errors that occur.
|
57
473
|
Defaults to False.
|
58
474
|
log_path : str, optional
|
59
475
|
The path to the directory where the log files should be stored. Defaults to "logs/".
|
476
|
+
file_name : str, optional
|
477
|
+
The name of the file to save the response to. Defaults to "data_wrangler.py".
|
478
|
+
function_name : str, optional
|
479
|
+
The name of the function to be generated. Defaults to "data_wrangler".
|
60
480
|
overwrite : bool, optional
|
61
481
|
Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
|
62
482
|
Defaults to True.
|
@@ -94,11 +514,16 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
94
514
|
|
95
515
|
Returns
|
96
516
|
-------
|
97
|
-
app : langchain.graphs.
|
517
|
+
app : langchain.graphs.CompiledStateGraph
|
98
518
|
The data wrangling agent as a state graph.
|
99
519
|
"""
|
100
520
|
llm = model
|
101
521
|
|
522
|
+
# Human in th loop requires recommended steps
|
523
|
+
if bypass_recommended_steps and human_in_the_loop:
|
524
|
+
bypass_recommended_steps = False
|
525
|
+
print("Bypass recommended steps set to False to enable human in the loop.")
|
526
|
+
|
102
527
|
# Setup Log Directory
|
103
528
|
if log:
|
104
529
|
if log_path is None:
|
@@ -122,7 +547,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
122
547
|
retry_count: int
|
123
548
|
|
124
549
|
def recommend_wrangling_steps(state: GraphState):
|
125
|
-
print(
|
550
|
+
print(format_agent_name(AGENT_NAME))
|
126
551
|
print(" * RECOMMEND WRANGLING STEPS")
|
127
552
|
|
128
553
|
data_raw = state.get("data_raw")
|
@@ -143,7 +568,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
143
568
|
|
144
569
|
# Create a summary for all datasets
|
145
570
|
# We'll include a short sample and info for each dataset
|
146
|
-
all_datasets_summary = get_dataframe_summary(dataframes)
|
571
|
+
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
|
147
572
|
|
148
573
|
# Join all datasets summaries into one big text block
|
149
574
|
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
@@ -176,6 +601,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
176
601
|
|
177
602
|
Avoid these:
|
178
603
|
1. Do not include steps to save files.
|
604
|
+
2. Do not include unrelated user instructions that are not related to the data wrangling.
|
179
605
|
""",
|
180
606
|
input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
|
181
607
|
)
|
@@ -188,19 +614,46 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
188
614
|
})
|
189
615
|
|
190
616
|
return {
|
191
|
-
"recommended_steps": "
|
617
|
+
"recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Wrangling Steps:"),
|
192
618
|
"all_datasets_summary": all_datasets_summary_str,
|
193
619
|
}
|
194
620
|
|
195
621
|
|
196
622
|
def create_data_wrangler_code(state: GraphState):
|
197
623
|
if bypass_recommended_steps:
|
198
|
-
print(
|
624
|
+
print(format_agent_name(AGENT_NAME))
|
625
|
+
|
626
|
+
data_raw = state.get("data_raw")
|
627
|
+
|
628
|
+
if isinstance(data_raw, dict):
|
629
|
+
# Single dataset scenario
|
630
|
+
primary_dataset_name = "main"
|
631
|
+
datasets = {primary_dataset_name: data_raw}
|
632
|
+
elif isinstance(data_raw, list) and all(isinstance(item, dict) for item in data_raw):
|
633
|
+
# Multiple datasets scenario
|
634
|
+
datasets = {f"dataset_{i}": d for i, d in enumerate(data_raw, start=1)}
|
635
|
+
primary_dataset_name = "dataset_1"
|
636
|
+
else:
|
637
|
+
raise ValueError("data_raw must be a dict or a list of dicts.")
|
638
|
+
|
639
|
+
# Convert all datasets to DataFrames for inspection
|
640
|
+
dataframes = {name: pd.DataFrame.from_dict(d) for name, d in datasets.items()}
|
641
|
+
|
642
|
+
# Create a summary for all datasets
|
643
|
+
# We'll include a short sample and info for each dataset
|
644
|
+
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
|
645
|
+
|
646
|
+
# Join all datasets summaries into one big text block
|
647
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
648
|
+
|
649
|
+
else:
|
650
|
+
all_datasets_summary_str = state.get("all_datasets_summary")
|
651
|
+
|
199
652
|
print(" * CREATE DATA WRANGLER CODE")
|
200
653
|
|
201
654
|
data_wrangling_prompt = PromptTemplate(
|
202
655
|
template="""
|
203
|
-
You are a Data Wrangling Coding Agent. Your job is to create a
|
656
|
+
You are a Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data.
|
204
657
|
|
205
658
|
Follow these recommended steps:
|
206
659
|
{recommended_steps}
|
@@ -210,10 +663,10 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
210
663
|
Below are summaries of all datasets provided. If more than one dataset is provided, you may need to merge or join them.:
|
211
664
|
{all_datasets_summary}
|
212
665
|
|
213
|
-
Return Python code in ```python``` format with a single function definition,
|
666
|
+
Return Python code in ```python``` format with a single function definition, {function_name}(), that includes all imports inside the function. And returns a single pandas data frame.
|
214
667
|
|
215
668
|
```python
|
216
|
-
def
|
669
|
+
def {function_name}(data_list):
|
217
670
|
'''
|
218
671
|
Wrangle the data provided in data.
|
219
672
|
|
@@ -235,23 +688,24 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
235
688
|
|
236
689
|
|
237
690
|
""",
|
238
|
-
input_variables=["recommended_steps", "all_datasets_summary"]
|
691
|
+
input_variables=["recommended_steps", "all_datasets_summary", "function_name"]
|
239
692
|
)
|
240
693
|
|
241
694
|
data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
|
242
695
|
|
243
696
|
response = data_wrangling_agent.invoke({
|
244
697
|
"recommended_steps": state.get("recommended_steps"),
|
245
|
-
"all_datasets_summary":
|
698
|
+
"all_datasets_summary": all_datasets_summary_str,
|
699
|
+
"function_name": function_name
|
246
700
|
})
|
247
701
|
|
248
702
|
response = relocate_imports_inside_function(response)
|
249
703
|
response = add_comments_to_top(response, agent_name=AGENT_NAME)
|
250
704
|
|
251
705
|
# For logging: store the code generated
|
252
|
-
file_path,
|
706
|
+
file_path, file_name_2 = log_ai_function(
|
253
707
|
response=response,
|
254
|
-
file_name=
|
708
|
+
file_name=file_name,
|
255
709
|
log=log,
|
256
710
|
log_path=log_path,
|
257
711
|
overwrite=overwrite
|
@@ -260,7 +714,9 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
260
714
|
return {
|
261
715
|
"data_wrangler_function" : response,
|
262
716
|
"data_wrangler_function_path": file_path,
|
263
|
-
"
|
717
|
+
"data_wrangler_file_name": file_name_2,
|
718
|
+
"data_wrangler_function_name": function_name,
|
719
|
+
"all_datasets_summary": all_datasets_summary_str
|
264
720
|
}
|
265
721
|
|
266
722
|
|
@@ -273,6 +729,33 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
273
729
|
user_instructions_key="user_instructions",
|
274
730
|
recommended_steps_key="recommended_steps"
|
275
731
|
)
|
732
|
+
|
733
|
+
# Human Review
|
734
|
+
|
735
|
+
prompt_text_human_review = "Are the following data wrangling instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
|
736
|
+
|
737
|
+
if not bypass_explain_code:
|
738
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "explain_data_wrangler_code"]]:
|
739
|
+
return node_func_human_review(
|
740
|
+
state=state,
|
741
|
+
prompt_text=prompt_text_human_review,
|
742
|
+
yes_goto= 'explain_data_wrangler_code',
|
743
|
+
no_goto="recommend_wrangling_steps",
|
744
|
+
user_instructions_key="user_instructions",
|
745
|
+
recommended_steps_key="recommended_steps",
|
746
|
+
code_snippet_key="data_wrangler_function",
|
747
|
+
)
|
748
|
+
else:
|
749
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "__end__"]]:
|
750
|
+
return node_func_human_review(
|
751
|
+
state=state,
|
752
|
+
prompt_text=prompt_text_human_review,
|
753
|
+
yes_goto= '__end__',
|
754
|
+
no_goto="recommend_wrangling_steps",
|
755
|
+
user_instructions_key="user_instructions",
|
756
|
+
recommended_steps_key="recommended_steps",
|
757
|
+
code_snippet_key="data_wrangler_function",
|
758
|
+
)
|
276
759
|
|
277
760
|
def execute_data_wrangler_code(state: GraphState):
|
278
761
|
return node_func_execute_agent_code_on_data(
|
@@ -281,7 +764,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
281
764
|
result_key="data_wrangled",
|
282
765
|
error_key="data_wrangler_error",
|
283
766
|
code_snippet_key="data_wrangler_function",
|
284
|
-
agent_function_name="
|
767
|
+
agent_function_name=state.get("data_wrangler_function_name"),
|
285
768
|
# pre_processing=pre_processing,
|
286
769
|
post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
287
770
|
error_message_prefix="An error occurred during data wrangling: "
|
@@ -289,11 +772,11 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
289
772
|
|
290
773
|
def fix_data_wrangler_code(state: GraphState):
|
291
774
|
data_wrangler_prompt = """
|
292
|
-
You are a Data Wrangling Agent. Your job is to create a
|
775
|
+
You are a Data Wrangling Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
|
293
776
|
|
294
|
-
Make sure to only return the function definition for
|
777
|
+
Make sure to only return the function definition for {function_name}().
|
295
778
|
|
296
|
-
Return Python code in ```python``` format with a single function definition,
|
779
|
+
Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
|
297
780
|
|
298
781
|
This is the broken code (please fix):
|
299
782
|
{code_snippet}
|
@@ -311,6 +794,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
311
794
|
agent_name=AGENT_NAME,
|
312
795
|
log=log,
|
313
796
|
file_path=state.get("data_wrangler_function_path"),
|
797
|
+
function_name=state.get("data_wrangler_function_name"),
|
314
798
|
)
|
315
799
|
|
316
800
|
def explain_data_wrangler_code(state: GraphState):
|