ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1 +1 @@
1
- __version__ = "0.0.0.9006"
1
+ __version__ = "0.0.0.9008"
@@ -1,5 +1,6 @@
1
- from ai_data_science_team.agents.data_cleaning_agent import make_data_cleaning_agent
2
- from ai_data_science_team.agents.feature_engineering_agent import make_feature_engineering_agent
3
- from ai_data_science_team.agents.data_wrangling_agent import make_data_wrangling_agent
4
- from ai_data_science_team.agents.sql_database_agent import make_sql_database_agent
1
+ from ai_data_science_team.agents.data_cleaning_agent import make_data_cleaning_agent, DataCleaningAgent
2
+ from ai_data_science_team.agents.feature_engineering_agent import make_feature_engineering_agent, FeatureEngineeringAgent
3
+ from ai_data_science_team.agents.data_wrangling_agent import make_data_wrangling_agent, DataWranglingAgent
4
+ from ai_data_science_team.agents.sql_database_agent import make_sql_database_agent, SQLDatabaseAgent
5
+ from ai_data_science_team.agents.data_visualization_agent import make_data_visualization_agent, DataVisualizationAgent
5
6
 
@@ -17,15 +17,18 @@ import os
17
17
  import io
18
18
  import pandas as pd
19
19
 
20
- from ai_data_science_team.templates.agent_templates import(
20
+ from IPython.display import Markdown
21
+
22
+ from ai_data_science_team.templates import(
21
23
  node_func_execute_agent_code_on_data,
22
24
  node_func_human_review,
23
25
  node_func_fix_agent_code,
24
26
  node_func_explain_agent_code,
25
- create_coding_agent_graph
27
+ create_coding_agent_graph,
28
+ BaseAgent,
26
29
  )
27
30
  from ai_data_science_team.tools.parsers import PythonOutputParser
28
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
31
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name, format_recommended_steps
29
32
  from ai_data_science_team.tools.metadata import get_dataframe_summary
30
33
  from ai_data_science_team.tools.logging import log_ai_function
31
34
 
@@ -33,9 +36,281 @@ from ai_data_science_team.tools.logging import log_ai_function
33
36
  AGENT_NAME = "data_cleaning_agent"
34
37
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
35
38
 
39
+
40
+
41
+ # Class
42
+ class DataCleaningAgent(BaseAgent):
43
+ """
44
+ Creates a data cleaning agent that can process datasets based on user-defined instructions or default cleaning steps.
45
+ The agent generates a Python function to clean the dataset, performs the cleaning, and logs the process, including code
46
+ and errors. It is designed to facilitate reproducible and customizable data cleaning workflows.
47
+
48
+ The agent performs the following default cleaning steps unless instructed otherwise:
49
+
50
+ - Removing columns with more than 40% missing values.
51
+ - Imputing missing values with the mean for numeric columns.
52
+ - Imputing missing values with the mode for categorical columns.
53
+ - Converting columns to appropriate data types.
54
+ - Removing duplicate rows.
55
+ - Removing rows with missing values.
56
+ - Removing rows with extreme outliers (values 3x the interquartile range).
57
+
58
+ User instructions can modify, add, or remove any of these steps to tailor the cleaning process.
59
+
60
+ Parameters
61
+ ----------
62
+ model : langchain.llms.base.LLM
63
+ The language model used to generate the data cleaning function.
64
+ n_samples : int, optional
65
+ Number of samples used when summarizing the dataset. Defaults to 30. Reducing this number can help
66
+ avoid exceeding the model's token limits.
67
+ log : bool, optional
68
+ Whether to log the generated code and errors. Defaults to False.
69
+ log_path : str, optional
70
+ Directory path for storing log files. Defaults to None.
71
+ file_name : str, optional
72
+ Name of the file for saving the generated response. Defaults to "data_cleaner.py".
73
+ function_name : str, optional
74
+ Name of the generated data cleaning function. Defaults to "data_cleaner".
75
+ overwrite : bool, optional
76
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
77
+ human_in_the_loop : bool, optional
78
+ Enables user review of data cleaning instructions. Defaults to False.
79
+ bypass_recommended_steps : bool, optional
80
+ If True, skips the default recommended cleaning steps. Defaults to False.
81
+ bypass_explain_code : bool, optional
82
+ If True, skips the step that provides code explanations. Defaults to False.
83
+
84
+ Methods
85
+ -------
86
+ update_params(**kwargs)
87
+ Updates the agent's parameters and rebuilds the compiled state graph.
88
+ ainvoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
89
+ Cleans the provided dataset asynchronously based on user instructions.
90
+ invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
91
+ Cleans the provided dataset synchronously based on user instructions.
92
+ explain_cleaning_steps()
93
+ Returns an explanation of the cleaning steps performed by the agent.
94
+ get_log_summary()
95
+ Retrieves a summary of logged operations if logging is enabled.
96
+ get_state_keys()
97
+ Returns a list of keys from the state graph response.
98
+ get_state_properties()
99
+ Returns detailed properties of the state graph response.
100
+ get_data_cleaned()
101
+ Retrieves the cleaned dataset as a pandas DataFrame.
102
+ get_data_raw()
103
+ Retrieves the raw dataset as a pandas DataFrame.
104
+ get_data_cleaner_function()
105
+ Retrieves the generated Python function used for cleaning the data.
106
+ get_recommended_cleaning_steps()
107
+ Retrieves the agent's recommended cleaning steps.
108
+ get_response()
109
+ Returns the response from the agent as a dictionary.
110
+ show()
111
+ Displays the agent's mermaid diagram.
112
+
113
+ Examples
114
+ --------
115
+ ```python
116
+ import pandas as pd
117
+ from langchain_openai import ChatOpenAI
118
+ from ai_data_science_team.agents import DataCleaningAgent
119
+
120
+ llm = ChatOpenAI(model="gpt-4o-mini")
121
+
122
+ data_cleaning_agent = DataCleaningAgent(
123
+ model=llm, n_samples=50, log=True, log_path="logs", human_in_the_loop=True
124
+ )
125
+
126
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
127
+
128
+ data_cleaning_agent.invoke_agent(
129
+ user_instructions="Don't remove outliers when cleaning the data.",
130
+ data_raw=df,
131
+ max_retries=3,
132
+ retry_count=0
133
+ )
134
+
135
+ cleaned_data = data_cleaning_agent.get_data_cleaned()
136
+
137
+ response = data_cleaning_agent.response
138
+ ```
139
+
140
+ Returns
141
+ --------
142
+ DataCleaningAgent : langchain.graphs.CompiledStateGraph
143
+ A data cleaning agent implemented as a compiled state graph.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ model,
149
+ n_samples=30,
150
+ log=False,
151
+ log_path=None,
152
+ file_name="data_cleaner.py",
153
+ function_name="data_cleaner",
154
+ overwrite=True,
155
+ human_in_the_loop=False,
156
+ bypass_recommended_steps=False,
157
+ bypass_explain_code=False
158
+ ):
159
+ self._params = {
160
+ "model": model,
161
+ "n_samples": n_samples,
162
+ "log": log,
163
+ "log_path": log_path,
164
+ "file_name": file_name,
165
+ "function_name": function_name,
166
+ "overwrite": overwrite,
167
+ "human_in_the_loop": human_in_the_loop,
168
+ "bypass_recommended_steps": bypass_recommended_steps,
169
+ "bypass_explain_code": bypass_explain_code,
170
+ }
171
+ self._compiled_graph = self._make_compiled_graph()
172
+ self.response = None
173
+
174
+ def _make_compiled_graph(self):
175
+ """
176
+ Create the compiled graph for the data cleaning agent. Running this method will reset the response to None.
177
+ """
178
+ self.response=None
179
+ return make_data_cleaning_agent(**self._params)
180
+
181
+
182
+ def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
183
+ """
184
+ Asynchronously invokes the agent. The response is stored in the response attribute.
185
+
186
+ Parameters:
187
+ ----------
188
+ data_raw (pd.DataFrame):
189
+ The raw dataset to be cleaned.
190
+ user_instructions (str):
191
+ Instructions for data cleaning agent.
192
+ max_retries (int):
193
+ Maximum retry attempts for cleaning.
194
+ retry_count (int):
195
+ Current retry attempt.
196
+ **kwargs
197
+ Additional keyword arguments to pass to ainvoke().
198
+
199
+ Returns:
200
+ --------
201
+ None. The response is stored in the response attribute.
202
+ """
203
+ response = self._compiled_graph.ainvoke({
204
+ "user_instructions": user_instructions,
205
+ "data_raw": data_raw.to_dict(),
206
+ "max_retries": max_retries,
207
+ "retry_count": retry_count,
208
+ }, **kwargs)
209
+ self.response = response
210
+ return None
211
+
212
+ def invoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
213
+ """
214
+ Invokes the agent. The response is stored in the response attribute.
215
+
216
+ Parameters:
217
+ ----------
218
+ data_raw (pd.DataFrame):
219
+ The raw dataset to be cleaned.
220
+ user_instructions (str):
221
+ Instructions for data cleaning agent.
222
+ max_retries (int):
223
+ Maximum retry attempts for cleaning.
224
+ retry_count (int):
225
+ Current retry attempt.
226
+ **kwargs
227
+ Additional keyword arguments to pass to invoke().
228
+
229
+ Returns:
230
+ --------
231
+ None. The response is stored in the response attribute.
232
+ """
233
+ response = self._compiled_graph.invoke({
234
+ "user_instructions": user_instructions,
235
+ "data_raw": data_raw.to_dict(),
236
+ "max_retries": max_retries,
237
+ "retry_count": retry_count,
238
+ },**kwargs)
239
+ self.response = response
240
+ return None
241
+
242
+ def explain_cleaning_steps(self):
243
+ """
244
+ Provides an explanation of the cleaning steps performed by the agent.
245
+
246
+ Returns:
247
+ str: Explanation of the cleaning steps.
248
+ """
249
+ messages = self.response.get("messages", [])
250
+ return messages
251
+
252
+ def get_log_summary(self, markdown=False):
253
+ """
254
+ Logs a summary of the agent's operations, if logging is enabled.
255
+ """
256
+ if self.response:
257
+ if self.response.get('data_cleaner_function_path'):
258
+ log_details = f"Log Path: {self.response.get('data_cleaner_function_path')}"
259
+ if markdown:
260
+ return Markdown(log_details)
261
+ else:
262
+ return log_details
263
+
264
+ def get_data_cleaned(self):
265
+ """
266
+ Retrieves the cleaned data stored after running invoke_agent or clean_data methods.
267
+ """
268
+ if self.response:
269
+ return pd.DataFrame(self.response.get("data_cleaned"))
270
+
271
+ def get_data_raw(self):
272
+ """
273
+ Retrieves the raw data.
274
+ """
275
+ if self.response:
276
+ return pd.DataFrame(self.response.get("data_raw"))
277
+
278
+ def get_data_cleaner_function(self, markdown=False):
279
+ """
280
+ Retrieves the agent's pipeline function.
281
+ """
282
+ if self.response:
283
+ if markdown:
284
+ return Markdown(f"```python\n{self.response.get('data_cleaner_function')}\n```")
285
+ else:
286
+ return self.response.get("data_cleaner_function")
287
+
288
+ def get_recommended_cleaning_steps(self, markdown=False):
289
+ """
290
+ Retrieves the agent's recommended cleaning steps
291
+ """
292
+ if self.response:
293
+ if markdown:
294
+ return Markdown(self.response.get('recommended_steps'))
295
+ else:
296
+ return self.response.get('recommended_steps')
297
+
298
+
299
+
36
300
  # Agent
37
301
 
38
- def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True, human_in_the_loop=False, bypass_recommended_steps=False, bypass_explain_code=False):
302
+ def make_data_cleaning_agent(
303
+ model,
304
+ n_samples = 30,
305
+ log=False,
306
+ log_path=None,
307
+ file_name="data_cleaner.py",
308
+ function_name="data_cleaner",
309
+ overwrite = True,
310
+ human_in_the_loop=False,
311
+ bypass_recommended_steps=False,
312
+ bypass_explain_code=False
313
+ ):
39
314
  """
40
315
  Creates a data cleaning agent that can be run on a dataset. The agent can be used to clean a dataset in a variety of
41
316
  ways, such as removing columns with more than 40% missing values, imputing missing
@@ -44,9 +319,9 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
44
319
  The agent takes in a dataset and some user instructions, and outputs a python
45
320
  function that can be used to clean the dataset. The agent also logs the code
46
321
  generated and any errors that occur.
47
-
322
+
48
323
  The agent is instructed to to perform the following data cleaning steps:
49
-
324
+
50
325
  - Removing columns if more than 40 percent of the data is missing
51
326
  - Imputing missing values with the mean of the column if the column is numeric
52
327
  - Imputing missing values with the mode of the column if the column is categorical
@@ -60,12 +335,20 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
60
335
  ----------
61
336
  model : langchain.llms.base.LLM
62
337
  The language model to use to generate code.
338
+ n_samples : int, optional
339
+ The number of samples to use when summarizing the dataset. Defaults to 30.
340
+ If you get an error due to maximum tokens, try reducing this number.
341
+ > "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
63
342
  log : bool, optional
64
343
  Whether or not to log the code generated and any errors that occur.
65
344
  Defaults to False.
66
345
  log_path : str, optional
67
346
  The path to the directory where the log files should be stored. Defaults to
68
347
  "logs/".
348
+ file_name : str, optional
349
+ The name of the file to save the response to. Defaults to "data_cleaner.py".
350
+ function_name : str, optional
351
+ The name of the function that will be generated to clean the data. Defaults to "data_cleaner".
69
352
  overwrite : bool, optional
70
353
  Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
71
354
  Defaults to True.
@@ -82,30 +365,35 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
82
365
  import pandas as pd
83
366
  from langchain_openai import ChatOpenAI
84
367
  from ai_data_science_team.agents import data_cleaning_agent
85
-
368
+
86
369
  llm = ChatOpenAI(model = "gpt-4o-mini")
87
370
 
88
371
  data_cleaning_agent = make_data_cleaning_agent(llm)
89
-
372
+
90
373
  df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
91
-
374
+
92
375
  response = data_cleaning_agent.invoke({
93
376
  "user_instructions": "Don't remove outliers when cleaning the data.",
94
377
  "data_raw": df.to_dict(),
95
378
  "max_retries":3,
96
379
  "retry_count":0
97
380
  })
98
-
381
+
99
382
  pd.DataFrame(response['data_cleaned'])
100
383
  ```
101
384
 
102
385
  Returns
103
386
  -------
104
- app : langchain.graphs.StateGraph
387
+ app : langchain.graphs.CompiledStateGraph
105
388
  The data cleaning agent as a state graph.
106
389
  """
107
390
  llm = model
108
391
 
392
+ # Human in th loop requires recommended steps
393
+ if bypass_recommended_steps and human_in_the_loop:
394
+ bypass_recommended_steps = False
395
+ print("Bypass recommended steps set to False to enable human in the loop.")
396
+
109
397
  # Setup Log Directory
110
398
  if log:
111
399
  if log_path is None:
@@ -123,6 +411,7 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
123
411
  all_datasets_summary: str
124
412
  data_cleaner_function: str
125
413
  data_cleaner_function_path: str
414
+ data_cleaner_file_name: str
126
415
  data_cleaner_function_name: str
127
416
  data_cleaner_error: str
128
417
  max_retries: int
@@ -134,7 +423,7 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
134
423
  Recommend a series of data cleaning steps based on the input data.
135
424
  These recommended steps will be appended to the user_instructions.
136
425
  """
137
- print("---DATA CLEANING AGENT----")
426
+ print(format_agent_name(AGENT_NAME))
138
427
  print(" * RECOMMEND CLEANING STEPS")
139
428
 
140
429
  # Prompt to get recommended steps from the LLM
@@ -177,6 +466,7 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
177
466
 
178
467
  Avoid these:
179
468
  1. Do not include steps to save files.
469
+ 2. Do not include unrelated user instructions that are not related to the data cleaning.
180
470
  """,
181
471
  input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
182
472
  )
@@ -184,7 +474,7 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
184
474
  data_raw = state.get("data_raw")
185
475
  df = pd.DataFrame.from_dict(data_raw)
186
476
 
187
- all_datasets_summary = get_dataframe_summary([df])
477
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples)
188
478
 
189
479
  all_datasets_summary_str = "\n\n".join(all_datasets_summary)
190
480
 
@@ -196,60 +486,73 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
196
486
  })
197
487
 
198
488
  return {
199
- "recommended_steps": "\n\n# Recommended Data Cleaning Steps:\n" + recommended_steps.content.strip(),
489
+ "recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Cleaning Steps:"),
200
490
  "all_datasets_summary": all_datasets_summary_str
201
491
  }
202
492
 
203
493
  def create_data_cleaner_code(state: GraphState):
204
- if bypass_recommended_steps:
205
- print("---DATA CLEANING AGENT----")
494
+
206
495
  print(" * CREATE DATA CLEANER CODE")
207
496
 
497
+ if bypass_recommended_steps:
498
+ print(format_agent_name(AGENT_NAME))
499
+
500
+ data_raw = state.get("data_raw")
501
+ df = pd.DataFrame.from_dict(data_raw)
502
+
503
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples)
504
+
505
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
506
+ else:
507
+ all_datasets_summary_str = state.get("all_datasets_summary")
508
+
509
+
208
510
  data_cleaning_prompt = PromptTemplate(
209
511
  template="""
210
- You are a Data Cleaning Agent. Your job is to create a data_cleaner() function that can be run on the data provided using the following recommended steps.
211
-
512
+ You are a Data Cleaning Agent. Your job is to create a {function_name}() function that can be run on the data provided using the following recommended steps.
513
+
212
514
  Recommended Steps:
213
515
  {recommended_steps}
214
-
516
+
215
517
  You can use Pandas, Numpy, and Scikit Learn libraries to clean the data.
216
-
518
+
217
519
  Below are summaries of all datasets provided. Use this information about the data to help determine how to clean the data:
218
520
 
219
521
  {all_datasets_summary}
220
-
221
- Return Python code in ```python ``` format with a single function definition, data_cleaner(data_raw), that incldues all imports inside the function.
222
-
522
+
523
+ Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
524
+
223
525
  Return code to provide the data cleaning function:
224
-
225
- def data_cleaner(data_raw):
526
+
527
+ def {function_name}(data_raw):
226
528
  import pandas as pd
227
529
  import numpy as np
228
530
  ...
229
531
  return data_cleaned
230
-
532
+
231
533
  Best Practices and Error Preventions:
232
-
534
+
233
535
  Always ensure that when assigning the output of fit_transform() from SimpleImputer to a Pandas DataFrame column, you call .ravel() or flatten the array, because fit_transform() returns a 2D array while a DataFrame column is 1D.
234
536
 
235
537
  """,
236
- input_variables=["recommended_steps", "all_datasets_summary"]
538
+ input_variables=["recommended_steps", "all_datasets_summary", "function_name"]
237
539
  )
238
540
 
239
541
  data_cleaning_agent = data_cleaning_prompt | llm | PythonOutputParser()
240
542
 
241
543
  response = data_cleaning_agent.invoke({
242
544
  "recommended_steps": state.get("recommended_steps"),
243
- "all_datasets_summary": state.get("all_datasets_summary")
545
+ "all_datasets_summary": all_datasets_summary_str,
546
+ "function_name": function_name
244
547
  })
245
548
 
246
549
  response = relocate_imports_inside_function(response)
247
550
  response = add_comments_to_top(response, agent_name=AGENT_NAME)
248
551
 
249
552
  # For logging: store the code generated:
250
- file_path, file_name = log_ai_function(
553
+ file_path, file_name_2 = log_ai_function(
251
554
  response=response,
252
- file_name="data_cleaner.py",
555
+ file_name=file_name,
253
556
  log=log,
254
557
  log_path=log_path,
255
558
  overwrite=overwrite
@@ -258,18 +561,37 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
258
561
  return {
259
562
  "data_cleaner_function" : response,
260
563
  "data_cleaner_function_path": file_path,
261
- "data_cleaner_function_name": file_name
564
+ "data_cleaner_file_name": file_name_2,
565
+ "data_cleaner_function_name": function_name,
566
+ "all_datasets_summary": all_datasets_summary_str
262
567
  }
568
+
569
+ # Human Review
570
+
571
+ prompt_text_human_review = "Are the following data cleaning instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
263
572
 
264
- def human_review(state: GraphState) -> Command[Literal["recommend_cleaning_steps", "create_data_cleaner_code"]]:
265
- return node_func_human_review(
266
- state=state,
267
- prompt_text="Is the following data cleaning instructions correct? (Answer 'yes' or provide modifications)\n{steps}",
268
- yes_goto="create_data_cleaner_code",
269
- no_goto="recommend_cleaning_steps",
270
- user_instructions_key="user_instructions",
271
- recommended_steps_key="recommended_steps"
272
- )
573
+ if not bypass_explain_code:
574
+ def human_review(state: GraphState) -> Command[Literal["recommend_cleaning_steps", "explain_data_cleaner_code"]]:
575
+ return node_func_human_review(
576
+ state=state,
577
+ prompt_text=prompt_text_human_review,
578
+ yes_goto= 'explain_data_cleaner_code',
579
+ no_goto="recommend_cleaning_steps",
580
+ user_instructions_key="user_instructions",
581
+ recommended_steps_key="recommended_steps",
582
+ code_snippet_key="data_cleaner_function",
583
+ )
584
+ else:
585
+ def human_review(state: GraphState) -> Command[Literal["recommend_cleaning_steps", "__end__"]]:
586
+ return node_func_human_review(
587
+ state=state,
588
+ prompt_text=prompt_text_human_review,
589
+ yes_goto= '__end__',
590
+ no_goto="recommend_cleaning_steps",
591
+ user_instructions_key="user_instructions",
592
+ recommended_steps_key="recommended_steps",
593
+ code_snippet_key="data_cleaner_function",
594
+ )
273
595
 
274
596
  def execute_data_cleaner_code(state):
275
597
  return node_func_execute_agent_code_on_data(
@@ -278,7 +600,7 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
278
600
  result_key="data_cleaned",
279
601
  error_key="data_cleaner_error",
280
602
  code_snippet_key="data_cleaner_function",
281
- agent_function_name="data_cleaner",
603
+ agent_function_name=state.get("data_cleaner_function_name"),
282
604
  pre_processing=lambda data: pd.DataFrame.from_dict(data),
283
605
  post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
284
606
  error_message_prefix="An error occurred during data cleaning: "
@@ -286,11 +608,11 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
286
608
 
287
609
  def fix_data_cleaner_code(state: GraphState):
288
610
  data_cleaner_prompt = """
289
- You are a Data Cleaning Agent. Your job is to create a data_cleaner() function that can be run on the data provided. The function is currently broken and needs to be fixed.
611
+ You are a Data Cleaning Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
290
612
 
291
- Make sure to only return the function definition for data_cleaner().
613
+ Make sure to only return the function definition for {function_name}().
292
614
 
293
- Return Python code in ```python``` format with a single function definition, data_cleaner(data_raw), that includes all imports inside the function.
615
+ Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
294
616
 
295
617
  This is the broken code (please fix):
296
618
  {code_snippet}
@@ -308,6 +630,7 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
308
630
  agent_name=AGENT_NAME,
309
631
  log=log,
310
632
  file_path=state.get("data_cleaner_function_path"),
633
+ function_name=state.get("data_cleaner_function_name"),
311
634
  )
312
635
 
313
636
  def explain_data_cleaner_code(state: GraphState):
@@ -353,3 +676,6 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
353
676
  )
354
677
 
355
678
  return app
679
+
680
+
681
+