ai-data-science-team 0.0.0.9000__py3-none-any.whl → 0.0.0.9005__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +4 -0
- ai_data_science_team/agents/data_cleaning_agent.py +347 -0
- ai_data_science_team/agents/data_wrangling_agent.py +365 -0
- ai_data_science_team/agents/feature_engineering_agent.py +368 -0
- ai_data_science_team/templates/__init__.py +0 -0
- ai_data_science_team/templates/agent_templates.py +409 -0
- ai_data_science_team/tools/__init__.py +0 -0
- ai_data_science_team/tools/data_analysis.py +116 -0
- ai_data_science_team/tools/logging.py +61 -0
- ai_data_science_team/tools/parsers.py +57 -0
- ai_data_science_team/tools/regex.py +73 -0
- ai_data_science_team-0.0.0.9005.dist-info/METADATA +162 -0
- ai_data_science_team-0.0.0.9005.dist-info/RECORD +19 -0
- ai_data_science_team/agents.py +0 -325
- ai_data_science_team-0.0.0.9000.dist-info/METADATA +0 -131
- ai_data_science_team-0.0.0.9000.dist-info/RECORD +0 -9
- {ai_data_science_team-0.0.0.9000.dist-info → ai_data_science_team-0.0.0.9005.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9000.dist-info → ai_data_science_team-0.0.0.9005.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9000.dist-info → ai_data_science_team-0.0.0.9005.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,365 @@
|
|
1
|
+
# BUSINESS SCIENCE UNIVERSITY
|
2
|
+
# AI DATA SCIENCE TEAM
|
3
|
+
# ***
|
4
|
+
# * Agents: Data Wrangling Agent
|
5
|
+
|
6
|
+
# Libraries
|
7
|
+
from typing import TypedDict, Annotated, Sequence, Literal, Union
|
8
|
+
import operator
|
9
|
+
import os
|
10
|
+
import io
|
11
|
+
import pandas as pd
|
12
|
+
|
13
|
+
from langchain.prompts import PromptTemplate
|
14
|
+
from langchain_core.messages import BaseMessage
|
15
|
+
from langgraph.types import Command
|
16
|
+
from langgraph.checkpoint.memory import MemorySaver
|
17
|
+
|
18
|
+
from ai_data_science_team.templates.agent_templates import(
|
19
|
+
node_func_execute_agent_code_on_data,
|
20
|
+
node_func_human_review,
|
21
|
+
node_func_fix_agent_code,
|
22
|
+
node_func_explain_agent_code,
|
23
|
+
create_coding_agent_graph
|
24
|
+
)
|
25
|
+
from ai_data_science_team.tools.parsers import PythonOutputParser
|
26
|
+
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
|
27
|
+
from ai_data_science_team.tools.data_analysis import summarize_dataframes
|
28
|
+
from ai_data_science_team.tools.logging import log_ai_function
|
29
|
+
|
30
|
+
# Setup Logging Path
|
31
|
+
AGENT_NAME = "data_wrangling_agent"
|
32
|
+
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
33
|
+
|
34
|
+
def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True, human_in_the_loop=False):
|
35
|
+
"""
|
36
|
+
Creates a data wrangling agent that can be run on one or more datasets. The agent can be
|
37
|
+
instructed to perform common data wrangling steps such as:
|
38
|
+
|
39
|
+
- Joining or merging multiple datasets
|
40
|
+
- Reshaping data (pivoting, melting)
|
41
|
+
- Aggregating data via groupby operations
|
42
|
+
- Encoding categorical variables (one-hot, label encoding)
|
43
|
+
- Creating computed features (e.g., ratio of two columns)
|
44
|
+
- Ensuring consistent data types
|
45
|
+
- Dropping or rearranging columns
|
46
|
+
|
47
|
+
The agent takes in one or more datasets (passed as a dictionary or list of dictionaries if working on multiple dictionaries), user instructions,
|
48
|
+
and outputs a python function that can be used to wrangle the data. If multiple datasets
|
49
|
+
are provided, the agent should combine or transform them according to user instructions.
|
50
|
+
|
51
|
+
Parameters
|
52
|
+
----------
|
53
|
+
model : langchain.llms.base.LLM
|
54
|
+
The language model to use to generate code.
|
55
|
+
log : bool, optional
|
56
|
+
Whether or not to log the code generated and any errors that occur.
|
57
|
+
Defaults to False.
|
58
|
+
log_path : str, optional
|
59
|
+
The path to the directory where the log files should be stored. Defaults to "logs/".
|
60
|
+
overwrite : bool, optional
|
61
|
+
Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
|
62
|
+
Defaults to True.
|
63
|
+
human_in_the_loop : bool, optional
|
64
|
+
Whether or not to use human in the loop. If True, adds an interrupt and human-in-the-loop
|
65
|
+
step that asks the user to review the data wrangling instructions. Defaults to False.
|
66
|
+
|
67
|
+
Example
|
68
|
+
-------
|
69
|
+
``` python
|
70
|
+
from langchain_openai import ChatOpenAI
|
71
|
+
import pandas as pd
|
72
|
+
|
73
|
+
df = pd.DataFrame({
|
74
|
+
'category': ['A', 'B', 'A', 'C'],
|
75
|
+
'value': [10, 20, 15, 5]
|
76
|
+
})
|
77
|
+
|
78
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
79
|
+
|
80
|
+
data_wrangling_agent = make_data_wrangling_agent(llm)
|
81
|
+
|
82
|
+
response = data_wrangling_agent.invoke({
|
83
|
+
"user_instructions": "Calculate the sum and mean of 'value' by 'category'.",
|
84
|
+
"data_raw": df.to_dict(),
|
85
|
+
"max_retries":3,
|
86
|
+
"retry_count":0
|
87
|
+
})
|
88
|
+
pd.DataFrame(response['data_wrangled'])
|
89
|
+
```
|
90
|
+
|
91
|
+
Returns
|
92
|
+
-------
|
93
|
+
app : langchain.graphs.StateGraph
|
94
|
+
The data wrangling agent as a state graph.
|
95
|
+
"""
|
96
|
+
llm = model
|
97
|
+
|
98
|
+
# Setup Log Directory
|
99
|
+
if log:
|
100
|
+
if log_path is None:
|
101
|
+
log_path = LOG_PATH
|
102
|
+
if not os.path.exists(log_path):
|
103
|
+
os.makedirs(log_path)
|
104
|
+
|
105
|
+
class GraphState(TypedDict):
|
106
|
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
107
|
+
user_instructions: str
|
108
|
+
recommended_steps: str
|
109
|
+
# data_raw should be a dict for a single dataset or a list of dicts for multiple datasets
|
110
|
+
data_raw: Union[dict, list]
|
111
|
+
data_wrangled: dict
|
112
|
+
all_datasets_summary: str
|
113
|
+
data_wrangler_function: str
|
114
|
+
data_wrangler_function_path: str
|
115
|
+
data_wrangler_function_name: str
|
116
|
+
data_wrangler_error: str
|
117
|
+
max_retries: int
|
118
|
+
retry_count: int
|
119
|
+
|
120
|
+
def recommend_wrangling_steps(state: GraphState):
|
121
|
+
print("---DATA WRANGLING AGENT----")
|
122
|
+
print(" * RECOMMEND WRANGLING STEPS")
|
123
|
+
|
124
|
+
data_raw = state.get("data_raw")
|
125
|
+
|
126
|
+
if isinstance(data_raw, dict):
|
127
|
+
# Single dataset scenario
|
128
|
+
primary_dataset_name = "main"
|
129
|
+
datasets = {primary_dataset_name: data_raw}
|
130
|
+
elif isinstance(data_raw, list) and all(isinstance(item, dict) for item in data_raw):
|
131
|
+
# Multiple datasets scenario
|
132
|
+
datasets = {f"dataset_{i}": d for i, d in enumerate(data_raw, start=1)}
|
133
|
+
primary_dataset_name = "dataset_1"
|
134
|
+
else:
|
135
|
+
raise ValueError("data_raw must be a dict or a list of dicts.")
|
136
|
+
|
137
|
+
# Convert all datasets to DataFrames for inspection
|
138
|
+
dataframes = {name: pd.DataFrame.from_dict(d) for name, d in datasets.items()}
|
139
|
+
|
140
|
+
# Create a summary for all datasets
|
141
|
+
# We'll include a short sample and info for each dataset
|
142
|
+
all_datasets_summary = summarize_dataframes(dataframes)
|
143
|
+
|
144
|
+
# Join all datasets summaries into one big text block
|
145
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
146
|
+
|
147
|
+
# Prepare the prompt:
|
148
|
+
# We now include summaries for all datasets, not just the primary dataset.
|
149
|
+
# The LLM can then use all this info to recommend steps that consider merging/joining.
|
150
|
+
recommend_steps_prompt = PromptTemplate(
|
151
|
+
template="""
|
152
|
+
You are a Data Wrangling Expert. Given the following data (one or multiple datasets) and user instructions,
|
153
|
+
recommend a series of numbered steps to wrangle the data based on a user's needs.
|
154
|
+
|
155
|
+
You can use any common data wrangling techniques such as joining, reshaping, aggregating, encoding, etc.
|
156
|
+
|
157
|
+
If multiple datasets are provided, you may need to recommend how to merge or join them.
|
158
|
+
|
159
|
+
Also consider any special transformations requested by the user. If the user instructions
|
160
|
+
say to do something else or not to do certain steps, follow those instructions.
|
161
|
+
|
162
|
+
User instructions:
|
163
|
+
{user_instructions}
|
164
|
+
|
165
|
+
Previously Recommended Steps (if any):
|
166
|
+
{recommended_steps}
|
167
|
+
|
168
|
+
Below are summaries of all datasets provided:
|
169
|
+
{all_datasets_summary}
|
170
|
+
|
171
|
+
Return your recommended steps as a numbered point list, explaining briefly why each step is needed.
|
172
|
+
|
173
|
+
Avoid these:
|
174
|
+
1. Do not include steps to save files.
|
175
|
+
""",
|
176
|
+
input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
|
177
|
+
)
|
178
|
+
|
179
|
+
steps_agent = recommend_steps_prompt | llm
|
180
|
+
recommended_steps = steps_agent.invoke({
|
181
|
+
"user_instructions": state.get("user_instructions"),
|
182
|
+
"recommended_steps": state.get("recommended_steps"),
|
183
|
+
"all_datasets_summary": all_datasets_summary_str,
|
184
|
+
})
|
185
|
+
|
186
|
+
return {
|
187
|
+
"recommended_steps": "\n\n# Recommended Wrangling Steps:\n" + recommended_steps.content.strip(),
|
188
|
+
"all_datasets_summary": all_datasets_summary_str,
|
189
|
+
}
|
190
|
+
|
191
|
+
|
192
|
+
def create_data_wrangler_code(state: GraphState):
|
193
|
+
print(" * CREATE DATA WRANGLER CODE")
|
194
|
+
|
195
|
+
data_wrangling_prompt = PromptTemplate(
|
196
|
+
template="""
|
197
|
+
You are a Data Wrangling Coding Agent. Your job is to create a data_wrangler() function that can be run on the provided data.
|
198
|
+
|
199
|
+
Follow these recommended steps:
|
200
|
+
{recommended_steps}
|
201
|
+
|
202
|
+
If multiple datasets are provided, you may need to merge or join them. Make sure to handle that scenario based on the recommended steps and user instructions.
|
203
|
+
|
204
|
+
Below are summaries of all datasets provided. If more than one dataset is provided, you may need to merge or join them.:
|
205
|
+
{all_datasets_summary}
|
206
|
+
|
207
|
+
Return Python code in ```python``` format with a single function definition, data_wrangler(), that includes all imports inside the function. And returns a single pandas data frame.
|
208
|
+
|
209
|
+
```python
|
210
|
+
def data_wrangler(data_list):
|
211
|
+
'''
|
212
|
+
Wrangle the data provided in data.
|
213
|
+
|
214
|
+
data_list: A list of one or more pandas data frames containing the raw data to be wrangled.
|
215
|
+
'''
|
216
|
+
import pandas as pd
|
217
|
+
import numpy as np
|
218
|
+
# Implement the wrangling steps here
|
219
|
+
|
220
|
+
# Return a single DataFrame
|
221
|
+
return data_wrangled
|
222
|
+
```
|
223
|
+
|
224
|
+
Avoid Errors:
|
225
|
+
1. If the incoming data is not a list. Convert it to a list first.
|
226
|
+
2. Do not specify data types inside the function arguments.
|
227
|
+
|
228
|
+
Make sure to explain any non-trivial steps with inline comments. Follow user instructions. Comment code thoroughly.
|
229
|
+
|
230
|
+
|
231
|
+
""",
|
232
|
+
input_variables=["recommended_steps", "all_datasets_summary"]
|
233
|
+
)
|
234
|
+
|
235
|
+
data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
|
236
|
+
|
237
|
+
response = data_wrangling_agent.invoke({
|
238
|
+
"recommended_steps": state.get("recommended_steps"),
|
239
|
+
"all_datasets_summary": state.get("all_datasets_summary")
|
240
|
+
})
|
241
|
+
|
242
|
+
response = relocate_imports_inside_function(response)
|
243
|
+
response = add_comments_to_top(response, agent_name=AGENT_NAME)
|
244
|
+
|
245
|
+
# For logging: store the code generated
|
246
|
+
file_path, file_name = log_ai_function(
|
247
|
+
response=response,
|
248
|
+
file_name="data_wrangler.py",
|
249
|
+
log=log,
|
250
|
+
log_path=log_path,
|
251
|
+
overwrite=overwrite
|
252
|
+
)
|
253
|
+
|
254
|
+
return {
|
255
|
+
"data_wrangler_function" : response,
|
256
|
+
"data_wrangler_function_path": file_path,
|
257
|
+
"data_wrangler_function_name": file_name
|
258
|
+
}
|
259
|
+
|
260
|
+
|
261
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "create_data_wrangler_code"]]:
|
262
|
+
return node_func_human_review(
|
263
|
+
state=state,
|
264
|
+
prompt_text="Are the following data wrangling steps correct? (Answer 'yes' or provide modifications)\n{steps}",
|
265
|
+
yes_goto="create_data_wrangler_code",
|
266
|
+
no_goto="recommend_wrangling_steps",
|
267
|
+
user_instructions_key="user_instructions",
|
268
|
+
recommended_steps_key="recommended_steps"
|
269
|
+
)
|
270
|
+
|
271
|
+
def execute_data_wrangler_code(state: GraphState):
|
272
|
+
|
273
|
+
# Handle multiple datasets as lists
|
274
|
+
# def pre_processing(data):
|
275
|
+
# df = []
|
276
|
+
# for i in range(len(data)):
|
277
|
+
# df[i] = pd.DataFrame.from_dict(data[i])
|
278
|
+
# return df
|
279
|
+
|
280
|
+
# def post_processing(df):
|
281
|
+
# return df.to_dict()
|
282
|
+
|
283
|
+
return node_func_execute_agent_code_on_data(
|
284
|
+
state=state,
|
285
|
+
data_key="data_raw",
|
286
|
+
result_key="data_wrangled",
|
287
|
+
error_key="data_wrangler_error",
|
288
|
+
code_snippet_key="data_wrangler_function",
|
289
|
+
agent_function_name="data_wrangler",
|
290
|
+
# pre_processing=pre_processing,
|
291
|
+
# post_processing=post_processing,
|
292
|
+
error_message_prefix="An error occurred during data wrangling: "
|
293
|
+
)
|
294
|
+
|
295
|
+
def fix_data_wrangler_code(state: GraphState):
|
296
|
+
data_wrangler_prompt = """
|
297
|
+
You are a Data Wrangling Agent. Your job is to create a data_wrangler() function that can be run on the data provided. The function is currently broken and needs to be fixed.
|
298
|
+
|
299
|
+
Make sure to only return the function definition for data_wrangler().
|
300
|
+
|
301
|
+
Return Python code in ```python``` format with a single function definition, data_wrangler(data_raw), that includes all imports inside the function.
|
302
|
+
|
303
|
+
This is the broken code (please fix):
|
304
|
+
{code_snippet}
|
305
|
+
|
306
|
+
Last Known Error:
|
307
|
+
{error}
|
308
|
+
"""
|
309
|
+
|
310
|
+
return node_func_fix_agent_code(
|
311
|
+
state=state,
|
312
|
+
code_snippet_key="data_wrangler_function",
|
313
|
+
error_key="data_wrangler_error",
|
314
|
+
llm=llm,
|
315
|
+
prompt_template=data_wrangler_prompt,
|
316
|
+
agent_name=AGENT_NAME,
|
317
|
+
log=log,
|
318
|
+
file_path=state.get("data_wrangler_function_path"),
|
319
|
+
)
|
320
|
+
|
321
|
+
def explain_data_wrangler_code(state: GraphState):
|
322
|
+
return node_func_explain_agent_code(
|
323
|
+
state=state,
|
324
|
+
code_snippet_key="data_wrangler_function",
|
325
|
+
result_key="messages",
|
326
|
+
error_key="data_wrangler_error",
|
327
|
+
llm=llm,
|
328
|
+
role=AGENT_NAME,
|
329
|
+
explanation_prompt_template="""
|
330
|
+
Explain the data wrangling steps that the data wrangling agent performed in this function.
|
331
|
+
Keep the summary succinct and to the point.\n\n# Data Wrangling Agent:\n\n{code}
|
332
|
+
""",
|
333
|
+
success_prefix="# Data Wrangling Agent:\n\n ",
|
334
|
+
error_message="The Data Wrangling Agent encountered an error during data wrangling. Data could not be explained."
|
335
|
+
)
|
336
|
+
|
337
|
+
# Define the graph
|
338
|
+
node_functions = {
|
339
|
+
"recommend_wrangling_steps": recommend_wrangling_steps,
|
340
|
+
"human_review": human_review,
|
341
|
+
"create_data_wrangler_code": create_data_wrangler_code,
|
342
|
+
"execute_data_wrangler_code": execute_data_wrangler_code,
|
343
|
+
"fix_data_wrangler_code": fix_data_wrangler_code,
|
344
|
+
"explain_data_wrangler_code": explain_data_wrangler_code
|
345
|
+
}
|
346
|
+
|
347
|
+
app = create_coding_agent_graph(
|
348
|
+
GraphState=GraphState,
|
349
|
+
node_functions=node_functions,
|
350
|
+
recommended_steps_node_name="recommend_wrangling_steps",
|
351
|
+
create_code_node_name="create_data_wrangler_code",
|
352
|
+
execute_code_node_name="execute_data_wrangler_code",
|
353
|
+
fix_code_node_name="fix_data_wrangler_code",
|
354
|
+
explain_code_node_name="explain_data_wrangler_code",
|
355
|
+
error_key="data_wrangler_error",
|
356
|
+
human_in_the_loop=human_in_the_loop,
|
357
|
+
human_review_node_name="human_review",
|
358
|
+
checkpointer=MemorySaver() if human_in_the_loop else None
|
359
|
+
)
|
360
|
+
|
361
|
+
return app
|
362
|
+
|
363
|
+
|
364
|
+
|
365
|
+
|