ai-data-science-team 0.0.0.9000__py3-none-any.whl → 0.0.0.9005__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +4 -0
- ai_data_science_team/agents/data_cleaning_agent.py +347 -0
- ai_data_science_team/agents/data_wrangling_agent.py +365 -0
- ai_data_science_team/agents/feature_engineering_agent.py +368 -0
- ai_data_science_team/templates/__init__.py +0 -0
- ai_data_science_team/templates/agent_templates.py +409 -0
- ai_data_science_team/tools/__init__.py +0 -0
- ai_data_science_team/tools/data_analysis.py +116 -0
- ai_data_science_team/tools/logging.py +61 -0
- ai_data_science_team/tools/parsers.py +57 -0
- ai_data_science_team/tools/regex.py +73 -0
- ai_data_science_team-0.0.0.9005.dist-info/METADATA +162 -0
- ai_data_science_team-0.0.0.9005.dist-info/RECORD +19 -0
- ai_data_science_team/agents.py +0 -325
- ai_data_science_team-0.0.0.9000.dist-info/METADATA +0 -131
- ai_data_science_team-0.0.0.9000.dist-info/RECORD +0 -9
- {ai_data_science_team-0.0.0.9000.dist-info → ai_data_science_team-0.0.0.9005.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9000.dist-info → ai_data_science_team-0.0.0.9005.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9000.dist-info → ai_data_science_team-0.0.0.9005.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,365 @@
|
|
1
|
+
# BUSINESS SCIENCE UNIVERSITY
|
2
|
+
# AI DATA SCIENCE TEAM
|
3
|
+
# ***
|
4
|
+
# * Agents: Data Wrangling Agent
|
5
|
+
|
6
|
+
# Libraries
|
7
|
+
from typing import TypedDict, Annotated, Sequence, Literal, Union
|
8
|
+
import operator
|
9
|
+
import os
|
10
|
+
import io
|
11
|
+
import pandas as pd
|
12
|
+
|
13
|
+
from langchain.prompts import PromptTemplate
|
14
|
+
from langchain_core.messages import BaseMessage
|
15
|
+
from langgraph.types import Command
|
16
|
+
from langgraph.checkpoint.memory import MemorySaver
|
17
|
+
|
18
|
+
from ai_data_science_team.templates.agent_templates import(
|
19
|
+
node_func_execute_agent_code_on_data,
|
20
|
+
node_func_human_review,
|
21
|
+
node_func_fix_agent_code,
|
22
|
+
node_func_explain_agent_code,
|
23
|
+
create_coding_agent_graph
|
24
|
+
)
|
25
|
+
from ai_data_science_team.tools.parsers import PythonOutputParser
|
26
|
+
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
|
27
|
+
from ai_data_science_team.tools.data_analysis import summarize_dataframes
|
28
|
+
from ai_data_science_team.tools.logging import log_ai_function
|
29
|
+
|
30
|
+
# Setup Logging Path
|
31
|
+
AGENT_NAME = "data_wrangling_agent"
|
32
|
+
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
33
|
+
|
34
|
+
def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True, human_in_the_loop=False):
|
35
|
+
"""
|
36
|
+
Creates a data wrangling agent that can be run on one or more datasets. The agent can be
|
37
|
+
instructed to perform common data wrangling steps such as:
|
38
|
+
|
39
|
+
- Joining or merging multiple datasets
|
40
|
+
- Reshaping data (pivoting, melting)
|
41
|
+
- Aggregating data via groupby operations
|
42
|
+
- Encoding categorical variables (one-hot, label encoding)
|
43
|
+
- Creating computed features (e.g., ratio of two columns)
|
44
|
+
- Ensuring consistent data types
|
45
|
+
- Dropping or rearranging columns
|
46
|
+
|
47
|
+
The agent takes in one or more datasets (passed as a dictionary or list of dictionaries if working on multiple dictionaries), user instructions,
|
48
|
+
and outputs a python function that can be used to wrangle the data. If multiple datasets
|
49
|
+
are provided, the agent should combine or transform them according to user instructions.
|
50
|
+
|
51
|
+
Parameters
|
52
|
+
----------
|
53
|
+
model : langchain.llms.base.LLM
|
54
|
+
The language model to use to generate code.
|
55
|
+
log : bool, optional
|
56
|
+
Whether or not to log the code generated and any errors that occur.
|
57
|
+
Defaults to False.
|
58
|
+
log_path : str, optional
|
59
|
+
The path to the directory where the log files should be stored. Defaults to "logs/".
|
60
|
+
overwrite : bool, optional
|
61
|
+
Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
|
62
|
+
Defaults to True.
|
63
|
+
human_in_the_loop : bool, optional
|
64
|
+
Whether or not to use human in the loop. If True, adds an interrupt and human-in-the-loop
|
65
|
+
step that asks the user to review the data wrangling instructions. Defaults to False.
|
66
|
+
|
67
|
+
Example
|
68
|
+
-------
|
69
|
+
``` python
|
70
|
+
from langchain_openai import ChatOpenAI
|
71
|
+
import pandas as pd
|
72
|
+
|
73
|
+
df = pd.DataFrame({
|
74
|
+
'category': ['A', 'B', 'A', 'C'],
|
75
|
+
'value': [10, 20, 15, 5]
|
76
|
+
})
|
77
|
+
|
78
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
79
|
+
|
80
|
+
data_wrangling_agent = make_data_wrangling_agent(llm)
|
81
|
+
|
82
|
+
response = data_wrangling_agent.invoke({
|
83
|
+
"user_instructions": "Calculate the sum and mean of 'value' by 'category'.",
|
84
|
+
"data_raw": df.to_dict(),
|
85
|
+
"max_retries":3,
|
86
|
+
"retry_count":0
|
87
|
+
})
|
88
|
+
pd.DataFrame(response['data_wrangled'])
|
89
|
+
```
|
90
|
+
|
91
|
+
Returns
|
92
|
+
-------
|
93
|
+
app : langchain.graphs.StateGraph
|
94
|
+
The data wrangling agent as a state graph.
|
95
|
+
"""
|
96
|
+
llm = model
|
97
|
+
|
98
|
+
# Setup Log Directory
|
99
|
+
if log:
|
100
|
+
if log_path is None:
|
101
|
+
log_path = LOG_PATH
|
102
|
+
if not os.path.exists(log_path):
|
103
|
+
os.makedirs(log_path)
|
104
|
+
|
105
|
+
class GraphState(TypedDict):
|
106
|
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
107
|
+
user_instructions: str
|
108
|
+
recommended_steps: str
|
109
|
+
# data_raw should be a dict for a single dataset or a list of dicts for multiple datasets
|
110
|
+
data_raw: Union[dict, list]
|
111
|
+
data_wrangled: dict
|
112
|
+
all_datasets_summary: str
|
113
|
+
data_wrangler_function: str
|
114
|
+
data_wrangler_function_path: str
|
115
|
+
data_wrangler_function_name: str
|
116
|
+
data_wrangler_error: str
|
117
|
+
max_retries: int
|
118
|
+
retry_count: int
|
119
|
+
|
120
|
+
def recommend_wrangling_steps(state: GraphState):
|
121
|
+
print("---DATA WRANGLING AGENT----")
|
122
|
+
print(" * RECOMMEND WRANGLING STEPS")
|
123
|
+
|
124
|
+
data_raw = state.get("data_raw")
|
125
|
+
|
126
|
+
if isinstance(data_raw, dict):
|
127
|
+
# Single dataset scenario
|
128
|
+
primary_dataset_name = "main"
|
129
|
+
datasets = {primary_dataset_name: data_raw}
|
130
|
+
elif isinstance(data_raw, list) and all(isinstance(item, dict) for item in data_raw):
|
131
|
+
# Multiple datasets scenario
|
132
|
+
datasets = {f"dataset_{i}": d for i, d in enumerate(data_raw, start=1)}
|
133
|
+
primary_dataset_name = "dataset_1"
|
134
|
+
else:
|
135
|
+
raise ValueError("data_raw must be a dict or a list of dicts.")
|
136
|
+
|
137
|
+
# Convert all datasets to DataFrames for inspection
|
138
|
+
dataframes = {name: pd.DataFrame.from_dict(d) for name, d in datasets.items()}
|
139
|
+
|
140
|
+
# Create a summary for all datasets
|
141
|
+
# We'll include a short sample and info for each dataset
|
142
|
+
all_datasets_summary = summarize_dataframes(dataframes)
|
143
|
+
|
144
|
+
# Join all datasets summaries into one big text block
|
145
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
146
|
+
|
147
|
+
# Prepare the prompt:
|
148
|
+
# We now include summaries for all datasets, not just the primary dataset.
|
149
|
+
# The LLM can then use all this info to recommend steps that consider merging/joining.
|
150
|
+
recommend_steps_prompt = PromptTemplate(
|
151
|
+
template="""
|
152
|
+
You are a Data Wrangling Expert. Given the following data (one or multiple datasets) and user instructions,
|
153
|
+
recommend a series of numbered steps to wrangle the data based on a user's needs.
|
154
|
+
|
155
|
+
You can use any common data wrangling techniques such as joining, reshaping, aggregating, encoding, etc.
|
156
|
+
|
157
|
+
If multiple datasets are provided, you may need to recommend how to merge or join them.
|
158
|
+
|
159
|
+
Also consider any special transformations requested by the user. If the user instructions
|
160
|
+
say to do something else or not to do certain steps, follow those instructions.
|
161
|
+
|
162
|
+
User instructions:
|
163
|
+
{user_instructions}
|
164
|
+
|
165
|
+
Previously Recommended Steps (if any):
|
166
|
+
{recommended_steps}
|
167
|
+
|
168
|
+
Below are summaries of all datasets provided:
|
169
|
+
{all_datasets_summary}
|
170
|
+
|
171
|
+
Return your recommended steps as a numbered point list, explaining briefly why each step is needed.
|
172
|
+
|
173
|
+
Avoid these:
|
174
|
+
1. Do not include steps to save files.
|
175
|
+
""",
|
176
|
+
input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
|
177
|
+
)
|
178
|
+
|
179
|
+
steps_agent = recommend_steps_prompt | llm
|
180
|
+
recommended_steps = steps_agent.invoke({
|
181
|
+
"user_instructions": state.get("user_instructions"),
|
182
|
+
"recommended_steps": state.get("recommended_steps"),
|
183
|
+
"all_datasets_summary": all_datasets_summary_str,
|
184
|
+
})
|
185
|
+
|
186
|
+
return {
|
187
|
+
"recommended_steps": "\n\n# Recommended Wrangling Steps:\n" + recommended_steps.content.strip(),
|
188
|
+
"all_datasets_summary": all_datasets_summary_str,
|
189
|
+
}
|
190
|
+
|
191
|
+
|
192
|
+
def create_data_wrangler_code(state: GraphState):
|
193
|
+
print(" * CREATE DATA WRANGLER CODE")
|
194
|
+
|
195
|
+
data_wrangling_prompt = PromptTemplate(
|
196
|
+
template="""
|
197
|
+
You are a Data Wrangling Coding Agent. Your job is to create a data_wrangler() function that can be run on the provided data.
|
198
|
+
|
199
|
+
Follow these recommended steps:
|
200
|
+
{recommended_steps}
|
201
|
+
|
202
|
+
If multiple datasets are provided, you may need to merge or join them. Make sure to handle that scenario based on the recommended steps and user instructions.
|
203
|
+
|
204
|
+
Below are summaries of all datasets provided. If more than one dataset is provided, you may need to merge or join them.:
|
205
|
+
{all_datasets_summary}
|
206
|
+
|
207
|
+
Return Python code in ```python``` format with a single function definition, data_wrangler(), that includes all imports inside the function. And returns a single pandas data frame.
|
208
|
+
|
209
|
+
```python
|
210
|
+
def data_wrangler(data_list):
|
211
|
+
'''
|
212
|
+
Wrangle the data provided in data.
|
213
|
+
|
214
|
+
data_list: A list of one or more pandas data frames containing the raw data to be wrangled.
|
215
|
+
'''
|
216
|
+
import pandas as pd
|
217
|
+
import numpy as np
|
218
|
+
# Implement the wrangling steps here
|
219
|
+
|
220
|
+
# Return a single DataFrame
|
221
|
+
return data_wrangled
|
222
|
+
```
|
223
|
+
|
224
|
+
Avoid Errors:
|
225
|
+
1. If the incoming data is not a list. Convert it to a list first.
|
226
|
+
2. Do not specify data types inside the function arguments.
|
227
|
+
|
228
|
+
Make sure to explain any non-trivial steps with inline comments. Follow user instructions. Comment code thoroughly.
|
229
|
+
|
230
|
+
|
231
|
+
""",
|
232
|
+
input_variables=["recommended_steps", "all_datasets_summary"]
|
233
|
+
)
|
234
|
+
|
235
|
+
data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
|
236
|
+
|
237
|
+
response = data_wrangling_agent.invoke({
|
238
|
+
"recommended_steps": state.get("recommended_steps"),
|
239
|
+
"all_datasets_summary": state.get("all_datasets_summary")
|
240
|
+
})
|
241
|
+
|
242
|
+
response = relocate_imports_inside_function(response)
|
243
|
+
response = add_comments_to_top(response, agent_name=AGENT_NAME)
|
244
|
+
|
245
|
+
# For logging: store the code generated
|
246
|
+
file_path, file_name = log_ai_function(
|
247
|
+
response=response,
|
248
|
+
file_name="data_wrangler.py",
|
249
|
+
log=log,
|
250
|
+
log_path=log_path,
|
251
|
+
overwrite=overwrite
|
252
|
+
)
|
253
|
+
|
254
|
+
return {
|
255
|
+
"data_wrangler_function" : response,
|
256
|
+
"data_wrangler_function_path": file_path,
|
257
|
+
"data_wrangler_function_name": file_name
|
258
|
+
}
|
259
|
+
|
260
|
+
|
261
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "create_data_wrangler_code"]]:
|
262
|
+
return node_func_human_review(
|
263
|
+
state=state,
|
264
|
+
prompt_text="Are the following data wrangling steps correct? (Answer 'yes' or provide modifications)\n{steps}",
|
265
|
+
yes_goto="create_data_wrangler_code",
|
266
|
+
no_goto="recommend_wrangling_steps",
|
267
|
+
user_instructions_key="user_instructions",
|
268
|
+
recommended_steps_key="recommended_steps"
|
269
|
+
)
|
270
|
+
|
271
|
+
def execute_data_wrangler_code(state: GraphState):
|
272
|
+
|
273
|
+
# Handle multiple datasets as lists
|
274
|
+
# def pre_processing(data):
|
275
|
+
# df = []
|
276
|
+
# for i in range(len(data)):
|
277
|
+
# df[i] = pd.DataFrame.from_dict(data[i])
|
278
|
+
# return df
|
279
|
+
|
280
|
+
# def post_processing(df):
|
281
|
+
# return df.to_dict()
|
282
|
+
|
283
|
+
return node_func_execute_agent_code_on_data(
|
284
|
+
state=state,
|
285
|
+
data_key="data_raw",
|
286
|
+
result_key="data_wrangled",
|
287
|
+
error_key="data_wrangler_error",
|
288
|
+
code_snippet_key="data_wrangler_function",
|
289
|
+
agent_function_name="data_wrangler",
|
290
|
+
# pre_processing=pre_processing,
|
291
|
+
# post_processing=post_processing,
|
292
|
+
error_message_prefix="An error occurred during data wrangling: "
|
293
|
+
)
|
294
|
+
|
295
|
+
def fix_data_wrangler_code(state: GraphState):
|
296
|
+
data_wrangler_prompt = """
|
297
|
+
You are a Data Wrangling Agent. Your job is to create a data_wrangler() function that can be run on the data provided. The function is currently broken and needs to be fixed.
|
298
|
+
|
299
|
+
Make sure to only return the function definition for data_wrangler().
|
300
|
+
|
301
|
+
Return Python code in ```python``` format with a single function definition, data_wrangler(data_raw), that includes all imports inside the function.
|
302
|
+
|
303
|
+
This is the broken code (please fix):
|
304
|
+
{code_snippet}
|
305
|
+
|
306
|
+
Last Known Error:
|
307
|
+
{error}
|
308
|
+
"""
|
309
|
+
|
310
|
+
return node_func_fix_agent_code(
|
311
|
+
state=state,
|
312
|
+
code_snippet_key="data_wrangler_function",
|
313
|
+
error_key="data_wrangler_error",
|
314
|
+
llm=llm,
|
315
|
+
prompt_template=data_wrangler_prompt,
|
316
|
+
agent_name=AGENT_NAME,
|
317
|
+
log=log,
|
318
|
+
file_path=state.get("data_wrangler_function_path"),
|
319
|
+
)
|
320
|
+
|
321
|
+
def explain_data_wrangler_code(state: GraphState):
|
322
|
+
return node_func_explain_agent_code(
|
323
|
+
state=state,
|
324
|
+
code_snippet_key="data_wrangler_function",
|
325
|
+
result_key="messages",
|
326
|
+
error_key="data_wrangler_error",
|
327
|
+
llm=llm,
|
328
|
+
role=AGENT_NAME,
|
329
|
+
explanation_prompt_template="""
|
330
|
+
Explain the data wrangling steps that the data wrangling agent performed in this function.
|
331
|
+
Keep the summary succinct and to the point.\n\n# Data Wrangling Agent:\n\n{code}
|
332
|
+
""",
|
333
|
+
success_prefix="# Data Wrangling Agent:\n\n ",
|
334
|
+
error_message="The Data Wrangling Agent encountered an error during data wrangling. Data could not be explained."
|
335
|
+
)
|
336
|
+
|
337
|
+
# Define the graph
|
338
|
+
node_functions = {
|
339
|
+
"recommend_wrangling_steps": recommend_wrangling_steps,
|
340
|
+
"human_review": human_review,
|
341
|
+
"create_data_wrangler_code": create_data_wrangler_code,
|
342
|
+
"execute_data_wrangler_code": execute_data_wrangler_code,
|
343
|
+
"fix_data_wrangler_code": fix_data_wrangler_code,
|
344
|
+
"explain_data_wrangler_code": explain_data_wrangler_code
|
345
|
+
}
|
346
|
+
|
347
|
+
app = create_coding_agent_graph(
|
348
|
+
GraphState=GraphState,
|
349
|
+
node_functions=node_functions,
|
350
|
+
recommended_steps_node_name="recommend_wrangling_steps",
|
351
|
+
create_code_node_name="create_data_wrangler_code",
|
352
|
+
execute_code_node_name="execute_data_wrangler_code",
|
353
|
+
fix_code_node_name="fix_data_wrangler_code",
|
354
|
+
explain_code_node_name="explain_data_wrangler_code",
|
355
|
+
error_key="data_wrangler_error",
|
356
|
+
human_in_the_loop=human_in_the_loop,
|
357
|
+
human_review_node_name="human_review",
|
358
|
+
checkpointer=MemorySaver() if human_in_the_loop else None
|
359
|
+
)
|
360
|
+
|
361
|
+
return app
|
362
|
+
|
363
|
+
|
364
|
+
|
365
|
+
|