ai-data-science-team 0.0.0.9008__py3-none-any.whl → 0.0.0.9010__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +0 -1
- ai_data_science_team/agents/data_cleaning_agent.py +50 -39
- ai_data_science_team/agents/data_loader_tools_agent.py +69 -0
- ai_data_science_team/agents/data_visualization_agent.py +45 -50
- ai_data_science_team/agents/data_wrangling_agent.py +50 -49
- ai_data_science_team/agents/feature_engineering_agent.py +48 -67
- ai_data_science_team/agents/sql_database_agent.py +130 -76
- ai_data_science_team/ml_agents/__init__.py +2 -0
- ai_data_science_team/ml_agents/h2o_ml_agent.py +852 -0
- ai_data_science_team/ml_agents/mlflow_tools_agent.py +327 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +120 -9
- ai_data_science_team/parsers/__init__.py +0 -0
- ai_data_science_team/{tools → parsers}/parsers.py +0 -1
- ai_data_science_team/templates/__init__.py +1 -0
- ai_data_science_team/templates/agent_templates.py +78 -7
- ai_data_science_team/tools/data_loader.py +378 -0
- ai_data_science_team/tools/{metadata.py → dataframe.py} +0 -91
- ai_data_science_team/tools/h2o.py +643 -0
- ai_data_science_team/tools/mlflow.py +961 -0
- ai_data_science_team/tools/sql.py +126 -0
- ai_data_science_team/{tools → utils}/regex.py +59 -1
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/METADATA +56 -24
- ai_data_science_team-0.0.0.9010.dist-info/RECORD +35 -0
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +0 -26
- /ai_data_science_team/{tools → utils}/logging.py +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/top_level.txt +0 -0
ai_data_science_team/_version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.0.0.
|
1
|
+
__version__ = "0.0.0.9010"
|
@@ -3,4 +3,3 @@ from ai_data_science_team.agents.feature_engineering_agent import make_feature_e
|
|
3
3
|
from ai_data_science_team.agents.data_wrangling_agent import make_data_wrangling_agent, DataWranglingAgent
|
4
4
|
from ai_data_science_team.agents.sql_database_agent import make_sql_database_agent, SQLDatabaseAgent
|
5
5
|
from ai_data_science_team.agents.data_visualization_agent import make_data_visualization_agent, DataVisualizationAgent
|
6
|
-
|
@@ -14,7 +14,7 @@ from langgraph.types import Command
|
|
14
14
|
from langgraph.checkpoint.memory import MemorySaver
|
15
15
|
|
16
16
|
import os
|
17
|
-
import
|
17
|
+
import json
|
18
18
|
import pandas as pd
|
19
19
|
|
20
20
|
from IPython.display import Markdown
|
@@ -23,21 +23,26 @@ from ai_data_science_team.templates import(
|
|
23
23
|
node_func_execute_agent_code_on_data,
|
24
24
|
node_func_human_review,
|
25
25
|
node_func_fix_agent_code,
|
26
|
-
|
26
|
+
node_func_report_agent_outputs,
|
27
27
|
create_coding_agent_graph,
|
28
28
|
BaseAgent,
|
29
29
|
)
|
30
|
-
from ai_data_science_team.
|
31
|
-
from ai_data_science_team.
|
32
|
-
|
33
|
-
|
30
|
+
from ai_data_science_team.parsers.parsers import PythonOutputParser
|
31
|
+
from ai_data_science_team.utils.regex import (
|
32
|
+
relocate_imports_inside_function,
|
33
|
+
add_comments_to_top,
|
34
|
+
format_agent_name,
|
35
|
+
format_recommended_steps,
|
36
|
+
get_generic_summary,
|
37
|
+
)
|
38
|
+
from ai_data_science_team.tools.dataframe import get_dataframe_summary
|
39
|
+
from ai_data_science_team.utils.logging import log_ai_function
|
34
40
|
|
35
41
|
# Setup
|
36
42
|
AGENT_NAME = "data_cleaning_agent"
|
37
43
|
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
38
44
|
|
39
45
|
|
40
|
-
|
41
46
|
# Class
|
42
47
|
class DataCleaningAgent(BaseAgent):
|
43
48
|
"""
|
@@ -89,8 +94,8 @@ class DataCleaningAgent(BaseAgent):
|
|
89
94
|
Cleans the provided dataset asynchronously based on user instructions.
|
90
95
|
invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
|
91
96
|
Cleans the provided dataset synchronously based on user instructions.
|
92
|
-
|
93
|
-
|
97
|
+
get_workflow_summary()
|
98
|
+
Retrieves a summary of the agent's workflow.
|
94
99
|
get_log_summary()
|
95
100
|
Retrieves a summary of logged operations if logging is enabled.
|
96
101
|
get_state_keys()
|
@@ -178,8 +183,7 @@ class DataCleaningAgent(BaseAgent):
|
|
178
183
|
self.response=None
|
179
184
|
return make_data_cleaning_agent(**self._params)
|
180
185
|
|
181
|
-
|
182
|
-
def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
186
|
+
async def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
183
187
|
"""
|
184
188
|
Asynchronously invokes the agent. The response is stored in the response attribute.
|
185
189
|
|
@@ -200,7 +204,7 @@ class DataCleaningAgent(BaseAgent):
|
|
200
204
|
--------
|
201
205
|
None. The response is stored in the response attribute.
|
202
206
|
"""
|
203
|
-
response = self._compiled_graph.ainvoke({
|
207
|
+
response = await self._compiled_graph.ainvoke({
|
204
208
|
"user_instructions": user_instructions,
|
205
209
|
"data_raw": data_raw.to_dict(),
|
206
210
|
"max_retries": max_retries,
|
@@ -239,15 +243,16 @@ class DataCleaningAgent(BaseAgent):
|
|
239
243
|
self.response = response
|
240
244
|
return None
|
241
245
|
|
242
|
-
def
|
246
|
+
def get_workflow_summary(self, markdown=False):
|
243
247
|
"""
|
244
|
-
|
245
|
-
|
246
|
-
Returns:
|
247
|
-
str: Explanation of the cleaning steps.
|
248
|
+
Retrieves the agent's workflow summary, if logging is enabled.
|
248
249
|
"""
|
249
|
-
|
250
|
-
|
250
|
+
if self.response and self.response.get("messages"):
|
251
|
+
summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
|
252
|
+
if markdown:
|
253
|
+
return Markdown(summary)
|
254
|
+
else:
|
255
|
+
return summary
|
251
256
|
|
252
257
|
def get_log_summary(self, markdown=False):
|
253
258
|
"""
|
@@ -255,7 +260,13 @@ class DataCleaningAgent(BaseAgent):
|
|
255
260
|
"""
|
256
261
|
if self.response:
|
257
262
|
if self.response.get('data_cleaner_function_path'):
|
258
|
-
log_details = f"
|
263
|
+
log_details = f"""
|
264
|
+
## Data Cleaning Agent Log Summary:
|
265
|
+
|
266
|
+
Function Path: {self.response.get('data_cleaner_function_path')}
|
267
|
+
|
268
|
+
Function Name: {self.response.get('data_cleaner_function_name')}
|
269
|
+
"""
|
259
270
|
if markdown:
|
260
271
|
return Markdown(log_details)
|
261
272
|
else:
|
@@ -462,7 +473,7 @@ def make_data_cleaning_agent(
|
|
462
473
|
Below are summaries of all datasets provided:
|
463
474
|
{all_datasets_summary}
|
464
475
|
|
465
|
-
Return
|
476
|
+
Return steps as a numbered list. You can return short code snippets to demonstrate actions. But do not return a fully coded solution. The code will be generated separately by a Coding Agent.
|
466
477
|
|
467
478
|
Avoid these:
|
468
479
|
1. Do not include steps to save files.
|
@@ -633,32 +644,31 @@ def make_data_cleaning_agent(
|
|
633
644
|
function_name=state.get("data_cleaner_function_name"),
|
634
645
|
)
|
635
646
|
|
636
|
-
|
637
|
-
|
647
|
+
# Final reporting node
|
648
|
+
def report_agent_outputs(state: GraphState):
|
649
|
+
return node_func_report_agent_outputs(
|
638
650
|
state=state,
|
639
|
-
|
651
|
+
keys_to_include=[
|
652
|
+
"recommended_steps",
|
653
|
+
"data_cleaner_function",
|
654
|
+
"data_cleaner_function_path",
|
655
|
+
"data_cleaner_function_name",
|
656
|
+
"data_cleaner_error",
|
657
|
+
],
|
640
658
|
result_key="messages",
|
641
|
-
error_key="data_cleaner_error",
|
642
|
-
llm=llm,
|
643
659
|
role=AGENT_NAME,
|
644
|
-
|
645
|
-
Explain the data cleaning steps that the data cleaning agent performed in this function.
|
646
|
-
Keep the summary succinct and to the point.\n\n# Data Cleaning Agent:\n\n{code}
|
647
|
-
""",
|
648
|
-
success_prefix="# Data Cleaning Agent:\n\n ",
|
649
|
-
error_message="The Data Cleaning Agent encountered an error during data cleaning. Data could not be explained."
|
660
|
+
custom_title="Data Cleaning Agent Outputs"
|
650
661
|
)
|
651
|
-
|
652
|
-
# Define the graph
|
662
|
+
|
653
663
|
node_functions = {
|
654
664
|
"recommend_cleaning_steps": recommend_cleaning_steps,
|
655
665
|
"human_review": human_review,
|
656
666
|
"create_data_cleaner_code": create_data_cleaner_code,
|
657
667
|
"execute_data_cleaner_code": execute_data_cleaner_code,
|
658
668
|
"fix_data_cleaner_code": fix_data_cleaner_code,
|
659
|
-
"
|
669
|
+
"report_agent_outputs": report_agent_outputs,
|
660
670
|
}
|
661
|
-
|
671
|
+
|
662
672
|
app = create_coding_agent_graph(
|
663
673
|
GraphState=GraphState,
|
664
674
|
node_functions=node_functions,
|
@@ -666,16 +676,17 @@ def make_data_cleaning_agent(
|
|
666
676
|
create_code_node_name="create_data_cleaner_code",
|
667
677
|
execute_code_node_name="execute_data_cleaner_code",
|
668
678
|
fix_code_node_name="fix_data_cleaner_code",
|
669
|
-
explain_code_node_name="
|
679
|
+
explain_code_node_name="report_agent_outputs",
|
670
680
|
error_key="data_cleaner_error",
|
671
|
-
human_in_the_loop=human_in_the_loop,
|
681
|
+
human_in_the_loop=human_in_the_loop,
|
672
682
|
human_review_node_name="human_review",
|
673
683
|
checkpointer=MemorySaver() if human_in_the_loop else None,
|
674
684
|
bypass_recommended_steps=bypass_recommended_steps,
|
675
685
|
bypass_explain_code=bypass_explain_code,
|
676
686
|
)
|
677
|
-
|
687
|
+
|
678
688
|
return app
|
689
|
+
|
679
690
|
|
680
691
|
|
681
692
|
|
@@ -0,0 +1,69 @@
|
|
1
|
+
|
2
|
+
|
3
|
+
|
4
|
+
from typing import Any, Optional, Annotated, Sequence, List, Dict
|
5
|
+
import operator
|
6
|
+
|
7
|
+
import pandas as pd
|
8
|
+
import os
|
9
|
+
|
10
|
+
from IPython.display import Markdown
|
11
|
+
|
12
|
+
from langchain_core.messages import BaseMessage, AIMessage
|
13
|
+
|
14
|
+
from langgraph.prebuilt import create_react_agent, ToolNode
|
15
|
+
from langgraph.prebuilt.chat_agent_executor import AgentState
|
16
|
+
from langgraph.graph import START, END, StateGraph
|
17
|
+
|
18
|
+
from ai_data_science_team.templates import BaseAgent
|
19
|
+
from ai_data_science_team.utils.regex import format_agent_name
|
20
|
+
from ai_data_science_team.tools.data_loader import (
|
21
|
+
load_directory,
|
22
|
+
load_file,
|
23
|
+
list_directory_contents,
|
24
|
+
list_directory_recursive,
|
25
|
+
get_file_info,
|
26
|
+
search_files_by_pattern,
|
27
|
+
)
|
28
|
+
|
29
|
+
AGENT_NAME = "data_loader_tools_agent"
|
30
|
+
|
31
|
+
tools = [
|
32
|
+
load_directory,
|
33
|
+
load_file,
|
34
|
+
list_directory_contents,
|
35
|
+
list_directory_recursive,
|
36
|
+
get_file_info,
|
37
|
+
search_files_by_pattern,
|
38
|
+
]
|
39
|
+
|
40
|
+
|
41
|
+
|
42
|
+
def make_data_loader_tools_agent(
|
43
|
+
model: Any,
|
44
|
+
directory: Optional[str] = os.getcwd(),
|
45
|
+
):
|
46
|
+
"""
|
47
|
+
Creates a Data Loader Agent that can interact with data loading tools.
|
48
|
+
|
49
|
+
Parameters:
|
50
|
+
----------
|
51
|
+
model : langchain.llms.base.LLM
|
52
|
+
The language model used to generate the tool calling agent.
|
53
|
+
directory : str, optional
|
54
|
+
The directory to search for files. Defaults to the current working directory.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
--------
|
58
|
+
Data Loader Agent
|
59
|
+
An agent that can interact with data loading tools.
|
60
|
+
"""
|
61
|
+
|
62
|
+
class GraphState(AgentState):
|
63
|
+
internal_messages: Annotated[Sequence[BaseMessage], operator.add]
|
64
|
+
directory: str
|
65
|
+
user_instructions: str
|
66
|
+
data_artifacts: dict
|
67
|
+
|
68
|
+
pass
|
69
|
+
|
@@ -10,13 +10,13 @@ from typing import TypedDict, Annotated, Sequence, Literal
|
|
10
10
|
import operator
|
11
11
|
|
12
12
|
from langchain.prompts import PromptTemplate
|
13
|
-
from langchain_core.output_parsers import StrOutputParser
|
14
13
|
from langchain_core.messages import BaseMessage
|
15
14
|
|
16
15
|
from langgraph.types import Command
|
17
16
|
from langgraph.checkpoint.memory import MemorySaver
|
18
17
|
|
19
18
|
import os
|
19
|
+
import json
|
20
20
|
import pandas as pd
|
21
21
|
|
22
22
|
from IPython.display import Markdown
|
@@ -25,19 +25,20 @@ from ai_data_science_team.templates import(
|
|
25
25
|
node_func_execute_agent_code_on_data,
|
26
26
|
node_func_human_review,
|
27
27
|
node_func_fix_agent_code,
|
28
|
-
|
28
|
+
node_func_report_agent_outputs,
|
29
29
|
create_coding_agent_graph,
|
30
30
|
BaseAgent,
|
31
31
|
)
|
32
|
-
from ai_data_science_team.
|
33
|
-
from ai_data_science_team.
|
32
|
+
from ai_data_science_team.parsers.parsers import PythonOutputParser
|
33
|
+
from ai_data_science_team.utils.regex import (
|
34
34
|
relocate_imports_inside_function,
|
35
35
|
add_comments_to_top,
|
36
36
|
format_agent_name,
|
37
|
-
format_recommended_steps
|
37
|
+
format_recommended_steps,
|
38
|
+
get_generic_summary,
|
38
39
|
)
|
39
|
-
from ai_data_science_team.tools.
|
40
|
-
from ai_data_science_team.
|
40
|
+
from ai_data_science_team.tools.dataframe import get_dataframe_summary
|
41
|
+
from ai_data_science_team.utils.logging import log_ai_function
|
41
42
|
from ai_data_science_team.utils.plotly import plotly_from_dict
|
42
43
|
|
43
44
|
# Setup
|
@@ -93,8 +94,8 @@ class DataVisualizationAgent(BaseAgent):
|
|
93
94
|
Asynchronously generates a visualization based on user instructions.
|
94
95
|
invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
|
95
96
|
Synchronously generates a visualization based on user instructions.
|
96
|
-
|
97
|
-
|
97
|
+
get_workflow_summary()
|
98
|
+
Retrieves a summary of the agent's workflow.
|
98
99
|
get_log_summary()
|
99
100
|
Retrieves a summary of logged operations if logging is enabled.
|
100
101
|
get_plotly_graph()
|
@@ -195,7 +196,7 @@ class DataVisualizationAgent(BaseAgent):
|
|
195
196
|
# Rebuild the compiled graph
|
196
197
|
self._compiled_graph = self._make_compiled_graph()
|
197
198
|
|
198
|
-
def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
199
|
+
async def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
|
199
200
|
"""
|
200
201
|
Asynchronously invokes the agent to generate a visualization.
|
201
202
|
The response is stored in the 'response' attribute.
|
@@ -217,7 +218,7 @@ class DataVisualizationAgent(BaseAgent):
|
|
217
218
|
-------
|
218
219
|
None
|
219
220
|
"""
|
220
|
-
response = self._compiled_graph.ainvoke({
|
221
|
+
response = await self._compiled_graph.ainvoke({
|
221
222
|
"user_instructions": user_instructions,
|
222
223
|
"data_raw": data_raw.to_dict(),
|
223
224
|
"max_retries": max_retries,
|
@@ -257,40 +258,34 @@ class DataVisualizationAgent(BaseAgent):
|
|
257
258
|
self.response = response
|
258
259
|
return None
|
259
260
|
|
260
|
-
def
|
261
|
+
def get_workflow_summary(self, markdown=False):
|
261
262
|
"""
|
262
|
-
|
263
|
-
|
264
|
-
Returns
|
265
|
-
-------
|
266
|
-
str
|
267
|
-
Explanation of the visualization steps, if any are available.
|
263
|
+
Retrieves the agent's workflow summary, if logging is enabled.
|
268
264
|
"""
|
269
|
-
if self.response:
|
270
|
-
|
271
|
-
|
265
|
+
if self.response and self.response.get("messages"):
|
266
|
+
summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
|
267
|
+
if markdown:
|
268
|
+
return Markdown(summary)
|
269
|
+
else:
|
270
|
+
return summary
|
272
271
|
|
273
272
|
def get_log_summary(self, markdown=False):
|
274
273
|
"""
|
275
274
|
Logs a summary of the agent's operations, if logging is enabled.
|
275
|
+
"""
|
276
|
+
if self.response:
|
277
|
+
if self.response.get('data_visualization_function_path'):
|
278
|
+
log_details = f"""
|
279
|
+
## Data Visualization Agent Log Summary:
|
276
280
|
|
277
|
-
|
278
|
-
----------
|
279
|
-
markdown : bool, optional
|
280
|
-
If True, returns Markdown-formatted output.
|
281
|
+
Function Path: {self.response.get('data_visualization_function_path')}
|
281
282
|
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
log_details = f"Log Path: {self.response.get('data_visualization_function_path')}"
|
289
|
-
if markdown:
|
290
|
-
return Markdown(log_details)
|
291
|
-
else:
|
292
|
-
return log_details
|
293
|
-
return None
|
283
|
+
Function Name: {self.response.get('data_visualization_function_name')}
|
284
|
+
"""
|
285
|
+
if markdown:
|
286
|
+
return Markdown(log_details)
|
287
|
+
else:
|
288
|
+
return log_details
|
294
289
|
|
295
290
|
def get_plotly_graph(self):
|
296
291
|
"""
|
@@ -719,20 +714,20 @@ def make_data_visualization_agent(
|
|
719
714
|
function_name=state.get("data_visualization_function_name"),
|
720
715
|
)
|
721
716
|
|
722
|
-
|
723
|
-
|
717
|
+
# Final reporting node
|
718
|
+
def report_agent_outputs(state: GraphState):
|
719
|
+
return node_func_report_agent_outputs(
|
724
720
|
state=state,
|
725
|
-
|
721
|
+
keys_to_include=[
|
722
|
+
"recommended_steps",
|
723
|
+
"data_visualization_function",
|
724
|
+
"data_visualization_function_path",
|
725
|
+
"data_visualization_function_name",
|
726
|
+
"data_visualization_error",
|
727
|
+
],
|
726
728
|
result_key="messages",
|
727
|
-
error_key="data_visualization_error",
|
728
|
-
llm=llm,
|
729
729
|
role=AGENT_NAME,
|
730
|
-
|
731
|
-
Explain the data visualization steps that the data visualization agent performed in this function.
|
732
|
-
Keep the summary succinct and to the point.\n\n# Data Visualization Agent:\n\n{code}
|
733
|
-
""",
|
734
|
-
success_prefix="# Data Visualization Agent:\n\n ",
|
735
|
-
error_message="The Data Visualization Agent encountered an error during data visualization. No explanation could be provided."
|
730
|
+
custom_title="Data Visualization Agent Outputs"
|
736
731
|
)
|
737
732
|
|
738
733
|
# Define the graph
|
@@ -742,7 +737,7 @@ def make_data_visualization_agent(
|
|
742
737
|
"chart_generator": chart_generator,
|
743
738
|
"execute_data_visualization_code": execute_data_visualization_code,
|
744
739
|
"fix_data_visualization_code": fix_data_visualization_code,
|
745
|
-
"
|
740
|
+
"report_agent_outputs": report_agent_outputs,
|
746
741
|
}
|
747
742
|
|
748
743
|
app = create_coding_agent_graph(
|
@@ -752,7 +747,7 @@ def make_data_visualization_agent(
|
|
752
747
|
create_code_node_name="chart_generator",
|
753
748
|
execute_code_node_name="execute_data_visualization_code",
|
754
749
|
fix_code_node_name="fix_data_visualization_code",
|
755
|
-
explain_code_node_name="
|
750
|
+
explain_code_node_name="report_agent_outputs",
|
756
751
|
error_key="data_visualization_error",
|
757
752
|
human_in_the_loop=human_in_the_loop, # or False
|
758
753
|
human_review_node_name="human_review",
|
@@ -7,6 +7,7 @@
|
|
7
7
|
from typing import TypedDict, Annotated, Sequence, Literal, Union, Optional
|
8
8
|
import operator
|
9
9
|
import os
|
10
|
+
import json
|
10
11
|
import pandas as pd
|
11
12
|
from IPython.display import Markdown
|
12
13
|
|
@@ -19,14 +20,20 @@ from ai_data_science_team.templates import(
|
|
19
20
|
node_func_execute_agent_code_on_data,
|
20
21
|
node_func_human_review,
|
21
22
|
node_func_fix_agent_code,
|
22
|
-
|
23
|
+
node_func_report_agent_outputs,
|
23
24
|
create_coding_agent_graph,
|
24
25
|
BaseAgent,
|
25
26
|
)
|
26
|
-
from ai_data_science_team.
|
27
|
-
from ai_data_science_team.
|
28
|
-
|
29
|
-
|
27
|
+
from ai_data_science_team.parsers.parsers import PythonOutputParser
|
28
|
+
from ai_data_science_team.utils.regex import (
|
29
|
+
relocate_imports_inside_function,
|
30
|
+
add_comments_to_top,
|
31
|
+
format_agent_name,
|
32
|
+
format_recommended_steps,
|
33
|
+
get_generic_summary,
|
34
|
+
)
|
35
|
+
from ai_data_science_team.tools.dataframe import get_dataframe_summary
|
36
|
+
from ai_data_science_team.utils.logging import log_ai_function
|
30
37
|
|
31
38
|
# Setup Logging Path
|
32
39
|
AGENT_NAME = "data_wrangling_agent"
|
@@ -88,8 +95,8 @@ class DataWranglingAgent(BaseAgent):
|
|
88
95
|
invoke_agent(user_instructions: str, data_raw: Union[dict, list], max_retries=3, retry_count=0)
|
89
96
|
Synchronously wrangles the provided dataset(s) based on user instructions.
|
90
97
|
|
91
|
-
|
92
|
-
|
98
|
+
get_workflow_summary()
|
99
|
+
Retrieves a summary of the agent's workflow.
|
93
100
|
|
94
101
|
get_log_summary()
|
95
102
|
Retrieves a summary of logged operations if logging is enabled.
|
@@ -206,7 +213,7 @@ class DataWranglingAgent(BaseAgent):
|
|
206
213
|
self._params[k] = v
|
207
214
|
self._compiled_graph = self._make_compiled_graph()
|
208
215
|
|
209
|
-
def ainvoke_agent(
|
216
|
+
async def ainvoke_agent(
|
210
217
|
self,
|
211
218
|
data_raw: Union[pd.DataFrame, dict, list],
|
212
219
|
user_instructions: str=None,
|
@@ -238,7 +245,7 @@ class DataWranglingAgent(BaseAgent):
|
|
238
245
|
None
|
239
246
|
"""
|
240
247
|
data_input = self._convert_data_input(data_raw)
|
241
|
-
response = self._compiled_graph.ainvoke({
|
248
|
+
response = await self._compiled_graph.ainvoke({
|
242
249
|
"user_instructions": user_instructions,
|
243
250
|
"data_raw": data_input,
|
244
251
|
"max_retries": max_retries,
|
@@ -287,40 +294,34 @@ class DataWranglingAgent(BaseAgent):
|
|
287
294
|
self.response = response
|
288
295
|
return None
|
289
296
|
|
290
|
-
def
|
297
|
+
def get_workflow_summary(self, markdown=False):
|
291
298
|
"""
|
292
|
-
|
293
|
-
|
294
|
-
Returns
|
295
|
-
-------
|
296
|
-
str or list
|
297
|
-
Explanation of the data wrangling steps.
|
299
|
+
Retrieves the agent's workflow summary, if logging is enabled.
|
298
300
|
"""
|
299
|
-
if self.response:
|
300
|
-
|
301
|
-
|
301
|
+
if self.response and self.response.get("messages"):
|
302
|
+
summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
|
303
|
+
if markdown:
|
304
|
+
return Markdown(summary)
|
305
|
+
else:
|
306
|
+
return summary
|
302
307
|
|
303
308
|
def get_log_summary(self, markdown=False):
|
304
309
|
"""
|
305
310
|
Logs a summary of the agent's operations, if logging is enabled.
|
311
|
+
"""
|
312
|
+
if self.response:
|
313
|
+
if self.response.get('data_wrangler_function_path'):
|
314
|
+
log_details = f"""
|
315
|
+
## Data Wrangling Agent Log Summary:
|
306
316
|
|
307
|
-
|
308
|
-
----------
|
309
|
-
markdown : bool, optional
|
310
|
-
If True, returns the summary in Markdown.
|
317
|
+
Function Path: {self.response.get('data_wrangler_function_path')}
|
311
318
|
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
log_details = f"Log Path: {self.response.get('data_wrangler_function_path')}"
|
319
|
-
if markdown:
|
320
|
-
return Markdown(log_details)
|
321
|
-
else:
|
322
|
-
return log_details
|
323
|
-
return None
|
319
|
+
Function Name: {self.response.get('data_wrangler_function_name')}
|
320
|
+
"""
|
321
|
+
if markdown:
|
322
|
+
return Markdown(log_details)
|
323
|
+
else:
|
324
|
+
return log_details
|
324
325
|
|
325
326
|
def get_data_wrangled(self) -> Optional[pd.DataFrame]:
|
326
327
|
"""
|
@@ -597,7 +598,7 @@ def make_data_wrangling_agent(
|
|
597
598
|
Below are summaries of all datasets provided:
|
598
599
|
{all_datasets_summary}
|
599
600
|
|
600
|
-
Return
|
601
|
+
Return steps as a numbered list. You can return short code snippets to demonstrate actions. But do not return a fully coded solution. The code will be generated separately by a Coding Agent.
|
601
602
|
|
602
603
|
Avoid these:
|
603
604
|
1. Do not include steps to save files.
|
@@ -797,20 +798,20 @@ def make_data_wrangling_agent(
|
|
797
798
|
function_name=state.get("data_wrangler_function_name"),
|
798
799
|
)
|
799
800
|
|
800
|
-
|
801
|
-
|
801
|
+
# Final reporting node
|
802
|
+
def report_agent_outputs(state: GraphState):
|
803
|
+
return node_func_report_agent_outputs(
|
802
804
|
state=state,
|
803
|
-
|
805
|
+
keys_to_include=[
|
806
|
+
"recommended_steps",
|
807
|
+
"data_wrangler_function",
|
808
|
+
"data_wrangler_function_path",
|
809
|
+
"data_wrangler_function_name",
|
810
|
+
"data_wrangler_error",
|
811
|
+
],
|
804
812
|
result_key="messages",
|
805
|
-
error_key="data_wrangler_error",
|
806
|
-
llm=llm,
|
807
813
|
role=AGENT_NAME,
|
808
|
-
|
809
|
-
Explain the data wrangling steps that the data wrangling agent performed in this function.
|
810
|
-
Keep the summary succinct and to the point.\n\n# Data Wrangling Agent:\n\n{code}
|
811
|
-
""",
|
812
|
-
success_prefix="# Data Wrangling Agent:\n\n ",
|
813
|
-
error_message="The Data Wrangling Agent encountered an error during data wrangling. Data could not be explained."
|
814
|
+
custom_title="Data Wrangling Agent Outputs"
|
814
815
|
)
|
815
816
|
|
816
817
|
# Define the graph
|
@@ -820,7 +821,7 @@ def make_data_wrangling_agent(
|
|
820
821
|
"create_data_wrangler_code": create_data_wrangler_code,
|
821
822
|
"execute_data_wrangler_code": execute_data_wrangler_code,
|
822
823
|
"fix_data_wrangler_code": fix_data_wrangler_code,
|
823
|
-
"
|
824
|
+
"report_agent_outputs": report_agent_outputs,
|
824
825
|
}
|
825
826
|
|
826
827
|
app = create_coding_agent_graph(
|
@@ -830,7 +831,7 @@ def make_data_wrangling_agent(
|
|
830
831
|
create_code_node_name="create_data_wrangler_code",
|
831
832
|
execute_code_node_name="execute_data_wrangler_code",
|
832
833
|
fix_code_node_name="fix_data_wrangler_code",
|
833
|
-
explain_code_node_name="
|
834
|
+
explain_code_node_name="report_agent_outputs",
|
834
835
|
error_key="data_wrangler_error",
|
835
836
|
human_in_the_loop=human_in_the_loop,
|
836
837
|
human_review_node_name="human_review",
|