ai-data-science-team 0.0.0.9000__py3-none-any.whl → 0.0.0.9005__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,365 @@
1
+ # BUSINESS SCIENCE UNIVERSITY
2
+ # AI DATA SCIENCE TEAM
3
+ # ***
4
+ # * Agents: Data Wrangling Agent
5
+
6
+ # Libraries
7
+ from typing import TypedDict, Annotated, Sequence, Literal, Union
8
+ import operator
9
+ import os
10
+ import io
11
+ import pandas as pd
12
+
13
+ from langchain.prompts import PromptTemplate
14
+ from langchain_core.messages import BaseMessage
15
+ from langgraph.types import Command
16
+ from langgraph.checkpoint.memory import MemorySaver
17
+
18
+ from ai_data_science_team.templates.agent_templates import(
19
+ node_func_execute_agent_code_on_data,
20
+ node_func_human_review,
21
+ node_func_fix_agent_code,
22
+ node_func_explain_agent_code,
23
+ create_coding_agent_graph
24
+ )
25
+ from ai_data_science_team.tools.parsers import PythonOutputParser
26
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
27
+ from ai_data_science_team.tools.data_analysis import summarize_dataframes
28
+ from ai_data_science_team.tools.logging import log_ai_function
29
+
30
+ # Setup Logging Path
31
+ AGENT_NAME = "data_wrangling_agent"
32
+ LOG_PATH = os.path.join(os.getcwd(), "logs/")
33
+
34
+ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True, human_in_the_loop=False):
35
+ """
36
+ Creates a data wrangling agent that can be run on one or more datasets. The agent can be
37
+ instructed to perform common data wrangling steps such as:
38
+
39
+ - Joining or merging multiple datasets
40
+ - Reshaping data (pivoting, melting)
41
+ - Aggregating data via groupby operations
42
+ - Encoding categorical variables (one-hot, label encoding)
43
+ - Creating computed features (e.g., ratio of two columns)
44
+ - Ensuring consistent data types
45
+ - Dropping or rearranging columns
46
+
47
+ The agent takes in one or more datasets (passed as a dictionary or list of dictionaries if working on multiple dictionaries), user instructions,
48
+ and outputs a python function that can be used to wrangle the data. If multiple datasets
49
+ are provided, the agent should combine or transform them according to user instructions.
50
+
51
+ Parameters
52
+ ----------
53
+ model : langchain.llms.base.LLM
54
+ The language model to use to generate code.
55
+ log : bool, optional
56
+ Whether or not to log the code generated and any errors that occur.
57
+ Defaults to False.
58
+ log_path : str, optional
59
+ The path to the directory where the log files should be stored. Defaults to "logs/".
60
+ overwrite : bool, optional
61
+ Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
62
+ Defaults to True.
63
+ human_in_the_loop : bool, optional
64
+ Whether or not to use human in the loop. If True, adds an interrupt and human-in-the-loop
65
+ step that asks the user to review the data wrangling instructions. Defaults to False.
66
+
67
+ Example
68
+ -------
69
+ ``` python
70
+ from langchain_openai import ChatOpenAI
71
+ import pandas as pd
72
+
73
+ df = pd.DataFrame({
74
+ 'category': ['A', 'B', 'A', 'C'],
75
+ 'value': [10, 20, 15, 5]
76
+ })
77
+
78
+ llm = ChatOpenAI(model="gpt-4o-mini")
79
+
80
+ data_wrangling_agent = make_data_wrangling_agent(llm)
81
+
82
+ response = data_wrangling_agent.invoke({
83
+ "user_instructions": "Calculate the sum and mean of 'value' by 'category'.",
84
+ "data_raw": df.to_dict(),
85
+ "max_retries":3,
86
+ "retry_count":0
87
+ })
88
+ pd.DataFrame(response['data_wrangled'])
89
+ ```
90
+
91
+ Returns
92
+ -------
93
+ app : langchain.graphs.StateGraph
94
+ The data wrangling agent as a state graph.
95
+ """
96
+ llm = model
97
+
98
+ # Setup Log Directory
99
+ if log:
100
+ if log_path is None:
101
+ log_path = LOG_PATH
102
+ if not os.path.exists(log_path):
103
+ os.makedirs(log_path)
104
+
105
+ class GraphState(TypedDict):
106
+ messages: Annotated[Sequence[BaseMessage], operator.add]
107
+ user_instructions: str
108
+ recommended_steps: str
109
+ # data_raw should be a dict for a single dataset or a list of dicts for multiple datasets
110
+ data_raw: Union[dict, list]
111
+ data_wrangled: dict
112
+ all_datasets_summary: str
113
+ data_wrangler_function: str
114
+ data_wrangler_function_path: str
115
+ data_wrangler_function_name: str
116
+ data_wrangler_error: str
117
+ max_retries: int
118
+ retry_count: int
119
+
120
+ def recommend_wrangling_steps(state: GraphState):
121
+ print("---DATA WRANGLING AGENT----")
122
+ print(" * RECOMMEND WRANGLING STEPS")
123
+
124
+ data_raw = state.get("data_raw")
125
+
126
+ if isinstance(data_raw, dict):
127
+ # Single dataset scenario
128
+ primary_dataset_name = "main"
129
+ datasets = {primary_dataset_name: data_raw}
130
+ elif isinstance(data_raw, list) and all(isinstance(item, dict) for item in data_raw):
131
+ # Multiple datasets scenario
132
+ datasets = {f"dataset_{i}": d for i, d in enumerate(data_raw, start=1)}
133
+ primary_dataset_name = "dataset_1"
134
+ else:
135
+ raise ValueError("data_raw must be a dict or a list of dicts.")
136
+
137
+ # Convert all datasets to DataFrames for inspection
138
+ dataframes = {name: pd.DataFrame.from_dict(d) for name, d in datasets.items()}
139
+
140
+ # Create a summary for all datasets
141
+ # We'll include a short sample and info for each dataset
142
+ all_datasets_summary = summarize_dataframes(dataframes)
143
+
144
+ # Join all datasets summaries into one big text block
145
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
146
+
147
+ # Prepare the prompt:
148
+ # We now include summaries for all datasets, not just the primary dataset.
149
+ # The LLM can then use all this info to recommend steps that consider merging/joining.
150
+ recommend_steps_prompt = PromptTemplate(
151
+ template="""
152
+ You are a Data Wrangling Expert. Given the following data (one or multiple datasets) and user instructions,
153
+ recommend a series of numbered steps to wrangle the data based on a user's needs.
154
+
155
+ You can use any common data wrangling techniques such as joining, reshaping, aggregating, encoding, etc.
156
+
157
+ If multiple datasets are provided, you may need to recommend how to merge or join them.
158
+
159
+ Also consider any special transformations requested by the user. If the user instructions
160
+ say to do something else or not to do certain steps, follow those instructions.
161
+
162
+ User instructions:
163
+ {user_instructions}
164
+
165
+ Previously Recommended Steps (if any):
166
+ {recommended_steps}
167
+
168
+ Below are summaries of all datasets provided:
169
+ {all_datasets_summary}
170
+
171
+ Return your recommended steps as a numbered point list, explaining briefly why each step is needed.
172
+
173
+ Avoid these:
174
+ 1. Do not include steps to save files.
175
+ """,
176
+ input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
177
+ )
178
+
179
+ steps_agent = recommend_steps_prompt | llm
180
+ recommended_steps = steps_agent.invoke({
181
+ "user_instructions": state.get("user_instructions"),
182
+ "recommended_steps": state.get("recommended_steps"),
183
+ "all_datasets_summary": all_datasets_summary_str,
184
+ })
185
+
186
+ return {
187
+ "recommended_steps": "\n\n# Recommended Wrangling Steps:\n" + recommended_steps.content.strip(),
188
+ "all_datasets_summary": all_datasets_summary_str,
189
+ }
190
+
191
+
192
+ def create_data_wrangler_code(state: GraphState):
193
+ print(" * CREATE DATA WRANGLER CODE")
194
+
195
+ data_wrangling_prompt = PromptTemplate(
196
+ template="""
197
+ You are a Data Wrangling Coding Agent. Your job is to create a data_wrangler() function that can be run on the provided data.
198
+
199
+ Follow these recommended steps:
200
+ {recommended_steps}
201
+
202
+ If multiple datasets are provided, you may need to merge or join them. Make sure to handle that scenario based on the recommended steps and user instructions.
203
+
204
+ Below are summaries of all datasets provided. If more than one dataset is provided, you may need to merge or join them.:
205
+ {all_datasets_summary}
206
+
207
+ Return Python code in ```python``` format with a single function definition, data_wrangler(), that includes all imports inside the function. And returns a single pandas data frame.
208
+
209
+ ```python
210
+ def data_wrangler(data_list):
211
+ '''
212
+ Wrangle the data provided in data.
213
+
214
+ data_list: A list of one or more pandas data frames containing the raw data to be wrangled.
215
+ '''
216
+ import pandas as pd
217
+ import numpy as np
218
+ # Implement the wrangling steps here
219
+
220
+ # Return a single DataFrame
221
+ return data_wrangled
222
+ ```
223
+
224
+ Avoid Errors:
225
+ 1. If the incoming data is not a list. Convert it to a list first.
226
+ 2. Do not specify data types inside the function arguments.
227
+
228
+ Make sure to explain any non-trivial steps with inline comments. Follow user instructions. Comment code thoroughly.
229
+
230
+
231
+ """,
232
+ input_variables=["recommended_steps", "all_datasets_summary"]
233
+ )
234
+
235
+ data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
236
+
237
+ response = data_wrangling_agent.invoke({
238
+ "recommended_steps": state.get("recommended_steps"),
239
+ "all_datasets_summary": state.get("all_datasets_summary")
240
+ })
241
+
242
+ response = relocate_imports_inside_function(response)
243
+ response = add_comments_to_top(response, agent_name=AGENT_NAME)
244
+
245
+ # For logging: store the code generated
246
+ file_path, file_name = log_ai_function(
247
+ response=response,
248
+ file_name="data_wrangler.py",
249
+ log=log,
250
+ log_path=log_path,
251
+ overwrite=overwrite
252
+ )
253
+
254
+ return {
255
+ "data_wrangler_function" : response,
256
+ "data_wrangler_function_path": file_path,
257
+ "data_wrangler_function_name": file_name
258
+ }
259
+
260
+
261
+ def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "create_data_wrangler_code"]]:
262
+ return node_func_human_review(
263
+ state=state,
264
+ prompt_text="Are the following data wrangling steps correct? (Answer 'yes' or provide modifications)\n{steps}",
265
+ yes_goto="create_data_wrangler_code",
266
+ no_goto="recommend_wrangling_steps",
267
+ user_instructions_key="user_instructions",
268
+ recommended_steps_key="recommended_steps"
269
+ )
270
+
271
+ def execute_data_wrangler_code(state: GraphState):
272
+
273
+ # Handle multiple datasets as lists
274
+ # def pre_processing(data):
275
+ # df = []
276
+ # for i in range(len(data)):
277
+ # df[i] = pd.DataFrame.from_dict(data[i])
278
+ # return df
279
+
280
+ # def post_processing(df):
281
+ # return df.to_dict()
282
+
283
+ return node_func_execute_agent_code_on_data(
284
+ state=state,
285
+ data_key="data_raw",
286
+ result_key="data_wrangled",
287
+ error_key="data_wrangler_error",
288
+ code_snippet_key="data_wrangler_function",
289
+ agent_function_name="data_wrangler",
290
+ # pre_processing=pre_processing,
291
+ # post_processing=post_processing,
292
+ error_message_prefix="An error occurred during data wrangling: "
293
+ )
294
+
295
+ def fix_data_wrangler_code(state: GraphState):
296
+ data_wrangler_prompt = """
297
+ You are a Data Wrangling Agent. Your job is to create a data_wrangler() function that can be run on the data provided. The function is currently broken and needs to be fixed.
298
+
299
+ Make sure to only return the function definition for data_wrangler().
300
+
301
+ Return Python code in ```python``` format with a single function definition, data_wrangler(data_raw), that includes all imports inside the function.
302
+
303
+ This is the broken code (please fix):
304
+ {code_snippet}
305
+
306
+ Last Known Error:
307
+ {error}
308
+ """
309
+
310
+ return node_func_fix_agent_code(
311
+ state=state,
312
+ code_snippet_key="data_wrangler_function",
313
+ error_key="data_wrangler_error",
314
+ llm=llm,
315
+ prompt_template=data_wrangler_prompt,
316
+ agent_name=AGENT_NAME,
317
+ log=log,
318
+ file_path=state.get("data_wrangler_function_path"),
319
+ )
320
+
321
+ def explain_data_wrangler_code(state: GraphState):
322
+ return node_func_explain_agent_code(
323
+ state=state,
324
+ code_snippet_key="data_wrangler_function",
325
+ result_key="messages",
326
+ error_key="data_wrangler_error",
327
+ llm=llm,
328
+ role=AGENT_NAME,
329
+ explanation_prompt_template="""
330
+ Explain the data wrangling steps that the data wrangling agent performed in this function.
331
+ Keep the summary succinct and to the point.\n\n# Data Wrangling Agent:\n\n{code}
332
+ """,
333
+ success_prefix="# Data Wrangling Agent:\n\n ",
334
+ error_message="The Data Wrangling Agent encountered an error during data wrangling. Data could not be explained."
335
+ )
336
+
337
+ # Define the graph
338
+ node_functions = {
339
+ "recommend_wrangling_steps": recommend_wrangling_steps,
340
+ "human_review": human_review,
341
+ "create_data_wrangler_code": create_data_wrangler_code,
342
+ "execute_data_wrangler_code": execute_data_wrangler_code,
343
+ "fix_data_wrangler_code": fix_data_wrangler_code,
344
+ "explain_data_wrangler_code": explain_data_wrangler_code
345
+ }
346
+
347
+ app = create_coding_agent_graph(
348
+ GraphState=GraphState,
349
+ node_functions=node_functions,
350
+ recommended_steps_node_name="recommend_wrangling_steps",
351
+ create_code_node_name="create_data_wrangler_code",
352
+ execute_code_node_name="execute_data_wrangler_code",
353
+ fix_code_node_name="fix_data_wrangler_code",
354
+ explain_code_node_name="explain_data_wrangler_code",
355
+ error_key="data_wrangler_error",
356
+ human_in_the_loop=human_in_the_loop,
357
+ human_review_node_name="human_review",
358
+ checkpointer=MemorySaver() if human_in_the_loop else None
359
+ )
360
+
361
+ return app
362
+
363
+
364
+
365
+