ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -4,26 +4,27 @@
4
4
  # * Agents: Data Wrangling Agent
5
5
 
6
6
  # Libraries
7
- from typing import TypedDict, Annotated, Sequence, Literal, Union
7
+ from typing import TypedDict, Annotated, Sequence, Literal, Union, Optional
8
8
  import operator
9
9
  import os
10
- import io
11
10
  import pandas as pd
11
+ from IPython.display import Markdown
12
12
 
13
13
  from langchain.prompts import PromptTemplate
14
14
  from langchain_core.messages import BaseMessage
15
15
  from langgraph.types import Command
16
16
  from langgraph.checkpoint.memory import MemorySaver
17
17
 
18
- from ai_data_science_team.templates.agent_templates import(
18
+ from ai_data_science_team.templates import(
19
19
  node_func_execute_agent_code_on_data,
20
20
  node_func_human_review,
21
21
  node_func_fix_agent_code,
22
22
  node_func_explain_agent_code,
23
- create_coding_agent_graph
23
+ create_coding_agent_graph,
24
+ BaseAgent,
24
25
  )
25
26
  from ai_data_science_team.tools.parsers import PythonOutputParser
26
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
27
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name, format_recommended_steps
27
28
  from ai_data_science_team.tools.metadata import get_dataframe_summary
28
29
  from ai_data_science_team.tools.logging import log_ai_function
29
30
 
@@ -31,7 +32,418 @@ from ai_data_science_team.tools.logging import log_ai_function
31
32
  AGENT_NAME = "data_wrangling_agent"
32
33
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
33
34
 
34
- def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True, human_in_the_loop=False, bypass_recommended_steps=False, bypass_explain_code=False):
35
+ # Class
36
+
37
+ class DataWranglingAgent(BaseAgent):
38
+ """
39
+ Creates a data wrangling agent that can work with one or more datasets, performing operations such as
40
+ joining/merging multiple datasets, reshaping, aggregating, encoding, creating computed features,
41
+ and ensuring consistent data types. The agent generates a Python function to wrangle the data,
42
+ executes the function, and logs the process (if enabled).
43
+
44
+ The agent can handle:
45
+ - A single dataset (provided as a dictionary of {column: list_of_values})
46
+ - Multiple datasets (provided as a list of such dictionaries)
47
+
48
+ Key wrangling steps can include:
49
+ - Merging or joining datasets
50
+ - Pivoting/melting data for reshaping
51
+ - GroupBy aggregations (sums, means, counts, etc.)
52
+ - Encoding categorical variables
53
+ - Computing new columns from existing ones
54
+ - Dropping or rearranging columns
55
+ - Any additional user instructions
56
+
57
+ Parameters
58
+ ----------
59
+ model : langchain.llms.base.LLM
60
+ The language model used to generate the data wrangling function.
61
+ n_samples : int, optional
62
+ Number of samples to show in the data summary for wrangling. Defaults to 30.
63
+ log : bool, optional
64
+ Whether to log the generated code and errors. Defaults to False.
65
+ log_path : str, optional
66
+ Directory path for storing log files. Defaults to None.
67
+ file_name : str, optional
68
+ Name of the file for saving the generated response. Defaults to "data_wrangler.py".
69
+ function_name : str, optional
70
+ Name of the function to be generated. Defaults to "data_wrangler".
71
+ overwrite : bool, optional
72
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
73
+ human_in_the_loop : bool, optional
74
+ Enables user review of data wrangling instructions. Defaults to False.
75
+ bypass_recommended_steps : bool, optional
76
+ If True, skips the step that generates recommended data wrangling steps. Defaults to False.
77
+ bypass_explain_code : bool, optional
78
+ If True, skips the step that provides code explanations. Defaults to False.
79
+
80
+ Methods
81
+ -------
82
+ update_params(**kwargs)
83
+ Updates the agent's parameters and rebuilds the compiled state graph.
84
+
85
+ ainvoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
86
+ Asynchronously wrangles the provided dataset(s) based on user instructions.
87
+
88
+ invoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
89
+ Synchronously wrangles the provided dataset(s) based on user instructions.
90
+
91
+ explain_wrangling_steps()
92
+ Returns an explanation of the wrangling steps performed by the agent.
93
+
94
+ get_log_summary()
95
+ Retrieves a summary of logged operations if logging is enabled.
96
+
97
+ get_data_wrangled()
98
+ Retrieves the final wrangled dataset (as a dictionary of {column: list_of_values}).
99
+
100
+ get_data_raw()
101
+ Retrieves the raw dataset(s).
102
+
103
+ get_data_wrangler_function()
104
+ Retrieves the generated Python function used for data wrangling.
105
+
106
+ get_recommended_wrangling_steps()
107
+ Retrieves the agent's recommended wrangling steps.
108
+
109
+ get_response()
110
+ Returns the full response dictionary from the agent.
111
+
112
+ show()
113
+ Displays the agent's mermaid diagram for visual inspection of the compiled graph.
114
+
115
+ Examples
116
+ --------
117
+ ```python
118
+ import pandas as pd
119
+ from langchain_openai import ChatOpenAI
120
+ from ai_data_science_team.agents import DataWranglingAgent
121
+
122
+ # Single dataset example
123
+ llm = ChatOpenAI(model="gpt-4o-mini")
124
+
125
+ data_wrangling_agent = DataWranglingAgent(
126
+ model=llm,
127
+ n_samples=30,
128
+ log=True,
129
+ log_path="logs",
130
+ human_in_the_loop=True
131
+ )
132
+
133
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
134
+
135
+ data_wrangling_agent.invoke_agent(
136
+ user_instructions="Group by 'gender' and compute mean of 'tenure'.",
137
+ data_raw=df, # data_raw can be df.to_dict() or just a DataFrame
138
+ max_retries=3,
139
+ retry_count=0
140
+ )
141
+
142
+ data_wrangled = data_wrangling_agent.get_data_wrangled()
143
+ response = data_wrangling_agent.get_response()
144
+
145
+ # Multiple dataset example (list of dicts)
146
+ df1 = pd.DataFrame({'id': [1,2,3], 'val1': [10,20,30]})
147
+ df2 = pd.DataFrame({'id': [1,2,3], 'val2': [40,50,60]})
148
+
149
+ data_wrangling_agent.invoke_agent(
150
+ user_instructions="Merge these two datasets on 'id' and compute a new column 'val_sum' = val1+val2",
151
+ data_raw=[df1, df2], # multiple datasets
152
+ max_retries=3,
153
+ retry_count=0
154
+ )
155
+
156
+ data_wrangled = data_wrangling_agent.get_data_wrangled()
157
+ ```
158
+
159
+ Returns
160
+ -------
161
+ DataWranglingAgent : langchain.graphs.CompiledStateGraph
162
+ A data wrangling agent implemented as a compiled state graph.
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ model,
168
+ n_samples=30,
169
+ log=False,
170
+ log_path=None,
171
+ file_name="data_wrangler.py",
172
+ function_name="data_wrangler",
173
+ overwrite=True,
174
+ human_in_the_loop=False,
175
+ bypass_recommended_steps=False,
176
+ bypass_explain_code=False
177
+ ):
178
+ self._params = {
179
+ "model": model,
180
+ "n_samples": n_samples,
181
+ "log": log,
182
+ "log_path": log_path,
183
+ "file_name": file_name,
184
+ "function_name": function_name,
185
+ "overwrite": overwrite,
186
+ "human_in_the_loop": human_in_the_loop,
187
+ "bypass_recommended_steps": bypass_recommended_steps,
188
+ "bypass_explain_code": bypass_explain_code
189
+ }
190
+ self._compiled_graph = self._make_compiled_graph()
191
+ self.response = None
192
+
193
+ def _make_compiled_graph(self):
194
+ """
195
+ Create the compiled graph for the data wrangling agent.
196
+ Running this method will reset the response to None.
197
+ """
198
+ self.response = None
199
+ return make_data_wrangling_agent(**self._params)
200
+
201
+ def update_params(self, **kwargs):
202
+ """
203
+ Updates the agent's parameters and rebuilds the compiled graph.
204
+ """
205
+ for k, v in kwargs.items():
206
+ self._params[k] = v
207
+ self._compiled_graph = self._make_compiled_graph()
208
+
209
+ def ainvoke_agent(
210
+ self,
211
+ data_raw: Union[pd.DataFrame, dict, list],
212
+ user_instructions: str=None,
213
+ max_retries:int=3,
214
+ retry_count:int=0,
215
+ **kwargs
216
+ ):
217
+ """
218
+ Asynchronously wrangles the provided dataset(s) based on user instructions.
219
+ The response is stored in the 'response' attribute.
220
+
221
+ Parameters
222
+ ----------
223
+ data_raw : Union[pd.DataFrame, dict, list]
224
+ The raw dataset(s) to be wrangled.
225
+ Can be a single DataFrame, a single dict ({col: list_of_values}),
226
+ or a list of dicts if multiple datasets are provided.
227
+ user_instructions : str
228
+ Instructions for data wrangling.
229
+ max_retries : int
230
+ Maximum retry attempts.
231
+ retry_count : int
232
+ Current retry attempt count.
233
+ **kwargs
234
+ Additional keyword arguments to pass to ainvoke().
235
+
236
+ Returns
237
+ -------
238
+ None
239
+ """
240
+ data_input = self._convert_data_input(data_raw)
241
+ response = self._compiled_graph.ainvoke({
242
+ "user_instructions": user_instructions,
243
+ "data_raw": data_input,
244
+ "max_retries": max_retries,
245
+ "retry_count": retry_count
246
+ }, **kwargs)
247
+ self.response = response
248
+ return None
249
+
250
+ def invoke_agent(
251
+ self,
252
+ data_raw: Union[pd.DataFrame, dict, list],
253
+ user_instructions: str=None,
254
+ max_retries:int=3,
255
+ retry_count:int=0,
256
+ **kwargs
257
+ ):
258
+ """
259
+ Synchronously wrangles the provided dataset(s) based on user instructions.
260
+ The response is stored in the 'response' attribute.
261
+
262
+ Parameters
263
+ ----------
264
+ data_raw : Union[pd.DataFrame, dict, list]
265
+ The raw dataset(s) to be wrangled.
266
+ Can be a single DataFrame, a single dict, or a list of dicts.
267
+ user_instructions : str
268
+ Instructions for data wrangling agent.
269
+ max_retries : int
270
+ Maximum retry attempts.
271
+ retry_count : int
272
+ Current retry attempt count.
273
+ **kwargs
274
+ Additional keyword arguments to pass to invoke().
275
+
276
+ Returns
277
+ -------
278
+ None
279
+ """
280
+ data_input = self._convert_data_input(data_raw)
281
+ response = self._compiled_graph.invoke({
282
+ "user_instructions": user_instructions,
283
+ "data_raw": data_input,
284
+ "max_retries": max_retries,
285
+ "retry_count": retry_count
286
+ }, **kwargs)
287
+ self.response = response
288
+ return None
289
+
290
+ def explain_wrangling_steps(self):
291
+ """
292
+ Provides an explanation of the wrangling steps performed by the agent.
293
+
294
+ Returns
295
+ -------
296
+ str or list
297
+ Explanation of the data wrangling steps.
298
+ """
299
+ if self.response:
300
+ return self.response.get("messages", [])
301
+ return []
302
+
303
+ def get_log_summary(self, markdown=False):
304
+ """
305
+ Logs a summary of the agent's operations, if logging is enabled.
306
+
307
+ Parameters
308
+ ----------
309
+ markdown : bool, optional
310
+ If True, returns the summary in Markdown.
311
+
312
+ Returns
313
+ -------
314
+ str or None
315
+ The log details, or None if not available.
316
+ """
317
+ if self.response and self.response.get("data_wrangler_function_path"):
318
+ log_details = f"Log Path: {self.response.get('data_wrangler_function_path')}"
319
+ if markdown:
320
+ return Markdown(log_details)
321
+ else:
322
+ return log_details
323
+ return None
324
+
325
+ def get_data_wrangled(self) -> Optional[pd.DataFrame]:
326
+ """
327
+ Retrieves the wrangled data after running invoke_agent() or ainvoke_agent().
328
+
329
+ Returns
330
+ -------
331
+ pd.DataFrame or None
332
+ The wrangled dataset as a pandas DataFrame (if available).
333
+ """
334
+ if self.response and "data_wrangled" in self.response:
335
+ return pd.DataFrame(self.response["data_wrangled"])
336
+ return None
337
+
338
+ def get_data_raw(self) -> Union[dict, list, None]:
339
+ """
340
+ Retrieves the original raw data from the last invocation.
341
+
342
+ Returns
343
+ -------
344
+ Union[dict, list, None]
345
+ The original dataset(s) as a single dict or a list of dicts, or None if not available.
346
+ """
347
+ if self.response and "data_raw" in self.response:
348
+ return self.response["data_raw"]
349
+ return None
350
+
351
+ def get_data_wrangler_function(self, markdown=False) -> Optional[str]:
352
+ """
353
+ Retrieves the generated data wrangling function code.
354
+
355
+ Parameters
356
+ ----------
357
+ markdown : bool, optional
358
+ If True, returns the function in Markdown code block format.
359
+
360
+ Returns
361
+ -------
362
+ str or None
363
+ The Python function code, or None if not available.
364
+ """
365
+ if self.response and "data_wrangler_function" in self.response:
366
+ code = self.response["data_wrangler_function"]
367
+ if markdown:
368
+ return Markdown(f"```python\n{code}\n```")
369
+ return code
370
+ return None
371
+
372
+ def get_recommended_wrangling_steps(self, markdown=False) -> Optional[str]:
373
+ """
374
+ Retrieves the agent's recommended data wrangling steps.
375
+
376
+ Parameters
377
+ ----------
378
+ markdown : bool, optional
379
+ If True, returns the steps in Markdown format.
380
+
381
+ Returns
382
+ -------
383
+ str or None
384
+ The recommended steps, or None if not available.
385
+ """
386
+ if self.response and "recommended_steps" in self.response:
387
+ steps = self.response["recommended_steps"]
388
+ if markdown:
389
+ return Markdown(steps)
390
+ return steps
391
+ return None
392
+
393
+ @staticmethod
394
+ def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
395
+ """
396
+ Internal utility to convert data_raw (which could be a DataFrame, dict, or list of dicts)
397
+ into the format expected by the underlying agent (dict or list of dicts).
398
+
399
+ Parameters
400
+ ----------
401
+ data_raw : Union[pd.DataFrame, dict, list]
402
+ The raw input data to be converted.
403
+
404
+ Returns
405
+ -------
406
+ Union[dict, list]
407
+ The data in a dictionary or list-of-dictionaries format.
408
+ """
409
+ # If a single DataFrame, convert to dict
410
+ if isinstance(data_raw, pd.DataFrame):
411
+ return data_raw.to_dict()
412
+
413
+ # If it's already a dict (single dataset)
414
+ if isinstance(data_raw, dict):
415
+ return data_raw
416
+
417
+ # If it's already a list, check if it's a list of DataFrames or dicts
418
+ if isinstance(data_raw, list):
419
+ # Convert any DataFrame item to dict
420
+ converted_list = []
421
+ for item in data_raw:
422
+ if isinstance(item, pd.DataFrame):
423
+ converted_list.append(item.to_dict())
424
+ elif isinstance(item, dict):
425
+ converted_list.append(item)
426
+ else:
427
+ raise ValueError("List must contain only DataFrames or dictionaries.")
428
+ return converted_list
429
+
430
+ raise ValueError("data_raw must be a DataFrame, a dict, or a list of dicts/DataFrames.")
431
+
432
+
433
+ # Function
434
+
435
+ def make_data_wrangling_agent(
436
+ model,
437
+ n_samples=30,
438
+ log=False,
439
+ log_path=None,
440
+ file_name="data_wrangler.py",
441
+ function_name="data_wrangler",
442
+ overwrite=True,
443
+ human_in_the_loop=False,
444
+ bypass_recommended_steps=False,
445
+ bypass_explain_code=False
446
+ ):
35
447
  """
36
448
  Creates a data wrangling agent that can be run on one or more datasets. The agent can be
37
449
  instructed to perform common data wrangling steps such as:
@@ -52,11 +464,19 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
52
464
  ----------
53
465
  model : langchain.llms.base.LLM
54
466
  The language model to use to generate code.
467
+ n_samples : int, optional
468
+ The number of samples to show in the data summary. Defaults to 30.
469
+ If you get an error due to maximum tokens, try reducing this number.
470
+ > "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
55
471
  log : bool, optional
56
472
  Whether or not to log the code generated and any errors that occur.
57
473
  Defaults to False.
58
474
  log_path : str, optional
59
475
  The path to the directory where the log files should be stored. Defaults to "logs/".
476
+ file_name : str, optional
477
+ The name of the file to save the response to. Defaults to "data_wrangler.py".
478
+ function_name : str, optional
479
+ The name of the function to be generated. Defaults to "data_wrangler".
60
480
  overwrite : bool, optional
61
481
  Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
62
482
  Defaults to True.
@@ -94,11 +514,16 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
94
514
 
95
515
  Returns
96
516
  -------
97
- app : langchain.graphs.StateGraph
517
+ app : langchain.graphs.CompiledStateGraph
98
518
  The data wrangling agent as a state graph.
99
519
  """
100
520
  llm = model
101
521
 
522
+ # Human in th loop requires recommended steps
523
+ if bypass_recommended_steps and human_in_the_loop:
524
+ bypass_recommended_steps = False
525
+ print("Bypass recommended steps set to False to enable human in the loop.")
526
+
102
527
  # Setup Log Directory
103
528
  if log:
104
529
  if log_path is None:
@@ -122,7 +547,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
122
547
  retry_count: int
123
548
 
124
549
  def recommend_wrangling_steps(state: GraphState):
125
- print("---DATA WRANGLING AGENT----")
550
+ print(format_agent_name(AGENT_NAME))
126
551
  print(" * RECOMMEND WRANGLING STEPS")
127
552
 
128
553
  data_raw = state.get("data_raw")
@@ -143,7 +568,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
143
568
 
144
569
  # Create a summary for all datasets
145
570
  # We'll include a short sample and info for each dataset
146
- all_datasets_summary = get_dataframe_summary(dataframes)
571
+ all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
147
572
 
148
573
  # Join all datasets summaries into one big text block
149
574
  all_datasets_summary_str = "\n\n".join(all_datasets_summary)
@@ -176,6 +601,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
176
601
 
177
602
  Avoid these:
178
603
  1. Do not include steps to save files.
604
+ 2. Do not include unrelated user instructions that are not related to the data wrangling.
179
605
  """,
180
606
  input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
181
607
  )
@@ -188,19 +614,46 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
188
614
  })
189
615
 
190
616
  return {
191
- "recommended_steps": "\n\n# Recommended Wrangling Steps:\n" + recommended_steps.content.strip(),
617
+ "recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Wrangling Steps:"),
192
618
  "all_datasets_summary": all_datasets_summary_str,
193
619
  }
194
620
 
195
621
 
196
622
  def create_data_wrangler_code(state: GraphState):
197
623
  if bypass_recommended_steps:
198
- print("---DATA WRANGLING AGENT----")
624
+ print(format_agent_name(AGENT_NAME))
625
+
626
+ data_raw = state.get("data_raw")
627
+
628
+ if isinstance(data_raw, dict):
629
+ # Single dataset scenario
630
+ primary_dataset_name = "main"
631
+ datasets = {primary_dataset_name: data_raw}
632
+ elif isinstance(data_raw, list) and all(isinstance(item, dict) for item in data_raw):
633
+ # Multiple datasets scenario
634
+ datasets = {f"dataset_{i}": d for i, d in enumerate(data_raw, start=1)}
635
+ primary_dataset_name = "dataset_1"
636
+ else:
637
+ raise ValueError("data_raw must be a dict or a list of dicts.")
638
+
639
+ # Convert all datasets to DataFrames for inspection
640
+ dataframes = {name: pd.DataFrame.from_dict(d) for name, d in datasets.items()}
641
+
642
+ # Create a summary for all datasets
643
+ # We'll include a short sample and info for each dataset
644
+ all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
645
+
646
+ # Join all datasets summaries into one big text block
647
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
648
+
649
+ else:
650
+ all_datasets_summary_str = state.get("all_datasets_summary")
651
+
199
652
  print(" * CREATE DATA WRANGLER CODE")
200
653
 
201
654
  data_wrangling_prompt = PromptTemplate(
202
655
  template="""
203
- You are a Data Wrangling Coding Agent. Your job is to create a data_wrangler() function that can be run on the provided data.
656
+ You are a Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data.
204
657
 
205
658
  Follow these recommended steps:
206
659
  {recommended_steps}
@@ -210,10 +663,10 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
210
663
  Below are summaries of all datasets provided. If more than one dataset is provided, you may need to merge or join them.:
211
664
  {all_datasets_summary}
212
665
 
213
- Return Python code in ```python``` format with a single function definition, data_wrangler(), that includes all imports inside the function. And returns a single pandas data frame.
666
+ Return Python code in ```python``` format with a single function definition, {function_name}(), that includes all imports inside the function. And returns a single pandas data frame.
214
667
 
215
668
  ```python
216
- def data_wrangler(data_list):
669
+ def {function_name}(data_list):
217
670
  '''
218
671
  Wrangle the data provided in data.
219
672
 
@@ -235,23 +688,24 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
235
688
 
236
689
 
237
690
  """,
238
- input_variables=["recommended_steps", "all_datasets_summary"]
691
+ input_variables=["recommended_steps", "all_datasets_summary", "function_name"]
239
692
  )
240
693
 
241
694
  data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
242
695
 
243
696
  response = data_wrangling_agent.invoke({
244
697
  "recommended_steps": state.get("recommended_steps"),
245
- "all_datasets_summary": state.get("all_datasets_summary")
698
+ "all_datasets_summary": all_datasets_summary_str,
699
+ "function_name": function_name
246
700
  })
247
701
 
248
702
  response = relocate_imports_inside_function(response)
249
703
  response = add_comments_to_top(response, agent_name=AGENT_NAME)
250
704
 
251
705
  # For logging: store the code generated
252
- file_path, file_name = log_ai_function(
706
+ file_path, file_name_2 = log_ai_function(
253
707
  response=response,
254
- file_name="data_wrangler.py",
708
+ file_name=file_name,
255
709
  log=log,
256
710
  log_path=log_path,
257
711
  overwrite=overwrite
@@ -260,7 +714,9 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
260
714
  return {
261
715
  "data_wrangler_function" : response,
262
716
  "data_wrangler_function_path": file_path,
263
- "data_wrangler_function_name": file_name
717
+ "data_wrangler_file_name": file_name_2,
718
+ "data_wrangler_function_name": function_name,
719
+ "all_datasets_summary": all_datasets_summary_str
264
720
  }
265
721
 
266
722
 
@@ -273,6 +729,33 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
273
729
  user_instructions_key="user_instructions",
274
730
  recommended_steps_key="recommended_steps"
275
731
  )
732
+
733
+ # Human Review
734
+
735
+ prompt_text_human_review = "Are the following data wrangling instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
736
+
737
+ if not bypass_explain_code:
738
+ def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "explain_data_wrangler_code"]]:
739
+ return node_func_human_review(
740
+ state=state,
741
+ prompt_text=prompt_text_human_review,
742
+ yes_goto= 'explain_data_wrangler_code',
743
+ no_goto="recommend_wrangling_steps",
744
+ user_instructions_key="user_instructions",
745
+ recommended_steps_key="recommended_steps",
746
+ code_snippet_key="data_wrangler_function",
747
+ )
748
+ else:
749
+ def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "__end__"]]:
750
+ return node_func_human_review(
751
+ state=state,
752
+ prompt_text=prompt_text_human_review,
753
+ yes_goto= '__end__',
754
+ no_goto="recommend_wrangling_steps",
755
+ user_instructions_key="user_instructions",
756
+ recommended_steps_key="recommended_steps",
757
+ code_snippet_key="data_wrangler_function",
758
+ )
276
759
 
277
760
  def execute_data_wrangler_code(state: GraphState):
278
761
  return node_func_execute_agent_code_on_data(
@@ -281,7 +764,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
281
764
  result_key="data_wrangled",
282
765
  error_key="data_wrangler_error",
283
766
  code_snippet_key="data_wrangler_function",
284
- agent_function_name="data_wrangler",
767
+ agent_function_name=state.get("data_wrangler_function_name"),
285
768
  # pre_processing=pre_processing,
286
769
  post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
287
770
  error_message_prefix="An error occurred during data wrangling: "
@@ -289,11 +772,11 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
289
772
 
290
773
  def fix_data_wrangler_code(state: GraphState):
291
774
  data_wrangler_prompt = """
292
- You are a Data Wrangling Agent. Your job is to create a data_wrangler() function that can be run on the data provided. The function is currently broken and needs to be fixed.
775
+ You are a Data Wrangling Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
293
776
 
294
- Make sure to only return the function definition for data_wrangler().
777
+ Make sure to only return the function definition for {function_name}().
295
778
 
296
- Return Python code in ```python``` format with a single function definition, data_wrangler(data_raw), that includes all imports inside the function.
779
+ Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
297
780
 
298
781
  This is the broken code (please fix):
299
782
  {code_snippet}
@@ -311,6 +794,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
311
794
  agent_name=AGENT_NAME,
312
795
  log=log,
313
796
  file_path=state.get("data_wrangler_function_path"),
797
+ function_name=state.get("data_wrangler_function_name"),
314
798
  )
315
799
 
316
800
  def explain_data_wrangler_code(state: GraphState):