ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +5 -4
- ai_data_science_team/agents/data_cleaning_agent.py +371 -45
- ai_data_science_team/agents/data_visualization_agent.py +764 -0
- ai_data_science_team/agents/data_wrangling_agent.py +507 -23
- ai_data_science_team/agents/feature_engineering_agent.py +467 -34
- ai_data_science_team/agents/sql_database_agent.py +394 -30
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +286 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +9 -0
- ai_data_science_team/templates/agent_templates.py +247 -42
- ai_data_science_team/tools/metadata.py +110 -47
- ai_data_science_team/tools/regex.py +33 -0
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- ai_data_science_team-0.0.0.9008.dist-info/METADATA +231 -0
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +26 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/WHEEL +1 -1
- ai_data_science_team-0.0.0.9006.dist-info/METADATA +0 -165
- ai_data_science_team-0.0.0.9006.dist-info/RECORD +0 -20
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/top_level.txt +0 -0
@@ -4,26 +4,27 @@
|
|
4
4
|
# * Agents: Data Wrangling Agent
|
5
5
|
|
6
6
|
# Libraries
|
7
|
-
from typing import TypedDict, Annotated, Sequence, Literal, Union
|
7
|
+
from typing import TypedDict, Annotated, Sequence, Literal, Union, Optional
|
8
8
|
import operator
|
9
9
|
import os
|
10
|
-
import io
|
11
10
|
import pandas as pd
|
11
|
+
from IPython.display import Markdown
|
12
12
|
|
13
13
|
from langchain.prompts import PromptTemplate
|
14
14
|
from langchain_core.messages import BaseMessage
|
15
15
|
from langgraph.types import Command
|
16
16
|
from langgraph.checkpoint.memory import MemorySaver
|
17
17
|
|
18
|
-
from ai_data_science_team.templates
|
18
|
+
from ai_data_science_team.templates import(
|
19
19
|
node_func_execute_agent_code_on_data,
|
20
20
|
node_func_human_review,
|
21
21
|
node_func_fix_agent_code,
|
22
22
|
node_func_explain_agent_code,
|
23
|
-
create_coding_agent_graph
|
23
|
+
create_coding_agent_graph,
|
24
|
+
BaseAgent,
|
24
25
|
)
|
25
26
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
26
|
-
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
|
27
|
+
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name, format_recommended_steps
|
27
28
|
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
28
29
|
from ai_data_science_team.tools.logging import log_ai_function
|
29
30
|
|
@@ -31,7 +32,418 @@ from ai_data_science_team.tools.logging import log_ai_function
|
|
31
32
|
AGENT_NAME = "data_wrangling_agent"
|
32
33
|
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
33
34
|
|
34
|
-
|
35
|
+
# Class
|
36
|
+
|
37
|
+
class DataWranglingAgent(BaseAgent):
|
38
|
+
"""
|
39
|
+
Creates a data wrangling agent that can work with one or more datasets, performing operations such as
|
40
|
+
joining/merging multiple datasets, reshaping, aggregating, encoding, creating computed features,
|
41
|
+
and ensuring consistent data types. The agent generates a Python function to wrangle the data,
|
42
|
+
executes the function, and logs the process (if enabled).
|
43
|
+
|
44
|
+
The agent can handle:
|
45
|
+
- A single dataset (provided as a dictionary of {column: list_of_values})
|
46
|
+
- Multiple datasets (provided as a list of such dictionaries)
|
47
|
+
|
48
|
+
Key wrangling steps can include:
|
49
|
+
- Merging or joining datasets
|
50
|
+
- Pivoting/melting data for reshaping
|
51
|
+
- GroupBy aggregations (sums, means, counts, etc.)
|
52
|
+
- Encoding categorical variables
|
53
|
+
- Computing new columns from existing ones
|
54
|
+
- Dropping or rearranging columns
|
55
|
+
- Any additional user instructions
|
56
|
+
|
57
|
+
Parameters
|
58
|
+
----------
|
59
|
+
model : langchain.llms.base.LLM
|
60
|
+
The language model used to generate the data wrangling function.
|
61
|
+
n_samples : int, optional
|
62
|
+
Number of samples to show in the data summary for wrangling. Defaults to 30.
|
63
|
+
log : bool, optional
|
64
|
+
Whether to log the generated code and errors. Defaults to False.
|
65
|
+
log_path : str, optional
|
66
|
+
Directory path for storing log files. Defaults to None.
|
67
|
+
file_name : str, optional
|
68
|
+
Name of the file for saving the generated response. Defaults to "data_wrangler.py".
|
69
|
+
function_name : str, optional
|
70
|
+
Name of the function to be generated. Defaults to "data_wrangler".
|
71
|
+
overwrite : bool, optional
|
72
|
+
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
|
73
|
+
human_in_the_loop : bool, optional
|
74
|
+
Enables user review of data wrangling instructions. Defaults to False.
|
75
|
+
bypass_recommended_steps : bool, optional
|
76
|
+
If True, skips the step that generates recommended data wrangling steps. Defaults to False.
|
77
|
+
bypass_explain_code : bool, optional
|
78
|
+
If True, skips the step that provides code explanations. Defaults to False.
|
79
|
+
|
80
|
+
Methods
|
81
|
+
-------
|
82
|
+
update_params(**kwargs)
|
83
|
+
Updates the agent's parameters and rebuilds the compiled state graph.
|
84
|
+
|
85
|
+
ainvoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
|
86
|
+
Asynchronously wrangles the provided dataset(s) based on user instructions.
|
87
|
+
|
88
|
+
invoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
|
89
|
+
Synchronously wrangles the provided dataset(s) based on user instructions.
|
90
|
+
|
91
|
+
explain_wrangling_steps()
|
92
|
+
Returns an explanation of the wrangling steps performed by the agent.
|
93
|
+
|
94
|
+
get_log_summary()
|
95
|
+
Retrieves a summary of logged operations if logging is enabled.
|
96
|
+
|
97
|
+
get_data_wrangled()
|
98
|
+
Retrieves the final wrangled dataset (as a dictionary of {column: list_of_values}).
|
99
|
+
|
100
|
+
get_data_raw()
|
101
|
+
Retrieves the raw dataset(s).
|
102
|
+
|
103
|
+
get_data_wrangler_function()
|
104
|
+
Retrieves the generated Python function used for data wrangling.
|
105
|
+
|
106
|
+
get_recommended_wrangling_steps()
|
107
|
+
Retrieves the agent's recommended wrangling steps.
|
108
|
+
|
109
|
+
get_response()
|
110
|
+
Returns the full response dictionary from the agent.
|
111
|
+
|
112
|
+
show()
|
113
|
+
Displays the agent's mermaid diagram for visual inspection of the compiled graph.
|
114
|
+
|
115
|
+
Examples
|
116
|
+
--------
|
117
|
+
```python
|
118
|
+
import pandas as pd
|
119
|
+
from langchain_openai import ChatOpenAI
|
120
|
+
from ai_data_science_team.agents import DataWranglingAgent
|
121
|
+
|
122
|
+
# Single dataset example
|
123
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
124
|
+
|
125
|
+
data_wrangling_agent = DataWranglingAgent(
|
126
|
+
model=llm,
|
127
|
+
n_samples=30,
|
128
|
+
log=True,
|
129
|
+
log_path="logs",
|
130
|
+
human_in_the_loop=True
|
131
|
+
)
|
132
|
+
|
133
|
+
df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
|
134
|
+
|
135
|
+
data_wrangling_agent.invoke_agent(
|
136
|
+
user_instructions="Group by 'gender' and compute mean of 'tenure'.",
|
137
|
+
data_raw=df, # data_raw can be df.to_dict() or just a DataFrame
|
138
|
+
max_retries=3,
|
139
|
+
retry_count=0
|
140
|
+
)
|
141
|
+
|
142
|
+
data_wrangled = data_wrangling_agent.get_data_wrangled()
|
143
|
+
response = data_wrangling_agent.get_response()
|
144
|
+
|
145
|
+
# Multiple dataset example (list of dicts)
|
146
|
+
df1 = pd.DataFrame({'id': [1,2,3], 'val1': [10,20,30]})
|
147
|
+
df2 = pd.DataFrame({'id': [1,2,3], 'val2': [40,50,60]})
|
148
|
+
|
149
|
+
data_wrangling_agent.invoke_agent(
|
150
|
+
user_instructions="Merge these two datasets on 'id' and compute a new column 'val_sum' = val1+val2",
|
151
|
+
data_raw=[df1, df2], # multiple datasets
|
152
|
+
max_retries=3,
|
153
|
+
retry_count=0
|
154
|
+
)
|
155
|
+
|
156
|
+
data_wrangled = data_wrangling_agent.get_data_wrangled()
|
157
|
+
```
|
158
|
+
|
159
|
+
Returns
|
160
|
+
-------
|
161
|
+
DataWranglingAgent : langchain.graphs.CompiledStateGraph
|
162
|
+
A data wrangling agent implemented as a compiled state graph.
|
163
|
+
"""
|
164
|
+
|
165
|
+
def __init__(
|
166
|
+
self,
|
167
|
+
model,
|
168
|
+
n_samples=30,
|
169
|
+
log=False,
|
170
|
+
log_path=None,
|
171
|
+
file_name="data_wrangler.py",
|
172
|
+
function_name="data_wrangler",
|
173
|
+
overwrite=True,
|
174
|
+
human_in_the_loop=False,
|
175
|
+
bypass_recommended_steps=False,
|
176
|
+
bypass_explain_code=False
|
177
|
+
):
|
178
|
+
self._params = {
|
179
|
+
"model": model,
|
180
|
+
"n_samples": n_samples,
|
181
|
+
"log": log,
|
182
|
+
"log_path": log_path,
|
183
|
+
"file_name": file_name,
|
184
|
+
"function_name": function_name,
|
185
|
+
"overwrite": overwrite,
|
186
|
+
"human_in_the_loop": human_in_the_loop,
|
187
|
+
"bypass_recommended_steps": bypass_recommended_steps,
|
188
|
+
"bypass_explain_code": bypass_explain_code
|
189
|
+
}
|
190
|
+
self._compiled_graph = self._make_compiled_graph()
|
191
|
+
self.response = None
|
192
|
+
|
193
|
+
def _make_compiled_graph(self):
|
194
|
+
"""
|
195
|
+
Create the compiled graph for the data wrangling agent.
|
196
|
+
Running this method will reset the response to None.
|
197
|
+
"""
|
198
|
+
self.response = None
|
199
|
+
return make_data_wrangling_agent(**self._params)
|
200
|
+
|
201
|
+
def update_params(self, **kwargs):
|
202
|
+
"""
|
203
|
+
Updates the agent's parameters and rebuilds the compiled graph.
|
204
|
+
"""
|
205
|
+
for k, v in kwargs.items():
|
206
|
+
self._params[k] = v
|
207
|
+
self._compiled_graph = self._make_compiled_graph()
|
208
|
+
|
209
|
+
def ainvoke_agent(
|
210
|
+
self,
|
211
|
+
data_raw: Union[pd.DataFrame, dict, list],
|
212
|
+
user_instructions: str=None,
|
213
|
+
max_retries:int=3,
|
214
|
+
retry_count:int=0,
|
215
|
+
**kwargs
|
216
|
+
):
|
217
|
+
"""
|
218
|
+
Asynchronously wrangles the provided dataset(s) based on user instructions.
|
219
|
+
The response is stored in the 'response' attribute.
|
220
|
+
|
221
|
+
Parameters
|
222
|
+
----------
|
223
|
+
data_raw : Union[pd.DataFrame, dict, list]
|
224
|
+
The raw dataset(s) to be wrangled.
|
225
|
+
Can be a single DataFrame, a single dict ({col: list_of_values}),
|
226
|
+
or a list of dicts if multiple datasets are provided.
|
227
|
+
user_instructions : str
|
228
|
+
Instructions for data wrangling.
|
229
|
+
max_retries : int
|
230
|
+
Maximum retry attempts.
|
231
|
+
retry_count : int
|
232
|
+
Current retry attempt count.
|
233
|
+
**kwargs
|
234
|
+
Additional keyword arguments to pass to ainvoke().
|
235
|
+
|
236
|
+
Returns
|
237
|
+
-------
|
238
|
+
None
|
239
|
+
"""
|
240
|
+
data_input = self._convert_data_input(data_raw)
|
241
|
+
response = self._compiled_graph.ainvoke({
|
242
|
+
"user_instructions": user_instructions,
|
243
|
+
"data_raw": data_input,
|
244
|
+
"max_retries": max_retries,
|
245
|
+
"retry_count": retry_count
|
246
|
+
}, **kwargs)
|
247
|
+
self.response = response
|
248
|
+
return None
|
249
|
+
|
250
|
+
def invoke_agent(
|
251
|
+
self,
|
252
|
+
data_raw: Union[pd.DataFrame, dict, list],
|
253
|
+
user_instructions: str=None,
|
254
|
+
max_retries:int=3,
|
255
|
+
retry_count:int=0,
|
256
|
+
**kwargs
|
257
|
+
):
|
258
|
+
"""
|
259
|
+
Synchronously wrangles the provided dataset(s) based on user instructions.
|
260
|
+
The response is stored in the 'response' attribute.
|
261
|
+
|
262
|
+
Parameters
|
263
|
+
----------
|
264
|
+
data_raw : Union[pd.DataFrame, dict, list]
|
265
|
+
The raw dataset(s) to be wrangled.
|
266
|
+
Can be a single DataFrame, a single dict, or a list of dicts.
|
267
|
+
user_instructions : str
|
268
|
+
Instructions for data wrangling agent.
|
269
|
+
max_retries : int
|
270
|
+
Maximum retry attempts.
|
271
|
+
retry_count : int
|
272
|
+
Current retry attempt count.
|
273
|
+
**kwargs
|
274
|
+
Additional keyword arguments to pass to invoke().
|
275
|
+
|
276
|
+
Returns
|
277
|
+
-------
|
278
|
+
None
|
279
|
+
"""
|
280
|
+
data_input = self._convert_data_input(data_raw)
|
281
|
+
response = self._compiled_graph.invoke({
|
282
|
+
"user_instructions": user_instructions,
|
283
|
+
"data_raw": data_input,
|
284
|
+
"max_retries": max_retries,
|
285
|
+
"retry_count": retry_count
|
286
|
+
}, **kwargs)
|
287
|
+
self.response = response
|
288
|
+
return None
|
289
|
+
|
290
|
+
def explain_wrangling_steps(self):
|
291
|
+
"""
|
292
|
+
Provides an explanation of the wrangling steps performed by the agent.
|
293
|
+
|
294
|
+
Returns
|
295
|
+
-------
|
296
|
+
str or list
|
297
|
+
Explanation of the data wrangling steps.
|
298
|
+
"""
|
299
|
+
if self.response:
|
300
|
+
return self.response.get("messages", [])
|
301
|
+
return []
|
302
|
+
|
303
|
+
def get_log_summary(self, markdown=False):
|
304
|
+
"""
|
305
|
+
Logs a summary of the agent's operations, if logging is enabled.
|
306
|
+
|
307
|
+
Parameters
|
308
|
+
----------
|
309
|
+
markdown : bool, optional
|
310
|
+
If True, returns the summary in Markdown.
|
311
|
+
|
312
|
+
Returns
|
313
|
+
-------
|
314
|
+
str or None
|
315
|
+
The log details, or None if not available.
|
316
|
+
"""
|
317
|
+
if self.response and self.response.get("data_wrangler_function_path"):
|
318
|
+
log_details = f"Log Path: {self.response.get('data_wrangler_function_path')}"
|
319
|
+
if markdown:
|
320
|
+
return Markdown(log_details)
|
321
|
+
else:
|
322
|
+
return log_details
|
323
|
+
return None
|
324
|
+
|
325
|
+
def get_data_wrangled(self) -> Optional[pd.DataFrame]:
|
326
|
+
"""
|
327
|
+
Retrieves the wrangled data after running invoke_agent() or ainvoke_agent().
|
328
|
+
|
329
|
+
Returns
|
330
|
+
-------
|
331
|
+
pd.DataFrame or None
|
332
|
+
The wrangled dataset as a pandas DataFrame (if available).
|
333
|
+
"""
|
334
|
+
if self.response and "data_wrangled" in self.response:
|
335
|
+
return pd.DataFrame(self.response["data_wrangled"])
|
336
|
+
return None
|
337
|
+
|
338
|
+
def get_data_raw(self) -> Union[dict, list, None]:
|
339
|
+
"""
|
340
|
+
Retrieves the original raw data from the last invocation.
|
341
|
+
|
342
|
+
Returns
|
343
|
+
-------
|
344
|
+
Union[dict, list, None]
|
345
|
+
The original dataset(s) as a single dict or a list of dicts, or None if not available.
|
346
|
+
"""
|
347
|
+
if self.response and "data_raw" in self.response:
|
348
|
+
return self.response["data_raw"]
|
349
|
+
return None
|
350
|
+
|
351
|
+
def get_data_wrangler_function(self, markdown=False) -> Optional[str]:
|
352
|
+
"""
|
353
|
+
Retrieves the generated data wrangling function code.
|
354
|
+
|
355
|
+
Parameters
|
356
|
+
----------
|
357
|
+
markdown : bool, optional
|
358
|
+
If True, returns the function in Markdown code block format.
|
359
|
+
|
360
|
+
Returns
|
361
|
+
-------
|
362
|
+
str or None
|
363
|
+
The Python function code, or None if not available.
|
364
|
+
"""
|
365
|
+
if self.response and "data_wrangler_function" in self.response:
|
366
|
+
code = self.response["data_wrangler_function"]
|
367
|
+
if markdown:
|
368
|
+
return Markdown(f"```python\n{code}\n```")
|
369
|
+
return code
|
370
|
+
return None
|
371
|
+
|
372
|
+
def get_recommended_wrangling_steps(self, markdown=False) -> Optional[str]:
|
373
|
+
"""
|
374
|
+
Retrieves the agent's recommended data wrangling steps.
|
375
|
+
|
376
|
+
Parameters
|
377
|
+
----------
|
378
|
+
markdown : bool, optional
|
379
|
+
If True, returns the steps in Markdown format.
|
380
|
+
|
381
|
+
Returns
|
382
|
+
-------
|
383
|
+
str or None
|
384
|
+
The recommended steps, or None if not available.
|
385
|
+
"""
|
386
|
+
if self.response and "recommended_steps" in self.response:
|
387
|
+
steps = self.response["recommended_steps"]
|
388
|
+
if markdown:
|
389
|
+
return Markdown(steps)
|
390
|
+
return steps
|
391
|
+
return None
|
392
|
+
|
393
|
+
@staticmethod
|
394
|
+
def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
|
395
|
+
"""
|
396
|
+
Internal utility to convert data_raw (which could be a DataFrame, dict, or list of dicts)
|
397
|
+
into the format expected by the underlying agent (dict or list of dicts).
|
398
|
+
|
399
|
+
Parameters
|
400
|
+
----------
|
401
|
+
data_raw : Union[pd.DataFrame, dict, list]
|
402
|
+
The raw input data to be converted.
|
403
|
+
|
404
|
+
Returns
|
405
|
+
-------
|
406
|
+
Union[dict, list]
|
407
|
+
The data in a dictionary or list-of-dictionaries format.
|
408
|
+
"""
|
409
|
+
# If a single DataFrame, convert to dict
|
410
|
+
if isinstance(data_raw, pd.DataFrame):
|
411
|
+
return data_raw.to_dict()
|
412
|
+
|
413
|
+
# If it's already a dict (single dataset)
|
414
|
+
if isinstance(data_raw, dict):
|
415
|
+
return data_raw
|
416
|
+
|
417
|
+
# If it's already a list, check if it's a list of DataFrames or dicts
|
418
|
+
if isinstance(data_raw, list):
|
419
|
+
# Convert any DataFrame item to dict
|
420
|
+
converted_list = []
|
421
|
+
for item in data_raw:
|
422
|
+
if isinstance(item, pd.DataFrame):
|
423
|
+
converted_list.append(item.to_dict())
|
424
|
+
elif isinstance(item, dict):
|
425
|
+
converted_list.append(item)
|
426
|
+
else:
|
427
|
+
raise ValueError("List must contain only DataFrames or dictionaries.")
|
428
|
+
return converted_list
|
429
|
+
|
430
|
+
raise ValueError("data_raw must be a DataFrame, a dict, or a list of dicts/DataFrames.")
|
431
|
+
|
432
|
+
|
433
|
+
# Function
|
434
|
+
|
435
|
+
def make_data_wrangling_agent(
|
436
|
+
model,
|
437
|
+
n_samples=30,
|
438
|
+
log=False,
|
439
|
+
log_path=None,
|
440
|
+
file_name="data_wrangler.py",
|
441
|
+
function_name="data_wrangler",
|
442
|
+
overwrite=True,
|
443
|
+
human_in_the_loop=False,
|
444
|
+
bypass_recommended_steps=False,
|
445
|
+
bypass_explain_code=False
|
446
|
+
):
|
35
447
|
"""
|
36
448
|
Creates a data wrangling agent that can be run on one or more datasets. The agent can be
|
37
449
|
instructed to perform common data wrangling steps such as:
|
@@ -52,11 +464,19 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
52
464
|
----------
|
53
465
|
model : langchain.llms.base.LLM
|
54
466
|
The language model to use to generate code.
|
467
|
+
n_samples : int, optional
|
468
|
+
The number of samples to show in the data summary. Defaults to 30.
|
469
|
+
If you get an error due to maximum tokens, try reducing this number.
|
470
|
+
> "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
|
55
471
|
log : bool, optional
|
56
472
|
Whether or not to log the code generated and any errors that occur.
|
57
473
|
Defaults to False.
|
58
474
|
log_path : str, optional
|
59
475
|
The path to the directory where the log files should be stored. Defaults to "logs/".
|
476
|
+
file_name : str, optional
|
477
|
+
The name of the file to save the response to. Defaults to "data_wrangler.py".
|
478
|
+
function_name : str, optional
|
479
|
+
The name of the function to be generated. Defaults to "data_wrangler".
|
60
480
|
overwrite : bool, optional
|
61
481
|
Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
|
62
482
|
Defaults to True.
|
@@ -94,11 +514,16 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
94
514
|
|
95
515
|
Returns
|
96
516
|
-------
|
97
|
-
app : langchain.graphs.
|
517
|
+
app : langchain.graphs.CompiledStateGraph
|
98
518
|
The data wrangling agent as a state graph.
|
99
519
|
"""
|
100
520
|
llm = model
|
101
521
|
|
522
|
+
# Human in th loop requires recommended steps
|
523
|
+
if bypass_recommended_steps and human_in_the_loop:
|
524
|
+
bypass_recommended_steps = False
|
525
|
+
print("Bypass recommended steps set to False to enable human in the loop.")
|
526
|
+
|
102
527
|
# Setup Log Directory
|
103
528
|
if log:
|
104
529
|
if log_path is None:
|
@@ -122,7 +547,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
122
547
|
retry_count: int
|
123
548
|
|
124
549
|
def recommend_wrangling_steps(state: GraphState):
|
125
|
-
print(
|
550
|
+
print(format_agent_name(AGENT_NAME))
|
126
551
|
print(" * RECOMMEND WRANGLING STEPS")
|
127
552
|
|
128
553
|
data_raw = state.get("data_raw")
|
@@ -143,7 +568,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
143
568
|
|
144
569
|
# Create a summary for all datasets
|
145
570
|
# We'll include a short sample and info for each dataset
|
146
|
-
all_datasets_summary = get_dataframe_summary(dataframes)
|
571
|
+
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
|
147
572
|
|
148
573
|
# Join all datasets summaries into one big text block
|
149
574
|
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
@@ -176,6 +601,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
176
601
|
|
177
602
|
Avoid these:
|
178
603
|
1. Do not include steps to save files.
|
604
|
+
2. Do not include unrelated user instructions that are not related to the data wrangling.
|
179
605
|
""",
|
180
606
|
input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
|
181
607
|
)
|
@@ -188,19 +614,46 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
188
614
|
})
|
189
615
|
|
190
616
|
return {
|
191
|
-
"recommended_steps": "
|
617
|
+
"recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Wrangling Steps:"),
|
192
618
|
"all_datasets_summary": all_datasets_summary_str,
|
193
619
|
}
|
194
620
|
|
195
621
|
|
196
622
|
def create_data_wrangler_code(state: GraphState):
|
197
623
|
if bypass_recommended_steps:
|
198
|
-
print(
|
624
|
+
print(format_agent_name(AGENT_NAME))
|
625
|
+
|
626
|
+
data_raw = state.get("data_raw")
|
627
|
+
|
628
|
+
if isinstance(data_raw, dict):
|
629
|
+
# Single dataset scenario
|
630
|
+
primary_dataset_name = "main"
|
631
|
+
datasets = {primary_dataset_name: data_raw}
|
632
|
+
elif isinstance(data_raw, list) and all(isinstance(item, dict) for item in data_raw):
|
633
|
+
# Multiple datasets scenario
|
634
|
+
datasets = {f"dataset_{i}": d for i, d in enumerate(data_raw, start=1)}
|
635
|
+
primary_dataset_name = "dataset_1"
|
636
|
+
else:
|
637
|
+
raise ValueError("data_raw must be a dict or a list of dicts.")
|
638
|
+
|
639
|
+
# Convert all datasets to DataFrames for inspection
|
640
|
+
dataframes = {name: pd.DataFrame.from_dict(d) for name, d in datasets.items()}
|
641
|
+
|
642
|
+
# Create a summary for all datasets
|
643
|
+
# We'll include a short sample and info for each dataset
|
644
|
+
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
|
645
|
+
|
646
|
+
# Join all datasets summaries into one big text block
|
647
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
648
|
+
|
649
|
+
else:
|
650
|
+
all_datasets_summary_str = state.get("all_datasets_summary")
|
651
|
+
|
199
652
|
print(" * CREATE DATA WRANGLER CODE")
|
200
653
|
|
201
654
|
data_wrangling_prompt = PromptTemplate(
|
202
655
|
template="""
|
203
|
-
You are a Data Wrangling Coding Agent. Your job is to create a
|
656
|
+
You are a Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data.
|
204
657
|
|
205
658
|
Follow these recommended steps:
|
206
659
|
{recommended_steps}
|
@@ -210,10 +663,10 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
210
663
|
Below are summaries of all datasets provided. If more than one dataset is provided, you may need to merge or join them.:
|
211
664
|
{all_datasets_summary}
|
212
665
|
|
213
|
-
Return Python code in ```python``` format with a single function definition,
|
666
|
+
Return Python code in ```python``` format with a single function definition, {function_name}(), that includes all imports inside the function. And returns a single pandas data frame.
|
214
667
|
|
215
668
|
```python
|
216
|
-
def
|
669
|
+
def {function_name}(data_list):
|
217
670
|
'''
|
218
671
|
Wrangle the data provided in data.
|
219
672
|
|
@@ -235,23 +688,24 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
235
688
|
|
236
689
|
|
237
690
|
""",
|
238
|
-
input_variables=["recommended_steps", "all_datasets_summary"]
|
691
|
+
input_variables=["recommended_steps", "all_datasets_summary", "function_name"]
|
239
692
|
)
|
240
693
|
|
241
694
|
data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
|
242
695
|
|
243
696
|
response = data_wrangling_agent.invoke({
|
244
697
|
"recommended_steps": state.get("recommended_steps"),
|
245
|
-
"all_datasets_summary":
|
698
|
+
"all_datasets_summary": all_datasets_summary_str,
|
699
|
+
"function_name": function_name
|
246
700
|
})
|
247
701
|
|
248
702
|
response = relocate_imports_inside_function(response)
|
249
703
|
response = add_comments_to_top(response, agent_name=AGENT_NAME)
|
250
704
|
|
251
705
|
# For logging: store the code generated
|
252
|
-
file_path,
|
706
|
+
file_path, file_name_2 = log_ai_function(
|
253
707
|
response=response,
|
254
|
-
file_name=
|
708
|
+
file_name=file_name,
|
255
709
|
log=log,
|
256
710
|
log_path=log_path,
|
257
711
|
overwrite=overwrite
|
@@ -260,7 +714,9 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
260
714
|
return {
|
261
715
|
"data_wrangler_function" : response,
|
262
716
|
"data_wrangler_function_path": file_path,
|
263
|
-
"
|
717
|
+
"data_wrangler_file_name": file_name_2,
|
718
|
+
"data_wrangler_function_name": function_name,
|
719
|
+
"all_datasets_summary": all_datasets_summary_str
|
264
720
|
}
|
265
721
|
|
266
722
|
|
@@ -273,6 +729,33 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
273
729
|
user_instructions_key="user_instructions",
|
274
730
|
recommended_steps_key="recommended_steps"
|
275
731
|
)
|
732
|
+
|
733
|
+
# Human Review
|
734
|
+
|
735
|
+
prompt_text_human_review = "Are the following data wrangling instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
|
736
|
+
|
737
|
+
if not bypass_explain_code:
|
738
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "explain_data_wrangler_code"]]:
|
739
|
+
return node_func_human_review(
|
740
|
+
state=state,
|
741
|
+
prompt_text=prompt_text_human_review,
|
742
|
+
yes_goto= 'explain_data_wrangler_code',
|
743
|
+
no_goto="recommend_wrangling_steps",
|
744
|
+
user_instructions_key="user_instructions",
|
745
|
+
recommended_steps_key="recommended_steps",
|
746
|
+
code_snippet_key="data_wrangler_function",
|
747
|
+
)
|
748
|
+
else:
|
749
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "__end__"]]:
|
750
|
+
return node_func_human_review(
|
751
|
+
state=state,
|
752
|
+
prompt_text=prompt_text_human_review,
|
753
|
+
yes_goto= '__end__',
|
754
|
+
no_goto="recommend_wrangling_steps",
|
755
|
+
user_instructions_key="user_instructions",
|
756
|
+
recommended_steps_key="recommended_steps",
|
757
|
+
code_snippet_key="data_wrangler_function",
|
758
|
+
)
|
276
759
|
|
277
760
|
def execute_data_wrangler_code(state: GraphState):
|
278
761
|
return node_func_execute_agent_code_on_data(
|
@@ -281,7 +764,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
281
764
|
result_key="data_wrangled",
|
282
765
|
error_key="data_wrangler_error",
|
283
766
|
code_snippet_key="data_wrangler_function",
|
284
|
-
agent_function_name="
|
767
|
+
agent_function_name=state.get("data_wrangler_function_name"),
|
285
768
|
# pre_processing=pre_processing,
|
286
769
|
post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
287
770
|
error_message_prefix="An error occurred during data wrangling: "
|
@@ -289,11 +772,11 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
289
772
|
|
290
773
|
def fix_data_wrangler_code(state: GraphState):
|
291
774
|
data_wrangler_prompt = """
|
292
|
-
You are a Data Wrangling Agent. Your job is to create a
|
775
|
+
You are a Data Wrangling Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
|
293
776
|
|
294
|
-
Make sure to only return the function definition for
|
777
|
+
Make sure to only return the function definition for {function_name}().
|
295
778
|
|
296
|
-
Return Python code in ```python``` format with a single function definition,
|
779
|
+
Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
|
297
780
|
|
298
781
|
This is the broken code (please fix):
|
299
782
|
{code_snippet}
|
@@ -311,6 +794,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
311
794
|
agent_name=AGENT_NAME,
|
312
795
|
log=log,
|
313
796
|
file_path=state.get("data_wrangler_function_path"),
|
797
|
+
function_name=state.get("data_wrangler_function_name"),
|
314
798
|
)
|
315
799
|
|
316
800
|
def explain_data_wrangler_code(state: GraphState):
|