ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -4,11 +4,11 @@
4
4
  # * Agents: Data Wrangling Agent
5
5
 
6
6
  # Libraries
7
- from typing import TypedDict, Annotated, Sequence, Literal, Union
7
+ from typing import TypedDict, Annotated, Sequence, Literal, Union, Optional
8
8
  import operator
9
9
  import os
10
- import io
11
10
  import pandas as pd
11
+ from IPython.display import Markdown
12
12
 
13
13
  from langchain.prompts import PromptTemplate
14
14
  from langchain_core.messages import BaseMessage
@@ -20,10 +20,11 @@ from ai_data_science_team.templates import(
20
20
  node_func_human_review,
21
21
  node_func_fix_agent_code,
22
22
  node_func_explain_agent_code,
23
- create_coding_agent_graph
23
+ create_coding_agent_graph,
24
+ BaseAgent,
24
25
  )
25
26
  from ai_data_science_team.tools.parsers import PythonOutputParser
26
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
27
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name, format_recommended_steps
27
28
  from ai_data_science_team.tools.metadata import get_dataframe_summary
28
29
  from ai_data_science_team.tools.logging import log_ai_function
29
30
 
@@ -31,13 +32,414 @@ from ai_data_science_team.tools.logging import log_ai_function
31
32
  AGENT_NAME = "data_wrangling_agent"
32
33
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
33
34
 
35
+ # Class
36
+
37
+ class DataWranglingAgent(BaseAgent):
38
+ """
39
+ Creates a data wrangling agent that can work with one or more datasets, performing operations such as
40
+ joining/merging multiple datasets, reshaping, aggregating, encoding, creating computed features,
41
+ and ensuring consistent data types. The agent generates a Python function to wrangle the data,
42
+ executes the function, and logs the process (if enabled).
43
+
44
+ The agent can handle:
45
+ - A single dataset (provided as a dictionary of {column: list_of_values})
46
+ - Multiple datasets (provided as a list of such dictionaries)
47
+
48
+ Key wrangling steps can include:
49
+ - Merging or joining datasets
50
+ - Pivoting/melting data for reshaping
51
+ - GroupBy aggregations (sums, means, counts, etc.)
52
+ - Encoding categorical variables
53
+ - Computing new columns from existing ones
54
+ - Dropping or rearranging columns
55
+ - Any additional user instructions
56
+
57
+ Parameters
58
+ ----------
59
+ model : langchain.llms.base.LLM
60
+ The language model used to generate the data wrangling function.
61
+ n_samples : int, optional
62
+ Number of samples to show in the data summary for wrangling. Defaults to 30.
63
+ log : bool, optional
64
+ Whether to log the generated code and errors. Defaults to False.
65
+ log_path : str, optional
66
+ Directory path for storing log files. Defaults to None.
67
+ file_name : str, optional
68
+ Name of the file for saving the generated response. Defaults to "data_wrangler.py".
69
+ function_name : str, optional
70
+ Name of the function to be generated. Defaults to "data_wrangler".
71
+ overwrite : bool, optional
72
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
73
+ human_in_the_loop : bool, optional
74
+ Enables user review of data wrangling instructions. Defaults to False.
75
+ bypass_recommended_steps : bool, optional
76
+ If True, skips the step that generates recommended data wrangling steps. Defaults to False.
77
+ bypass_explain_code : bool, optional
78
+ If True, skips the step that provides code explanations. Defaults to False.
79
+
80
+ Methods
81
+ -------
82
+ update_params(**kwargs)
83
+ Updates the agent's parameters and rebuilds the compiled state graph.
84
+
85
+ ainvoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
86
+ Asynchronously wrangles the provided dataset(s) based on user instructions.
87
+
88
+ invoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
89
+ Synchronously wrangles the provided dataset(s) based on user instructions.
90
+
91
+ explain_wrangling_steps()
92
+ Returns an explanation of the wrangling steps performed by the agent.
93
+
94
+ get_log_summary()
95
+ Retrieves a summary of logged operations if logging is enabled.
96
+
97
+ get_data_wrangled()
98
+ Retrieves the final wrangled dataset (as a dictionary of {column: list_of_values}).
99
+
100
+ get_data_raw()
101
+ Retrieves the raw dataset(s).
102
+
103
+ get_data_wrangler_function()
104
+ Retrieves the generated Python function used for data wrangling.
105
+
106
+ get_recommended_wrangling_steps()
107
+ Retrieves the agent's recommended wrangling steps.
108
+
109
+ get_response()
110
+ Returns the full response dictionary from the agent.
111
+
112
+ show()
113
+ Displays the agent's mermaid diagram for visual inspection of the compiled graph.
114
+
115
+ Examples
116
+ --------
117
+ ```python
118
+ import pandas as pd
119
+ from langchain_openai import ChatOpenAI
120
+ from ai_data_science_team.agents import DataWranglingAgent
121
+
122
+ # Single dataset example
123
+ llm = ChatOpenAI(model="gpt-4o-mini")
124
+
125
+ data_wrangling_agent = DataWranglingAgent(
126
+ model=llm,
127
+ n_samples=30,
128
+ log=True,
129
+ log_path="logs",
130
+ human_in_the_loop=True
131
+ )
132
+
133
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
134
+
135
+ data_wrangling_agent.invoke_agent(
136
+ user_instructions="Group by 'gender' and compute mean of 'tenure'.",
137
+ data_raw=df, # data_raw can be df.to_dict() or just a DataFrame
138
+ max_retries=3,
139
+ retry_count=0
140
+ )
141
+
142
+ data_wrangled = data_wrangling_agent.get_data_wrangled()
143
+ response = data_wrangling_agent.get_response()
144
+
145
+ # Multiple dataset example (list of dicts)
146
+ df1 = pd.DataFrame({'id': [1,2,3], 'val1': [10,20,30]})
147
+ df2 = pd.DataFrame({'id': [1,2,3], 'val2': [40,50,60]})
148
+
149
+ data_wrangling_agent.invoke_agent(
150
+ user_instructions="Merge these two datasets on 'id' and compute a new column 'val_sum' = val1+val2",
151
+ data_raw=[df1, df2], # multiple datasets
152
+ max_retries=3,
153
+ retry_count=0
154
+ )
155
+
156
+ data_wrangled = data_wrangling_agent.get_data_wrangled()
157
+ ```
158
+
159
+ Returns
160
+ -------
161
+ DataWranglingAgent : langchain.graphs.CompiledStateGraph
162
+ A data wrangling agent implemented as a compiled state graph.
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ model,
168
+ n_samples=30,
169
+ log=False,
170
+ log_path=None,
171
+ file_name="data_wrangler.py",
172
+ function_name="data_wrangler",
173
+ overwrite=True,
174
+ human_in_the_loop=False,
175
+ bypass_recommended_steps=False,
176
+ bypass_explain_code=False
177
+ ):
178
+ self._params = {
179
+ "model": model,
180
+ "n_samples": n_samples,
181
+ "log": log,
182
+ "log_path": log_path,
183
+ "file_name": file_name,
184
+ "function_name": function_name,
185
+ "overwrite": overwrite,
186
+ "human_in_the_loop": human_in_the_loop,
187
+ "bypass_recommended_steps": bypass_recommended_steps,
188
+ "bypass_explain_code": bypass_explain_code
189
+ }
190
+ self._compiled_graph = self._make_compiled_graph()
191
+ self.response = None
192
+
193
+ def _make_compiled_graph(self):
194
+ """
195
+ Create the compiled graph for the data wrangling agent.
196
+ Running this method will reset the response to None.
197
+ """
198
+ self.response = None
199
+ return make_data_wrangling_agent(**self._params)
200
+
201
+ def update_params(self, **kwargs):
202
+ """
203
+ Updates the agent's parameters and rebuilds the compiled graph.
204
+ """
205
+ for k, v in kwargs.items():
206
+ self._params[k] = v
207
+ self._compiled_graph = self._make_compiled_graph()
208
+
209
+ def ainvoke_agent(
210
+ self,
211
+ data_raw: Union[pd.DataFrame, dict, list],
212
+ user_instructions: str=None,
213
+ max_retries:int=3,
214
+ retry_count:int=0,
215
+ **kwargs
216
+ ):
217
+ """
218
+ Asynchronously wrangles the provided dataset(s) based on user instructions.
219
+ The response is stored in the 'response' attribute.
220
+
221
+ Parameters
222
+ ----------
223
+ data_raw : Union[pd.DataFrame, dict, list]
224
+ The raw dataset(s) to be wrangled.
225
+ Can be a single DataFrame, a single dict ({col: list_of_values}),
226
+ or a list of dicts if multiple datasets are provided.
227
+ user_instructions : str
228
+ Instructions for data wrangling.
229
+ max_retries : int
230
+ Maximum retry attempts.
231
+ retry_count : int
232
+ Current retry attempt count.
233
+ **kwargs
234
+ Additional keyword arguments to pass to ainvoke().
235
+
236
+ Returns
237
+ -------
238
+ None
239
+ """
240
+ data_input = self._convert_data_input(data_raw)
241
+ response = self._compiled_graph.ainvoke({
242
+ "user_instructions": user_instructions,
243
+ "data_raw": data_input,
244
+ "max_retries": max_retries,
245
+ "retry_count": retry_count
246
+ }, **kwargs)
247
+ self.response = response
248
+ return None
249
+
250
+ def invoke_agent(
251
+ self,
252
+ data_raw: Union[pd.DataFrame, dict, list],
253
+ user_instructions: str=None,
254
+ max_retries:int=3,
255
+ retry_count:int=0,
256
+ **kwargs
257
+ ):
258
+ """
259
+ Synchronously wrangles the provided dataset(s) based on user instructions.
260
+ The response is stored in the 'response' attribute.
261
+
262
+ Parameters
263
+ ----------
264
+ data_raw : Union[pd.DataFrame, dict, list]
265
+ The raw dataset(s) to be wrangled.
266
+ Can be a single DataFrame, a single dict, or a list of dicts.
267
+ user_instructions : str
268
+ Instructions for data wrangling agent.
269
+ max_retries : int
270
+ Maximum retry attempts.
271
+ retry_count : int
272
+ Current retry attempt count.
273
+ **kwargs
274
+ Additional keyword arguments to pass to invoke().
275
+
276
+ Returns
277
+ -------
278
+ None
279
+ """
280
+ data_input = self._convert_data_input(data_raw)
281
+ response = self._compiled_graph.invoke({
282
+ "user_instructions": user_instructions,
283
+ "data_raw": data_input,
284
+ "max_retries": max_retries,
285
+ "retry_count": retry_count
286
+ }, **kwargs)
287
+ self.response = response
288
+ return None
289
+
290
+ def explain_wrangling_steps(self):
291
+ """
292
+ Provides an explanation of the wrangling steps performed by the agent.
293
+
294
+ Returns
295
+ -------
296
+ str or list
297
+ Explanation of the data wrangling steps.
298
+ """
299
+ if self.response:
300
+ return self.response.get("messages", [])
301
+ return []
302
+
303
+ def get_log_summary(self, markdown=False):
304
+ """
305
+ Logs a summary of the agent's operations, if logging is enabled.
306
+
307
+ Parameters
308
+ ----------
309
+ markdown : bool, optional
310
+ If True, returns the summary in Markdown.
311
+
312
+ Returns
313
+ -------
314
+ str or None
315
+ The log details, or None if not available.
316
+ """
317
+ if self.response and self.response.get("data_wrangler_function_path"):
318
+ log_details = f"Log Path: {self.response.get('data_wrangler_function_path')}"
319
+ if markdown:
320
+ return Markdown(log_details)
321
+ else:
322
+ return log_details
323
+ return None
324
+
325
+ def get_data_wrangled(self) -> Optional[pd.DataFrame]:
326
+ """
327
+ Retrieves the wrangled data after running invoke_agent() or ainvoke_agent().
328
+
329
+ Returns
330
+ -------
331
+ pd.DataFrame or None
332
+ The wrangled dataset as a pandas DataFrame (if available).
333
+ """
334
+ if self.response and "data_wrangled" in self.response:
335
+ return pd.DataFrame(self.response["data_wrangled"])
336
+ return None
337
+
338
+ def get_data_raw(self) -> Union[dict, list, None]:
339
+ """
340
+ Retrieves the original raw data from the last invocation.
341
+
342
+ Returns
343
+ -------
344
+ Union[dict, list, None]
345
+ The original dataset(s) as a single dict or a list of dicts, or None if not available.
346
+ """
347
+ if self.response and "data_raw" in self.response:
348
+ return self.response["data_raw"]
349
+ return None
350
+
351
+ def get_data_wrangler_function(self, markdown=False) -> Optional[str]:
352
+ """
353
+ Retrieves the generated data wrangling function code.
354
+
355
+ Parameters
356
+ ----------
357
+ markdown : bool, optional
358
+ If True, returns the function in Markdown code block format.
359
+
360
+ Returns
361
+ -------
362
+ str or None
363
+ The Python function code, or None if not available.
364
+ """
365
+ if self.response and "data_wrangler_function" in self.response:
366
+ code = self.response["data_wrangler_function"]
367
+ if markdown:
368
+ return Markdown(f"```python\n{code}\n```")
369
+ return code
370
+ return None
371
+
372
+ def get_recommended_wrangling_steps(self, markdown=False) -> Optional[str]:
373
+ """
374
+ Retrieves the agent's recommended data wrangling steps.
375
+
376
+ Parameters
377
+ ----------
378
+ markdown : bool, optional
379
+ If True, returns the steps in Markdown format.
380
+
381
+ Returns
382
+ -------
383
+ str or None
384
+ The recommended steps, or None if not available.
385
+ """
386
+ if self.response and "recommended_steps" in self.response:
387
+ steps = self.response["recommended_steps"]
388
+ if markdown:
389
+ return Markdown(steps)
390
+ return steps
391
+ return None
392
+
393
+ @staticmethod
394
+ def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
395
+ """
396
+ Internal utility to convert data_raw (which could be a DataFrame, dict, or list of dicts)
397
+ into the format expected by the underlying agent (dict or list of dicts).
398
+
399
+ Parameters
400
+ ----------
401
+ data_raw : Union[pd.DataFrame, dict, list]
402
+ The raw input data to be converted.
403
+
404
+ Returns
405
+ -------
406
+ Union[dict, list]
407
+ The data in a dictionary or list-of-dictionaries format.
408
+ """
409
+ # If a single DataFrame, convert to dict
410
+ if isinstance(data_raw, pd.DataFrame):
411
+ return data_raw.to_dict()
412
+
413
+ # If it's already a dict (single dataset)
414
+ if isinstance(data_raw, dict):
415
+ return data_raw
416
+
417
+ # If it's already a list, check if it's a list of DataFrames or dicts
418
+ if isinstance(data_raw, list):
419
+ # Convert any DataFrame item to dict
420
+ converted_list = []
421
+ for item in data_raw:
422
+ if isinstance(item, pd.DataFrame):
423
+ converted_list.append(item.to_dict())
424
+ elif isinstance(item, dict):
425
+ converted_list.append(item)
426
+ else:
427
+ raise ValueError("List must contain only DataFrames or dictionaries.")
428
+ return converted_list
429
+
430
+ raise ValueError("data_raw must be a DataFrame, a dict, or a list of dicts/DataFrames.")
431
+
432
+
433
+ # Function
434
+
34
435
  def make_data_wrangling_agent(
35
436
  model,
36
437
  n_samples=30,
37
438
  log=False,
38
439
  log_path=None,
39
440
  file_name="data_wrangler.py",
40
- overwrite = True,
441
+ function_name="data_wrangler",
442
+ overwrite=True,
41
443
  human_in_the_loop=False,
42
444
  bypass_recommended_steps=False,
43
445
  bypass_explain_code=False
@@ -73,6 +475,8 @@ def make_data_wrangling_agent(
73
475
  The path to the directory where the log files should be stored. Defaults to "logs/".
74
476
  file_name : str, optional
75
477
  The name of the file to save the response to. Defaults to "data_wrangler.py".
478
+ function_name : str, optional
479
+ The name of the function to be generated. Defaults to "data_wrangler".
76
480
  overwrite : bool, optional
77
481
  Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
78
482
  Defaults to True.
@@ -115,6 +519,11 @@ def make_data_wrangling_agent(
115
519
  """
116
520
  llm = model
117
521
 
522
+ # Human in th loop requires recommended steps
523
+ if bypass_recommended_steps and human_in_the_loop:
524
+ bypass_recommended_steps = False
525
+ print("Bypass recommended steps set to False to enable human in the loop.")
526
+
118
527
  # Setup Log Directory
119
528
  if log:
120
529
  if log_path is None:
@@ -205,7 +614,7 @@ def make_data_wrangling_agent(
205
614
  })
206
615
 
207
616
  return {
208
- "recommended_steps": "\n\n# Recommended Wrangling Steps:\n" + recommended_steps.content.strip(),
617
+ "recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Wrangling Steps:"),
209
618
  "all_datasets_summary": all_datasets_summary_str,
210
619
  }
211
620
 
@@ -244,7 +653,7 @@ def make_data_wrangling_agent(
244
653
 
245
654
  data_wrangling_prompt = PromptTemplate(
246
655
  template="""
247
- You are a Data Wrangling Coding Agent. Your job is to create a data_wrangler() function that can be run on the provided data.
656
+ You are a Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data.
248
657
 
249
658
  Follow these recommended steps:
250
659
  {recommended_steps}
@@ -254,10 +663,10 @@ def make_data_wrangling_agent(
254
663
  Below are summaries of all datasets provided. If more than one dataset is provided, you may need to merge or join them.:
255
664
  {all_datasets_summary}
256
665
 
257
- Return Python code in ```python``` format with a single function definition, data_wrangler(), that includes all imports inside the function. And returns a single pandas data frame.
666
+ Return Python code in ```python``` format with a single function definition, {function_name}(), that includes all imports inside the function. And returns a single pandas data frame.
258
667
 
259
668
  ```python
260
- def data_wrangler(data_list):
669
+ def {function_name}(data_list):
261
670
  '''
262
671
  Wrangle the data provided in data.
263
672
 
@@ -279,14 +688,15 @@ def make_data_wrangling_agent(
279
688
 
280
689
 
281
690
  """,
282
- input_variables=["recommended_steps", "all_datasets_summary"]
691
+ input_variables=["recommended_steps", "all_datasets_summary", "function_name"]
283
692
  )
284
693
 
285
694
  data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
286
695
 
287
696
  response = data_wrangling_agent.invoke({
288
697
  "recommended_steps": state.get("recommended_steps"),
289
- "all_datasets_summary": all_datasets_summary_str
698
+ "all_datasets_summary": all_datasets_summary_str,
699
+ "function_name": function_name
290
700
  })
291
701
 
292
702
  response = relocate_imports_inside_function(response)
@@ -304,7 +714,8 @@ def make_data_wrangling_agent(
304
714
  return {
305
715
  "data_wrangler_function" : response,
306
716
  "data_wrangler_function_path": file_path,
307
- "data_wrangler_function_name": file_name_2,
717
+ "data_wrangler_file_name": file_name_2,
718
+ "data_wrangler_function_name": function_name,
308
719
  "all_datasets_summary": all_datasets_summary_str
309
720
  }
310
721
 
@@ -318,6 +729,33 @@ def make_data_wrangling_agent(
318
729
  user_instructions_key="user_instructions",
319
730
  recommended_steps_key="recommended_steps"
320
731
  )
732
+
733
+ # Human Review
734
+
735
+ prompt_text_human_review = "Are the following data wrangling instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
736
+
737
+ if not bypass_explain_code:
738
+ def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "explain_data_wrangler_code"]]:
739
+ return node_func_human_review(
740
+ state=state,
741
+ prompt_text=prompt_text_human_review,
742
+ yes_goto= 'explain_data_wrangler_code',
743
+ no_goto="recommend_wrangling_steps",
744
+ user_instructions_key="user_instructions",
745
+ recommended_steps_key="recommended_steps",
746
+ code_snippet_key="data_wrangler_function",
747
+ )
748
+ else:
749
+ def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "__end__"]]:
750
+ return node_func_human_review(
751
+ state=state,
752
+ prompt_text=prompt_text_human_review,
753
+ yes_goto= '__end__',
754
+ no_goto="recommend_wrangling_steps",
755
+ user_instructions_key="user_instructions",
756
+ recommended_steps_key="recommended_steps",
757
+ code_snippet_key="data_wrangler_function",
758
+ )
321
759
 
322
760
  def execute_data_wrangler_code(state: GraphState):
323
761
  return node_func_execute_agent_code_on_data(
@@ -326,7 +764,7 @@ def make_data_wrangling_agent(
326
764
  result_key="data_wrangled",
327
765
  error_key="data_wrangler_error",
328
766
  code_snippet_key="data_wrangler_function",
329
- agent_function_name="data_wrangler",
767
+ agent_function_name=state.get("data_wrangler_function_name"),
330
768
  # pre_processing=pre_processing,
331
769
  post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
332
770
  error_message_prefix="An error occurred during data wrangling: "
@@ -334,11 +772,11 @@ def make_data_wrangling_agent(
334
772
 
335
773
  def fix_data_wrangler_code(state: GraphState):
336
774
  data_wrangler_prompt = """
337
- You are a Data Wrangling Agent. Your job is to create a data_wrangler() function that can be run on the data provided. The function is currently broken and needs to be fixed.
775
+ You are a Data Wrangling Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
338
776
 
339
- Make sure to only return the function definition for data_wrangler().
777
+ Make sure to only return the function definition for {function_name}().
340
778
 
341
- Return Python code in ```python``` format with a single function definition, data_wrangler(data_raw), that includes all imports inside the function.
779
+ Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
342
780
 
343
781
  This is the broken code (please fix):
344
782
  {code_snippet}
@@ -356,6 +794,7 @@ def make_data_wrangling_agent(
356
794
  agent_name=AGENT_NAME,
357
795
  log=log,
358
796
  file_path=state.get("data_wrangler_function_path"),
797
+ function_name=state.get("data_wrangler_function_name"),
359
798
  )
360
799
 
361
800
  def explain_data_wrangler_code(state: GraphState):