ai-data-science-team 0.0.0.9000__py3-none-any.whl → 0.0.0.9005__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,365 @@
1
+ # BUSINESS SCIENCE UNIVERSITY
2
+ # AI DATA SCIENCE TEAM
3
+ # ***
4
+ # * Agents: Data Wrangling Agent
5
+
6
+ # Libraries
7
+ from typing import TypedDict, Annotated, Sequence, Literal, Union
8
+ import operator
9
+ import os
10
+ import io
11
+ import pandas as pd
12
+
13
+ from langchain.prompts import PromptTemplate
14
+ from langchain_core.messages import BaseMessage
15
+ from langgraph.types import Command
16
+ from langgraph.checkpoint.memory import MemorySaver
17
+
18
+ from ai_data_science_team.templates.agent_templates import(
19
+ node_func_execute_agent_code_on_data,
20
+ node_func_human_review,
21
+ node_func_fix_agent_code,
22
+ node_func_explain_agent_code,
23
+ create_coding_agent_graph
24
+ )
25
+ from ai_data_science_team.tools.parsers import PythonOutputParser
26
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
27
+ from ai_data_science_team.tools.data_analysis import summarize_dataframes
28
+ from ai_data_science_team.tools.logging import log_ai_function
29
+
30
+ # Setup Logging Path
31
+ AGENT_NAME = "data_wrangling_agent"
32
+ LOG_PATH = os.path.join(os.getcwd(), "logs/")
33
+
34
+ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True, human_in_the_loop=False):
35
+ """
36
+ Creates a data wrangling agent that can be run on one or more datasets. The agent can be
37
+ instructed to perform common data wrangling steps such as:
38
+
39
+ - Joining or merging multiple datasets
40
+ - Reshaping data (pivoting, melting)
41
+ - Aggregating data via groupby operations
42
+ - Encoding categorical variables (one-hot, label encoding)
43
+ - Creating computed features (e.g., ratio of two columns)
44
+ - Ensuring consistent data types
45
+ - Dropping or rearranging columns
46
+
47
+ The agent takes in one or more datasets (passed as a dictionary or list of dictionaries if working on multiple dictionaries), user instructions,
48
+ and outputs a python function that can be used to wrangle the data. If multiple datasets
49
+ are provided, the agent should combine or transform them according to user instructions.
50
+
51
+ Parameters
52
+ ----------
53
+ model : langchain.llms.base.LLM
54
+ The language model to use to generate code.
55
+ log : bool, optional
56
+ Whether or not to log the code generated and any errors that occur.
57
+ Defaults to False.
58
+ log_path : str, optional
59
+ The path to the directory where the log files should be stored. Defaults to "logs/".
60
+ overwrite : bool, optional
61
+ Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
62
+ Defaults to True.
63
+ human_in_the_loop : bool, optional
64
+ Whether or not to use human in the loop. If True, adds an interrupt and human-in-the-loop
65
+ step that asks the user to review the data wrangling instructions. Defaults to False.
66
+
67
+ Example
68
+ -------
69
+ ``` python
70
+ from langchain_openai import ChatOpenAI
71
+ import pandas as pd
72
+
73
+ df = pd.DataFrame({
74
+ 'category': ['A', 'B', 'A', 'C'],
75
+ 'value': [10, 20, 15, 5]
76
+ })
77
+
78
+ llm = ChatOpenAI(model="gpt-4o-mini")
79
+
80
+ data_wrangling_agent = make_data_wrangling_agent(llm)
81
+
82
+ response = data_wrangling_agent.invoke({
83
+ "user_instructions": "Calculate the sum and mean of 'value' by 'category'.",
84
+ "data_raw": df.to_dict(),
85
+ "max_retries":3,
86
+ "retry_count":0
87
+ })
88
+ pd.DataFrame(response['data_wrangled'])
89
+ ```
90
+
91
+ Returns
92
+ -------
93
+ app : langchain.graphs.StateGraph
94
+ The data wrangling agent as a state graph.
95
+ """
96
+ llm = model
97
+
98
+ # Setup Log Directory
99
+ if log:
100
+ if log_path is None:
101
+ log_path = LOG_PATH
102
+ if not os.path.exists(log_path):
103
+ os.makedirs(log_path)
104
+
105
+ class GraphState(TypedDict):
106
+ messages: Annotated[Sequence[BaseMessage], operator.add]
107
+ user_instructions: str
108
+ recommended_steps: str
109
+ # data_raw should be a dict for a single dataset or a list of dicts for multiple datasets
110
+ data_raw: Union[dict, list]
111
+ data_wrangled: dict
112
+ all_datasets_summary: str
113
+ data_wrangler_function: str
114
+ data_wrangler_function_path: str
115
+ data_wrangler_function_name: str
116
+ data_wrangler_error: str
117
+ max_retries: int
118
+ retry_count: int
119
+
120
+ def recommend_wrangling_steps(state: GraphState):
121
+ print("---DATA WRANGLING AGENT----")
122
+ print(" * RECOMMEND WRANGLING STEPS")
123
+
124
+ data_raw = state.get("data_raw")
125
+
126
+ if isinstance(data_raw, dict):
127
+ # Single dataset scenario
128
+ primary_dataset_name = "main"
129
+ datasets = {primary_dataset_name: data_raw}
130
+ elif isinstance(data_raw, list) and all(isinstance(item, dict) for item in data_raw):
131
+ # Multiple datasets scenario
132
+ datasets = {f"dataset_{i}": d for i, d in enumerate(data_raw, start=1)}
133
+ primary_dataset_name = "dataset_1"
134
+ else:
135
+ raise ValueError("data_raw must be a dict or a list of dicts.")
136
+
137
+ # Convert all datasets to DataFrames for inspection
138
+ dataframes = {name: pd.DataFrame.from_dict(d) for name, d in datasets.items()}
139
+
140
+ # Create a summary for all datasets
141
+ # We'll include a short sample and info for each dataset
142
+ all_datasets_summary = summarize_dataframes(dataframes)
143
+
144
+ # Join all datasets summaries into one big text block
145
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
146
+
147
+ # Prepare the prompt:
148
+ # We now include summaries for all datasets, not just the primary dataset.
149
+ # The LLM can then use all this info to recommend steps that consider merging/joining.
150
+ recommend_steps_prompt = PromptTemplate(
151
+ template="""
152
+ You are a Data Wrangling Expert. Given the following data (one or multiple datasets) and user instructions,
153
+ recommend a series of numbered steps to wrangle the data based on a user's needs.
154
+
155
+ You can use any common data wrangling techniques such as joining, reshaping, aggregating, encoding, etc.
156
+
157
+ If multiple datasets are provided, you may need to recommend how to merge or join them.
158
+
159
+ Also consider any special transformations requested by the user. If the user instructions
160
+ say to do something else or not to do certain steps, follow those instructions.
161
+
162
+ User instructions:
163
+ {user_instructions}
164
+
165
+ Previously Recommended Steps (if any):
166
+ {recommended_steps}
167
+
168
+ Below are summaries of all datasets provided:
169
+ {all_datasets_summary}
170
+
171
+ Return your recommended steps as a numbered point list, explaining briefly why each step is needed.
172
+
173
+ Avoid these:
174
+ 1. Do not include steps to save files.
175
+ """,
176
+ input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
177
+ )
178
+
179
+ steps_agent = recommend_steps_prompt | llm
180
+ recommended_steps = steps_agent.invoke({
181
+ "user_instructions": state.get("user_instructions"),
182
+ "recommended_steps": state.get("recommended_steps"),
183
+ "all_datasets_summary": all_datasets_summary_str,
184
+ })
185
+
186
+ return {
187
+ "recommended_steps": "\n\n# Recommended Wrangling Steps:\n" + recommended_steps.content.strip(),
188
+ "all_datasets_summary": all_datasets_summary_str,
189
+ }
190
+
191
+
192
+ def create_data_wrangler_code(state: GraphState):
193
+ print(" * CREATE DATA WRANGLER CODE")
194
+
195
+ data_wrangling_prompt = PromptTemplate(
196
+ template="""
197
+ You are a Data Wrangling Coding Agent. Your job is to create a data_wrangler() function that can be run on the provided data.
198
+
199
+ Follow these recommended steps:
200
+ {recommended_steps}
201
+
202
+ If multiple datasets are provided, you may need to merge or join them. Make sure to handle that scenario based on the recommended steps and user instructions.
203
+
204
+ Below are summaries of all datasets provided. If more than one dataset is provided, you may need to merge or join them.:
205
+ {all_datasets_summary}
206
+
207
+ Return Python code in ```python``` format with a single function definition, data_wrangler(), that includes all imports inside the function. And returns a single pandas data frame.
208
+
209
+ ```python
210
+ def data_wrangler(data_list):
211
+ '''
212
+ Wrangle the data provided in data.
213
+
214
+ data_list: A list of one or more pandas data frames containing the raw data to be wrangled.
215
+ '''
216
+ import pandas as pd
217
+ import numpy as np
218
+ # Implement the wrangling steps here
219
+
220
+ # Return a single DataFrame
221
+ return data_wrangled
222
+ ```
223
+
224
+ Avoid Errors:
225
+ 1. If the incoming data is not a list. Convert it to a list first.
226
+ 2. Do not specify data types inside the function arguments.
227
+
228
+ Make sure to explain any non-trivial steps with inline comments. Follow user instructions. Comment code thoroughly.
229
+
230
+
231
+ """,
232
+ input_variables=["recommended_steps", "all_datasets_summary"]
233
+ )
234
+
235
+ data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
236
+
237
+ response = data_wrangling_agent.invoke({
238
+ "recommended_steps": state.get("recommended_steps"),
239
+ "all_datasets_summary": state.get("all_datasets_summary")
240
+ })
241
+
242
+ response = relocate_imports_inside_function(response)
243
+ response = add_comments_to_top(response, agent_name=AGENT_NAME)
244
+
245
+ # For logging: store the code generated
246
+ file_path, file_name = log_ai_function(
247
+ response=response,
248
+ file_name="data_wrangler.py",
249
+ log=log,
250
+ log_path=log_path,
251
+ overwrite=overwrite
252
+ )
253
+
254
+ return {
255
+ "data_wrangler_function" : response,
256
+ "data_wrangler_function_path": file_path,
257
+ "data_wrangler_function_name": file_name
258
+ }
259
+
260
+
261
+ def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "create_data_wrangler_code"]]:
262
+ return node_func_human_review(
263
+ state=state,
264
+ prompt_text="Are the following data wrangling steps correct? (Answer 'yes' or provide modifications)\n{steps}",
265
+ yes_goto="create_data_wrangler_code",
266
+ no_goto="recommend_wrangling_steps",
267
+ user_instructions_key="user_instructions",
268
+ recommended_steps_key="recommended_steps"
269
+ )
270
+
271
+ def execute_data_wrangler_code(state: GraphState):
272
+
273
+ # Handle multiple datasets as lists
274
+ # def pre_processing(data):
275
+ # df = []
276
+ # for i in range(len(data)):
277
+ # df[i] = pd.DataFrame.from_dict(data[i])
278
+ # return df
279
+
280
+ # def post_processing(df):
281
+ # return df.to_dict()
282
+
283
+ return node_func_execute_agent_code_on_data(
284
+ state=state,
285
+ data_key="data_raw",
286
+ result_key="data_wrangled",
287
+ error_key="data_wrangler_error",
288
+ code_snippet_key="data_wrangler_function",
289
+ agent_function_name="data_wrangler",
290
+ # pre_processing=pre_processing,
291
+ # post_processing=post_processing,
292
+ error_message_prefix="An error occurred during data wrangling: "
293
+ )
294
+
295
+ def fix_data_wrangler_code(state: GraphState):
296
+ data_wrangler_prompt = """
297
+ You are a Data Wrangling Agent. Your job is to create a data_wrangler() function that can be run on the data provided. The function is currently broken and needs to be fixed.
298
+
299
+ Make sure to only return the function definition for data_wrangler().
300
+
301
+ Return Python code in ```python``` format with a single function definition, data_wrangler(data_raw), that includes all imports inside the function.
302
+
303
+ This is the broken code (please fix):
304
+ {code_snippet}
305
+
306
+ Last Known Error:
307
+ {error}
308
+ """
309
+
310
+ return node_func_fix_agent_code(
311
+ state=state,
312
+ code_snippet_key="data_wrangler_function",
313
+ error_key="data_wrangler_error",
314
+ llm=llm,
315
+ prompt_template=data_wrangler_prompt,
316
+ agent_name=AGENT_NAME,
317
+ log=log,
318
+ file_path=state.get("data_wrangler_function_path"),
319
+ )
320
+
321
+ def explain_data_wrangler_code(state: GraphState):
322
+ return node_func_explain_agent_code(
323
+ state=state,
324
+ code_snippet_key="data_wrangler_function",
325
+ result_key="messages",
326
+ error_key="data_wrangler_error",
327
+ llm=llm,
328
+ role=AGENT_NAME,
329
+ explanation_prompt_template="""
330
+ Explain the data wrangling steps that the data wrangling agent performed in this function.
331
+ Keep the summary succinct and to the point.\n\n# Data Wrangling Agent:\n\n{code}
332
+ """,
333
+ success_prefix="# Data Wrangling Agent:\n\n ",
334
+ error_message="The Data Wrangling Agent encountered an error during data wrangling. Data could not be explained."
335
+ )
336
+
337
+ # Define the graph
338
+ node_functions = {
339
+ "recommend_wrangling_steps": recommend_wrangling_steps,
340
+ "human_review": human_review,
341
+ "create_data_wrangler_code": create_data_wrangler_code,
342
+ "execute_data_wrangler_code": execute_data_wrangler_code,
343
+ "fix_data_wrangler_code": fix_data_wrangler_code,
344
+ "explain_data_wrangler_code": explain_data_wrangler_code
345
+ }
346
+
347
+ app = create_coding_agent_graph(
348
+ GraphState=GraphState,
349
+ node_functions=node_functions,
350
+ recommended_steps_node_name="recommend_wrangling_steps",
351
+ create_code_node_name="create_data_wrangler_code",
352
+ execute_code_node_name="execute_data_wrangler_code",
353
+ fix_code_node_name="fix_data_wrangler_code",
354
+ explain_code_node_name="explain_data_wrangler_code",
355
+ error_key="data_wrangler_error",
356
+ human_in_the_loop=human_in_the_loop,
357
+ human_review_node_name="human_review",
358
+ checkpointer=MemorySaver() if human_in_the_loop else None
359
+ )
360
+
361
+ return app
362
+
363
+
364
+
365
+