ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,26 +4,27 @@
4
4
  # * Agents: Data Wrangling Agent
5
5
 
6
6
  # Libraries
7
- from typing import TypedDict, Annotated, Sequence, Literal, Union
7
+ from typing import TypedDict, Annotated, Sequence, Literal, Union, Optional
8
8
  import operator
9
9
  import os
10
- import io
11
10
  import pandas as pd
11
+ from IPython.display import Markdown
12
12
 
13
13
  from langchain.prompts import PromptTemplate
14
14
  from langchain_core.messages import BaseMessage
15
15
  from langgraph.types import Command
16
16
  from langgraph.checkpoint.memory import MemorySaver
17
17
 
18
- from ai_data_science_team.templates.agent_templates import(
18
+ from ai_data_science_team.templates import(
19
19
  node_func_execute_agent_code_on_data,
20
20
  node_func_human_review,
21
21
  node_func_fix_agent_code,
22
22
  node_func_explain_agent_code,
23
- create_coding_agent_graph
23
+ create_coding_agent_graph,
24
+ BaseAgent,
24
25
  )
25
26
  from ai_data_science_team.tools.parsers import PythonOutputParser
26
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
27
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name, format_recommended_steps
27
28
  from ai_data_science_team.tools.metadata import get_dataframe_summary
28
29
  from ai_data_science_team.tools.logging import log_ai_function
29
30
 
@@ -31,7 +32,418 @@ from ai_data_science_team.tools.logging import log_ai_function
31
32
  AGENT_NAME = "data_wrangling_agent"
32
33
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
33
34
 
34
- def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True, human_in_the_loop=False, bypass_recommended_steps=False, bypass_explain_code=False):
35
+ # Class
36
+
37
+ class DataWranglingAgent(BaseAgent):
38
+ """
39
+ Creates a data wrangling agent that can work with one or more datasets, performing operations such as
40
+ joining/merging multiple datasets, reshaping, aggregating, encoding, creating computed features,
41
+ and ensuring consistent data types. The agent generates a Python function to wrangle the data,
42
+ executes the function, and logs the process (if enabled).
43
+
44
+ The agent can handle:
45
+ - A single dataset (provided as a dictionary of {column: list_of_values})
46
+ - Multiple datasets (provided as a list of such dictionaries)
47
+
48
+ Key wrangling steps can include:
49
+ - Merging or joining datasets
50
+ - Pivoting/melting data for reshaping
51
+ - GroupBy aggregations (sums, means, counts, etc.)
52
+ - Encoding categorical variables
53
+ - Computing new columns from existing ones
54
+ - Dropping or rearranging columns
55
+ - Any additional user instructions
56
+
57
+ Parameters
58
+ ----------
59
+ model : langchain.llms.base.LLM
60
+ The language model used to generate the data wrangling function.
61
+ n_samples : int, optional
62
+ Number of samples to show in the data summary for wrangling. Defaults to 30.
63
+ log : bool, optional
64
+ Whether to log the generated code and errors. Defaults to False.
65
+ log_path : str, optional
66
+ Directory path for storing log files. Defaults to None.
67
+ file_name : str, optional
68
+ Name of the file for saving the generated response. Defaults to "data_wrangler.py".
69
+ function_name : str, optional
70
+ Name of the function to be generated. Defaults to "data_wrangler".
71
+ overwrite : bool, optional
72
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
73
+ human_in_the_loop : bool, optional
74
+ Enables user review of data wrangling instructions. Defaults to False.
75
+ bypass_recommended_steps : bool, optional
76
+ If True, skips the step that generates recommended data wrangling steps. Defaults to False.
77
+ bypass_explain_code : bool, optional
78
+ If True, skips the step that provides code explanations. Defaults to False.
79
+
80
+ Methods
81
+ -------
82
+ update_params(**kwargs)
83
+ Updates the agent's parameters and rebuilds the compiled state graph.
84
+
85
+ ainvoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
86
+ Asynchronously wrangles the provided dataset(s) based on user instructions.
87
+
88
+ invoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
89
+ Synchronously wrangles the provided dataset(s) based on user instructions.
90
+
91
+ explain_wrangling_steps()
92
+ Returns an explanation of the wrangling steps performed by the agent.
93
+
94
+ get_log_summary()
95
+ Retrieves a summary of logged operations if logging is enabled.
96
+
97
+ get_data_wrangled()
98
+ Retrieves the final wrangled dataset (as a dictionary of {column: list_of_values}).
99
+
100
+ get_data_raw()
101
+ Retrieves the raw dataset(s).
102
+
103
+ get_data_wrangler_function()
104
+ Retrieves the generated Python function used for data wrangling.
105
+
106
+ get_recommended_wrangling_steps()
107
+ Retrieves the agent's recommended wrangling steps.
108
+
109
+ get_response()
110
+ Returns the full response dictionary from the agent.
111
+
112
+ show()
113
+ Displays the agent's mermaid diagram for visual inspection of the compiled graph.
114
+
115
+ Examples
116
+ --------
117
+ ```python
118
+ import pandas as pd
119
+ from langchain_openai import ChatOpenAI
120
+ from ai_data_science_team.agents import DataWranglingAgent
121
+
122
+ # Single dataset example
123
+ llm = ChatOpenAI(model="gpt-4o-mini")
124
+
125
+ data_wrangling_agent = DataWranglingAgent(
126
+ model=llm,
127
+ n_samples=30,
128
+ log=True,
129
+ log_path="logs",
130
+ human_in_the_loop=True
131
+ )
132
+
133
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
134
+
135
+ data_wrangling_agent.invoke_agent(
136
+ user_instructions="Group by 'gender' and compute mean of 'tenure'.",
137
+ data_raw=df, # data_raw can be df.to_dict() or just a DataFrame
138
+ max_retries=3,
139
+ retry_count=0
140
+ )
141
+
142
+ data_wrangled = data_wrangling_agent.get_data_wrangled()
143
+ response = data_wrangling_agent.get_response()
144
+
145
+ # Multiple dataset example (list of dicts)
146
+ df1 = pd.DataFrame({'id': [1,2,3], 'val1': [10,20,30]})
147
+ df2 = pd.DataFrame({'id': [1,2,3], 'val2': [40,50,60]})
148
+
149
+ data_wrangling_agent.invoke_agent(
150
+ user_instructions="Merge these two datasets on 'id' and compute a new column 'val_sum' = val1+val2",
151
+ data_raw=[df1, df2], # multiple datasets
152
+ max_retries=3,
153
+ retry_count=0
154
+ )
155
+
156
+ data_wrangled = data_wrangling_agent.get_data_wrangled()
157
+ ```
158
+
159
+ Returns
160
+ -------
161
+ DataWranglingAgent : langchain.graphs.CompiledStateGraph
162
+ A data wrangling agent implemented as a compiled state graph.
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ model,
168
+ n_samples=30,
169
+ log=False,
170
+ log_path=None,
171
+ file_name="data_wrangler.py",
172
+ function_name="data_wrangler",
173
+ overwrite=True,
174
+ human_in_the_loop=False,
175
+ bypass_recommended_steps=False,
176
+ bypass_explain_code=False
177
+ ):
178
+ self._params = {
179
+ "model": model,
180
+ "n_samples": n_samples,
181
+ "log": log,
182
+ "log_path": log_path,
183
+ "file_name": file_name,
184
+ "function_name": function_name,
185
+ "overwrite": overwrite,
186
+ "human_in_the_loop": human_in_the_loop,
187
+ "bypass_recommended_steps": bypass_recommended_steps,
188
+ "bypass_explain_code": bypass_explain_code
189
+ }
190
+ self._compiled_graph = self._make_compiled_graph()
191
+ self.response = None
192
+
193
+ def _make_compiled_graph(self):
194
+ """
195
+ Create the compiled graph for the data wrangling agent.
196
+ Running this method will reset the response to None.
197
+ """
198
+ self.response = None
199
+ return make_data_wrangling_agent(**self._params)
200
+
201
+ def update_params(self, **kwargs):
202
+ """
203
+ Updates the agent's parameters and rebuilds the compiled graph.
204
+ """
205
+ for k, v in kwargs.items():
206
+ self._params[k] = v
207
+ self._compiled_graph = self._make_compiled_graph()
208
+
209
+ def ainvoke_agent(
210
+ self,
211
+ data_raw: Union[pd.DataFrame, dict, list],
212
+ user_instructions: str=None,
213
+ max_retries:int=3,
214
+ retry_count:int=0,
215
+ **kwargs
216
+ ):
217
+ """
218
+ Asynchronously wrangles the provided dataset(s) based on user instructions.
219
+ The response is stored in the 'response' attribute.
220
+
221
+ Parameters
222
+ ----------
223
+ data_raw : Union[pd.DataFrame, dict, list]
224
+ The raw dataset(s) to be wrangled.
225
+ Can be a single DataFrame, a single dict ({col: list_of_values}),
226
+ or a list of dicts if multiple datasets are provided.
227
+ user_instructions : str
228
+ Instructions for data wrangling.
229
+ max_retries : int
230
+ Maximum retry attempts.
231
+ retry_count : int
232
+ Current retry attempt count.
233
+ **kwargs
234
+ Additional keyword arguments to pass to ainvoke().
235
+
236
+ Returns
237
+ -------
238
+ None
239
+ """
240
+ data_input = self._convert_data_input(data_raw)
241
+ response = self._compiled_graph.ainvoke({
242
+ "user_instructions": user_instructions,
243
+ "data_raw": data_input,
244
+ "max_retries": max_retries,
245
+ "retry_count": retry_count
246
+ }, **kwargs)
247
+ self.response = response
248
+ return None
249
+
250
+ def invoke_agent(
251
+ self,
252
+ data_raw: Union[pd.DataFrame, dict, list],
253
+ user_instructions: str=None,
254
+ max_retries:int=3,
255
+ retry_count:int=0,
256
+ **kwargs
257
+ ):
258
+ """
259
+ Synchronously wrangles the provided dataset(s) based on user instructions.
260
+ The response is stored in the 'response' attribute.
261
+
262
+ Parameters
263
+ ----------
264
+ data_raw : Union[pd.DataFrame, dict, list]
265
+ The raw dataset(s) to be wrangled.
266
+ Can be a single DataFrame, a single dict, or a list of dicts.
267
+ user_instructions : str
268
+ Instructions for data wrangling agent.
269
+ max_retries : int
270
+ Maximum retry attempts.
271
+ retry_count : int
272
+ Current retry attempt count.
273
+ **kwargs
274
+ Additional keyword arguments to pass to invoke().
275
+
276
+ Returns
277
+ -------
278
+ None
279
+ """
280
+ data_input = self._convert_data_input(data_raw)
281
+ response = self._compiled_graph.invoke({
282
+ "user_instructions": user_instructions,
283
+ "data_raw": data_input,
284
+ "max_retries": max_retries,
285
+ "retry_count": retry_count
286
+ }, **kwargs)
287
+ self.response = response
288
+ return None
289
+
290
+ def explain_wrangling_steps(self):
291
+ """
292
+ Provides an explanation of the wrangling steps performed by the agent.
293
+
294
+ Returns
295
+ -------
296
+ str or list
297
+ Explanation of the data wrangling steps.
298
+ """
299
+ if self.response:
300
+ return self.response.get("messages", [])
301
+ return []
302
+
303
+ def get_log_summary(self, markdown=False):
304
+ """
305
+ Logs a summary of the agent's operations, if logging is enabled.
306
+
307
+ Parameters
308
+ ----------
309
+ markdown : bool, optional
310
+ If True, returns the summary in Markdown.
311
+
312
+ Returns
313
+ -------
314
+ str or None
315
+ The log details, or None if not available.
316
+ """
317
+ if self.response and self.response.get("data_wrangler_function_path"):
318
+ log_details = f"Log Path: {self.response.get('data_wrangler_function_path')}"
319
+ if markdown:
320
+ return Markdown(log_details)
321
+ else:
322
+ return log_details
323
+ return None
324
+
325
+ def get_data_wrangled(self) -> Optional[pd.DataFrame]:
326
+ """
327
+ Retrieves the wrangled data after running invoke_agent() or ainvoke_agent().
328
+
329
+ Returns
330
+ -------
331
+ pd.DataFrame or None
332
+ The wrangled dataset as a pandas DataFrame (if available).
333
+ """
334
+ if self.response and "data_wrangled" in self.response:
335
+ return pd.DataFrame(self.response["data_wrangled"])
336
+ return None
337
+
338
+ def get_data_raw(self) -> Union[dict, list, None]:
339
+ """
340
+ Retrieves the original raw data from the last invocation.
341
+
342
+ Returns
343
+ -------
344
+ Union[dict, list, None]
345
+ The original dataset(s) as a single dict or a list of dicts, or None if not available.
346
+ """
347
+ if self.response and "data_raw" in self.response:
348
+ return self.response["data_raw"]
349
+ return None
350
+
351
+ def get_data_wrangler_function(self, markdown=False) -> Optional[str]:
352
+ """
353
+ Retrieves the generated data wrangling function code.
354
+
355
+ Parameters
356
+ ----------
357
+ markdown : bool, optional
358
+ If True, returns the function in Markdown code block format.
359
+
360
+ Returns
361
+ -------
362
+ str or None
363
+ The Python function code, or None if not available.
364
+ """
365
+ if self.response and "data_wrangler_function" in self.response:
366
+ code = self.response["data_wrangler_function"]
367
+ if markdown:
368
+ return Markdown(f"```python\n{code}\n```")
369
+ return code
370
+ return None
371
+
372
+ def get_recommended_wrangling_steps(self, markdown=False) -> Optional[str]:
373
+ """
374
+ Retrieves the agent's recommended data wrangling steps.
375
+
376
+ Parameters
377
+ ----------
378
+ markdown : bool, optional
379
+ If True, returns the steps in Markdown format.
380
+
381
+ Returns
382
+ -------
383
+ str or None
384
+ The recommended steps, or None if not available.
385
+ """
386
+ if self.response and "recommended_steps" in self.response:
387
+ steps = self.response["recommended_steps"]
388
+ if markdown:
389
+ return Markdown(steps)
390
+ return steps
391
+ return None
392
+
393
+ @staticmethod
394
+ def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
395
+ """
396
+ Internal utility to convert data_raw (which could be a DataFrame, dict, or list of dicts)
397
+ into the format expected by the underlying agent (dict or list of dicts).
398
+
399
+ Parameters
400
+ ----------
401
+ data_raw : Union[pd.DataFrame, dict, list]
402
+ The raw input data to be converted.
403
+
404
+ Returns
405
+ -------
406
+ Union[dict, list]
407
+ The data in a dictionary or list-of-dictionaries format.
408
+ """
409
+ # If a single DataFrame, convert to dict
410
+ if isinstance(data_raw, pd.DataFrame):
411
+ return data_raw.to_dict()
412
+
413
+ # If it's already a dict (single dataset)
414
+ if isinstance(data_raw, dict):
415
+ return data_raw
416
+
417
+ # If it's already a list, check if it's a list of DataFrames or dicts
418
+ if isinstance(data_raw, list):
419
+ # Convert any DataFrame item to dict
420
+ converted_list = []
421
+ for item in data_raw:
422
+ if isinstance(item, pd.DataFrame):
423
+ converted_list.append(item.to_dict())
424
+ elif isinstance(item, dict):
425
+ converted_list.append(item)
426
+ else:
427
+ raise ValueError("List must contain only DataFrames or dictionaries.")
428
+ return converted_list
429
+
430
+ raise ValueError("data_raw must be a DataFrame, a dict, or a list of dicts/DataFrames.")
431
+
432
+
433
+ # Function
434
+
435
+ def make_data_wrangling_agent(
436
+ model,
437
+ n_samples=30,
438
+ log=False,
439
+ log_path=None,
440
+ file_name="data_wrangler.py",
441
+ function_name="data_wrangler",
442
+ overwrite=True,
443
+ human_in_the_loop=False,
444
+ bypass_recommended_steps=False,
445
+ bypass_explain_code=False
446
+ ):
35
447
  """
36
448
  Creates a data wrangling agent that can be run on one or more datasets. The agent can be
37
449
  instructed to perform common data wrangling steps such as:
@@ -52,11 +464,19 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
52
464
  ----------
53
465
  model : langchain.llms.base.LLM
54
466
  The language model to use to generate code.
467
+ n_samples : int, optional
468
+ The number of samples to show in the data summary. Defaults to 30.
469
+ If you get an error due to maximum tokens, try reducing this number.
470
+ > "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
55
471
  log : bool, optional
56
472
  Whether or not to log the code generated and any errors that occur.
57
473
  Defaults to False.
58
474
  log_path : str, optional
59
475
  The path to the directory where the log files should be stored. Defaults to "logs/".
476
+ file_name : str, optional
477
+ The name of the file to save the response to. Defaults to "data_wrangler.py".
478
+ function_name : str, optional
479
+ The name of the function to be generated. Defaults to "data_wrangler".
60
480
  overwrite : bool, optional
61
481
  Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
62
482
  Defaults to True.
@@ -94,11 +514,16 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
94
514
 
95
515
  Returns
96
516
  -------
97
- app : langchain.graphs.StateGraph
517
+ app : langchain.graphs.CompiledStateGraph
98
518
  The data wrangling agent as a state graph.
99
519
  """
100
520
  llm = model
101
521
 
522
+ # Human in th loop requires recommended steps
523
+ if bypass_recommended_steps and human_in_the_loop:
524
+ bypass_recommended_steps = False
525
+ print("Bypass recommended steps set to False to enable human in the loop.")
526
+
102
527
  # Setup Log Directory
103
528
  if log:
104
529
  if log_path is None:
@@ -122,7 +547,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
122
547
  retry_count: int
123
548
 
124
549
  def recommend_wrangling_steps(state: GraphState):
125
- print("---DATA WRANGLING AGENT----")
550
+ print(format_agent_name(AGENT_NAME))
126
551
  print(" * RECOMMEND WRANGLING STEPS")
127
552
 
128
553
  data_raw = state.get("data_raw")
@@ -143,7 +568,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
143
568
 
144
569
  # Create a summary for all datasets
145
570
  # We'll include a short sample and info for each dataset
146
- all_datasets_summary = get_dataframe_summary(dataframes)
571
+ all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
147
572
 
148
573
  # Join all datasets summaries into one big text block
149
574
  all_datasets_summary_str = "\n\n".join(all_datasets_summary)
@@ -176,6 +601,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
176
601
 
177
602
  Avoid these:
178
603
  1. Do not include steps to save files.
604
+ 2. Do not include unrelated user instructions that are not related to the data wrangling.
179
605
  """,
180
606
  input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
181
607
  )
@@ -188,19 +614,46 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
188
614
  })
189
615
 
190
616
  return {
191
- "recommended_steps": "\n\n# Recommended Wrangling Steps:\n" + recommended_steps.content.strip(),
617
+ "recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Wrangling Steps:"),
192
618
  "all_datasets_summary": all_datasets_summary_str,
193
619
  }
194
620
 
195
621
 
196
622
  def create_data_wrangler_code(state: GraphState):
197
623
  if bypass_recommended_steps:
198
- print("---DATA WRANGLING AGENT----")
624
+ print(format_agent_name(AGENT_NAME))
625
+
626
+ data_raw = state.get("data_raw")
627
+
628
+ if isinstance(data_raw, dict):
629
+ # Single dataset scenario
630
+ primary_dataset_name = "main"
631
+ datasets = {primary_dataset_name: data_raw}
632
+ elif isinstance(data_raw, list) and all(isinstance(item, dict) for item in data_raw):
633
+ # Multiple datasets scenario
634
+ datasets = {f"dataset_{i}": d for i, d in enumerate(data_raw, start=1)}
635
+ primary_dataset_name = "dataset_1"
636
+ else:
637
+ raise ValueError("data_raw must be a dict or a list of dicts.")
638
+
639
+ # Convert all datasets to DataFrames for inspection
640
+ dataframes = {name: pd.DataFrame.from_dict(d) for name, d in datasets.items()}
641
+
642
+ # Create a summary for all datasets
643
+ # We'll include a short sample and info for each dataset
644
+ all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
645
+
646
+ # Join all datasets summaries into one big text block
647
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
648
+
649
+ else:
650
+ all_datasets_summary_str = state.get("all_datasets_summary")
651
+
199
652
  print(" * CREATE DATA WRANGLER CODE")
200
653
 
201
654
  data_wrangling_prompt = PromptTemplate(
202
655
  template="""
203
- You are a Data Wrangling Coding Agent. Your job is to create a data_wrangler() function that can be run on the provided data.
656
+ You are a Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data.
204
657
 
205
658
  Follow these recommended steps:
206
659
  {recommended_steps}
@@ -210,10 +663,10 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
210
663
  Below are summaries of all datasets provided. If more than one dataset is provided, you may need to merge or join them.:
211
664
  {all_datasets_summary}
212
665
 
213
- Return Python code in ```python``` format with a single function definition, data_wrangler(), that includes all imports inside the function. And returns a single pandas data frame.
666
+ Return Python code in ```python``` format with a single function definition, {function_name}(), that includes all imports inside the function. And returns a single pandas data frame.
214
667
 
215
668
  ```python
216
- def data_wrangler(data_list):
669
+ def {function_name}(data_list):
217
670
  '''
218
671
  Wrangle the data provided in data.
219
672
 
@@ -235,23 +688,24 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
235
688
 
236
689
 
237
690
  """,
238
- input_variables=["recommended_steps", "all_datasets_summary"]
691
+ input_variables=["recommended_steps", "all_datasets_summary", "function_name"]
239
692
  )
240
693
 
241
694
  data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
242
695
 
243
696
  response = data_wrangling_agent.invoke({
244
697
  "recommended_steps": state.get("recommended_steps"),
245
- "all_datasets_summary": state.get("all_datasets_summary")
698
+ "all_datasets_summary": all_datasets_summary_str,
699
+ "function_name": function_name
246
700
  })
247
701
 
248
702
  response = relocate_imports_inside_function(response)
249
703
  response = add_comments_to_top(response, agent_name=AGENT_NAME)
250
704
 
251
705
  # For logging: store the code generated
252
- file_path, file_name = log_ai_function(
706
+ file_path, file_name_2 = log_ai_function(
253
707
  response=response,
254
- file_name="data_wrangler.py",
708
+ file_name=file_name,
255
709
  log=log,
256
710
  log_path=log_path,
257
711
  overwrite=overwrite
@@ -260,7 +714,9 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
260
714
  return {
261
715
  "data_wrangler_function" : response,
262
716
  "data_wrangler_function_path": file_path,
263
- "data_wrangler_function_name": file_name
717
+ "data_wrangler_file_name": file_name_2,
718
+ "data_wrangler_function_name": function_name,
719
+ "all_datasets_summary": all_datasets_summary_str
264
720
  }
265
721
 
266
722
 
@@ -273,6 +729,33 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
273
729
  user_instructions_key="user_instructions",
274
730
  recommended_steps_key="recommended_steps"
275
731
  )
732
+
733
+ # Human Review
734
+
735
+ prompt_text_human_review = "Are the following data wrangling instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
736
+
737
+ if not bypass_explain_code:
738
+ def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "explain_data_wrangler_code"]]:
739
+ return node_func_human_review(
740
+ state=state,
741
+ prompt_text=prompt_text_human_review,
742
+ yes_goto= 'explain_data_wrangler_code',
743
+ no_goto="recommend_wrangling_steps",
744
+ user_instructions_key="user_instructions",
745
+ recommended_steps_key="recommended_steps",
746
+ code_snippet_key="data_wrangler_function",
747
+ )
748
+ else:
749
+ def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "__end__"]]:
750
+ return node_func_human_review(
751
+ state=state,
752
+ prompt_text=prompt_text_human_review,
753
+ yes_goto= '__end__',
754
+ no_goto="recommend_wrangling_steps",
755
+ user_instructions_key="user_instructions",
756
+ recommended_steps_key="recommended_steps",
757
+ code_snippet_key="data_wrangler_function",
758
+ )
276
759
 
277
760
  def execute_data_wrangler_code(state: GraphState):
278
761
  return node_func_execute_agent_code_on_data(
@@ -281,7 +764,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
281
764
  result_key="data_wrangled",
282
765
  error_key="data_wrangler_error",
283
766
  code_snippet_key="data_wrangler_function",
284
- agent_function_name="data_wrangler",
767
+ agent_function_name=state.get("data_wrangler_function_name"),
285
768
  # pre_processing=pre_processing,
286
769
  post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
287
770
  error_message_prefix="An error occurred during data wrangling: "
@@ -289,11 +772,11 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
289
772
 
290
773
  def fix_data_wrangler_code(state: GraphState):
291
774
  data_wrangler_prompt = """
292
- You are a Data Wrangling Agent. Your job is to create a data_wrangler() function that can be run on the data provided. The function is currently broken and needs to be fixed.
775
+ You are a Data Wrangling Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
293
776
 
294
- Make sure to only return the function definition for data_wrangler().
777
+ Make sure to only return the function definition for {function_name}().
295
778
 
296
- Return Python code in ```python``` format with a single function definition, data_wrangler(data_raw), that includes all imports inside the function.
779
+ Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
297
780
 
298
781
  This is the broken code (please fix):
299
782
  {code_snippet}
@@ -311,6 +794,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
311
794
  agent_name=AGENT_NAME,
312
795
  log=log,
313
796
  file_path=state.get("data_wrangler_function_path"),
797
+ function_name=state.get("data_wrangler_function_name"),
314
798
  )
315
799
 
316
800
  def explain_data_wrangler_code(state: GraphState):