ai-data-science-team 0.0.0.9008__py3-none-any.whl → 0.0.0.9010__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +0 -1
- ai_data_science_team/agents/data_cleaning_agent.py +50 -39
- ai_data_science_team/agents/data_loader_tools_agent.py +69 -0
- ai_data_science_team/agents/data_visualization_agent.py +45 -50
- ai_data_science_team/agents/data_wrangling_agent.py +50 -49
- ai_data_science_team/agents/feature_engineering_agent.py +48 -67
- ai_data_science_team/agents/sql_database_agent.py +130 -76
- ai_data_science_team/ml_agents/__init__.py +2 -0
- ai_data_science_team/ml_agents/h2o_ml_agent.py +852 -0
- ai_data_science_team/ml_agents/mlflow_tools_agent.py +327 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +120 -9
- ai_data_science_team/parsers/__init__.py +0 -0
- ai_data_science_team/{tools → parsers}/parsers.py +0 -1
- ai_data_science_team/templates/__init__.py +1 -0
- ai_data_science_team/templates/agent_templates.py +78 -7
- ai_data_science_team/tools/data_loader.py +378 -0
- ai_data_science_team/tools/{metadata.py → dataframe.py} +0 -91
- ai_data_science_team/tools/h2o.py +643 -0
- ai_data_science_team/tools/mlflow.py +961 -0
- ai_data_science_team/tools/sql.py +126 -0
- ai_data_science_team/{tools → utils}/regex.py +59 -1
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/METADATA +56 -24
- ai_data_science_team-0.0.0.9010.dist-info/RECORD +35 -0
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +0 -26
- /ai_data_science_team/{tools → utils}/logging.py +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/top_level.txt +0 -0
ai_data_science_team/_version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.0.0.
|
1
|
+
__version__ = "0.0.0.9010"
|
@@ -3,4 +3,3 @@ from ai_data_science_team.agents.feature_engineering_agent import make_feature_e
|
|
3
3
|
from ai_data_science_team.agents.data_wrangling_agent import make_data_wrangling_agent, DataWranglingAgent
|
4
4
|
from ai_data_science_team.agents.sql_database_agent import make_sql_database_agent, SQLDatabaseAgent
|
5
5
|
from ai_data_science_team.agents.data_visualization_agent import make_data_visualization_agent, DataVisualizationAgent
|
6
|
-
|
@@ -14,7 +14,7 @@ from langgraph.types import Command
|
|
14
14
|
from langgraph.checkpoint.memory import MemorySaver
|
15
15
|
|
16
16
|
import os
|
17
|
-
import
|
17
|
+
import json
|
18
18
|
import pandas as pd
|
19
19
|
|
20
20
|
from IPython.display import Markdown
|
@@ -23,21 +23,26 @@ from ai_data_science_team.templates import(
|
|
23
23
|
node_func_execute_agent_code_on_data,
|
24
24
|
node_func_human_review,
|
25
25
|
node_func_fix_agent_code,
|
26
|
-
|
26
|
+
node_func_report_agent_outputs,
|
27
27
|
create_coding_agent_graph,
|
28
28
|
BaseAgent,
|
29
29
|
)
|
30
|
-
from ai_data_science_team.
|
31
|
-
from ai_data_science_team.
|
32
|
-
|
33
|
-
|
30
|
+
from ai_data_science_team.parsers.parsers import PythonOutputParser
|
31
|
+
from ai_data_science_team.utils.regex import (
|
32
|
+
relocate_imports_inside_function,
|
33
|
+
add_comments_to_top,
|
34
|
+
format_agent_name,
|
35
|
+
format_recommended_steps,
|
36
|
+
get_generic_summary,
|
37
|
+
)
|
38
|
+
from ai_data_science_team.tools.dataframe import get_dataframe_summary
|
39
|
+
from ai_data_science_team.utils.logging import log_ai_function
|
34
40
|
|
35
41
|
# Setup
|
36
42
|
AGENT_NAME = "data_cleaning_agent"
|
37
43
|
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
38
44
|
|
39
45
|
|
40
|
-
|
41
46
|
# Class
|
42
47
|
class DataCleaningAgent(BaseAgent):
|
43
48
|
"""
|
@@ -89,8 +94,8 @@ class DataCleaningAgent(BaseAgent):
|
|
89
94
|
Cleans the provided dataset asynchronously based on user instructions.
|
90
95
|
invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
|
91
96
|
Cleans the provided dataset synchronously based on user instructions.
|
92
|
-
|
93
|
-
|
97
|
+
get_workflow_summary()
|
98
|
+
Retrieves a summary of the agent's workflow.
|
94
99
|
get_log_summary()
|
95
100
|
Retrieves a summary of logged operations if logging is enabled.
|
96
101
|
get_state_keys()
|
@@ -178,8 +183,7 @@ class DataCleaningAgent(BaseAgent):
|
|
178
183
|
self.response=None
|
179
184
|
return make_data_cleaning_agent(**self._params)
|
180
185
|
|
181
|
-
|
182
|
-
def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
186
|
+
async def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
183
187
|
"""
|
184
188
|
Asynchronously invokes the agent. The response is stored in the response attribute.
|
185
189
|
|
@@ -200,7 +204,7 @@ class DataCleaningAgent(BaseAgent):
|
|
200
204
|
--------
|
201
205
|
None. The response is stored in the response attribute.
|
202
206
|
"""
|
203
|
-
response = self._compiled_graph.ainvoke({
|
207
|
+
response = await self._compiled_graph.ainvoke({
|
204
208
|
"user_instructions": user_instructions,
|
205
209
|
"data_raw": data_raw.to_dict(),
|
206
210
|
"max_retries": max_retries,
|
@@ -239,15 +243,16 @@ class DataCleaningAgent(BaseAgent):
|
|
239
243
|
self.response = response
|
240
244
|
return None
|
241
245
|
|
242
|
-
def
|
246
|
+
def get_workflow_summary(self, markdown=False):
|
243
247
|
"""
|
244
|
-
|
245
|
-
|
246
|
-
Returns:
|
247
|
-
str: Explanation of the cleaning steps.
|
248
|
+
Retrieves the agent's workflow summary, if logging is enabled.
|
248
249
|
"""
|
249
|
-
|
250
|
-
|
250
|
+
if self.response and self.response.get("messages"):
|
251
|
+
summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
|
252
|
+
if markdown:
|
253
|
+
return Markdown(summary)
|
254
|
+
else:
|
255
|
+
return summary
|
251
256
|
|
252
257
|
def get_log_summary(self, markdown=False):
|
253
258
|
"""
|
@@ -255,7 +260,13 @@ class DataCleaningAgent(BaseAgent):
|
|
255
260
|
"""
|
256
261
|
if self.response:
|
257
262
|
if self.response.get('data_cleaner_function_path'):
|
258
|
-
log_details = f"
|
263
|
+
log_details = f"""
|
264
|
+
## Data Cleaning Agent Log Summary:
|
265
|
+
|
266
|
+
Function Path: {self.response.get('data_cleaner_function_path')}
|
267
|
+
|
268
|
+
Function Name: {self.response.get('data_cleaner_function_name')}
|
269
|
+
"""
|
259
270
|
if markdown:
|
260
271
|
return Markdown(log_details)
|
261
272
|
else:
|
@@ -462,7 +473,7 @@ def make_data_cleaning_agent(
|
|
462
473
|
Below are summaries of all datasets provided:
|
463
474
|
{all_datasets_summary}
|
464
475
|
|
465
|
-
Return
|
476
|
+
Return steps as a numbered list. You can return short code snippets to demonstrate actions. But do not return a fully coded solution. The code will be generated separately by a Coding Agent.
|
466
477
|
|
467
478
|
Avoid these:
|
468
479
|
1. Do not include steps to save files.
|
@@ -633,32 +644,31 @@ def make_data_cleaning_agent(
|
|
633
644
|
function_name=state.get("data_cleaner_function_name"),
|
634
645
|
)
|
635
646
|
|
636
|
-
|
637
|
-
|
647
|
+
# Final reporting node
|
648
|
+
def report_agent_outputs(state: GraphState):
|
649
|
+
return node_func_report_agent_outputs(
|
638
650
|
state=state,
|
639
|
-
|
651
|
+
keys_to_include=[
|
652
|
+
"recommended_steps",
|
653
|
+
"data_cleaner_function",
|
654
|
+
"data_cleaner_function_path",
|
655
|
+
"data_cleaner_function_name",
|
656
|
+
"data_cleaner_error",
|
657
|
+
],
|
640
658
|
result_key="messages",
|
641
|
-
error_key="data_cleaner_error",
|
642
|
-
llm=llm,
|
643
659
|
role=AGENT_NAME,
|
644
|
-
|
645
|
-
Explain the data cleaning steps that the data cleaning agent performed in this function.
|
646
|
-
Keep the summary succinct and to the point.\n\n# Data Cleaning Agent:\n\n{code}
|
647
|
-
""",
|
648
|
-
success_prefix="# Data Cleaning Agent:\n\n ",
|
649
|
-
error_message="The Data Cleaning Agent encountered an error during data cleaning. Data could not be explained."
|
660
|
+
custom_title="Data Cleaning Agent Outputs"
|
650
661
|
)
|
651
|
-
|
652
|
-
# Define the graph
|
662
|
+
|
653
663
|
node_functions = {
|
654
664
|
"recommend_cleaning_steps": recommend_cleaning_steps,
|
655
665
|
"human_review": human_review,
|
656
666
|
"create_data_cleaner_code": create_data_cleaner_code,
|
657
667
|
"execute_data_cleaner_code": execute_data_cleaner_code,
|
658
668
|
"fix_data_cleaner_code": fix_data_cleaner_code,
|
659
|
-
"
|
669
|
+
"report_agent_outputs": report_agent_outputs,
|
660
670
|
}
|
661
|
-
|
671
|
+
|
662
672
|
app = create_coding_agent_graph(
|
663
673
|
GraphState=GraphState,
|
664
674
|
node_functions=node_functions,
|
@@ -666,16 +676,17 @@ def make_data_cleaning_agent(
|
|
666
676
|
create_code_node_name="create_data_cleaner_code",
|
667
677
|
execute_code_node_name="execute_data_cleaner_code",
|
668
678
|
fix_code_node_name="fix_data_cleaner_code",
|
669
|
-
explain_code_node_name="
|
679
|
+
explain_code_node_name="report_agent_outputs",
|
670
680
|
error_key="data_cleaner_error",
|
671
|
-
human_in_the_loop=human_in_the_loop,
|
681
|
+
human_in_the_loop=human_in_the_loop,
|
672
682
|
human_review_node_name="human_review",
|
673
683
|
checkpointer=MemorySaver() if human_in_the_loop else None,
|
674
684
|
bypass_recommended_steps=bypass_recommended_steps,
|
675
685
|
bypass_explain_code=bypass_explain_code,
|
676
686
|
)
|
677
|
-
|
687
|
+
|
678
688
|
return app
|
689
|
+
|
679
690
|
|
680
691
|
|
681
692
|
|
@@ -0,0 +1,69 @@
|
|
1
|
+
|
2
|
+
|
3
|
+
|
4
|
+
from typing import Any, Optional, Annotated, Sequence, List, Dict
|
5
|
+
import operator
|
6
|
+
|
7
|
+
import pandas as pd
|
8
|
+
import os
|
9
|
+
|
10
|
+
from IPython.display import Markdown
|
11
|
+
|
12
|
+
from langchain_core.messages import BaseMessage, AIMessage
|
13
|
+
|
14
|
+
from langgraph.prebuilt import create_react_agent, ToolNode
|
15
|
+
from langgraph.prebuilt.chat_agent_executor import AgentState
|
16
|
+
from langgraph.graph import START, END, StateGraph
|
17
|
+
|
18
|
+
from ai_data_science_team.templates import BaseAgent
|
19
|
+
from ai_data_science_team.utils.regex import format_agent_name
|
20
|
+
from ai_data_science_team.tools.data_loader import (
|
21
|
+
load_directory,
|
22
|
+
load_file,
|
23
|
+
list_directory_contents,
|
24
|
+
list_directory_recursive,
|
25
|
+
get_file_info,
|
26
|
+
search_files_by_pattern,
|
27
|
+
)
|
28
|
+
|
29
|
+
AGENT_NAME = "data_loader_tools_agent"
|
30
|
+
|
31
|
+
tools = [
|
32
|
+
load_directory,
|
33
|
+
load_file,
|
34
|
+
list_directory_contents,
|
35
|
+
list_directory_recursive,
|
36
|
+
get_file_info,
|
37
|
+
search_files_by_pattern,
|
38
|
+
]
|
39
|
+
|
40
|
+
|
41
|
+
|
42
|
+
def make_data_loader_tools_agent(
|
43
|
+
model: Any,
|
44
|
+
directory: Optional[str] = os.getcwd(),
|
45
|
+
):
|
46
|
+
"""
|
47
|
+
Creates a Data Loader Agent that can interact with data loading tools.
|
48
|
+
|
49
|
+
Parameters:
|
50
|
+
----------
|
51
|
+
model : langchain.llms.base.LLM
|
52
|
+
The language model used to generate the tool calling agent.
|
53
|
+
directory : str, optional
|
54
|
+
The directory to search for files. Defaults to the current working directory.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
--------
|
58
|
+
Data Loader Agent
|
59
|
+
An agent that can interact with data loading tools.
|
60
|
+
"""
|
61
|
+
|
62
|
+
class GraphState(AgentState):
|
63
|
+
internal_messages: Annotated[Sequence[BaseMessage], operator.add]
|
64
|
+
directory: str
|
65
|
+
user_instructions: str
|
66
|
+
data_artifacts: dict
|
67
|
+
|
68
|
+
pass
|
69
|
+
|
@@ -10,13 +10,13 @@ from typing import TypedDict, Annotated, Sequence, Literal
|
|
10
10
|
import operator
|
11
11
|
|
12
12
|
from langchain.prompts import PromptTemplate
|
13
|
-
from langchain_core.output_parsers import StrOutputParser
|
14
13
|
from langchain_core.messages import BaseMessage
|
15
14
|
|
16
15
|
from langgraph.types import Command
|
17
16
|
from langgraph.checkpoint.memory import MemorySaver
|
18
17
|
|
19
18
|
import os
|
19
|
+
import json
|
20
20
|
import pandas as pd
|
21
21
|
|
22
22
|
from IPython.display import Markdown
|
@@ -25,19 +25,20 @@ from ai_data_science_team.templates import(
|
|
25
25
|
node_func_execute_agent_code_on_data,
|
26
26
|
node_func_human_review,
|
27
27
|
node_func_fix_agent_code,
|
28
|
-
|
28
|
+
node_func_report_agent_outputs,
|
29
29
|
create_coding_agent_graph,
|
30
30
|
BaseAgent,
|
31
31
|
)
|
32
|
-
from ai_data_science_team.
|
33
|
-
from ai_data_science_team.
|
32
|
+
from ai_data_science_team.parsers.parsers import PythonOutputParser
|
33
|
+
from ai_data_science_team.utils.regex import (
|
34
34
|
relocate_imports_inside_function,
|
35
35
|
add_comments_to_top,
|
36
36
|
format_agent_name,
|
37
|
-
format_recommended_steps
|
37
|
+
format_recommended_steps,
|
38
|
+
get_generic_summary,
|
38
39
|
)
|
39
|
-
from ai_data_science_team.tools.
|
40
|
-
from ai_data_science_team.
|
40
|
+
from ai_data_science_team.tools.dataframe import get_dataframe_summary
|
41
|
+
from ai_data_science_team.utils.logging import log_ai_function
|
41
42
|
from ai_data_science_team.utils.plotly import plotly_from_dict
|
42
43
|
|
43
44
|
# Setup
|
@@ -93,8 +94,8 @@ class DataVisualizationAgent(BaseAgent):
|
|
93
94
|
Asynchronously generates a visualization based on user instructions.
|
94
95
|
invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
|
95
96
|
Synchronously generates a visualization based on user instructions.
|
96
|
-
|
97
|
-
|
97
|
+
get_workflow_summary()
|
98
|
+
Retrieves a summary of the agent's workflow.
|
98
99
|
get_log_summary()
|
99
100
|
Retrieves a summary of logged operations if logging is enabled.
|
100
101
|
get_plotly_graph()
|
@@ -195,7 +196,7 @@ class DataVisualizationAgent(BaseAgent):
|
|
195
196
|
# Rebuild the compiled graph
|
196
197
|
self._compiled_graph = self._make_compiled_graph()
|
197
198
|
|
198
|
-
def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
199
|
+
async def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
199
200
|
"""
|
200
201
|
Asynchronously invokes the agent to generate a visualization.
|
201
202
|
The response is stored in the 'response' attribute.
|
@@ -217,7 +218,7 @@ class DataVisualizationAgent(BaseAgent):
|
|
217
218
|
-------
|
218
219
|
None
|
219
220
|
"""
|
220
|
-
response = self._compiled_graph.ainvoke({
|
221
|
+
response = await self._compiled_graph.ainvoke({
|
221
222
|
"user_instructions": user_instructions,
|
222
223
|
"data_raw": data_raw.to_dict(),
|
223
224
|
"max_retries": max_retries,
|
@@ -257,40 +258,34 @@ class DataVisualizationAgent(BaseAgent):
|
|
257
258
|
self.response = response
|
258
259
|
return None
|
259
260
|
|
260
|
-
def
|
261
|
+
def get_workflow_summary(self, markdown=False):
|
261
262
|
"""
|
262
|
-
|
263
|
-
|
264
|
-
Returns
|
265
|
-
-------
|
266
|
-
str
|
267
|
-
Explanation of the visualization steps, if any are available.
|
263
|
+
Retrieves the agent's workflow summary, if logging is enabled.
|
268
264
|
"""
|
269
|
-
if self.response:
|
270
|
-
|
271
|
-
|
265
|
+
if self.response and self.response.get("messages"):
|
266
|
+
summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
|
267
|
+
if markdown:
|
268
|
+
return Markdown(summary)
|
269
|
+
else:
|
270
|
+
return summary
|
272
271
|
|
273
272
|
def get_log_summary(self, markdown=False):
|
274
273
|
"""
|
275
274
|
Logs a summary of the agent's operations, if logging is enabled.
|
275
|
+
"""
|
276
|
+
if self.response:
|
277
|
+
if self.response.get('data_visualization_function_path'):
|
278
|
+
log_details = f"""
|
279
|
+
## Data Visualization Agent Log Summary:
|
276
280
|
|
277
|
-
|
278
|
-
----------
|
279
|
-
markdown : bool, optional
|
280
|
-
If True, returns Markdown-formatted output.
|
281
|
+
Function Path: {self.response.get('data_visualization_function_path')}
|
281
282
|
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
log_details = f"Log Path: {self.response.get('data_visualization_function_path')}"
|
289
|
-
if markdown:
|
290
|
-
return Markdown(log_details)
|
291
|
-
else:
|
292
|
-
return log_details
|
293
|
-
return None
|
283
|
+
Function Name: {self.response.get('data_visualization_function_name')}
|
284
|
+
"""
|
285
|
+
if markdown:
|
286
|
+
return Markdown(log_details)
|
287
|
+
else:
|
288
|
+
return log_details
|
294
289
|
|
295
290
|
def get_plotly_graph(self):
|
296
291
|
"""
|
@@ -719,20 +714,20 @@ def make_data_visualization_agent(
|
|
719
714
|
function_name=state.get("data_visualization_function_name"),
|
720
715
|
)
|
721
716
|
|
722
|
-
|
723
|
-
|
717
|
+
# Final reporting node
|
718
|
+
def report_agent_outputs(state: GraphState):
|
719
|
+
return node_func_report_agent_outputs(
|
724
720
|
state=state,
|
725
|
-
|
721
|
+
keys_to_include=[
|
722
|
+
"recommended_steps",
|
723
|
+
"data_visualization_function",
|
724
|
+
"data_visualization_function_path",
|
725
|
+
"data_visualization_function_name",
|
726
|
+
"data_visualization_error",
|
727
|
+
],
|
726
728
|
result_key="messages",
|
727
|
-
error_key="data_visualization_error",
|
728
|
-
llm=llm,
|
729
729
|
role=AGENT_NAME,
|
730
|
-
|
731
|
-
Explain the data visualization steps that the data visualization agent performed in this function.
|
732
|
-
Keep the summary succinct and to the point.\n\n# Data Visualization Agent:\n\n{code}
|
733
|
-
""",
|
734
|
-
success_prefix="# Data Visualization Agent:\n\n ",
|
735
|
-
error_message="The Data Visualization Agent encountered an error during data visualization. No explanation could be provided."
|
730
|
+
custom_title="Data Visualization Agent Outputs"
|
736
731
|
)
|
737
732
|
|
738
733
|
# Define the graph
|
@@ -742,7 +737,7 @@ def make_data_visualization_agent(
|
|
742
737
|
"chart_generator": chart_generator,
|
743
738
|
"execute_data_visualization_code": execute_data_visualization_code,
|
744
739
|
"fix_data_visualization_code": fix_data_visualization_code,
|
745
|
-
"
|
740
|
+
"report_agent_outputs": report_agent_outputs,
|
746
741
|
}
|
747
742
|
|
748
743
|
app = create_coding_agent_graph(
|
@@ -752,7 +747,7 @@ def make_data_visualization_agent(
|
|
752
747
|
create_code_node_name="chart_generator",
|
753
748
|
execute_code_node_name="execute_data_visualization_code",
|
754
749
|
fix_code_node_name="fix_data_visualization_code",
|
755
|
-
explain_code_node_name="
|
750
|
+
explain_code_node_name="report_agent_outputs",
|
756
751
|
error_key="data_visualization_error",
|
757
752
|
human_in_the_loop=human_in_the_loop, # or False
|
758
753
|
human_review_node_name="human_review",
|
@@ -7,6 +7,7 @@
|
|
7
7
|
from typing import TypedDict, Annotated, Sequence, Literal, Union, Optional
|
8
8
|
import operator
|
9
9
|
import os
|
10
|
+
import json
|
10
11
|
import pandas as pd
|
11
12
|
from IPython.display import Markdown
|
12
13
|
|
@@ -19,14 +20,20 @@ from ai_data_science_team.templates import(
|
|
19
20
|
node_func_execute_agent_code_on_data,
|
20
21
|
node_func_human_review,
|
21
22
|
node_func_fix_agent_code,
|
22
|
-
|
23
|
+
node_func_report_agent_outputs,
|
23
24
|
create_coding_agent_graph,
|
24
25
|
BaseAgent,
|
25
26
|
)
|
26
|
-
from ai_data_science_team.
|
27
|
-
from ai_data_science_team.
|
28
|
-
|
29
|
-
|
27
|
+
from ai_data_science_team.parsers.parsers import PythonOutputParser
|
28
|
+
from ai_data_science_team.utils.regex import (
|
29
|
+
relocate_imports_inside_function,
|
30
|
+
add_comments_to_top,
|
31
|
+
format_agent_name,
|
32
|
+
format_recommended_steps,
|
33
|
+
get_generic_summary,
|
34
|
+
)
|
35
|
+
from ai_data_science_team.tools.dataframe import get_dataframe_summary
|
36
|
+
from ai_data_science_team.utils.logging import log_ai_function
|
30
37
|
|
31
38
|
# Setup Logging Path
|
32
39
|
AGENT_NAME = "data_wrangling_agent"
|
@@ -88,8 +95,8 @@ class DataWranglingAgent(BaseAgent):
|
|
88
95
|
invoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
|
89
96
|
Synchronously wrangles the provided dataset(s) based on user instructions.
|
90
97
|
|
91
|
-
|
92
|
-
|
98
|
+
get_workflow_summary()
|
99
|
+
Retrieves a summary of the agent's workflow.
|
93
100
|
|
94
101
|
get_log_summary()
|
95
102
|
Retrieves a summary of logged operations if logging is enabled.
|
@@ -206,7 +213,7 @@ class DataWranglingAgent(BaseAgent):
|
|
206
213
|
self._params[k] = v
|
207
214
|
self._compiled_graph = self._make_compiled_graph()
|
208
215
|
|
209
|
-
def ainvoke_agent(
|
216
|
+
async def ainvoke_agent(
|
210
217
|
self,
|
211
218
|
data_raw: Union[pd.DataFrame, dict, list],
|
212
219
|
user_instructions: str=None,
|
@@ -238,7 +245,7 @@ class DataWranglingAgent(BaseAgent):
|
|
238
245
|
None
|
239
246
|
"""
|
240
247
|
data_input = self._convert_data_input(data_raw)
|
241
|
-
response = self._compiled_graph.ainvoke({
|
248
|
+
response = await self._compiled_graph.ainvoke({
|
242
249
|
"user_instructions": user_instructions,
|
243
250
|
"data_raw": data_input,
|
244
251
|
"max_retries": max_retries,
|
@@ -287,40 +294,34 @@ class DataWranglingAgent(BaseAgent):
|
|
287
294
|
self.response = response
|
288
295
|
return None
|
289
296
|
|
290
|
-
def
|
297
|
+
def get_workflow_summary(self, markdown=False):
|
291
298
|
"""
|
292
|
-
|
293
|
-
|
294
|
-
Returns
|
295
|
-
-------
|
296
|
-
str or list
|
297
|
-
Explanation of the data wrangling steps.
|
299
|
+
Retrieves the agent's workflow summary, if logging is enabled.
|
298
300
|
"""
|
299
|
-
if self.response:
|
300
|
-
|
301
|
-
|
301
|
+
if self.response and self.response.get("messages"):
|
302
|
+
summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
|
303
|
+
if markdown:
|
304
|
+
return Markdown(summary)
|
305
|
+
else:
|
306
|
+
return summary
|
302
307
|
|
303
308
|
def get_log_summary(self, markdown=False):
|
304
309
|
"""
|
305
310
|
Logs a summary of the agent's operations, if logging is enabled.
|
311
|
+
"""
|
312
|
+
if self.response:
|
313
|
+
if self.response.get('data_wrangler_function_path'):
|
314
|
+
log_details = f"""
|
315
|
+
## Data Wrangling Agent Log Summary:
|
306
316
|
|
307
|
-
|
308
|
-
----------
|
309
|
-
markdown : bool, optional
|
310
|
-
If True, returns the summary in Markdown.
|
317
|
+
Function Path: {self.response.get('data_wrangler_function_path')}
|
311
318
|
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
log_details = f"Log Path: {self.response.get('data_wrangler_function_path')}"
|
319
|
-
if markdown:
|
320
|
-
return Markdown(log_details)
|
321
|
-
else:
|
322
|
-
return log_details
|
323
|
-
return None
|
319
|
+
Function Name: {self.response.get('data_wrangler_function_name')}
|
320
|
+
"""
|
321
|
+
if markdown:
|
322
|
+
return Markdown(log_details)
|
323
|
+
else:
|
324
|
+
return log_details
|
324
325
|
|
325
326
|
def get_data_wrangled(self) -> Optional[pd.DataFrame]:
|
326
327
|
"""
|
@@ -597,7 +598,7 @@ def make_data_wrangling_agent(
|
|
597
598
|
Below are summaries of all datasets provided:
|
598
599
|
{all_datasets_summary}
|
599
600
|
|
600
|
-
Return
|
601
|
+
Return steps as a numbered list. You can return short code snippets to demonstrate actions. But do not return a fully coded solution. The code will be generated separately by a Coding Agent.
|
601
602
|
|
602
603
|
Avoid these:
|
603
604
|
1. Do not include steps to save files.
|
@@ -797,20 +798,20 @@ def make_data_wrangling_agent(
|
|
797
798
|
function_name=state.get("data_wrangler_function_name"),
|
798
799
|
)
|
799
800
|
|
800
|
-
|
801
|
-
|
801
|
+
# Final reporting node
|
802
|
+
def report_agent_outputs(state: GraphState):
|
803
|
+
return node_func_report_agent_outputs(
|
802
804
|
state=state,
|
803
|
-
|
805
|
+
keys_to_include=[
|
806
|
+
"recommended_steps",
|
807
|
+
"data_wrangler_function",
|
808
|
+
"data_wrangler_function_path",
|
809
|
+
"data_wrangler_function_name",
|
810
|
+
"data_wrangler_error",
|
811
|
+
],
|
804
812
|
result_key="messages",
|
805
|
-
error_key="data_wrangler_error",
|
806
|
-
llm=llm,
|
807
813
|
role=AGENT_NAME,
|
808
|
-
|
809
|
-
Explain the data wrangling steps that the data wrangling agent performed in this function.
|
810
|
-
Keep the summary succinct and to the point.\n\n# Data Wrangling Agent:\n\n{code}
|
811
|
-
""",
|
812
|
-
success_prefix="# Data Wrangling Agent:\n\n ",
|
813
|
-
error_message="The Data Wrangling Agent encountered an error during data wrangling. Data could not be explained."
|
814
|
+
custom_title="Data Wrangling Agent Outputs"
|
814
815
|
)
|
815
816
|
|
816
817
|
# Define the graph
|
@@ -820,7 +821,7 @@ def make_data_wrangling_agent(
|
|
820
821
|
"create_data_wrangler_code": create_data_wrangler_code,
|
821
822
|
"execute_data_wrangler_code": execute_data_wrangler_code,
|
822
823
|
"fix_data_wrangler_code": fix_data_wrangler_code,
|
823
|
-
"
|
824
|
+
"report_agent_outputs": report_agent_outputs,
|
824
825
|
}
|
825
826
|
|
826
827
|
app = create_coding_agent_graph(
|
@@ -830,7 +831,7 @@ def make_data_wrangling_agent(
|
|
830
831
|
create_code_node_name="create_data_wrangler_code",
|
831
832
|
execute_code_node_name="execute_data_wrangler_code",
|
832
833
|
fix_code_node_name="fix_data_wrangler_code",
|
833
|
-
explain_code_node_name="
|
834
|
+
explain_code_node_name="report_agent_outputs",
|
834
835
|
error_key="data_wrangler_error",
|
835
836
|
human_in_the_loop=human_in_the_loop,
|
836
837
|
human_review_node_name="human_review",
|