ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +4 -5
  3. ai_data_science_team/agents/data_cleaning_agent.py +268 -116
  4. ai_data_science_team/agents/data_visualization_agent.py +470 -41
  5. ai_data_science_team/agents/data_wrangling_agent.py +471 -31
  6. ai_data_science_team/agents/feature_engineering_agent.py +426 -41
  7. ai_data_science_team/agents/sql_database_agent.py +458 -58
  8. ai_data_science_team/ml_agents/__init__.py +1 -0
  9. ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
  10. ai_data_science_team/multiagents/__init__.py +1 -0
  11. ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
  12. ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
  13. ai_data_science_team/templates/__init__.py +3 -1
  14. ai_data_science_team/templates/agent_templates.py +319 -43
  15. ai_data_science_team/tools/metadata.py +94 -62
  16. ai_data_science_team/tools/regex.py +86 -1
  17. ai_data_science_team/utils/__init__.py +0 -0
  18. ai_data_science_team/utils/plotly.py +24 -0
  19. ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
  20. ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
  21. ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
  22. ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
  23. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
  24. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
  25. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,12 @@
4
4
  # * Agents: Data Wrangling Agent
5
5
 
6
6
  # Libraries
7
- from typing import TypedDict, Annotated, Sequence, Literal, Union
7
+ from typing import TypedDict, Annotated, Sequence, Literal, Union, Optional
8
8
  import operator
9
9
  import os
10
- import io
10
+ import json
11
11
  import pandas as pd
12
+ from IPython.display import Markdown
12
13
 
13
14
  from langchain.prompts import PromptTemplate
14
15
  from langchain_core.messages import BaseMessage
@@ -19,11 +20,18 @@ from ai_data_science_team.templates import(
19
20
  node_func_execute_agent_code_on_data,
20
21
  node_func_human_review,
21
22
  node_func_fix_agent_code,
22
- node_func_explain_agent_code,
23
- create_coding_agent_graph
23
+ node_func_report_agent_outputs,
24
+ create_coding_agent_graph,
25
+ BaseAgent,
24
26
  )
25
27
  from ai_data_science_team.tools.parsers import PythonOutputParser
26
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
28
+ from ai_data_science_team.tools.regex import (
29
+ relocate_imports_inside_function,
30
+ add_comments_to_top,
31
+ format_agent_name,
32
+ format_recommended_steps,
33
+ get_generic_summary,
34
+ )
27
35
  from ai_data_science_team.tools.metadata import get_dataframe_summary
28
36
  from ai_data_science_team.tools.logging import log_ai_function
29
37
 
@@ -31,13 +39,408 @@ from ai_data_science_team.tools.logging import log_ai_function
31
39
  AGENT_NAME = "data_wrangling_agent"
32
40
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
33
41
 
42
+ # Class
43
+
44
+ class DataWranglingAgent(BaseAgent):
45
+ """
46
+ Creates a data wrangling agent that can work with one or more datasets, performing operations such as
47
+ joining/merging multiple datasets, reshaping, aggregating, encoding, creating computed features,
48
+ and ensuring consistent data types. The agent generates a Python function to wrangle the data,
49
+ executes the function, and logs the process (if enabled).
50
+
51
+ The agent can handle:
52
+ - A single dataset (provided as a dictionary of {column: list_of_values})
53
+ - Multiple datasets (provided as a list of such dictionaries)
54
+
55
+ Key wrangling steps can include:
56
+ - Merging or joining datasets
57
+ - Pivoting/melting data for reshaping
58
+ - GroupBy aggregations (sums, means, counts, etc.)
59
+ - Encoding categorical variables
60
+ - Computing new columns from existing ones
61
+ - Dropping or rearranging columns
62
+ - Any additional user instructions
63
+
64
+ Parameters
65
+ ----------
66
+ model : langchain.llms.base.LLM
67
+ The language model used to generate the data wrangling function.
68
+ n_samples : int, optional
69
+ Number of samples to show in the data summary for wrangling. Defaults to 30.
70
+ log : bool, optional
71
+ Whether to log the generated code and errors. Defaults to False.
72
+ log_path : str, optional
73
+ Directory path for storing log files. Defaults to None.
74
+ file_name : str, optional
75
+ Name of the file for saving the generated response. Defaults to "data_wrangler.py".
76
+ function_name : str, optional
77
+ Name of the function to be generated. Defaults to "data_wrangler".
78
+ overwrite : bool, optional
79
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
80
+ human_in_the_loop : bool, optional
81
+ Enables user review of data wrangling instructions. Defaults to False.
82
+ bypass_recommended_steps : bool, optional
83
+ If True, skips the step that generates recommended data wrangling steps. Defaults to False.
84
+ bypass_explain_code : bool, optional
85
+ If True, skips the step that provides code explanations. Defaults to False.
86
+
87
+ Methods
88
+ -------
89
+ update_params(**kwargs)
90
+ Updates the agent's parameters and rebuilds the compiled state graph.
91
+
92
+ ainvoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
93
+ Asynchronously wrangles the provided dataset(s) based on user instructions.
94
+
95
+ invoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
96
+ Synchronously wrangles the provided dataset(s) based on user instructions.
97
+
98
+ get_workflow_summary()
99
+ Retrieves a summary of the agent's workflow.
100
+
101
+ get_log_summary()
102
+ Retrieves a summary of logged operations if logging is enabled.
103
+
104
+ get_data_wrangled()
105
+ Retrieves the final wrangled dataset (as a dictionary of {column: list_of_values}).
106
+
107
+ get_data_raw()
108
+ Retrieves the raw dataset(s).
109
+
110
+ get_data_wrangler_function()
111
+ Retrieves the generated Python function used for data wrangling.
112
+
113
+ get_recommended_wrangling_steps()
114
+ Retrieves the agent's recommended wrangling steps.
115
+
116
+ get_response()
117
+ Returns the full response dictionary from the agent.
118
+
119
+ show()
120
+ Displays the agent's mermaid diagram for visual inspection of the compiled graph.
121
+
122
+ Examples
123
+ --------
124
+ ```python
125
+ import pandas as pd
126
+ from langchain_openai import ChatOpenAI
127
+ from ai_data_science_team.agents import DataWranglingAgent
128
+
129
+ # Single dataset example
130
+ llm = ChatOpenAI(model="gpt-4o-mini")
131
+
132
+ data_wrangling_agent = DataWranglingAgent(
133
+ model=llm,
134
+ n_samples=30,
135
+ log=True,
136
+ log_path="logs",
137
+ human_in_the_loop=True
138
+ )
139
+
140
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
141
+
142
+ data_wrangling_agent.invoke_agent(
143
+ user_instructions="Group by 'gender' and compute mean of 'tenure'.",
144
+ data_raw=df, # data_raw can be df.to_dict() or just a DataFrame
145
+ max_retries=3,
146
+ retry_count=0
147
+ )
148
+
149
+ data_wrangled = data_wrangling_agent.get_data_wrangled()
150
+ response = data_wrangling_agent.get_response()
151
+
152
+ # Multiple dataset example (list of dicts)
153
+ df1 = pd.DataFrame({'id': [1,2,3], 'val1': [10,20,30]})
154
+ df2 = pd.DataFrame({'id': [1,2,3], 'val2': [40,50,60]})
155
+
156
+ data_wrangling_agent.invoke_agent(
157
+ user_instructions="Merge these two datasets on 'id' and compute a new column 'val_sum' = val1+val2",
158
+ data_raw=[df1, df2], # multiple datasets
159
+ max_retries=3,
160
+ retry_count=0
161
+ )
162
+
163
+ data_wrangled = data_wrangling_agent.get_data_wrangled()
164
+ ```
165
+
166
+ Returns
167
+ -------
168
+ DataWranglingAgent : langchain.graphs.CompiledStateGraph
169
+ A data wrangling agent implemented as a compiled state graph.
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ model,
175
+ n_samples=30,
176
+ log=False,
177
+ log_path=None,
178
+ file_name="data_wrangler.py",
179
+ function_name="data_wrangler",
180
+ overwrite=True,
181
+ human_in_the_loop=False,
182
+ bypass_recommended_steps=False,
183
+ bypass_explain_code=False
184
+ ):
185
+ self._params = {
186
+ "model": model,
187
+ "n_samples": n_samples,
188
+ "log": log,
189
+ "log_path": log_path,
190
+ "file_name": file_name,
191
+ "function_name": function_name,
192
+ "overwrite": overwrite,
193
+ "human_in_the_loop": human_in_the_loop,
194
+ "bypass_recommended_steps": bypass_recommended_steps,
195
+ "bypass_explain_code": bypass_explain_code
196
+ }
197
+ self._compiled_graph = self._make_compiled_graph()
198
+ self.response = None
199
+
200
+ def _make_compiled_graph(self):
201
+ """
202
+ Create the compiled graph for the data wrangling agent.
203
+ Running this method will reset the response to None.
204
+ """
205
+ self.response = None
206
+ return make_data_wrangling_agent(**self._params)
207
+
208
+ def update_params(self, **kwargs):
209
+ """
210
+ Updates the agent's parameters and rebuilds the compiled graph.
211
+ """
212
+ for k, v in kwargs.items():
213
+ self._params[k] = v
214
+ self._compiled_graph = self._make_compiled_graph()
215
+
216
+ def ainvoke_agent(
217
+ self,
218
+ data_raw: Union[pd.DataFrame, dict, list],
219
+ user_instructions: str=None,
220
+ max_retries:int=3,
221
+ retry_count:int=0,
222
+ **kwargs
223
+ ):
224
+ """
225
+ Asynchronously wrangles the provided dataset(s) based on user instructions.
226
+ The response is stored in the 'response' attribute.
227
+
228
+ Parameters
229
+ ----------
230
+ data_raw : Union[pd.DataFrame, dict, list]
231
+ The raw dataset(s) to be wrangled.
232
+ Can be a single DataFrame, a single dict ({col: list_of_values}),
233
+ or a list of dicts if multiple datasets are provided.
234
+ user_instructions : str
235
+ Instructions for data wrangling.
236
+ max_retries : int
237
+ Maximum retry attempts.
238
+ retry_count : int
239
+ Current retry attempt count.
240
+ **kwargs
241
+ Additional keyword arguments to pass to ainvoke().
242
+
243
+ Returns
244
+ -------
245
+ None
246
+ """
247
+ data_input = self._convert_data_input(data_raw)
248
+ response = self._compiled_graph.ainvoke({
249
+ "user_instructions": user_instructions,
250
+ "data_raw": data_input,
251
+ "max_retries": max_retries,
252
+ "retry_count": retry_count
253
+ }, **kwargs)
254
+ self.response = response
255
+ return None
256
+
257
+ def invoke_agent(
258
+ self,
259
+ data_raw: Union[pd.DataFrame, dict, list],
260
+ user_instructions: str=None,
261
+ max_retries:int=3,
262
+ retry_count:int=0,
263
+ **kwargs
264
+ ):
265
+ """
266
+ Synchronously wrangles the provided dataset(s) based on user instructions.
267
+ The response is stored in the 'response' attribute.
268
+
269
+ Parameters
270
+ ----------
271
+ data_raw : Union[pd.DataFrame, dict, list]
272
+ The raw dataset(s) to be wrangled.
273
+ Can be a single DataFrame, a single dict, or a list of dicts.
274
+ user_instructions : str
275
+ Instructions for data wrangling agent.
276
+ max_retries : int
277
+ Maximum retry attempts.
278
+ retry_count : int
279
+ Current retry attempt count.
280
+ **kwargs
281
+ Additional keyword arguments to pass to invoke().
282
+
283
+ Returns
284
+ -------
285
+ None
286
+ """
287
+ data_input = self._convert_data_input(data_raw)
288
+ response = self._compiled_graph.invoke({
289
+ "user_instructions": user_instructions,
290
+ "data_raw": data_input,
291
+ "max_retries": max_retries,
292
+ "retry_count": retry_count
293
+ }, **kwargs)
294
+ self.response = response
295
+ return None
296
+
297
+ def get_workflow_summary(self, markdown=False):
298
+ """
299
+ Retrieves the agent's workflow summary, if logging is enabled.
300
+ """
301
+ if self.response and self.response.get("messages"):
302
+ summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
303
+ if markdown:
304
+ return Markdown(summary)
305
+ else:
306
+ return summary
307
+
308
+ def get_log_summary(self, markdown=False):
309
+ """
310
+ Logs a summary of the agent's operations, if logging is enabled.
311
+ """
312
+ if self.response:
313
+ if self.response.get('data_wrangler_function_path'):
314
+ log_details = f"""
315
+ ## Data Wrangling Agent Log Summary:
316
+
317
+ Function Path: {self.response.get('data_wrangler_function_path')}
318
+
319
+ Function Name: {self.response.get('data_wrangler_function_name')}
320
+ """
321
+ if markdown:
322
+ return Markdown(log_details)
323
+ else:
324
+ return log_details
325
+
326
+ def get_data_wrangled(self) -> Optional[pd.DataFrame]:
327
+ """
328
+ Retrieves the wrangled data after running invoke_agent() or ainvoke_agent().
329
+
330
+ Returns
331
+ -------
332
+ pd.DataFrame or None
333
+ The wrangled dataset as a pandas DataFrame (if available).
334
+ """
335
+ if self.response and "data_wrangled" in self.response:
336
+ return pd.DataFrame(self.response["data_wrangled"])
337
+ return None
338
+
339
+ def get_data_raw(self) -> Union[dict, list, None]:
340
+ """
341
+ Retrieves the original raw data from the last invocation.
342
+
343
+ Returns
344
+ -------
345
+ Union[dict, list, None]
346
+ The original dataset(s) as a single dict or a list of dicts, or None if not available.
347
+ """
348
+ if self.response and "data_raw" in self.response:
349
+ return self.response["data_raw"]
350
+ return None
351
+
352
+ def get_data_wrangler_function(self, markdown=False) -> Optional[str]:
353
+ """
354
+ Retrieves the generated data wrangling function code.
355
+
356
+ Parameters
357
+ ----------
358
+ markdown : bool, optional
359
+ If True, returns the function in Markdown code block format.
360
+
361
+ Returns
362
+ -------
363
+ str or None
364
+ The Python function code, or None if not available.
365
+ """
366
+ if self.response and "data_wrangler_function" in self.response:
367
+ code = self.response["data_wrangler_function"]
368
+ if markdown:
369
+ return Markdown(f"```python\n{code}\n```")
370
+ return code
371
+ return None
372
+
373
+ def get_recommended_wrangling_steps(self, markdown=False) -> Optional[str]:
374
+ """
375
+ Retrieves the agent's recommended data wrangling steps.
376
+
377
+ Parameters
378
+ ----------
379
+ markdown : bool, optional
380
+ If True, returns the steps in Markdown format.
381
+
382
+ Returns
383
+ -------
384
+ str or None
385
+ The recommended steps, or None if not available.
386
+ """
387
+ if self.response and "recommended_steps" in self.response:
388
+ steps = self.response["recommended_steps"]
389
+ if markdown:
390
+ return Markdown(steps)
391
+ return steps
392
+ return None
393
+
394
+ @staticmethod
395
+ def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
396
+ """
397
+ Internal utility to convert data_raw (which could be a DataFrame, dict, or list of dicts)
398
+ into the format expected by the underlying agent (dict or list of dicts).
399
+
400
+ Parameters
401
+ ----------
402
+ data_raw : Union[pd.DataFrame, dict, list]
403
+ The raw input data to be converted.
404
+
405
+ Returns
406
+ -------
407
+ Union[dict, list]
408
+ The data in a dictionary or list-of-dictionaries format.
409
+ """
410
+ # If a single DataFrame, convert to dict
411
+ if isinstance(data_raw, pd.DataFrame):
412
+ return data_raw.to_dict()
413
+
414
+ # If it's already a dict (single dataset)
415
+ if isinstance(data_raw, dict):
416
+ return data_raw
417
+
418
+ # If it's already a list, check if it's a list of DataFrames or dicts
419
+ if isinstance(data_raw, list):
420
+ # Convert any DataFrame item to dict
421
+ converted_list = []
422
+ for item in data_raw:
423
+ if isinstance(item, pd.DataFrame):
424
+ converted_list.append(item.to_dict())
425
+ elif isinstance(item, dict):
426
+ converted_list.append(item)
427
+ else:
428
+ raise ValueError("List must contain only DataFrames or dictionaries.")
429
+ return converted_list
430
+
431
+ raise ValueError("data_raw must be a DataFrame, a dict, or a list of dicts/DataFrames.")
432
+
433
+
434
+ # Function
435
+
34
436
  def make_data_wrangling_agent(
35
437
  model,
36
438
  n_samples=30,
37
439
  log=False,
38
440
  log_path=None,
39
441
  file_name="data_wrangler.py",
40
- overwrite = True,
442
+ function_name="data_wrangler",
443
+ overwrite=True,
41
444
  human_in_the_loop=False,
42
445
  bypass_recommended_steps=False,
43
446
  bypass_explain_code=False
@@ -73,6 +476,8 @@ def make_data_wrangling_agent(
73
476
  The path to the directory where the log files should be stored. Defaults to "logs/".
74
477
  file_name : str, optional
75
478
  The name of the file to save the response to. Defaults to "data_wrangler.py".
479
+ function_name : str, optional
480
+ The name of the function to be generated. Defaults to "data_wrangler".
76
481
  overwrite : bool, optional
77
482
  Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
78
483
  Defaults to True.
@@ -115,6 +520,11 @@ def make_data_wrangling_agent(
115
520
  """
116
521
  llm = model
117
522
 
523
+ # Human in th loop requires recommended steps
524
+ if bypass_recommended_steps and human_in_the_loop:
525
+ bypass_recommended_steps = False
526
+ print("Bypass recommended steps set to False to enable human in the loop.")
527
+
118
528
  # Setup Log Directory
119
529
  if log:
120
530
  if log_path is None:
@@ -188,7 +598,7 @@ def make_data_wrangling_agent(
188
598
  Below are summaries of all datasets provided:
189
599
  {all_datasets_summary}
190
600
 
191
- Return your recommended steps as a numbered point list, explaining briefly why each step is needed.
601
+ Return steps as a numbered list. You can return short code snippets to demonstrate actions. But do not return a fully coded solution. The code will be generated separately by a Coding Agent.
192
602
 
193
603
  Avoid these:
194
604
  1. Do not include steps to save files.
@@ -205,7 +615,7 @@ def make_data_wrangling_agent(
205
615
  })
206
616
 
207
617
  return {
208
- "recommended_steps": "\n\n# Recommended Wrangling Steps:\n" + recommended_steps.content.strip(),
618
+ "recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Wrangling Steps:"),
209
619
  "all_datasets_summary": all_datasets_summary_str,
210
620
  }
211
621
 
@@ -244,7 +654,7 @@ def make_data_wrangling_agent(
244
654
 
245
655
  data_wrangling_prompt = PromptTemplate(
246
656
  template="""
247
- You are a Data Wrangling Coding Agent. Your job is to create a data_wrangler() function that can be run on the provided data.
657
+ You are a Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data.
248
658
 
249
659
  Follow these recommended steps:
250
660
  {recommended_steps}
@@ -254,10 +664,10 @@ def make_data_wrangling_agent(
254
664
  Below are summaries of all datasets provided. If more than one dataset is provided, you may need to merge or join them.:
255
665
  {all_datasets_summary}
256
666
 
257
- Return Python code in ```python``` format with a single function definition, data_wrangler(), that includes all imports inside the function. And returns a single pandas data frame.
667
+ Return Python code in ```python``` format with a single function definition, {function_name}(), that includes all imports inside the function. And returns a single pandas data frame.
258
668
 
259
669
  ```python
260
- def data_wrangler(data_list):
670
+ def {function_name}(data_list):
261
671
  '''
262
672
  Wrangle the data provided in data.
263
673
 
@@ -279,14 +689,15 @@ def make_data_wrangling_agent(
279
689
 
280
690
 
281
691
  """,
282
- input_variables=["recommended_steps", "all_datasets_summary"]
692
+ input_variables=["recommended_steps", "all_datasets_summary", "function_name"]
283
693
  )
284
694
 
285
695
  data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
286
696
 
287
697
  response = data_wrangling_agent.invoke({
288
698
  "recommended_steps": state.get("recommended_steps"),
289
- "all_datasets_summary": all_datasets_summary_str
699
+ "all_datasets_summary": all_datasets_summary_str,
700
+ "function_name": function_name
290
701
  })
291
702
 
292
703
  response = relocate_imports_inside_function(response)
@@ -304,7 +715,8 @@ def make_data_wrangling_agent(
304
715
  return {
305
716
  "data_wrangler_function" : response,
306
717
  "data_wrangler_function_path": file_path,
307
- "data_wrangler_function_name": file_name_2,
718
+ "data_wrangler_file_name": file_name_2,
719
+ "data_wrangler_function_name": function_name,
308
720
  "all_datasets_summary": all_datasets_summary_str
309
721
  }
310
722
 
@@ -318,6 +730,33 @@ def make_data_wrangling_agent(
318
730
  user_instructions_key="user_instructions",
319
731
  recommended_steps_key="recommended_steps"
320
732
  )
733
+
734
+ # Human Review
735
+
736
+ prompt_text_human_review = "Are the following data wrangling instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
737
+
738
+ if not bypass_explain_code:
739
+ def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "explain_data_wrangler_code"]]:
740
+ return node_func_human_review(
741
+ state=state,
742
+ prompt_text=prompt_text_human_review,
743
+ yes_goto= 'explain_data_wrangler_code',
744
+ no_goto="recommend_wrangling_steps",
745
+ user_instructions_key="user_instructions",
746
+ recommended_steps_key="recommended_steps",
747
+ code_snippet_key="data_wrangler_function",
748
+ )
749
+ else:
750
+ def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "__end__"]]:
751
+ return node_func_human_review(
752
+ state=state,
753
+ prompt_text=prompt_text_human_review,
754
+ yes_goto= '__end__',
755
+ no_goto="recommend_wrangling_steps",
756
+ user_instructions_key="user_instructions",
757
+ recommended_steps_key="recommended_steps",
758
+ code_snippet_key="data_wrangler_function",
759
+ )
321
760
 
322
761
  def execute_data_wrangler_code(state: GraphState):
323
762
  return node_func_execute_agent_code_on_data(
@@ -326,7 +765,7 @@ def make_data_wrangling_agent(
326
765
  result_key="data_wrangled",
327
766
  error_key="data_wrangler_error",
328
767
  code_snippet_key="data_wrangler_function",
329
- agent_function_name="data_wrangler",
768
+ agent_function_name=state.get("data_wrangler_function_name"),
330
769
  # pre_processing=pre_processing,
331
770
  post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
332
771
  error_message_prefix="An error occurred during data wrangling: "
@@ -334,11 +773,11 @@ def make_data_wrangling_agent(
334
773
 
335
774
  def fix_data_wrangler_code(state: GraphState):
336
775
  data_wrangler_prompt = """
337
- You are a Data Wrangling Agent. Your job is to create a data_wrangler() function that can be run on the data provided. The function is currently broken and needs to be fixed.
776
+ You are a Data Wrangling Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
338
777
 
339
- Make sure to only return the function definition for data_wrangler().
778
+ Make sure to only return the function definition for {function_name}().
340
779
 
341
- Return Python code in ```python``` format with a single function definition, data_wrangler(data_raw), that includes all imports inside the function.
780
+ Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
342
781
 
343
782
  This is the broken code (please fix):
344
783
  {code_snippet}
@@ -356,22 +795,23 @@ def make_data_wrangling_agent(
356
795
  agent_name=AGENT_NAME,
357
796
  log=log,
358
797
  file_path=state.get("data_wrangler_function_path"),
798
+ function_name=state.get("data_wrangler_function_name"),
359
799
  )
360
800
 
361
- def explain_data_wrangler_code(state: GraphState):
362
- return node_func_explain_agent_code(
801
+ # Final reporting node
802
+ def report_agent_outputs(state: GraphState):
803
+ return node_func_report_agent_outputs(
363
804
  state=state,
364
- code_snippet_key="data_wrangler_function",
805
+ keys_to_include=[
806
+ "recommended_steps",
807
+ "data_wrangler_function",
808
+ "data_wrangler_function_path",
809
+ "data_wrangler_function_name",
810
+ "data_wrangler_error",
811
+ ],
365
812
  result_key="messages",
366
- error_key="data_wrangler_error",
367
- llm=llm,
368
813
  role=AGENT_NAME,
369
- explanation_prompt_template="""
370
- Explain the data wrangling steps that the data wrangling agent performed in this function.
371
- Keep the summary succinct and to the point.\n\n# Data Wrangling Agent:\n\n{code}
372
- """,
373
- success_prefix="# Data Wrangling Agent:\n\n ",
374
- error_message="The Data Wrangling Agent encountered an error during data wrangling. Data could not be explained."
814
+ custom_title="Data Wrangling Agent Outputs"
375
815
  )
376
816
 
377
817
  # Define the graph
@@ -381,7 +821,7 @@ def make_data_wrangling_agent(
381
821
  "create_data_wrangler_code": create_data_wrangler_code,
382
822
  "execute_data_wrangler_code": execute_data_wrangler_code,
383
823
  "fix_data_wrangler_code": fix_data_wrangler_code,
384
- "explain_data_wrangler_code": explain_data_wrangler_code
824
+ "report_agent_outputs": report_agent_outputs,
385
825
  }
386
826
 
387
827
  app = create_coding_agent_graph(
@@ -391,7 +831,7 @@ def make_data_wrangling_agent(
391
831
  create_code_node_name="create_data_wrangler_code",
392
832
  execute_code_node_name="execute_data_wrangler_code",
393
833
  fix_code_node_name="fix_data_wrangler_code",
394
- explain_code_node_name="explain_data_wrangler_code",
834
+ explain_code_node_name="report_agent_outputs",
395
835
  error_key="data_wrangler_error",
396
836
  human_in_the_loop=human_in_the_loop,
397
837
  human_review_node_name="human_review",