ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (25) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +4 -5
  3. ai_data_science_team/agents/data_cleaning_agent.py +268 -116
  4. ai_data_science_team/agents/data_visualization_agent.py +470 -41
  5. ai_data_science_team/agents/data_wrangling_agent.py +471 -31
  6. ai_data_science_team/agents/feature_engineering_agent.py +426 -41
  7. ai_data_science_team/agents/sql_database_agent.py +458 -58
  8. ai_data_science_team/ml_agents/__init__.py +1 -0
  9. ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
  10. ai_data_science_team/multiagents/__init__.py +1 -0
  11. ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
  12. ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
  13. ai_data_science_team/templates/__init__.py +3 -1
  14. ai_data_science_team/templates/agent_templates.py +319 -43
  15. ai_data_science_team/tools/metadata.py +94 -62
  16. ai_data_science_team/tools/regex.py +86 -1
  17. ai_data_science_team/utils/__init__.py +0 -0
  18. ai_data_science_team/utils/plotly.py +24 -0
  19. ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
  20. ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
  21. ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
  22. ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
  23. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
  24. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
  25. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,12 @@
4
4
  # * Agents: Data Wrangling Agent
5
5
 
6
6
  # Libraries
7
- from typing import TypedDict, Annotated, Sequence, Literal, Union
7
+ from typing import TypedDict, Annotated, Sequence, Literal, Union, Optional
8
8
  import operator
9
9
  import os
10
- import io
10
+ import json
11
11
  import pandas as pd
12
+ from IPython.display import Markdown
12
13
 
13
14
  from langchain.prompts import PromptTemplate
14
15
  from langchain_core.messages import BaseMessage
@@ -19,11 +20,18 @@ from ai_data_science_team.templates import(
19
20
  node_func_execute_agent_code_on_data,
20
21
  node_func_human_review,
21
22
  node_func_fix_agent_code,
22
- node_func_explain_agent_code,
23
- create_coding_agent_graph
23
+ node_func_report_agent_outputs,
24
+ create_coding_agent_graph,
25
+ BaseAgent,
24
26
  )
25
27
  from ai_data_science_team.tools.parsers import PythonOutputParser
26
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
28
+ from ai_data_science_team.tools.regex import (
29
+ relocate_imports_inside_function,
30
+ add_comments_to_top,
31
+ format_agent_name,
32
+ format_recommended_steps,
33
+ get_generic_summary,
34
+ )
27
35
  from ai_data_science_team.tools.metadata import get_dataframe_summary
28
36
  from ai_data_science_team.tools.logging import log_ai_function
29
37
 
@@ -31,13 +39,408 @@ from ai_data_science_team.tools.logging import log_ai_function
31
39
  AGENT_NAME = "data_wrangling_agent"
32
40
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
33
41
 
42
+ # Class
43
+
44
+ class DataWranglingAgent(BaseAgent):
45
+ """
46
+ Creates a data wrangling agent that can work with one or more datasets, performing operations such as
47
+ joining/merging multiple datasets, reshaping, aggregating, encoding, creating computed features,
48
+ and ensuring consistent data types. The agent generates a Python function to wrangle the data,
49
+ executes the function, and logs the process (if enabled).
50
+
51
+ The agent can handle:
52
+ - A single dataset (provided as a dictionary of {column: list_of_values})
53
+ - Multiple datasets (provided as a list of such dictionaries)
54
+
55
+ Key wrangling steps can include:
56
+ - Merging or joining datasets
57
+ - Pivoting/melting data for reshaping
58
+ - GroupBy aggregations (sums, means, counts, etc.)
59
+ - Encoding categorical variables
60
+ - Computing new columns from existing ones
61
+ - Dropping or rearranging columns
62
+ - Any additional user instructions
63
+
64
+ Parameters
65
+ ----------
66
+ model : langchain.llms.base.LLM
67
+ The language model used to generate the data wrangling function.
68
+ n_samples : int, optional
69
+ Number of samples to show in the data summary for wrangling. Defaults to 30.
70
+ log : bool, optional
71
+ Whether to log the generated code and errors. Defaults to False.
72
+ log_path : str, optional
73
+ Directory path for storing log files. Defaults to None.
74
+ file_name : str, optional
75
+ Name of the file for saving the generated response. Defaults to "data_wrangler.py".
76
+ function_name : str, optional
77
+ Name of the function to be generated. Defaults to "data_wrangler".
78
+ overwrite : bool, optional
79
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
80
+ human_in_the_loop : bool, optional
81
+ Enables user review of data wrangling instructions. Defaults to False.
82
+ bypass_recommended_steps : bool, optional
83
+ If True, skips the step that generates recommended data wrangling steps. Defaults to False.
84
+ bypass_explain_code : bool, optional
85
+ If True, skips the step that provides code explanations. Defaults to False.
86
+
87
+ Methods
88
+ -------
89
+ update_params(**kwargs)
90
+ Updates the agent's parameters and rebuilds the compiled state graph.
91
+
92
+ ainvoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
93
+ Asynchronously wrangles the provided dataset(s) based on user instructions.
94
+
95
+ invoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
96
+ Synchronously wrangles the provided dataset(s) based on user instructions.
97
+
98
+ get_workflow_summary()
99
+ Retrieves a summary of the agent's workflow.
100
+
101
+ get_log_summary()
102
+ Retrieves a summary of logged operations if logging is enabled.
103
+
104
+ get_data_wrangled()
105
+ Retrieves the final wrangled dataset (as a dictionary of {column: list_of_values}).
106
+
107
+ get_data_raw()
108
+ Retrieves the raw dataset(s).
109
+
110
+ get_data_wrangler_function()
111
+ Retrieves the generated Python function used for data wrangling.
112
+
113
+ get_recommended_wrangling_steps()
114
+ Retrieves the agent's recommended wrangling steps.
115
+
116
+ get_response()
117
+ Returns the full response dictionary from the agent.
118
+
119
+ show()
120
+ Displays the agent's mermaid diagram for visual inspection of the compiled graph.
121
+
122
+ Examples
123
+ --------
124
+ ```python
125
+ import pandas as pd
126
+ from langchain_openai import ChatOpenAI
127
+ from ai_data_science_team.agents import DataWranglingAgent
128
+
129
+ # Single dataset example
130
+ llm = ChatOpenAI(model="gpt-4o-mini")
131
+
132
+ data_wrangling_agent = DataWranglingAgent(
133
+ model=llm,
134
+ n_samples=30,
135
+ log=True,
136
+ log_path="logs",
137
+ human_in_the_loop=True
138
+ )
139
+
140
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
141
+
142
+ data_wrangling_agent.invoke_agent(
143
+ user_instructions="Group by 'gender' and compute mean of 'tenure'.",
144
+ data_raw=df, # data_raw can be df.to_dict() or just a DataFrame
145
+ max_retries=3,
146
+ retry_count=0
147
+ )
148
+
149
+ data_wrangled = data_wrangling_agent.get_data_wrangled()
150
+ response = data_wrangling_agent.get_response()
151
+
152
+ # Multiple dataset example (list of dicts)
153
+ df1 = pd.DataFrame({'id': [1,2,3], 'val1': [10,20,30]})
154
+ df2 = pd.DataFrame({'id': [1,2,3], 'val2': [40,50,60]})
155
+
156
+ data_wrangling_agent.invoke_agent(
157
+ user_instructions="Merge these two datasets on 'id' and compute a new column 'val_sum' = val1+val2",
158
+ data_raw=[df1, df2], # multiple datasets
159
+ max_retries=3,
160
+ retry_count=0
161
+ )
162
+
163
+ data_wrangled = data_wrangling_agent.get_data_wrangled()
164
+ ```
165
+
166
+ Returns
167
+ -------
168
+ DataWranglingAgent : langchain.graphs.CompiledStateGraph
169
+ A data wrangling agent implemented as a compiled state graph.
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ model,
175
+ n_samples=30,
176
+ log=False,
177
+ log_path=None,
178
+ file_name="data_wrangler.py",
179
+ function_name="data_wrangler",
180
+ overwrite=True,
181
+ human_in_the_loop=False,
182
+ bypass_recommended_steps=False,
183
+ bypass_explain_code=False
184
+ ):
185
+ self._params = {
186
+ "model": model,
187
+ "n_samples": n_samples,
188
+ "log": log,
189
+ "log_path": log_path,
190
+ "file_name": file_name,
191
+ "function_name": function_name,
192
+ "overwrite": overwrite,
193
+ "human_in_the_loop": human_in_the_loop,
194
+ "bypass_recommended_steps": bypass_recommended_steps,
195
+ "bypass_explain_code": bypass_explain_code
196
+ }
197
+ self._compiled_graph = self._make_compiled_graph()
198
+ self.response = None
199
+
200
+ def _make_compiled_graph(self):
201
+ """
202
+ Create the compiled graph for the data wrangling agent.
203
+ Running this method will reset the response to None.
204
+ """
205
+ self.response = None
206
+ return make_data_wrangling_agent(**self._params)
207
+
208
+ def update_params(self, **kwargs):
209
+ """
210
+ Updates the agent's parameters and rebuilds the compiled graph.
211
+ """
212
+ for k, v in kwargs.items():
213
+ self._params[k] = v
214
+ self._compiled_graph = self._make_compiled_graph()
215
+
216
+ def ainvoke_agent(
217
+ self,
218
+ data_raw: Union[pd.DataFrame, dict, list],
219
+ user_instructions: str=None,
220
+ max_retries:int=3,
221
+ retry_count:int=0,
222
+ **kwargs
223
+ ):
224
+ """
225
+ Asynchronously wrangles the provided dataset(s) based on user instructions.
226
+ The response is stored in the 'response' attribute.
227
+
228
+ Parameters
229
+ ----------
230
+ data_raw : Union[pd.DataFrame, dict, list]
231
+ The raw dataset(s) to be wrangled.
232
+ Can be a single DataFrame, a single dict ({col: list_of_values}),
233
+ or a list of dicts if multiple datasets are provided.
234
+ user_instructions : str
235
+ Instructions for data wrangling.
236
+ max_retries : int
237
+ Maximum retry attempts.
238
+ retry_count : int
239
+ Current retry attempt count.
240
+ **kwargs
241
+ Additional keyword arguments to pass to ainvoke().
242
+
243
+ Returns
244
+ -------
245
+ None
246
+ """
247
+ data_input = self._convert_data_input(data_raw)
248
+ response = self._compiled_graph.ainvoke({
249
+ "user_instructions": user_instructions,
250
+ "data_raw": data_input,
251
+ "max_retries": max_retries,
252
+ "retry_count": retry_count
253
+ }, **kwargs)
254
+ self.response = response
255
+ return None
256
+
257
+ def invoke_agent(
258
+ self,
259
+ data_raw: Union[pd.DataFrame, dict, list],
260
+ user_instructions: str=None,
261
+ max_retries:int=3,
262
+ retry_count:int=0,
263
+ **kwargs
264
+ ):
265
+ """
266
+ Synchronously wrangles the provided dataset(s) based on user instructions.
267
+ The response is stored in the 'response' attribute.
268
+
269
+ Parameters
270
+ ----------
271
+ data_raw : Union[pd.DataFrame, dict, list]
272
+ The raw dataset(s) to be wrangled.
273
+ Can be a single DataFrame, a single dict, or a list of dicts.
274
+ user_instructions : str
275
+ Instructions for data wrangling agent.
276
+ max_retries : int
277
+ Maximum retry attempts.
278
+ retry_count : int
279
+ Current retry attempt count.
280
+ **kwargs
281
+ Additional keyword arguments to pass to invoke().
282
+
283
+ Returns
284
+ -------
285
+ None
286
+ """
287
+ data_input = self._convert_data_input(data_raw)
288
+ response = self._compiled_graph.invoke({
289
+ "user_instructions": user_instructions,
290
+ "data_raw": data_input,
291
+ "max_retries": max_retries,
292
+ "retry_count": retry_count
293
+ }, **kwargs)
294
+ self.response = response
295
+ return None
296
+
297
+ def get_workflow_summary(self, markdown=False):
298
+ """
299
+ Retrieves the agent's workflow summary, if logging is enabled.
300
+ """
301
+ if self.response and self.response.get("messages"):
302
+ summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
303
+ if markdown:
304
+ return Markdown(summary)
305
+ else:
306
+ return summary
307
+
308
+ def get_log_summary(self, markdown=False):
309
+ """
310
+ Logs a summary of the agent's operations, if logging is enabled.
311
+ """
312
+ if self.response:
313
+ if self.response.get('data_wrangler_function_path'):
314
+ log_details = f"""
315
+ ## Data Wrangling Agent Log Summary:
316
+
317
+ Function Path: {self.response.get('data_wrangler_function_path')}
318
+
319
+ Function Name: {self.response.get('data_wrangler_function_name')}
320
+ """
321
+ if markdown:
322
+ return Markdown(log_details)
323
+ else:
324
+ return log_details
325
+
326
+ def get_data_wrangled(self) -> Optional[pd.DataFrame]:
327
+ """
328
+ Retrieves the wrangled data after running invoke_agent() or ainvoke_agent().
329
+
330
+ Returns
331
+ -------
332
+ pd.DataFrame or None
333
+ The wrangled dataset as a pandas DataFrame (if available).
334
+ """
335
+ if self.response and "data_wrangled" in self.response:
336
+ return pd.DataFrame(self.response["data_wrangled"])
337
+ return None
338
+
339
+ def get_data_raw(self) -> Union[dict, list, None]:
340
+ """
341
+ Retrieves the original raw data from the last invocation.
342
+
343
+ Returns
344
+ -------
345
+ Union[dict, list, None]
346
+ The original dataset(s) as a single dict or a list of dicts, or None if not available.
347
+ """
348
+ if self.response and "data_raw" in self.response:
349
+ return self.response["data_raw"]
350
+ return None
351
+
352
+ def get_data_wrangler_function(self, markdown=False) -> Optional[str]:
353
+ """
354
+ Retrieves the generated data wrangling function code.
355
+
356
+ Parameters
357
+ ----------
358
+ markdown : bool, optional
359
+ If True, returns the function in Markdown code block format.
360
+
361
+ Returns
362
+ -------
363
+ str or None
364
+ The Python function code, or None if not available.
365
+ """
366
+ if self.response and "data_wrangler_function" in self.response:
367
+ code = self.response["data_wrangler_function"]
368
+ if markdown:
369
+ return Markdown(f"```python\n{code}\n```")
370
+ return code
371
+ return None
372
+
373
+ def get_recommended_wrangling_steps(self, markdown=False) -> Optional[str]:
374
+ """
375
+ Retrieves the agent's recommended data wrangling steps.
376
+
377
+ Parameters
378
+ ----------
379
+ markdown : bool, optional
380
+ If True, returns the steps in Markdown format.
381
+
382
+ Returns
383
+ -------
384
+ str or None
385
+ The recommended steps, or None if not available.
386
+ """
387
+ if self.response and "recommended_steps" in self.response:
388
+ steps = self.response["recommended_steps"]
389
+ if markdown:
390
+ return Markdown(steps)
391
+ return steps
392
+ return None
393
+
394
+ @staticmethod
395
+ def _convert_data_input(data_raw: Union[pd.DataFrame, dict, list]) -> Union[dict, list]:
396
+ """
397
+ Internal utility to convert data_raw (which could be a DataFrame, dict, or list of dicts)
398
+ into the format expected by the underlying agent (dict or list of dicts).
399
+
400
+ Parameters
401
+ ----------
402
+ data_raw : Union[pd.DataFrame, dict, list]
403
+ The raw input data to be converted.
404
+
405
+ Returns
406
+ -------
407
+ Union[dict, list]
408
+ The data in a dictionary or list-of-dictionaries format.
409
+ """
410
+ # If a single DataFrame, convert to dict
411
+ if isinstance(data_raw, pd.DataFrame):
412
+ return data_raw.to_dict()
413
+
414
+ # If it's already a dict (single dataset)
415
+ if isinstance(data_raw, dict):
416
+ return data_raw
417
+
418
+ # If it's already a list, check if it's a list of DataFrames or dicts
419
+ if isinstance(data_raw, list):
420
+ # Convert any DataFrame item to dict
421
+ converted_list = []
422
+ for item in data_raw:
423
+ if isinstance(item, pd.DataFrame):
424
+ converted_list.append(item.to_dict())
425
+ elif isinstance(item, dict):
426
+ converted_list.append(item)
427
+ else:
428
+ raise ValueError("List must contain only DataFrames or dictionaries.")
429
+ return converted_list
430
+
431
+ raise ValueError("data_raw must be a DataFrame, a dict, or a list of dicts/DataFrames.")
432
+
433
+
434
+ # Function
435
+
34
436
  def make_data_wrangling_agent(
35
437
  model,
36
438
  n_samples=30,
37
439
  log=False,
38
440
  log_path=None,
39
441
  file_name="data_wrangler.py",
40
- overwrite = True,
442
+ function_name="data_wrangler",
443
+ overwrite=True,
41
444
  human_in_the_loop=False,
42
445
  bypass_recommended_steps=False,
43
446
  bypass_explain_code=False
@@ -73,6 +476,8 @@ def make_data_wrangling_agent(
73
476
  The path to the directory where the log files should be stored. Defaults to "logs/".
74
477
  file_name : str, optional
75
478
  The name of the file to save the response to. Defaults to "data_wrangler.py".
479
+ function_name : str, optional
480
+ The name of the function to be generated. Defaults to "data_wrangler".
76
481
  overwrite : bool, optional
77
482
  Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
78
483
  Defaults to True.
@@ -115,6 +520,11 @@ def make_data_wrangling_agent(
115
520
  """
116
521
  llm = model
117
522
 
523
+ # Human in th loop requires recommended steps
524
+ if bypass_recommended_steps and human_in_the_loop:
525
+ bypass_recommended_steps = False
526
+ print("Bypass recommended steps set to False to enable human in the loop.")
527
+
118
528
  # Setup Log Directory
119
529
  if log:
120
530
  if log_path is None:
@@ -188,7 +598,7 @@ def make_data_wrangling_agent(
188
598
  Below are summaries of all datasets provided:
189
599
  {all_datasets_summary}
190
600
 
191
- Return your recommended steps as a numbered point list, explaining briefly why each step is needed.
601
+ Return steps as a numbered list. You can return short code snippets to demonstrate actions. But do not return a fully coded solution. The code will be generated separately by a Coding Agent.
192
602
 
193
603
  Avoid these:
194
604
  1. Do not include steps to save files.
@@ -205,7 +615,7 @@ def make_data_wrangling_agent(
205
615
  })
206
616
 
207
617
  return {
208
- "recommended_steps": "\n\n# Recommended Wrangling Steps:\n" + recommended_steps.content.strip(),
618
+ "recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Wrangling Steps:"),
209
619
  "all_datasets_summary": all_datasets_summary_str,
210
620
  }
211
621
 
@@ -244,7 +654,7 @@ def make_data_wrangling_agent(
244
654
 
245
655
  data_wrangling_prompt = PromptTemplate(
246
656
  template="""
247
- You are a Data Wrangling Coding Agent. Your job is to create a data_wrangler() function that can be run on the provided data.
657
+ You are a Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data.
248
658
 
249
659
  Follow these recommended steps:
250
660
  {recommended_steps}
@@ -254,10 +664,10 @@ def make_data_wrangling_agent(
254
664
  Below are summaries of all datasets provided. If more than one dataset is provided, you may need to merge or join them.:
255
665
  {all_datasets_summary}
256
666
 
257
- Return Python code in ```python``` format with a single function definition, data_wrangler(), that includes all imports inside the function. And returns a single pandas data frame.
667
+ Return Python code in ```python``` format with a single function definition, {function_name}(), that includes all imports inside the function. And returns a single pandas data frame.
258
668
 
259
669
  ```python
260
- def data_wrangler(data_list):
670
+ def {function_name}(data_list):
261
671
  '''
262
672
  Wrangle the data provided in data.
263
673
 
@@ -279,14 +689,15 @@ def make_data_wrangling_agent(
279
689
 
280
690
 
281
691
  """,
282
- input_variables=["recommended_steps", "all_datasets_summary"]
692
+ input_variables=["recommended_steps", "all_datasets_summary", "function_name"]
283
693
  )
284
694
 
285
695
  data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
286
696
 
287
697
  response = data_wrangling_agent.invoke({
288
698
  "recommended_steps": state.get("recommended_steps"),
289
- "all_datasets_summary": all_datasets_summary_str
699
+ "all_datasets_summary": all_datasets_summary_str,
700
+ "function_name": function_name
290
701
  })
291
702
 
292
703
  response = relocate_imports_inside_function(response)
@@ -304,7 +715,8 @@ def make_data_wrangling_agent(
304
715
  return {
305
716
  "data_wrangler_function" : response,
306
717
  "data_wrangler_function_path": file_path,
307
- "data_wrangler_function_name": file_name_2,
718
+ "data_wrangler_file_name": file_name_2,
719
+ "data_wrangler_function_name": function_name,
308
720
  "all_datasets_summary": all_datasets_summary_str
309
721
  }
310
722
 
@@ -318,6 +730,33 @@ def make_data_wrangling_agent(
318
730
  user_instructions_key="user_instructions",
319
731
  recommended_steps_key="recommended_steps"
320
732
  )
733
+
734
+ # Human Review
735
+
736
+ prompt_text_human_review = "Are the following data wrangling instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
737
+
738
+ if not bypass_explain_code:
739
+ def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "explain_data_wrangler_code"]]:
740
+ return node_func_human_review(
741
+ state=state,
742
+ prompt_text=prompt_text_human_review,
743
+ yes_goto= 'explain_data_wrangler_code',
744
+ no_goto="recommend_wrangling_steps",
745
+ user_instructions_key="user_instructions",
746
+ recommended_steps_key="recommended_steps",
747
+ code_snippet_key="data_wrangler_function",
748
+ )
749
+ else:
750
+ def human_review(state: GraphState) -> Command[Literal["recommend_wrangling_steps", "__end__"]]:
751
+ return node_func_human_review(
752
+ state=state,
753
+ prompt_text=prompt_text_human_review,
754
+ yes_goto= '__end__',
755
+ no_goto="recommend_wrangling_steps",
756
+ user_instructions_key="user_instructions",
757
+ recommended_steps_key="recommended_steps",
758
+ code_snippet_key="data_wrangler_function",
759
+ )
321
760
 
322
761
  def execute_data_wrangler_code(state: GraphState):
323
762
  return node_func_execute_agent_code_on_data(
@@ -326,7 +765,7 @@ def make_data_wrangling_agent(
326
765
  result_key="data_wrangled",
327
766
  error_key="data_wrangler_error",
328
767
  code_snippet_key="data_wrangler_function",
329
- agent_function_name="data_wrangler",
768
+ agent_function_name=state.get("data_wrangler_function_name"),
330
769
  # pre_processing=pre_processing,
331
770
  post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
332
771
  error_message_prefix="An error occurred during data wrangling: "
@@ -334,11 +773,11 @@ def make_data_wrangling_agent(
334
773
 
335
774
  def fix_data_wrangler_code(state: GraphState):
336
775
  data_wrangler_prompt = """
337
- You are a Data Wrangling Agent. Your job is to create a data_wrangler() function that can be run on the data provided. The function is currently broken and needs to be fixed.
776
+ You are a Data Wrangling Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
338
777
 
339
- Make sure to only return the function definition for data_wrangler().
778
+ Make sure to only return the function definition for {function_name}().
340
779
 
341
- Return Python code in ```python``` format with a single function definition, data_wrangler(data_raw), that includes all imports inside the function.
780
+ Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
342
781
 
343
782
  This is the broken code (please fix):
344
783
  {code_snippet}
@@ -356,22 +795,23 @@ def make_data_wrangling_agent(
356
795
  agent_name=AGENT_NAME,
357
796
  log=log,
358
797
  file_path=state.get("data_wrangler_function_path"),
798
+ function_name=state.get("data_wrangler_function_name"),
359
799
  )
360
800
 
361
- def explain_data_wrangler_code(state: GraphState):
362
- return node_func_explain_agent_code(
801
+ # Final reporting node
802
+ def report_agent_outputs(state: GraphState):
803
+ return node_func_report_agent_outputs(
363
804
  state=state,
364
- code_snippet_key="data_wrangler_function",
805
+ keys_to_include=[
806
+ "recommended_steps",
807
+ "data_wrangler_function",
808
+ "data_wrangler_function_path",
809
+ "data_wrangler_function_name",
810
+ "data_wrangler_error",
811
+ ],
365
812
  result_key="messages",
366
- error_key="data_wrangler_error",
367
- llm=llm,
368
813
  role=AGENT_NAME,
369
- explanation_prompt_template="""
370
- Explain the data wrangling steps that the data wrangling agent performed in this function.
371
- Keep the summary succinct and to the point.\n\n# Data Wrangling Agent:\n\n{code}
372
- """,
373
- success_prefix="# Data Wrangling Agent:\n\n ",
374
- error_message="The Data Wrangling Agent encountered an error during data wrangling. Data could not be explained."
814
+ custom_title="Data Wrangling Agent Outputs"
375
815
  )
376
816
 
377
817
  # Define the graph
@@ -381,7 +821,7 @@ def make_data_wrangling_agent(
381
821
  "create_data_wrangler_code": create_data_wrangler_code,
382
822
  "execute_data_wrangler_code": execute_data_wrangler_code,
383
823
  "fix_data_wrangler_code": fix_data_wrangler_code,
384
- "explain_data_wrangler_code": explain_data_wrangler_code
824
+ "report_agent_outputs": report_agent_outputs,
385
825
  }
386
826
 
387
827
  app = create_coding_agent_graph(
@@ -391,7 +831,7 @@ def make_data_wrangling_agent(
391
831
  create_code_node_name="create_data_wrangler_code",
392
832
  execute_code_node_name="execute_data_wrangler_code",
393
833
  fix_code_node_name="fix_data_wrangler_code",
394
- explain_code_node_name="explain_data_wrangler_code",
834
+ explain_code_node_name="report_agent_outputs",
395
835
  error_key="data_wrangler_error",
396
836
  human_in_the_loop=human_in_the_loop,
397
837
  human_review_node_name="human_review",