ai-data-science-team 0.0.0.9010__py3-none-any.whl → 0.0.0.9012__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +1 -0
- ai_data_science_team/agents/data_loader_tools_agent.py +210 -7
- ai_data_science_team/ds_agents/__init__.py +1 -0
- ai_data_science_team/ds_agents/eda_tools_agent.py +245 -0
- ai_data_science_team/ds_agents/modeling_tools_agent.py +0 -0
- ai_data_science_team/ml_agents/h2o_ml_agent.py +2 -1
- ai_data_science_team/ml_agents/h2o_ml_tools_agent.py +0 -0
- ai_data_science_team/ml_agents/mlflow_tools_agent.py +32 -9
- ai_data_science_team/tools/data_loader.py +95 -25
- ai_data_science_team/tools/eda.py +293 -0
- ai_data_science_team/utils/html.py +27 -0
- ai_data_science_team/utils/matplotlib.py +46 -0
- {ai_data_science_team-0.0.0.9010.dist-info → ai_data_science_team-0.0.0.9012.dist-info}/METADATA +26 -9
- {ai_data_science_team-0.0.0.9010.dist-info → ai_data_science_team-0.0.0.9012.dist-info}/RECORD +18 -11
- {ai_data_science_team-0.0.0.9010.dist-info → ai_data_science_team-0.0.0.9012.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9010.dist-info → ai_data_science_team-0.0.0.9012.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9010.dist-info → ai_data_science_team-0.0.0.9012.dist-info}/top_level.txt +0 -0
ai_data_science_team/_version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.0.0.
|
1
|
+
__version__ = "0.0.0.9012"
|
@@ -3,3 +3,4 @@ from ai_data_science_team.agents.feature_engineering_agent import make_feature_e
|
|
3
3
|
from ai_data_science_team.agents.data_wrangling_agent import make_data_wrangling_agent, DataWranglingAgent
|
4
4
|
from ai_data_science_team.agents.sql_database_agent import make_sql_database_agent, SQLDatabaseAgent
|
5
5
|
from ai_data_science_team.agents.data_visualization_agent import make_data_visualization_agent, DataVisualizationAgent
|
6
|
+
from ai_data_science_team.agents.data_loader_tools_agent import make_data_loader_tools_agent, DataLoaderToolsAgent
|
@@ -37,11 +37,150 @@ tools = [
|
|
37
37
|
search_files_by_pattern,
|
38
38
|
]
|
39
39
|
|
40
|
+
class DataLoaderToolsAgent(BaseAgent):
|
41
|
+
"""
|
42
|
+
A Data Loader Agent that can interact with data loading tools and search for files in your file system.
|
43
|
+
|
44
|
+
Parameters:
|
45
|
+
----------
|
46
|
+
model : langchain.llms.base.LLM
|
47
|
+
The language model used to generate the tool calling agent.
|
48
|
+
react_agent_kwargs : dict
|
49
|
+
Additional keyword arguments to pass to the create_react_agent function.
|
50
|
+
invoke_react_agent_kwargs : dict
|
51
|
+
Additional keyword arguments to pass to the invoke method of the react agent.
|
52
|
+
|
53
|
+
Methods:
|
54
|
+
--------
|
55
|
+
update_params(**kwargs)
|
56
|
+
Updates the agent's parameters and rebuilds the compiled graph.
|
57
|
+
ainvoke_agent(user_instructions: str=None, **kwargs)
|
58
|
+
Runs the agent with the given user instructions asynchronously.
|
59
|
+
invoke_agent(user_instructions: str=None, **kwargs)
|
60
|
+
Runs the agent with the given user instructions.
|
61
|
+
get_internal_messages(markdown: bool=False)
|
62
|
+
Returns the internal messages from the agent's response.
|
63
|
+
get_artifacts(as_dataframe: bool=False)
|
64
|
+
Returns the MLflow artifacts from the agent's response.
|
65
|
+
get_ai_message(markdown: bool=False)
|
66
|
+
Returns the AI message from the agent's response.
|
67
|
+
|
68
|
+
"""
|
69
|
+
|
70
|
+
def __init__(
|
71
|
+
self,
|
72
|
+
model: Any,
|
73
|
+
create_react_agent_kwargs: Optional[Dict]={},
|
74
|
+
invoke_react_agent_kwargs: Optional[Dict]={},
|
75
|
+
):
|
76
|
+
self._params = {
|
77
|
+
"model": model,
|
78
|
+
"create_react_agent_kwargs": create_react_agent_kwargs,
|
79
|
+
"invoke_react_agent_kwargs": invoke_react_agent_kwargs,
|
80
|
+
}
|
81
|
+
self._compiled_graph = self._make_compiled_graph()
|
82
|
+
self.response = None
|
83
|
+
|
84
|
+
def _make_compiled_graph(self):
|
85
|
+
"""
|
86
|
+
Creates the compiled graph for the agent.
|
87
|
+
"""
|
88
|
+
self.response = None
|
89
|
+
return make_data_loader_tools_agent(**self._params)
|
90
|
+
|
91
|
+
|
92
|
+
def update_params(self, **kwargs):
|
93
|
+
"""
|
94
|
+
Updates the agent's parameters and rebuilds the compiled graph.
|
95
|
+
"""
|
96
|
+
for k, v in kwargs.items():
|
97
|
+
self._params[k] = v
|
98
|
+
self._compiled_graph = self._make_compiled_graph()
|
99
|
+
|
100
|
+
async def ainvoke_agent(
|
101
|
+
self,
|
102
|
+
user_instructions: str=None,
|
103
|
+
**kwargs
|
104
|
+
):
|
105
|
+
"""
|
106
|
+
Runs the agent with the given user instructions.
|
107
|
+
|
108
|
+
Parameters:
|
109
|
+
----------
|
110
|
+
user_instructions : str, optional
|
111
|
+
The user instructions to pass to the agent.
|
112
|
+
kwargs : dict, optional
|
113
|
+
Additional keyword arguments to pass to the agents ainvoke method.
|
114
|
+
|
115
|
+
"""
|
116
|
+
response = await self._compiled_graph.ainvoke(
|
117
|
+
{
|
118
|
+
"user_instructions": user_instructions,
|
119
|
+
},
|
120
|
+
**kwargs
|
121
|
+
)
|
122
|
+
self.response = response
|
123
|
+
return None
|
124
|
+
|
125
|
+
def invoke_agent(
|
126
|
+
self,
|
127
|
+
user_instructions: str=None,
|
128
|
+
**kwargs
|
129
|
+
):
|
130
|
+
"""
|
131
|
+
Runs the agent with the given user instructions.
|
132
|
+
|
133
|
+
Parameters:
|
134
|
+
----------
|
135
|
+
user_instructions : str, optional
|
136
|
+
The user instructions to pass to the agent.
|
137
|
+
kwargs : dict, optional
|
138
|
+
Additional keyword arguments to pass to the agents invoke method.
|
139
|
+
|
140
|
+
"""
|
141
|
+
response = self._compiled_graph.invoke(
|
142
|
+
{
|
143
|
+
"user_instructions": user_instructions,
|
144
|
+
},
|
145
|
+
**kwargs
|
146
|
+
)
|
147
|
+
self.response = response
|
148
|
+
return None
|
149
|
+
|
150
|
+
def get_internal_messages(self, markdown: bool=False):
|
151
|
+
"""
|
152
|
+
Returns the internal messages from the agent's response.
|
153
|
+
"""
|
154
|
+
pretty_print = "\n\n".join([f"### {msg.type.upper()}\n\nID: {msg.id}\n\nContent:\n\n{msg.content}" for msg in self.response["internal_messages"]])
|
155
|
+
if markdown:
|
156
|
+
return Markdown(pretty_print)
|
157
|
+
else:
|
158
|
+
return self.response["internal_messages"]
|
159
|
+
|
160
|
+
def get_artifacts(self, as_dataframe: bool=False):
|
161
|
+
"""
|
162
|
+
Returns the MLflow artifacts from the agent's response.
|
163
|
+
"""
|
164
|
+
if as_dataframe:
|
165
|
+
return pd.DataFrame(self.response["data_loader_artifacts"])
|
166
|
+
else:
|
167
|
+
return self.response["data_loader_artifacts"]
|
168
|
+
|
169
|
+
def get_ai_message(self, markdown: bool=False):
|
170
|
+
"""
|
171
|
+
Returns the AI message from the agent's response.
|
172
|
+
"""
|
173
|
+
if markdown:
|
174
|
+
return Markdown(self.response["messages"][0].content)
|
175
|
+
else:
|
176
|
+
return self.response["messages"][0].content
|
40
177
|
|
178
|
+
|
41
179
|
|
42
180
|
def make_data_loader_tools_agent(
|
43
181
|
model: Any,
|
44
|
-
|
182
|
+
create_react_agent_kwargs: Optional[Dict]={},
|
183
|
+
invoke_react_agent_kwargs: Optional[Dict]={},
|
45
184
|
):
|
46
185
|
"""
|
47
186
|
Creates a Data Loader Agent that can interact with data loading tools.
|
@@ -50,20 +189,84 @@ def make_data_loader_tools_agent(
|
|
50
189
|
----------
|
51
190
|
model : langchain.llms.base.LLM
|
52
191
|
The language model used to generate the tool calling agent.
|
53
|
-
|
54
|
-
|
192
|
+
react_agent_kwargs : dict
|
193
|
+
Additional keyword arguments to pass to the create_react_agent function.
|
194
|
+
invoke_react_agent_kwargs : dict
|
195
|
+
Additional keyword arguments to pass to the invoke method of the react agent.
|
55
196
|
|
56
197
|
Returns:
|
57
198
|
--------
|
58
|
-
|
199
|
+
app : langchain.graphs.CompiledStateGraph
|
59
200
|
An agent that can interact with data loading tools.
|
60
201
|
"""
|
61
202
|
|
62
203
|
class GraphState(AgentState):
|
63
204
|
internal_messages: Annotated[Sequence[BaseMessage], operator.add]
|
64
|
-
directory: str
|
65
205
|
user_instructions: str
|
66
|
-
|
206
|
+
data_loader_artifacts: dict
|
207
|
+
|
208
|
+
def data_loader_agent(state):
|
209
|
+
|
210
|
+
print(format_agent_name(AGENT_NAME))
|
211
|
+
print(" ")
|
212
|
+
|
213
|
+
print(" * RUN REACT TOOL-CALLING AGENT")
|
214
|
+
|
215
|
+
tool_node = ToolNode(
|
216
|
+
tools=tools
|
217
|
+
)
|
218
|
+
|
219
|
+
data_loader_agent = create_react_agent(
|
220
|
+
model,
|
221
|
+
tools=tool_node,
|
222
|
+
state_schema=GraphState,
|
223
|
+
**create_react_agent_kwargs,
|
224
|
+
)
|
225
|
+
|
226
|
+
response = data_loader_agent.invoke(
|
227
|
+
{
|
228
|
+
"messages": [("user", state["user_instructions"])],
|
229
|
+
},
|
230
|
+
invoke_react_agent_kwargs,
|
231
|
+
)
|
232
|
+
|
233
|
+
print(" * POST-PROCESS RESULTS")
|
234
|
+
|
235
|
+
internal_messages = response['messages']
|
236
|
+
|
237
|
+
# Ensure there is at least one AI message
|
238
|
+
if not internal_messages:
|
239
|
+
return {
|
240
|
+
"internal_messages": [],
|
241
|
+
"mlflow_artifacts": None,
|
242
|
+
}
|
243
|
+
|
244
|
+
# Get the last AI message
|
245
|
+
last_ai_message = AIMessage(internal_messages[-1].content, role = AGENT_NAME)
|
246
|
+
|
247
|
+
# Get the last tool artifact safely
|
248
|
+
last_tool_artifact = None
|
249
|
+
if len(internal_messages) > 1:
|
250
|
+
last_message = internal_messages[-2] # Get second-to-last message
|
251
|
+
if hasattr(last_message, "artifact"): # Check if it has an "artifact"
|
252
|
+
last_tool_artifact = last_message.artifact
|
253
|
+
elif isinstance(last_message, dict) and "artifact" in last_message:
|
254
|
+
last_tool_artifact = last_message["artifact"]
|
255
|
+
|
256
|
+
return {
|
257
|
+
"messages": [last_ai_message],
|
258
|
+
"internal_messages": internal_messages,
|
259
|
+
"data_loader_artifacts": last_tool_artifact,
|
260
|
+
}
|
261
|
+
|
262
|
+
workflow = StateGraph(GraphState)
|
67
263
|
|
68
|
-
|
264
|
+
workflow.add_node("data_loader_agent", data_loader_agent)
|
265
|
+
|
266
|
+
workflow.add_edge(START, "data_loader_agent")
|
267
|
+
workflow.add_edge("data_loader_agent", END)
|
268
|
+
|
269
|
+
app = workflow.compile()
|
270
|
+
|
271
|
+
return app
|
69
272
|
|
@@ -0,0 +1 @@
|
|
1
|
+
from ai_data_science_team.ds_agents.eda_tools_agent import EDAToolsAgent, make_eda_tools_agent
|
@@ -0,0 +1,245 @@
|
|
1
|
+
|
2
|
+
|
3
|
+
from typing import Any, Optional, Annotated, Sequence, List, Dict, Tuple
|
4
|
+
import operator
|
5
|
+
import pandas as pd
|
6
|
+
import os
|
7
|
+
from io import StringIO, BytesIO
|
8
|
+
import base64
|
9
|
+
import matplotlib.pyplot as plt
|
10
|
+
|
11
|
+
from IPython.display import Markdown
|
12
|
+
|
13
|
+
from langchain_core.messages import BaseMessage, AIMessage
|
14
|
+
from langgraph.prebuilt import create_react_agent, ToolNode
|
15
|
+
from langgraph.prebuilt.chat_agent_executor import AgentState
|
16
|
+
from langgraph.graph import START, END, StateGraph
|
17
|
+
|
18
|
+
from ai_data_science_team.templates import BaseAgent
|
19
|
+
from ai_data_science_team.utils.regex import format_agent_name
|
20
|
+
|
21
|
+
from ai_data_science_team.tools.eda import (
|
22
|
+
describe_dataset,
|
23
|
+
visualize_missing,
|
24
|
+
correlation_funnel,
|
25
|
+
generate_sweetviz_report,
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
AGENT_NAME = "exploratory_data_analyst_agent"
|
30
|
+
|
31
|
+
# Updated tool list for EDA
|
32
|
+
EDA_TOOLS = [
|
33
|
+
describe_dataset,
|
34
|
+
visualize_missing,
|
35
|
+
correlation_funnel,
|
36
|
+
generate_sweetviz_report,
|
37
|
+
]
|
38
|
+
|
39
|
+
class EDAToolsAgent(BaseAgent):
|
40
|
+
"""
|
41
|
+
An Exploratory Data Analysis Tools Agent that interacts with EDA tools to generate summary statistics,
|
42
|
+
missing data visualizations, correlation funnels, EDA reports, etc.
|
43
|
+
|
44
|
+
Parameters:
|
45
|
+
----------
|
46
|
+
model : langchain.llms.base.LLM
|
47
|
+
The language model for generating the tool-calling agent.
|
48
|
+
create_react_agent_kwargs : dict
|
49
|
+
Additional kwargs for create_react_agent.
|
50
|
+
invoke_react_agent_kwargs : dict
|
51
|
+
Additional kwargs for agent invocation.
|
52
|
+
"""
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
model: Any,
|
57
|
+
create_react_agent_kwargs: Optional[Dict] = {},
|
58
|
+
invoke_react_agent_kwargs: Optional[Dict] = {},
|
59
|
+
):
|
60
|
+
self._params = {
|
61
|
+
"model": model,
|
62
|
+
"create_react_agent_kwargs": create_react_agent_kwargs,
|
63
|
+
"invoke_react_agent_kwargs": invoke_react_agent_kwargs,
|
64
|
+
}
|
65
|
+
self._compiled_graph = self._make_compiled_graph()
|
66
|
+
self.response = None
|
67
|
+
|
68
|
+
def _make_compiled_graph(self):
|
69
|
+
"""
|
70
|
+
Creates the compiled state graph for the EDA agent.
|
71
|
+
"""
|
72
|
+
self.response = None
|
73
|
+
return make_eda_tools_agent(**self._params)
|
74
|
+
|
75
|
+
def update_params(self, **kwargs):
|
76
|
+
"""
|
77
|
+
Updates the agent's parameters and rebuilds the compiled graph.
|
78
|
+
"""
|
79
|
+
for k, v in kwargs.items():
|
80
|
+
self._params[k] = v
|
81
|
+
self._compiled_graph = self._make_compiled_graph()
|
82
|
+
|
83
|
+
async def ainvoke_agent(
|
84
|
+
self,
|
85
|
+
user_instructions: str = None,
|
86
|
+
data_raw: pd.DataFrame = None,
|
87
|
+
**kwargs
|
88
|
+
):
|
89
|
+
"""
|
90
|
+
Asynchronously runs the agent with user instructions and data.
|
91
|
+
|
92
|
+
Parameters:
|
93
|
+
----------
|
94
|
+
user_instructions : str, optional
|
95
|
+
The instructions for the agent.
|
96
|
+
data_raw : pd.DataFrame, optional
|
97
|
+
The input data as a DataFrame.
|
98
|
+
"""
|
99
|
+
response = await self._compiled_graph.ainvoke(
|
100
|
+
{
|
101
|
+
"user_instructions": user_instructions,
|
102
|
+
"data_raw": data_raw.to_dict() if data_raw is not None else None,
|
103
|
+
},
|
104
|
+
**kwargs
|
105
|
+
)
|
106
|
+
self.response = response
|
107
|
+
return None
|
108
|
+
|
109
|
+
def invoke_agent(
|
110
|
+
self,
|
111
|
+
user_instructions: str = None,
|
112
|
+
data_raw: pd.DataFrame = None,
|
113
|
+
**kwargs
|
114
|
+
):
|
115
|
+
"""
|
116
|
+
Synchronously runs the agent with user instructions and data.
|
117
|
+
|
118
|
+
Parameters:
|
119
|
+
----------
|
120
|
+
user_instructions : str, optional
|
121
|
+
The instructions for the agent.
|
122
|
+
data_raw : pd.DataFrame, optional
|
123
|
+
The input data as a DataFrame.
|
124
|
+
"""
|
125
|
+
response = self._compiled_graph.invoke(
|
126
|
+
{
|
127
|
+
"user_instructions": user_instructions,
|
128
|
+
"data_raw": data_raw.to_dict() if data_raw is not None else None,
|
129
|
+
},
|
130
|
+
**kwargs
|
131
|
+
)
|
132
|
+
self.response = response
|
133
|
+
return None
|
134
|
+
|
135
|
+
def get_internal_messages(self, markdown: bool = False):
|
136
|
+
"""
|
137
|
+
Returns internal messages from the agent response.
|
138
|
+
"""
|
139
|
+
pretty_print = "\n\n".join(
|
140
|
+
[f"### {msg.type.upper()}\n\nID: {msg.id}\n\nContent:\n\n{msg.content}"
|
141
|
+
for msg in self.response["internal_messages"]]
|
142
|
+
)
|
143
|
+
if markdown:
|
144
|
+
return Markdown(pretty_print)
|
145
|
+
else:
|
146
|
+
return self.response["internal_messages"]
|
147
|
+
|
148
|
+
def get_artifacts(self, as_dataframe: bool = False):
|
149
|
+
"""
|
150
|
+
Returns the EDA artifacts from the agent response.
|
151
|
+
"""
|
152
|
+
if as_dataframe:
|
153
|
+
return pd.DataFrame(self.response["eda_artifacts"])
|
154
|
+
else:
|
155
|
+
return self.response["eda_artifacts"]
|
156
|
+
|
157
|
+
def get_ai_message(self, markdown: bool = False):
|
158
|
+
"""
|
159
|
+
Returns the AI message from the agent response.
|
160
|
+
"""
|
161
|
+
if markdown:
|
162
|
+
return Markdown(self.response["messages"][0].content)
|
163
|
+
else:
|
164
|
+
return self.response["messages"][0].content
|
165
|
+
|
166
|
+
def make_eda_tools_agent(
|
167
|
+
model: Any,
|
168
|
+
create_react_agent_kwargs: Optional[Dict] = {},
|
169
|
+
invoke_react_agent_kwargs: Optional[Dict] = {},
|
170
|
+
):
|
171
|
+
"""
|
172
|
+
Creates an Exploratory Data Analyst Agent that can interact with EDA tools.
|
173
|
+
|
174
|
+
Parameters:
|
175
|
+
----------
|
176
|
+
model : Any
|
177
|
+
The language model used for tool-calling.
|
178
|
+
create_react_agent_kwargs : dict
|
179
|
+
Additional kwargs for create_react_agent.
|
180
|
+
invoke_react_agent_kwargs : dict
|
181
|
+
Additional kwargs for agent invocation.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
-------
|
185
|
+
app : langgraph.graph.CompiledStateGraph
|
186
|
+
The compiled state graph for the EDA agent.
|
187
|
+
"""
|
188
|
+
|
189
|
+
class GraphState(AgentState):
|
190
|
+
internal_messages: Annotated[Sequence[BaseMessage], operator.add]
|
191
|
+
user_instructions: str
|
192
|
+
data_raw: dict
|
193
|
+
eda_artifacts: dict
|
194
|
+
|
195
|
+
def exploratory_agent(state):
|
196
|
+
print(format_agent_name(AGENT_NAME))
|
197
|
+
print(" * RUN REACT TOOL-CALLING AGENT FOR EDA")
|
198
|
+
|
199
|
+
tool_node = ToolNode(
|
200
|
+
tools=EDA_TOOLS
|
201
|
+
)
|
202
|
+
|
203
|
+
eda_agent = create_react_agent(
|
204
|
+
model,
|
205
|
+
tools=tool_node,
|
206
|
+
state_schema=GraphState,
|
207
|
+
**create_react_agent_kwargs,
|
208
|
+
)
|
209
|
+
|
210
|
+
response = eda_agent.invoke(
|
211
|
+
{
|
212
|
+
"messages": [("user", state["user_instructions"])],
|
213
|
+
"data_raw": state["data_raw"],
|
214
|
+
},
|
215
|
+
invoke_react_agent_kwargs,
|
216
|
+
)
|
217
|
+
|
218
|
+
print(" * POST-PROCESSING EDA RESULTS")
|
219
|
+
|
220
|
+
internal_messages = response['messages']
|
221
|
+
if not internal_messages:
|
222
|
+
return {"internal_messages": [], "eda_artifacts": None}
|
223
|
+
|
224
|
+
last_ai_message = AIMessage(internal_messages[-1].content, role=AGENT_NAME)
|
225
|
+
last_tool_artifact = None
|
226
|
+
if len(internal_messages) > 1:
|
227
|
+
last_message = internal_messages[-2]
|
228
|
+
if hasattr(last_message, "artifact"):
|
229
|
+
last_tool_artifact = last_message.artifact
|
230
|
+
elif isinstance(last_message, dict) and "artifact" in last_message:
|
231
|
+
last_tool_artifact = last_message["artifact"]
|
232
|
+
|
233
|
+
return {
|
234
|
+
"messages": [last_ai_message],
|
235
|
+
"internal_messages": internal_messages,
|
236
|
+
"eda_artifacts": last_tool_artifact,
|
237
|
+
}
|
238
|
+
|
239
|
+
workflow = StateGraph(GraphState)
|
240
|
+
workflow.add_node("exploratory_agent", exploratory_agent)
|
241
|
+
workflow.add_edge(START, "exploratory_agent")
|
242
|
+
workflow.add_edge("exploratory_agent", END)
|
243
|
+
|
244
|
+
app = workflow.compile()
|
245
|
+
return app
|
File without changes
|
@@ -506,6 +506,7 @@ def make_h2o_ml_agent(
|
|
506
506
|
while remaining flexible to user instructions.
|
507
507
|
- Return a dict with keys: leaderboard, best_model_id, model_path, and model_results.
|
508
508
|
- If enable_mlfow is True, log the top metrics and save the model as an artifact. (See example function)
|
509
|
+
- IMPORTANT: if enable_mlflow is True, make sure to set enable_mlflow to True in the function definition.
|
509
510
|
|
510
511
|
Initial User Instructions (Disregard any instructions that are unrelated to modeling):
|
511
512
|
{user_instructions}
|
@@ -533,7 +534,7 @@ def make_h2o_ml_agent(
|
|
533
534
|
sort_metric: str ,
|
534
535
|
model_directory: Optional[str] = None,
|
535
536
|
log_path: Optional[str] = None,
|
536
|
-
enable_mlflow: bool,
|
537
|
+
enable_mlflow: bool, # If use has specified to enable MLflow, make sure to make this True
|
537
538
|
mlflow_tracking_uri: Optional[str],
|
538
539
|
mlflow_experiment_name: str,
|
539
540
|
mlflow_run_name: str,
|
File without changes
|
@@ -1,5 +1,5 @@
|
|
1
1
|
|
2
|
-
from typing import Any, Optional, Annotated, Sequence
|
2
|
+
from typing import Any, Optional, Annotated, Sequence, Dict
|
3
3
|
import operator
|
4
4
|
|
5
5
|
import pandas as pd
|
@@ -63,8 +63,10 @@ class MLflowToolsAgent(BaseAgent):
|
|
63
63
|
The tracking URI for MLflow. Defaults to None.
|
64
64
|
mlflow_registry_uri : str, optional
|
65
65
|
The registry URI for MLflow. Defaults to None.
|
66
|
-
|
67
|
-
Additional keyword arguments to pass to the
|
66
|
+
react_agent_kwargs : dict
|
67
|
+
Additional keyword arguments to pass to the create_react_agent function.
|
68
|
+
invoke_react_agent_kwargs : dict
|
69
|
+
Additional keyword arguments to pass to the invoke method of the react agent.
|
68
70
|
|
69
71
|
Methods:
|
70
72
|
--------
|
@@ -114,13 +116,15 @@ class MLflowToolsAgent(BaseAgent):
|
|
114
116
|
model: Any,
|
115
117
|
mlflow_tracking_uri: Optional[str]=None,
|
116
118
|
mlflow_registry_uri: Optional[str]=None,
|
117
|
-
|
119
|
+
create_react_agent_kwargs: Optional[Dict]={},
|
120
|
+
invoke_react_agent_kwargs: Optional[Dict]={},
|
118
121
|
):
|
119
122
|
self._params = {
|
120
123
|
"model": model,
|
121
124
|
"mlflow_tracking_uri": mlflow_tracking_uri,
|
122
125
|
"mlflow_registry_uri": mlflow_registry_uri,
|
123
|
-
|
126
|
+
"create_react_agent_kwargs": create_react_agent_kwargs,
|
127
|
+
"invoke_react_agent_kwargs": invoke_react_agent_kwargs,
|
124
128
|
}
|
125
129
|
self._compiled_graph = self._make_compiled_graph()
|
126
130
|
self.response = None
|
@@ -185,8 +189,6 @@ class MLflowToolsAgent(BaseAgent):
|
|
185
189
|
The user instructions to pass to the agent.
|
186
190
|
data_raw : pd.DataFrame, optional
|
187
191
|
The raw data to pass to the agent. Used for prediction and tool calls where data is required.
|
188
|
-
kwargs : dict, optional
|
189
|
-
Additional keyword arguments to pass to the agents invoke method.
|
190
192
|
|
191
193
|
"""
|
192
194
|
response = self._compiled_graph.invoke(
|
@@ -234,10 +236,30 @@ def make_mlflow_tools_agent(
|
|
234
236
|
model: Any,
|
235
237
|
mlflow_tracking_uri: str=None,
|
236
238
|
mlflow_registry_uri: str=None,
|
237
|
-
|
239
|
+
create_react_agent_kwargs: Optional[Dict]={},
|
240
|
+
invoke_react_agent_kwargs: Optional[Dict]={},
|
238
241
|
):
|
239
242
|
"""
|
240
243
|
MLflow Tool Calling Agent
|
244
|
+
|
245
|
+
Parameters:
|
246
|
+
----------
|
247
|
+
model : Any
|
248
|
+
The language model used to generate the agent.
|
249
|
+
mlflow_tracking_uri : str, optional
|
250
|
+
The tracking URI for MLflow. Defaults to None.
|
251
|
+
mlflow_registry_uri : str, optional
|
252
|
+
The registry URI for MLflow. Defaults to None.
|
253
|
+
create_react_agent_kwargs : dict, optional
|
254
|
+
Additional keyword arguments to pass to the agent's create_react_agent method.
|
255
|
+
invoke_react_agent_kwargs : dict, optional
|
256
|
+
Additional keyword arguments to pass to the agent's invoke method.
|
257
|
+
|
258
|
+
Returns
|
259
|
+
-------
|
260
|
+
app : langchain.graphs.CompiledStateGraph
|
261
|
+
A compiled state graph for the MLflow Tool Calling Agent.
|
262
|
+
|
241
263
|
"""
|
242
264
|
|
243
265
|
try:
|
@@ -274,7 +296,7 @@ def make_mlflow_tools_agent(
|
|
274
296
|
model,
|
275
297
|
tools=tool_node,
|
276
298
|
state_schema=GraphState,
|
277
|
-
**
|
299
|
+
**create_react_agent_kwargs,
|
278
300
|
)
|
279
301
|
|
280
302
|
response = mlflow_agent.invoke(
|
@@ -282,6 +304,7 @@ def make_mlflow_tools_agent(
|
|
282
304
|
"messages": [("user", state["user_instructions"])],
|
283
305
|
"data_raw": state["data_raw"],
|
284
306
|
},
|
307
|
+
invoke_react_agent_kwargs,
|
285
308
|
)
|
286
309
|
|
287
310
|
print(" * POST-PROCESS RESULTS")
|