ai-data-science-team 0.0.0.9000__py3-none-any.whl → 0.0.0.9005__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,409 @@
1
+ from langchain_core.messages import AIMessage
2
+ from langgraph.graph import StateGraph, END
3
+ from langgraph.types import interrupt, Command
4
+
5
+ import pandas as pd
6
+
7
+ from typing import Any, Callable, Dict, Type, Optional
8
+
9
+ from ai_data_science_team.tools.parsers import PythonOutputParser
10
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
11
+
12
+ def create_coding_agent_graph(
13
+ GraphState: Type,
14
+ node_functions: Dict[str, Callable],
15
+ recommended_steps_node_name: str,
16
+ create_code_node_name: str,
17
+ execute_code_node_name: str,
18
+ fix_code_node_name: str,
19
+ explain_code_node_name: str,
20
+ error_key: str,
21
+ max_retries_key: str = "max_retries",
22
+ retry_count_key: str = "retry_count",
23
+ human_in_the_loop: bool = False,
24
+ human_review_node_name: str = "human_review",
25
+ checkpointer: Optional[Callable] = None
26
+ ):
27
+ """
28
+ Creates a generic agent graph using the provided node functions and node names.
29
+
30
+ Parameters
31
+ ----------
32
+ GraphState : Type
33
+ The TypedDict or class used as state for the workflow.
34
+ node_functions : dict
35
+ A dictionary mapping node names to their respective functions.
36
+ Example: {
37
+ "recommend_cleaning_steps": recommend_cleaning_steps,
38
+ "human_review": human_review,
39
+ "create_data_cleaner_code": create_data_cleaner_code,
40
+ "execute_data_cleaner_code": execute_data_cleaner_code,
41
+ "fix_data_cleaner_code": fix_data_cleaner_code,
42
+ "explain_data_cleaner_code": explain_data_cleaner_code
43
+ }
44
+ recommended_steps_node_name : str
45
+ The node name that recommends steps.
46
+ create_code_node_name : str
47
+ The node name that creates the code.
48
+ execute_code_node_name : str
49
+ The node name that executes the generated code.
50
+ fix_code_node_name : str
51
+ The node name that fixes code if errors occur.
52
+ explain_code_node_name : str
53
+ The node name that explains the final code.
54
+ error_key : str
55
+ The state key used to check for errors.
56
+ max_retries_key : str, optional
57
+ The state key used for the maximum number of retries.
58
+ retry_count_key : str, optional
59
+ The state key for the current retry count.
60
+ human_in_the_loop : bool, optional
61
+ Whether to include a human review step.
62
+ human_review_node_name : str, optional
63
+ The node name for human review if human_in_the_loop is True.
64
+ checkpointer : callable, optional
65
+ A checkpointer callable if desired.
66
+
67
+ Returns
68
+ -------
69
+ app : langchain.graphs.StateGraph
70
+ The compiled workflow application.
71
+ """
72
+
73
+ workflow = StateGraph(GraphState)
74
+
75
+ # Add the recommended steps node
76
+ workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
77
+
78
+ # Optionally add the human review node
79
+ if human_in_the_loop:
80
+ workflow.add_node(human_review_node_name, node_functions[human_review_node_name])
81
+
82
+ # Add main nodes
83
+ workflow.add_node(create_code_node_name, node_functions[create_code_node_name])
84
+ workflow.add_node(execute_code_node_name, node_functions[execute_code_node_name])
85
+ workflow.add_node(fix_code_node_name, node_functions[fix_code_node_name])
86
+ workflow.add_node(explain_code_node_name, node_functions[explain_code_node_name])
87
+
88
+ # Set the entry point
89
+ workflow.set_entry_point(recommended_steps_node_name)
90
+
91
+ # Add edges depending on human_in_the_loop
92
+ if human_in_the_loop:
93
+ workflow.add_edge(recommended_steps_node_name, human_review_node_name)
94
+ else:
95
+ workflow.add_edge(recommended_steps_node_name, create_code_node_name)
96
+
97
+ # Connect create_code_node to execution node
98
+ workflow.add_edge(create_code_node_name, execute_code_node_name)
99
+
100
+ # Add conditional edges for error handling
101
+ workflow.add_conditional_edges(
102
+ execute_code_node_name,
103
+ lambda state: "fix_code" if (
104
+ state.get(error_key) is not None and
105
+ state.get(retry_count_key) is not None and
106
+ state.get(max_retries_key) is not None and
107
+ state.get(retry_count_key) < state.get(max_retries_key)
108
+ ) else "explain_code",
109
+ {"fix_code": fix_code_node_name, "explain_code": explain_code_node_name},
110
+ )
111
+
112
+ # From fix_code_node_name back to execution node
113
+ workflow.add_edge(fix_code_node_name, execute_code_node_name)
114
+
115
+ # explain_code_node_name leads to end
116
+ workflow.add_edge(explain_code_node_name, END)
117
+
118
+ # Compile workflow, optionally with checkpointer
119
+ if human_in_the_loop and checkpointer is not None:
120
+ app = workflow.compile(checkpointer=checkpointer)
121
+ else:
122
+ app = workflow.compile()
123
+
124
+ return app
125
+
126
+
127
+ def node_func_human_review(
128
+ state: Any,
129
+ prompt_text: str,
130
+ yes_goto: str,
131
+ no_goto: str,
132
+ user_instructions_key: str = "user_instructions",
133
+ recommended_steps_key: str = "recommended_steps",
134
+ ) -> Command[str]:
135
+ """
136
+ A generic function to handle human review steps.
137
+
138
+ Parameters
139
+ ----------
140
+ state : GraphState
141
+ The current GraphState.
142
+ prompt_text : str
143
+ The text to display to the user before their input.
144
+ yes_goto : str
145
+ The node to go to if the user confirms (answers "yes").
146
+ no_goto : str
147
+ The node to go to if the user suggests modifications.
148
+ user_instructions_key : str, optional
149
+ The key in the state to store user instructions.
150
+ recommended_steps_key : str, optional
151
+ The key in the state to store recommended steps.
152
+
153
+ Returns
154
+ -------
155
+ Command[str]
156
+ A Command object directing the next state and updates to the state.
157
+ """
158
+ print(" * HUMAN REVIEW")
159
+
160
+ # Display instructions and get user response
161
+ user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '')))
162
+
163
+ # Decide next steps based on user input
164
+ if user_input.strip().lower() == "yes":
165
+ goto = yes_goto
166
+ update = {}
167
+ else:
168
+ goto = no_goto
169
+ modifications = "Modifications: \n" + user_input
170
+ if state.get(user_instructions_key) is None:
171
+ update = {user_instructions_key: modifications}
172
+ else:
173
+ update = {user_instructions_key: state.get(user_instructions_key) + modifications}
174
+
175
+ return Command(goto=goto, update=update)
176
+
177
+
178
+ def node_func_execute_agent_code_on_data(
179
+ state: Any,
180
+ data_key: str,
181
+ code_snippet_key: str,
182
+ result_key: str,
183
+ error_key: str,
184
+ agent_function_name: str,
185
+ pre_processing: Optional[Callable[[Any], Any]] = None,
186
+ post_processing: Optional[Callable[[Any], Any]] = None,
187
+ error_message_prefix: str = "An error occurred during agent execution: "
188
+ ) -> Dict[str, Any]:
189
+ """
190
+ Execute a generic agent code defined in a code snippet retrieved from the state on input data and return the result.
191
+
192
+ Parameters
193
+ ----------
194
+ state : Any
195
+ A state object that supports `get(key: str)` method to retrieve values.
196
+ data_key : str
197
+ The key in the state used to retrieve the input data.
198
+ code_snippet_key : str
199
+ The key in the state used to retrieve the Python code snippet defining the agent function.
200
+ result_key : str
201
+ The key in the state used to store the result of the agent function.
202
+ error_key : str
203
+ The key in the state used to store the error message if any.
204
+ agent_function_name : str
205
+ The name of the function (e.g., 'data_cleaner') expected to be defined in the code snippet.
206
+ pre_processing : Callable[[Any], Any], optional
207
+ A function to preprocess the data before passing it to the agent function.
208
+ This might be used to convert raw data into a DataFrame or otherwise transform it.
209
+ If not provided, a default approach will be used if data is a dict.
210
+ post_processing : Callable[[Any], Any], optional
211
+ A function to postprocess the output of the agent function before returning it.
212
+ error_message_prefix : str, optional
213
+ A prefix or full message to use in the error output if an exception occurs.
214
+
215
+ Returns
216
+ -------
217
+ Dict[str, Any]
218
+ A dictionary containing the result and/or error messages. Keys are arbitrary,
219
+ but typically include something like "result" or "error".
220
+ """
221
+
222
+ print(" * EXECUTING AGENT CODE")
223
+
224
+ # Retrieve raw data and code snippet from the state
225
+ data = state.get(data_key)
226
+ agent_code = state.get(code_snippet_key)
227
+
228
+ # Preprocessing: If no pre-processing function is given, attempt a default handling
229
+ if pre_processing is None:
230
+ if isinstance(data, dict):
231
+ df = pd.DataFrame.from_dict(data)
232
+ elif isinstance(data, list):
233
+ df = [pd.DataFrame.from_dict(item) for item in data]
234
+ else:
235
+ raise ValueError("Data is not a dictionary or list and no pre_processing function was provided.")
236
+ else:
237
+ df = pre_processing(data)
238
+
239
+ # Execute the code snippet to define the agent function
240
+ local_vars = {}
241
+ global_vars = {}
242
+ exec(agent_code, global_vars, local_vars)
243
+
244
+ # Retrieve the agent function from the executed code
245
+ agent_function = local_vars.get(agent_function_name, None)
246
+ if agent_function is None or not callable(agent_function):
247
+ raise ValueError(f"Agent function '{agent_function_name}' not found or not callable in the provided code.")
248
+
249
+ # Execute the agent function
250
+ agent_error = None
251
+ result = None
252
+ try:
253
+ result = agent_function(df)
254
+
255
+ # Test an error
256
+ # if state.get("retry_count") == 0:
257
+ # 10/0
258
+
259
+ # Apply post-processing if provided
260
+ if post_processing is not None:
261
+ result = post_processing(result)
262
+ except Exception as e:
263
+ print(e)
264
+ agent_error = f"{error_message_prefix}{str(e)}"
265
+
266
+ # Return results
267
+ output = {result_key: result, error_key: agent_error}
268
+ return output
269
+
270
+ def node_func_fix_agent_code(
271
+ state: Any,
272
+ code_snippet_key: str,
273
+ error_key: str,
274
+ llm: Any,
275
+ prompt_template: str,
276
+ agent_name: str,
277
+ retry_count_key: str = "retry_count",
278
+ log: bool = False,
279
+ file_path: str = "logs/agent_function.py",
280
+ ) -> dict:
281
+ """
282
+ Generic function to fix a given piece of agent code using an LLM and a prompt template.
283
+
284
+ Parameters
285
+ ----------
286
+ state : Any
287
+ A state object that supports `get(key: str)` method to retrieve values.
288
+ code_snippet_key : str
289
+ The key in the state used to retrieve the broken code snippet.
290
+ error_key : str
291
+ The key in the state used to retrieve the related error message.
292
+ llm : Any
293
+ The language model or pipeline capable of receiving prompts and returning responses.
294
+ It should support a call like `(llm | PythonOutputParser()).invoke(prompt)`.
295
+ prompt_template : str
296
+ A string template for the prompt that will be sent to the LLM. It should contain
297
+ placeholders `{code_snippet}` and `{error}` which will be formatted with the actual values.
298
+ agent_name : str
299
+ The name of the agent being fixed. This is used to add comments to the top of the code.
300
+ retry_count_key : str, optional
301
+ The key in the state that tracks how many times we've retried fixing the code.
302
+ log : bool, optional
303
+ Whether to log the returned code to a file.
304
+ file_path : str, optional
305
+ The path to the file where the code will be logged.
306
+
307
+ Returns
308
+ -------
309
+ dict
310
+ A dictionary containing updated code, cleared error, and incremented retry count.
311
+ """
312
+ print(" * FIX AGENT CODE")
313
+ print(" retry_count:" + str(state.get(retry_count_key)))
314
+
315
+ # Retrieve the code snippet and the error from the state
316
+ code_snippet = state.get(code_snippet_key)
317
+ error_message = state.get(error_key)
318
+
319
+ # Format the prompt with the code snippet and the error
320
+ prompt = prompt_template.format(
321
+ code_snippet=code_snippet,
322
+ error=error_message
323
+ )
324
+
325
+ # Execute the prompt with the LLM
326
+ response = (llm | PythonOutputParser()).invoke(prompt)
327
+
328
+ response = relocate_imports_inside_function(response)
329
+ response = add_comments_to_top(response, agent_name="data_wrangler")
330
+
331
+ # Log the response if requested
332
+ if log:
333
+ with open(file_path, 'w') as file:
334
+ file.write(response)
335
+ print(f" File saved to: {file_path}")
336
+
337
+ # Return updated results
338
+ return {
339
+ code_snippet_key: response,
340
+ error_key: None,
341
+ retry_count_key: state.get(retry_count_key) + 1
342
+ }
343
+
344
+ def node_func_explain_agent_code(
345
+ state: Any,
346
+ code_snippet_key: str,
347
+ result_key: str,
348
+ error_key: str,
349
+ llm: Any,
350
+ role: str,
351
+ explanation_prompt_template: str,
352
+ success_prefix: str = "# Agent Explanation:\n\n",
353
+ error_message: str = "The agent encountered an error during execution and cannot be explained."
354
+ ) -> Dict[str, Any]:
355
+ """
356
+ Generic function to explain what a given agent code snippet does.
357
+
358
+ Parameters
359
+ ----------
360
+ state : Any
361
+ A state object that supports `get(key: str)` to retrieve values.
362
+ code_snippet_key : str
363
+ The key in `state` where the agent code snippet is stored.
364
+ result_key : str
365
+ The key in `state` where the LLM's explanation is stored. Typically this is "messages".
366
+ error_key : str
367
+ The key in `state` where any error messages related to the code snippet are stored.
368
+ llm : Any
369
+ The language model used to explain the code. Should support `.invoke(prompt)`.
370
+ role : str
371
+ The role of the agent explaining the code snippet. Examples: "Data Scientist", "Data Engineer", etc.
372
+ explanation_prompt_template : str
373
+ A prompt template that can be used to explain the code. It should contain a placeholder
374
+ for inserting the agent code snippet. For example:
375
+
376
+ "Explain the steps performed by this agent code in a succinct manner:\n\n{code}"
377
+
378
+ success_prefix : str, optional
379
+ A prefix to add before the LLM's explanation, helping format the final message.
380
+ error_message : str, optional
381
+ Message to return if the agent code snippet cannot be explained due to an error.
382
+
383
+ Returns
384
+ -------
385
+ Dict[str, Any]
386
+ A dictionary containing one key "messages", which is a list of messages (e.g., AIMessage)
387
+ describing the explanation or the error.
388
+ """
389
+ print(" * EXPLAIN AGENT CODE")
390
+
391
+ # Check if there's an error associated with the code
392
+ agent_error = state.get(error_key)
393
+ if agent_error is None:
394
+ # Retrieve the code snippet
395
+ code_snippet = state.get(code_snippet_key)
396
+
397
+ # Format the prompt by inserting the code snippet
398
+ prompt = explanation_prompt_template.format(code=code_snippet)
399
+
400
+ # Invoke the LLM to get an explanation
401
+ response = llm.invoke(prompt)
402
+
403
+ # Prepare the success message
404
+ message = AIMessage(content=f"{success_prefix}{response.content}", role=role)
405
+ return {"messages": [message]}
406
+ else:
407
+ # Return an error message if there was a problem with the code
408
+ message = AIMessage(content=error_message)
409
+ return {result_key: [message]}
File without changes
@@ -0,0 +1,116 @@
1
+ import io
2
+ import pandas as pd
3
+ from typing import Union, List, Dict
4
+
5
+ def summarize_dataframes(
6
+ dataframes: Union[pd.DataFrame, List[pd.DataFrame], Dict[str, pd.DataFrame]]
7
+ ) -> List[str]:
8
+ """
9
+ Generate a summary for one or more DataFrames. Accepts a single DataFrame, a list of DataFrames,
10
+ or a dictionary mapping names to DataFrames.
11
+
12
+ Parameters
13
+ ----------
14
+ dataframes : pandas.DataFrame or list of pandas.DataFrame or dict of (str -> pandas.DataFrame)
15
+ - Single DataFrame: produce a single summary (returned within a one-element list).
16
+ - List of DataFrames: produce a summary for each DataFrame, using index-based names.
17
+ - Dictionary of DataFrames: produce a summary for each DataFrame, using dictionary keys as names.
18
+
19
+ Example:
20
+ --------
21
+ ``` python
22
+ import pandas as pd
23
+ from sklearn.datasets import load_iris
24
+ data = load_iris(as_frame=True)
25
+ dataframes = {
26
+ "iris": data.frame,
27
+ "iris_target": data.target,
28
+ }
29
+ summaries = summarize_dataframes(dataframes)
30
+ print(summaries[0])
31
+ ```
32
+
33
+ Returns
34
+ -------
35
+ list of str
36
+ A list of summaries, one for each provided DataFrame. Each summary includes:
37
+ - Shape of the DataFrame (rows, columns)
38
+ - Column data types
39
+ - Missing value percentage
40
+ - Unique value counts
41
+ - First 30 rows
42
+ - Descriptive statistics
43
+ - DataFrame info output
44
+ """
45
+
46
+ summaries = []
47
+
48
+ # --- Dictionary Case ---
49
+ if isinstance(dataframes, dict):
50
+ for dataset_name, df in dataframes.items():
51
+ summaries.append(_summarize_dataframe(df, dataset_name))
52
+
53
+ # --- Single DataFrame Case ---
54
+ elif isinstance(dataframes, pd.DataFrame):
55
+ summaries.append(_summarize_dataframe(dataframes, "Single_Dataset"))
56
+
57
+ # --- List of DataFrames Case ---
58
+ elif isinstance(dataframes, list):
59
+ for idx, df in enumerate(dataframes):
60
+ dataset_name = f"Dataset_{idx}"
61
+ summaries.append(_summarize_dataframe(df, dataset_name))
62
+
63
+ else:
64
+ raise TypeError(
65
+ "Input must be a single DataFrame, a list of DataFrames, or a dictionary of DataFrames."
66
+ )
67
+
68
+ return summaries
69
+
70
+
71
+ def _summarize_dataframe(df: pd.DataFrame, dataset_name: str) -> str:
72
+ """Generate a summary string for a single DataFrame."""
73
+ # 1. Convert dictionary-type cells to strings
74
+ # This prevents unhashable dict errors during df.nunique().
75
+ df = df.apply(lambda col: col.map(lambda x: str(x) if isinstance(x, dict) else x))
76
+
77
+ # 2. Capture df.info() output
78
+ buffer = io.StringIO()
79
+ df.info(buf=buffer)
80
+ info_text = buffer.getvalue()
81
+
82
+ # 3. Calculate missing value stats
83
+ missing_stats = (df.isna().sum() / len(df) * 100).sort_values(ascending=False)
84
+ missing_summary = "\n".join([f"{col}: {val:.2f}%" for col, val in missing_stats.items()])
85
+
86
+ # 4. Get column data types
87
+ column_types = "\n".join([f"{col}: {dtype}" for col, dtype in df.dtypes.items()])
88
+
89
+ # 5. Get unique value counts
90
+ unique_counts = df.nunique() # Will no longer fail on unhashable dict
91
+ unique_counts_summary = "\n".join([f"{col}: {count}" for col, count in unique_counts.items()])
92
+
93
+ summary_text = f"""
94
+ Dataset Name: {dataset_name}
95
+ ----------------------------
96
+ Shape: {df.shape[0]} rows x {df.shape[1]} columns
97
+
98
+ Column Data Types:
99
+ {column_types}
100
+
101
+ Missing Value Percentage:
102
+ {missing_summary}
103
+
104
+ Unique Value Counts:
105
+ {unique_counts_summary}
106
+
107
+ Data (first 30 rows):
108
+ {df.head(30).to_string()}
109
+
110
+ Data Description:
111
+ {df.describe().to_string()}
112
+
113
+ Data Info:
114
+ {info_text}
115
+ """
116
+ return summary_text.strip()
@@ -0,0 +1,61 @@
1
+
2
+ import os
3
+
4
+ def log_ai_function(response: str, file_name: str, log: bool = True, log_path: str = './logs/', overwrite: bool = True):
5
+ """
6
+ Logs the response of an AI function to a file.
7
+
8
+ Parameters
9
+ ----------
10
+ response : str
11
+ The response of the AI function.
12
+ file_name : str
13
+ The name of the file to save the response to.
14
+ log : bool, optional
15
+ Whether to log the response or not. The default is True.
16
+ log_path : str, optional
17
+ The path to save the log file. The default is './logs/'.
18
+ overwrite : bool, optional
19
+ Whether to overwrite the file if it already exists. The default is True.
20
+ - If True, the file will be overwritten.
21
+ - If False, a unique file name will be created.
22
+
23
+ Returns
24
+ -------
25
+ tuple
26
+ The path and name of the log file.
27
+ """
28
+
29
+ if log:
30
+ # Ensure the directory exists
31
+ os.makedirs(log_path, exist_ok=True)
32
+
33
+ # file_name = 'data_wrangler.py'
34
+ file_path = os.path.join(log_path, file_name)
35
+
36
+ if not overwrite:
37
+ # If file already exists and we're NOT overwriting, we create a new name
38
+ if os.path.exists(file_path):
39
+ # Use an incremental suffix (e.g., data_wrangler_1.py, data_wrangler_2.py, etc.)
40
+ # or a time-based suffix if you prefer.
41
+ base_name, ext = os.path.splitext(file_name)
42
+ i = 1
43
+ while True:
44
+ new_file_name = f"{base_name}_{i}{ext}"
45
+ new_file_path = os.path.join(log_path, new_file_name)
46
+ if not os.path.exists(new_file_path):
47
+ file_path = new_file_path
48
+ file_name = new_file_name
49
+ break
50
+ i += 1
51
+
52
+ # Write the file
53
+ with open(file_path, 'w', encoding='utf-8') as file:
54
+ file.write(response)
55
+
56
+ print(f" File saved to: {file_path}")
57
+
58
+ return (file_path, file_name)
59
+
60
+ else:
61
+ return None
@@ -0,0 +1,57 @@
1
+ # BUSINESS SCIENCE UNIVERSITY
2
+ # AI DATA SCIENCE TEAM
3
+ # ***
4
+ # Parsers
5
+
6
+ from langchain_core.output_parsers import JsonOutputParser
7
+ from langchain_core.output_parsers import BaseOutputParser
8
+
9
+ import re
10
+
11
+ # Python Parser for output standardization
12
+ class PythonOutputParser(BaseOutputParser):
13
+ def parse(self, text: str):
14
+ def extract_python_code(text):
15
+ python_code_match = re.search(r'```python(.*?)```', text, re.DOTALL)
16
+ if python_code_match:
17
+ python_code = python_code_match.group(1).strip()
18
+ return python_code
19
+ else:
20
+ python_code_match = re.search(r"python(.*?)'", text, re.DOTALL)
21
+ if python_code_match:
22
+ python_code = python_code_match.group(1).strip()
23
+ return python_code
24
+ else:
25
+ return None
26
+ python_code = extract_python_code(text)
27
+ if python_code is not None:
28
+ return python_code
29
+ else:
30
+ # Assume ```sql wasn't used
31
+ return text
32
+
33
+ # SQL Parser for output standardization
34
+ class SQLOutputParser(BaseOutputParser):
35
+ def parse(self, text: str):
36
+ def extract_sql_code(text):
37
+ sql_code_match = re.search(r'```sql(.*?)```', text, re.DOTALL)
38
+ sql_code_match_2 = re.search(r"SQLQuery:\s*(.*)", text)
39
+ if sql_code_match:
40
+ sql_code = sql_code_match.group(1).strip()
41
+ return sql_code
42
+ if sql_code_match_2:
43
+ sql_code = sql_code_match_2.group(1).strip()
44
+ return sql_code
45
+ else:
46
+ sql_code_match = re.search(r"sql(.*?)'", text, re.DOTALL)
47
+ if sql_code_match:
48
+ sql_code = sql_code_match.group(1).strip()
49
+ return sql_code
50
+ else:
51
+ return None
52
+ sql_code = extract_sql_code(text)
53
+ if sql_code is not None:
54
+ return sql_code
55
+ else:
56
+ # Assume ```sql wasn't used
57
+ return text