ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,11 +4,11 @@
4
4
  # * Agents: Data Wrangling Agent
5
5
 
6
6
  # Libraries
7
- from typing import TypedDict, Annotated, Sequence, Literal, Union
7
+ from typing import TypedDict, Annotated, Sequence, Literal, Union, Optional
8
8
  import operator
9
9
  import os
10
- import io
11
10
  import pandas as pd
11
+ from IPython.display import Markdown
12
12
 
13
13
  from langchain.prompts import PromptTemplate
14
14
  from langchain_core.messages import BaseMessage
@@ -20,10 +20,11 @@ from ai_data_science_team.templates import(
20
20
  node_func_human_review,
21
21
  node_func_fix_agent_code,
22
22
  node_func_explain_agent_code,
23
- create_coding_agent_graph
23
+ create_coding_agent_graph,
24
+ BaseAgent,
24
25
  )
25
26
  from ai_data_science_team.tools.parsers import PythonOutputParser
26
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
27
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name, format_recommended_steps
27
28
  from ai_data_science_team.tools.metadata import get_dataframe_summary
28
29
  from ai_data_science_team.tools.logging import log_ai_function
29
30
 
@@ -31,13 +32,414 @@ from ai_data_science_team.tools.logging import log_ai_function
31
32
  AGENT_NAME = "data_wrangling_agent"
32
33
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
33
34
 
35
+ # Class
36
+
37
+ class DataWranglingAgent(BaseAgent):
38
+ """
39
+ Creates a data wrangling agent that can work with one or more datasets, performing operations such as
40
+ joining/merging multiple datasets, reshaping, aggregating, encoding, creating computed features,
41
+ and ensuring consistent data types. The agent generates a Python function to wrangle the data,
42
+ executes the function, and logs the process (if enabled).
43
+
44
+ The agent can handle:
45
+ - A single dataset (provided as a dictionary of {column: list_of_values})
46
+ - Multiple datasets (provided as a list of such dictionaries)
47
+
48
+ Key wrangling steps can include:
49
+ - Merging or joining datasets
50
+ - Pivoting/melting data for reshaping
51
+ - GroupBy aggregations (sums, means, counts, etc.)
52
+ - Encoding categorical variables
53
+ - Computing new columns from existing ones
54
+ - Dropping or rearranging columns
55
+ - Any additional user instructions
56
+
57
+ Parameters
58
+ ----------
59
+ model : langchain.llms.base.LLM
60
+ The language model used to generate the data wrangling function.
61
+ n_samples : int, optional
62
+ Number of samples to show in the data summary for wrangling. Defaults to 30.
63
+ log : bool, optional
64
+ Whether to log the generated code and errors. Defaults to False.
65
+ log_path : str, optional
66
+ Directory path for storing log files. Defaults to None.
67
+ file_name : str, optional
68
+ Name of the file for saving the generated response. Defaults to "data_wrangler.py".
69
+ function_name : str, optional
70
+ Name of the function to be generated. Defaults to "data_wrangler".
71
+ overwrite : bool, optional
72
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
73
+ human_in_the_loop : bool, optional
74
+ Enables user review of data wrangling instructions. Defaults to False.
75
+ bypass_recommended_steps : bool, optional
76
+ If True, skips the step that generates recommended data wrangling steps. Defaults to False.
77
+ bypass_explain_code : bool, optional
78
+ If True, skips the step that provides code explanations. Defaults to False.
79
+
80
+ Methods
81
+ -------
82
+ update_params(**kwargs)
83
+ Updates the agent's parameters and rebuilds the compiled state graph.
84
+
85
+ ainvoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
86
+ Asynchronously wrangles the provided dataset(s) based on user instructions.
87
+
88
+ invoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
89
+ Synchronously wrangles the provided dataset(s) based on user instructions.
90
+
91
+ explain_wrangling_steps()
92
+ Returns an explanation of the wrangling steps performed by the agent.
93
+
94
+ get_log_summary()
95
+ Retrieves a summary of logged operations if logging is enabled.
96
+
97
+ get_data_wrangled()
98
+ Retrieves the final wrangled dataset (as a dictionary of {column: list_of_values}).
99
+
100
+ get_data_raw()
101
+ Retrieves the raw dataset(s).
102
+
103
+ get_data_wrangler_function()
104
+ Retrieves the generated Python function used for data wrangling.
105
+
106
+ get_recommended_wrangling_steps()
107
+ Retrieves the agent's recommended wrangling steps.
108
+
109
+ get_response()
110
+ Returns the full response dictionary from the agent.
111
+
112
+ show()
113
+ Displays the agent's mermaid diagram for visual inspection of the compiled graph.
114
+
115
+ Examples
116
+ --------
117
+ ```python
118
+ import pandas as pd
119
+ from langchain_openai import ChatOpenAI
120
+ from ai_data_science_team.agents import DataWranglingAgent
121
+
122
+ # Single dataset example
123
+ llm = ChatOpenAI(model="gpt-4o-mini")
124
+
125
+ data_wrangling_agent = DataWranglingAgent(
126
+ model=llm,
127
+ n_samples=30,
128
+ log=True,
129
+ log_path="logs",
130
+ human_in_the_loop=True
131
+ )
132
+
133
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
134
+
135
+ data_wrangling_agent.invoke_agent(
136
+ user_instructions="Group by 'gender' and compute mean of 'tenure'.",
137
+ data_raw=df, # data_raw can be df.to_dict() or just a DataFrame
138
+ max_retries=3,
139
+ retry_count=0
140
+ )
141
+
142
+ data_wrangled = data_wrangling_agent.get_data_wrangled()
143
+ response = data_wrangling_agent.get_response()
144
+
145
+ # Multiple dataset example (list of dicts)
146
+ df1 = pd.DataFrame({'id': [1,2,3], 'val1': [10,20,30]})
147
+ df2 = pd.DataFrame({'id': [1,2,3], 'val2': [40,50,60]})
148
+
149
+ data_wrangling_agent.invoke_agent(
150
+ user_instructions="Merge these two datasets on 'id' and compute a new column 'val_sum' = val1+val2",
151
+ data_raw=[df1, df2], # multiple datasets
152
+ max_retries=3,
153
+ retry_count=0
154
+ )
155
+
156
+ data_wrangled = data_wrangling_agent.get_data_wrangled()
157
+ ```
158
+
159
+ Returns
160
+ -------
161
+ DataWranglingAgent : langchain.graphs.CompiledStateGraph
162
+ A data wrangling agent implemented as a compiled state graph.
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ model,
168
+ n_samples=30,
169
+ log=False,
170
+ log_path=None,
171
+ file_name="data_wrangler.py",
172
+ function_name="data_wrangler",
173
+ overwrite=True,
174
+ human_in_the_loop=False,
175
+ bypass_recommended_steps=False,
176
+ bypass_explain_code=False
177
+ ):
178
+ self._params = {
179
+ "model": model,
180
+ "n_samples": n_samples,
181
+ "log": log,
182
+ "log_path": log_path,
183
+ "file_name": file_name,
184
+ "function_name": function_name,
185
+ "overwrite": overwrite,
186
+ "human_in_the_loop": human_in_the_loop,
187
+ "bypass_recommended_steps": bypass_recommended_steps,
188
+ "bypass_explain_code": bypass_explain_code
189
+ }
190
+ self._compiled_graph = self._make_compiled_graph()
191
+ self.response = None
192
+
193
+ def _make_compiled_graph(self):
194
+ """
195
+ Create the compiled graph for the data wrangling agent.
196
+ Running this method will reset the response to None.
197
+ """
198
+ self.response = None
199
+ return make_data_wrangling_agent(**self._params)
200
+
201
+ def update_params(self, **kwargs):
202
+ """
203
+ Updates the agent's parameters and rebuilds the compiled graph.
204
+ """
205
+ for k, v in kwargs.items():
206
+ self._params[k] = v
207
+ self._compiled_graph = self._make_compiled_graph()
208
+
209
+ def ainvoke_agent(
210
+ self,
211
+ data_raw: Union[pd.DataFrame, dict, list],
212
+ user_instructions: str=None,
213
+ max_retries:int=3,
214
+ retry_count:int=0,
215
+ **kwargs
216
+ ):
217
+ """
218
+ Asynchronously wrangles the provided dataset(s) based on user instructions.
219
+ The response is stored in the 'response' attribute.
220
+
221
+ Parameters
222
+ ----------
223
+ data_raw : Union[pd.DataFrame, dict, list]
224
+ The raw dataset(s) to be wrangled.
225
+ Can be a single DataFrame, a single dict ({col: list_of_values}),
226
+ or a list of dicts if multiple datasets are provided.
227
+ user_instructions : str
228
+ Instructions for data wrangling.
229
+ max_retries : int
230
+ Maximum retry attempts.
231
+ retry_count : int
232
+ Current retry attempt count.
233
+ **kwargs
234
+ Additional keyword arguments to pass to ainvoke().
235
+
236
+ Returns
237
+ -------
238
+ None
239
+ """
240
+ data_input = self._convert_data_input(data_raw)
241
+ response = self._compiled_graph.ainvoke({
242
+ "user_instructions": user_instructions,
243
+ "data_raw": data_input,
244
+ "max_retries": max_retries,
245
+ "retry_count": retry_count
246
+ }, **kwargs)
247
+ self.response = response
248
+ return None
249
+
250
+ def invoke_agent(
251
+ self,
252
+ data_raw: Union[pd.DataFrame, dict, list],
253
+ user_instructions: str=None,
254
+ max_retries:int=3,
255
+ retry_count:int=0,
256
+ **kwargs
257
+ ):
258
+ """
259
+ Synchronously wrangles the provided dataset(s) based on user instructions.
260
+ The response is stored in the 'response' attribute.
261
+
262
+ Parameters
263
+ ----------
264
+ data_raw : Union[pd.DataFrame, dict, list]
265
+ The raw dataset(s) to be wrangled.
266
+ Can be a single DataFrame, a single dict, or a list of dicts.
267
+ user_instructions : str
268
+ Instructions for data wrangling agent.
269
+ max_retries : int
270
+ Maximum retry attempts.
271
+ retry_count : int
272
+ Current retry attempt count.
273
+ **kwargs
274
+ Additional keyword arguments to pass to invoke().
275
+
276
+ Returns
277
+ -------
278
+ None
279
+ """
280
+ data_input = self._convert_data_input(data_raw)
281
+ response = self._compiled_graph.invoke({
282
+ "user_instructions": user_instructions,
283
+ "data_raw": data_input,
284
+ "max_retries": max_retries,
285
+ "retry_count": retry_count
286
+ }, **kwargs)
287
+ self.response = response
288
+ return None
289
+
290
+ def explain_wrangling_steps(self):
291
+ """
292
+ Provides an explanation of the wrangling steps performed by the agent.
293
+
294
+ Returns
295
+ -------
296
+ str or list
297
+ Explanation of the data wrangling steps.
298
+ """
299
+ if self.response:
300
+ return self.response.get("messages", [])
301
+ return []
302
+
303
+ def get_log_summary(self, markdown=False):
304
+ """
305
+ Logs a summary of the agent's operations, if logging is enabled.
306
+
307
+ Parameters
308
+ ----------
309
+ markdown : bool, optional
310
+ If True, returns the summary in Markdown.
311
+
312
+ Returns
313
+ -------
314
+ str or None
315
+ The log details, or None if not available.
316
+ """
317
+ if self.response and self.response.get("data_wrangler_function_path"):
318
+ log_details = f"Log Path: {self.response.get('data_wrangler_function_path')}"
319
+ if markdown:
320
+ return Markdown(log_details)
321
+ else:
322
+ return log_details
323
+ return None
324
+
325
+ def get_data_wrangled(self) -> Optional[pd.DataFrame]:
326
+ """
327
+ Retrieves the wrangled data after running invoke_agent() or ainvoke_agent().
328
+
329
+ Returns
330
+ -------
331
+ pd.DataFrame or None
332
+ The wrangled dataset as a pandas DataFrame (if available).
333
+ """
334
+ if self.response and "data_wrangled" in self.response:
335
+ return pd.DataFrame(self.response["data_wrangled"])
336
+ return None
337
+
338
+ def get_data_raw(self) -> Union[dict, list, None]:
339
+ """
340
+ Retrieves the original raw data from the last invocation.
341
+
342
+ Returns
343
+ -------
344
+ Union[dict, list, None]
345
+ The original dataset(s) as a single dict or a list of dicts, or None if not available.
346
+ """
347
+ if self.response and "data_raw" in self.response:
348
+ return self.response["data_raw"]
349
+ return None
350
+
351
+ def get_data_wrangler_function(self, markdown=False) -> Optional[str]:
352
+ """
353
+ Retrieves the generated data wrangling function code.
354
+
355
+ Parameters
356
+ ----------
357
+ markdown : bool, optional
358
+ If True, returns the function in Markdown code block format.
359
+
360
+ Returns
361
+ -------
362
+ str or None
363
+ The Python function code, or None if not available.
364
+ """
365
+ if self.response and "data_wrangler_function" in self.response:
366
+ code = self.response["data_wrangler_function"]
367
+ if markdown:
368
+ return Markdown(f"```python\n{code}\n```")
369
+ return code
370
+ return None
371
+
372
+ def get_recommended_wrangling_steps(self, markdown=False) -> Optional[str]:
373
+ """
374
+ Retrieves the agent's recommended data wrangling steps.
375
+
376
+ Parameters
377
+ ----------
378
+ markdown : bool, optional
379
+ If True, returns the steps in Markdown format.
380
+
381
+ Returns
382
+ -------
383
+ str or None
384
+ The recommended steps, or None if not available.
385
+ """
386
+ if self.response and "recommended_steps" in self.response:
387
+ steps = self.response["recommended_steps"]
388
+ if markdown:
389
+ return Markdown(steps)
390
+ return steps
391
+ return None
392
+
393
+ @staticmethod
394
+ def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
395
+ """
396
+ Internal utility to convert data_raw (which could be a DataFrame, dict, or list of dicts)
397
+ into the format expected by the underlying agent (dict or list of dicts).
398
+
399
+ Parameters
400
+ ----------
401
+ data_raw : Union[pd.DataFrame, dict, list]
402
+ The raw input data to be converted.
403
+
404
+ Returns
405
+ -------
406
+ Union[dict, list]
407
+ The data in a dictionary or list-of-dictionaries format.
408
+ """
409
+ # If a single DataFrame, convert to dict
410
+ if isinstance(data_raw, pd.DataFrame):
411
+ return data_raw.to_dict()
412
+
413
+ # If it's already a dict (single dataset)
414
+ if isinstance(data_raw, dict):
415
+ return data_raw
416
+
417
+ # If it's already a list, check if it's a list of DataFrames or dicts
418
+ if isinstance(data_raw, list):
419
+ # Convert any DataFrame item to dict
420
+ converted_list = []
421
+ for item in data_raw:
422
+ if isinstance(item, pd.DataFrame):
423
+ converted_list.append(item.to_dict())
424
+ elif isinstance(item, dict):
425
+ converted_list.append(item)
426
+ else:
427
+ raise ValueError("List must contain only DataFrames or dictionaries.")
428
+ return converted_list
429
+
430
+ raise ValueError("data_raw must be a DataFrame, a dict, or a list of dicts/DataFrames.")
431
+
432
+
433
+ # Function
434
+
34
435
  def make_data_wrangling_agent(
35
436
  model,
36
437
  n_samples=30,
37
438
  log=False,
38
439
  log_path=None,
39
440
  file_name="data_wrangler.py",
40
- overwrite = True,
441
+ function_name="data_wrangler",
442
+ overwrite=True,
41
443
  human_in_the_loop=False,
42
444
  bypass_recommended_steps=False,
43
445
  bypass_explain_code=False
@@ -73,6 +475,8 @@ def make_data_wrangling_agent(
73
475
  The path to the directory where the log files should be stored. Defaults to "logs/".
74
476
  file_name : str, optional
75
477
  The name of the file to save the response to. Defaults to "data_wrangler.py".
478
+ function_name : str, optional
479
+ The name of the function to be generated. Defaults to "data_wrangler".
76
480
  overwrite : bool, optional
77
481
  Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
78
482
  Defaults to True.
@@ -115,6 +519,11 @@ def make_data_wrangling_agent(
115
519
  """
116
520
  llm = model
117
521
 
522
+ # Human in th loop requires recommended steps
523
+ if bypass_recommended_steps and human_in_the_loop:
524
+ bypass_recommended_steps = False
525
+ print("Bypass recommended steps set to False to enable human in the loop.")
526
+
118
527
  # Setup Log Directory
119
528
  if log:
120
529
  if log_path is None:
@@ -205,7 +614,7 @@ def make_data_wrangling_agent(
205
614
  })
206
615
 
207
616
  return {
208
- "recommended_steps": "\n\n# Recommended Wrangling Steps:\n" + recommended_steps.content.strip(),
617
+ "recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Wrangling Steps:"),
209
618
  "all_datasets_summary": all_datasets_summary_str,
210
619
  }
211
620
 
@@ -244,7 +653,7 @@ def make_data_wrangling_agent(
244
653
 
245
654
  data_wrangling_prompt = PromptTemplate(
246
655
  template="""
247
- You are a Data Wrangling Coding Agent. Your job is to create a data_wrangler() function that can be run on the provided data.
656
+ You are a Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data.
248
657
 
249
658
  Follow these recommended steps:
250
659
  {recommended_steps}
@@ -254,10 +663,10 @@ def make_data_wrangling_agent(
254
663
  Below are summaries of all datasets provided. If more than one dataset is provided, you may need to merge or join them.:
255
664
  {all_datasets_summary}
256
665
 
257
- Return Python code in ```python``` format with a single function definition, data_wrangler(), that includes all imports inside the function. And returns a single pandas data frame.
666
+ Return Python code in ```python``` format with a single function definition, {function_name}(), that includes all imports inside the function. And returns a single pandas data frame.
258
667
 
259
668
  ```python
260
- def data_wrangler(data_list):
669
+ def {function_name}(data_list):
261
670
  '''
262
671
  Wrangle the data provided in data.
263
672
 
@@ -279,14 +688,15 @@ def make_data_wrangling_agent(
279
688
 
280
689
 
281
690
  """,
282
- input_variables=["recommended_steps", "all_datasets_summary"]
691
+ input_variables=["recommended_steps", "all_datasets_summary", "function_name"]
283
692
  )
284
693
 
285
694
  data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
286
695
 
287
696
  response = data_wrangling_agent.invoke({
288
697
  "recommended_steps": state.get("recommended_steps"),
289
- "all_datasets_summary": all_datasets_summary_str
698
+ "all_datasets_summary": all_datasets_summary_str,
699
+ "function_name": function_name
290
700
  })
291
701
 
292
702
  response = relocate_imports_inside_function(response)
@@ -304,7 +714,8 @@ def make_data_wrangling_agent(
304
714
  return {
305
715
  "data_wrangler_function" : response,
306
716
  "data_wrangler_function_path": file_path,
307
- "data_wrangler_function_name": file_name_2,
717
+ "data_wrangler_file_name": file_name_2,
718
+ "data_wrangler_function_name": function_name,
308
719
  "all_datasets_summary": all_datasets_summary_str
309
720
  }
310
721
 
@@ -318,6 +729,33 @@ def make_data_wrangling_agent(
318
729
  user_instructions_key="user_instructions",
319
730
  recommended_steps_key="recommended_steps"
320
731
  )
732
+
733
+ # Human Review
734
+
735
+ prompt_text_human_review = "Are the following data wrangling instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
736
+
737
+ if not bypass_explain_code:
738
+ def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "explain_data_wrangler_code"]]:
739
+ return node_func_human_review(
740
+ state=state,
741
+ prompt_text=prompt_text_human_review,
742
+ yes_goto= 'explain_data_wrangler_code',
743
+ no_goto="recommend_wrangling_steps",
744
+ user_instructions_key="user_instructions",
745
+ recommended_steps_key="recommended_steps",
746
+ code_snippet_key="data_wrangler_function",
747
+ )
748
+ else:
749
+ def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "__end__"]]:
750
+ return node_func_human_review(
751
+ state=state,
752
+ prompt_text=prompt_text_human_review,
753
+ yes_goto= '__end__',
754
+ no_goto="recommend_wrangling_steps",
755
+ user_instructions_key="user_instructions",
756
+ recommended_steps_key="recommended_steps",
757
+ code_snippet_key="data_wrangler_function",
758
+ )
321
759
 
322
760
  def execute_data_wrangler_code(state: GraphState):
323
761
  return node_func_execute_agent_code_on_data(
@@ -326,7 +764,7 @@ def make_data_wrangling_agent(
326
764
  result_key="data_wrangled",
327
765
  error_key="data_wrangler_error",
328
766
  code_snippet_key="data_wrangler_function",
329
- agent_function_name="data_wrangler",
767
+ agent_function_name=state.get("data_wrangler_function_name"),
330
768
  # pre_processing=pre_processing,
331
769
  post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
332
770
  error_message_prefix="An error occurred during data wrangling: "
@@ -334,11 +772,11 @@ def make_data_wrangling_agent(
334
772
 
335
773
  def fix_data_wrangler_code(state: GraphState):
336
774
  data_wrangler_prompt = """
337
- You are a Data Wrangling Agent. Your job is to create a data_wrangler() function that can be run on the data provided. The function is currently broken and needs to be fixed.
775
+ You are a Data Wrangling Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
338
776
 
339
- Make sure to only return the function definition for data_wrangler().
777
+ Make sure to only return the function definition for {function_name}().
340
778
 
341
- Return Python code in ```python``` format with a single function definition, data_wrangler(data_raw), that includes all imports inside the function.
779
+ Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
342
780
 
343
781
  This is the broken code (please fix):
344
782
  {code_snippet}
@@ -356,6 +794,7 @@ def make_data_wrangling_agent(
356
794
  agent_name=AGENT_NAME,
357
795
  log=log,
358
796
  file_path=state.get("data_wrangler_function_path"),
797
+ function_name=state.get("data_wrangler_function_name"),
359
798
  )
360
799
 
361
800
  def explain_data_wrangler_code(state: GraphState):