ai-data-science-team 0.0.0.9009__py3-none-any.whl → 0.0.0.9010__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (27) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/data_cleaning_agent.py +6 -6
  3. ai_data_science_team/agents/data_loader_tools_agent.py +69 -0
  4. ai_data_science_team/agents/data_visualization_agent.py +6 -7
  5. ai_data_science_team/agents/data_wrangling_agent.py +6 -6
  6. ai_data_science_team/agents/feature_engineering_agent.py +6 -6
  7. ai_data_science_team/agents/sql_database_agent.py +6 -6
  8. ai_data_science_team/ml_agents/__init__.py +1 -0
  9. ai_data_science_team/ml_agents/h2o_ml_agent.py +205 -385
  10. ai_data_science_team/ml_agents/mlflow_tools_agent.py +327 -0
  11. ai_data_science_team/multiagents/sql_data_analyst.py +3 -4
  12. ai_data_science_team/parsers/__init__.py +0 -0
  13. ai_data_science_team/{tools → parsers}/parsers.py +0 -1
  14. ai_data_science_team/templates/agent_templates.py +6 -6
  15. ai_data_science_team/tools/data_loader.py +378 -0
  16. ai_data_science_team/tools/dataframe.py +139 -0
  17. ai_data_science_team/tools/h2o.py +643 -0
  18. ai_data_science_team/tools/mlflow.py +961 -0
  19. ai_data_science_team/tools/{metadata.py → sql.py} +1 -137
  20. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/METADATA +34 -16
  21. ai_data_science_team-0.0.0.9010.dist-info/RECORD +35 -0
  22. ai_data_science_team-0.0.0.9009.dist-info/RECORD +0 -28
  23. /ai_data_science_team/{tools → utils}/logging.py +0 -0
  24. /ai_data_science_team/{tools → utils}/regex.py +0 -0
  25. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/LICENSE +0 -0
  26. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/WHEEL +0 -0
  27. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,327 @@
1
+
2
+ from typing import Any, Optional, Annotated, Sequence
3
+ import operator
4
+
5
+ import pandas as pd
6
+
7
+ from IPython.display import Markdown
8
+
9
+ from langchain_core.messages import BaseMessage, AIMessage
10
+
11
+ from langgraph.prebuilt import create_react_agent, ToolNode
12
+ from langgraph.prebuilt.chat_agent_executor import AgentState
13
+ from langgraph.graph import START, END, StateGraph
14
+
15
+ from ai_data_science_team.templates import BaseAgent
16
+ from ai_data_science_team.utils.regex import format_agent_name
17
+ from ai_data_science_team.tools.mlflow import (
18
+ mlflow_search_experiments,
19
+ mlflow_search_runs,
20
+ mlflow_create_experiment,
21
+ mlflow_predict_from_run_id,
22
+ mlflow_launch_ui,
23
+ mlflow_stop_ui,
24
+ mlflow_list_artifacts,
25
+ mlflow_download_artifacts,
26
+ mlflow_list_registered_models,
27
+ mlflow_search_registered_models,
28
+ mlflow_get_model_version_details,
29
+ )
30
+
31
+ AGENT_NAME = "mlflow_tools_agent"
32
+
33
+ # TOOL SETUP
34
+ tools = [
35
+ mlflow_search_experiments,
36
+ mlflow_search_runs,
37
+ mlflow_create_experiment,
38
+ mlflow_predict_from_run_id,
39
+ mlflow_launch_ui,
40
+ mlflow_stop_ui,
41
+ mlflow_list_artifacts,
42
+ mlflow_download_artifacts,
43
+ mlflow_list_registered_models,
44
+ mlflow_search_registered_models,
45
+ mlflow_get_model_version_details,
46
+ ]
47
+
48
+ class MLflowToolsAgent(BaseAgent):
49
+ """
50
+ An agent that can interact with MLflow by calling tools.
51
+
52
+ Current tools include:
53
+ - List Experiments
54
+ - Search Runs
55
+ - Create Experiment
56
+ - Predict (from a Run ID)
57
+
58
+ Parameters:
59
+ ----------
60
+ model : langchain.llms.base.LLM
61
+ The language model used to generate the tool calling agent.
62
+ mlfow_tracking_uri : str, optional
63
+ The tracking URI for MLflow. Defaults to None.
64
+ mlflow_registry_uri : str, optional
65
+ The registry URI for MLflow. Defaults to None.
66
+ **react_agent_kwargs : dict, optional
67
+ Additional keyword arguments to pass to the agent's react agent.
68
+
69
+ Methods:
70
+ --------
71
+ update_params(**kwargs):
72
+ Updates the agent's parameters and rebuilds the compiled graph.
73
+ ainvoke_agent(user_instructions: str=None, data_raw: pd.DataFrame=None, **kwargs):
74
+ Asynchronously runs the agent with the given user instructions.
75
+ invoke_agent(user_instructions: str=None, data_raw: pd.DataFrame=None, **kwargs):
76
+ Runs the agent with the given user instructions.
77
+ get_internal_messages(markdown: bool=False):
78
+ Returns the internal messages from the agent's response.
79
+ get_mlflow_artifacts(as_dataframe: bool=False):
80
+ Returns the MLflow artifacts from the agent's response.
81
+ get_ai_message(markdown: bool=False):
82
+ Returns the AI message from the agent's response
83
+
84
+
85
+
86
+ Examples:
87
+ --------
88
+ ```python
89
+ from ai_data_science_team.ml_agents import MLflowToolsAgent
90
+
91
+ mlflow_agent = MLflowToolsAgent(llm)
92
+
93
+ mlflow_agent.invoke_agent(user_instructions="List the MLflow experiments")
94
+
95
+ mlflow_agent.get_response()
96
+
97
+ mlflow_agent.get_internal_messages(markdown=True)
98
+
99
+ mlflow_agent.get_ai_message(markdown=True)
100
+
101
+ mlflow_agent.get_mlflow_artifacts(as_dataframe=True)
102
+
103
+ ```
104
+
105
+ Returns
106
+ -------
107
+ MLflowToolsAgent : langchain.graphs.CompiledStateGraph
108
+ An instance of the MLflow Tools Agent.
109
+
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ model: Any,
115
+ mlflow_tracking_uri: Optional[str]=None,
116
+ mlflow_registry_uri: Optional[str]=None,
117
+ **react_agent_kwargs,
118
+ ):
119
+ self._params = {
120
+ "model": model,
121
+ "mlflow_tracking_uri": mlflow_tracking_uri,
122
+ "mlflow_registry_uri": mlflow_registry_uri,
123
+ **react_agent_kwargs,
124
+ }
125
+ self._compiled_graph = self._make_compiled_graph()
126
+ self.response = None
127
+
128
+ def _make_compiled_graph(self):
129
+ """
130
+ Creates the compiled graph for the agent.
131
+ """
132
+ self.response = None
133
+ return make_mlflow_tools_agent(**self._params)
134
+
135
+
136
+ def update_params(self, **kwargs):
137
+ """
138
+ Updates the agent's parameters and rebuilds the compiled graph.
139
+ """
140
+ for k, v in kwargs.items():
141
+ self._params[k] = v
142
+ self._compiled_graph = self._make_compiled_graph()
143
+
144
+ async def ainvoke_agent(
145
+ self,
146
+ user_instructions: str=None,
147
+ data_raw: pd.DataFrame=None,
148
+ **kwargs
149
+ ):
150
+ """
151
+ Runs the agent with the given user instructions.
152
+
153
+ Parameters:
154
+ ----------
155
+ user_instructions : str, optional
156
+ The user instructions to pass to the agent.
157
+ data_raw : pd.DataFrame, optional
158
+ The data to pass to the agent. Used for prediction and tool calls where data is required.
159
+ kwargs : dict, optional
160
+ Additional keyword arguments to pass to the agents ainvoke method.
161
+
162
+ """
163
+ response = await self._compiled_graph.ainvoke(
164
+ {
165
+ "user_instructions": user_instructions,
166
+ "data_raw": data_raw.to_dict() if data_raw is not None else None,
167
+ },
168
+ **kwargs
169
+ )
170
+ self.response = response
171
+ return None
172
+
173
+ def invoke_agent(
174
+ self,
175
+ user_instructions: str=None,
176
+ data_raw: pd.DataFrame=None,
177
+ **kwargs
178
+ ):
179
+ """
180
+ Runs the agent with the given user instructions.
181
+
182
+ Parameters:
183
+ ----------
184
+ user_instructions : str, optional
185
+ The user instructions to pass to the agent.
186
+ data_raw : pd.DataFrame, optional
187
+ The raw data to pass to the agent. Used for prediction and tool calls where data is required.
188
+ kwargs : dict, optional
189
+ Additional keyword arguments to pass to the agents invoke method.
190
+
191
+ """
192
+ response = self._compiled_graph.invoke(
193
+ {
194
+ "user_instructions": user_instructions,
195
+ "data_raw": data_raw.to_dict() if data_raw is not None else None,
196
+ },
197
+ **kwargs
198
+ )
199
+ self.response = response
200
+ return None
201
+
202
+ def get_internal_messages(self, markdown: bool=False):
203
+ """
204
+ Returns the internal messages from the agent's response.
205
+ """
206
+ pretty_print = "\n\n".join([f"### {msg.type.upper()}\n\nID: {msg.id}\n\nContent:\n\n{msg.content}" for msg in self.response["internal_messages"]])
207
+ if markdown:
208
+ return Markdown(pretty_print)
209
+ else:
210
+ return self.response["internal_messages"]
211
+
212
+ def get_mlflow_artifacts(self, as_dataframe: bool=False):
213
+ """
214
+ Returns the MLflow artifacts from the agent's response.
215
+ """
216
+ if as_dataframe:
217
+ return pd.DataFrame(self.response["mlflow_artifacts"])
218
+ else:
219
+ return self.response["mlflow_artifacts"]
220
+
221
+ def get_ai_message(self, markdown: bool=False):
222
+ """
223
+ Returns the AI message from the agent's response.
224
+ """
225
+ if markdown:
226
+ return Markdown(self.response["messages"][0].content)
227
+ else:
228
+ return self.response["messages"][0].content
229
+
230
+
231
+
232
+
233
+ def make_mlflow_tools_agent(
234
+ model: Any,
235
+ mlflow_tracking_uri: str=None,
236
+ mlflow_registry_uri: str=None,
237
+ **react_agent_kwargs,
238
+ ):
239
+ """
240
+ MLflow Tool Calling Agent
241
+ """
242
+
243
+ try:
244
+ import mlflow
245
+ except ImportError:
246
+ return "MLflow is not installed. Please install it by running: !pip install mlflow"
247
+
248
+ if mlflow_tracking_uri is not None:
249
+ mlflow.set_tracking_uri(mlflow_tracking_uri)
250
+
251
+ if mlflow_registry_uri is not None:
252
+ mlflow.set_registry_uri(mlflow_registry_uri)
253
+
254
+ class GraphState(AgentState):
255
+ internal_messages: Annotated[Sequence[BaseMessage], operator.add]
256
+ user_instructions: str
257
+ data_raw: dict
258
+ mlflow_artifacts: dict
259
+
260
+
261
+ def mflfow_tools_agent(state):
262
+ """
263
+ Postprocesses the MLflow state, keeping only the last message
264
+ and extracting the last tool artifact.
265
+ """
266
+ print(format_agent_name(AGENT_NAME))
267
+ print(" * RUN REACT TOOL-CALLING AGENT")
268
+
269
+ tool_node = ToolNode(
270
+ tools=tools
271
+ )
272
+
273
+ mlflow_agent = create_react_agent(
274
+ model,
275
+ tools=tool_node,
276
+ state_schema=GraphState,
277
+ **react_agent_kwargs,
278
+ )
279
+
280
+ response = mlflow_agent.invoke(
281
+ {
282
+ "messages": [("user", state["user_instructions"])],
283
+ "data_raw": state["data_raw"],
284
+ },
285
+ )
286
+
287
+ print(" * POST-PROCESS RESULTS")
288
+
289
+ internal_messages = response['messages']
290
+
291
+ # Ensure there is at least one AI message
292
+ if not internal_messages:
293
+ return {
294
+ "internal_messages": [],
295
+ "mlflow_artifacts": None,
296
+ }
297
+
298
+ # Get the last AI message
299
+ last_ai_message = AIMessage(internal_messages[-1].content, role = AGENT_NAME)
300
+
301
+ # Get the last tool artifact safely
302
+ last_tool_artifact = None
303
+ if len(internal_messages) > 1:
304
+ last_message = internal_messages[-2] # Get second-to-last message
305
+ if hasattr(last_message, "artifact"): # Check if it has an "artifact"
306
+ last_tool_artifact = last_message.artifact
307
+ elif isinstance(last_message, dict) and "artifact" in last_message:
308
+ last_tool_artifact = last_message["artifact"]
309
+
310
+ return {
311
+ "messages": [last_ai_message],
312
+ "internal_messages": internal_messages,
313
+ "mlflow_artifacts": last_tool_artifact,
314
+ }
315
+
316
+
317
+ workflow = StateGraph(GraphState)
318
+
319
+ workflow.add_node("mlflow_tools_agent", mflfow_tools_agent)
320
+
321
+ workflow.add_edge(START, "mlflow_tools_agent")
322
+ workflow.add_edge("mlflow_tools_agent", END)
323
+
324
+ app = workflow.compile()
325
+
326
+ return app
327
+
@@ -1,6 +1,5 @@
1
1
 
2
2
  from langchain_core.messages import BaseMessage
3
- from langgraph.checkpoint.memory import MemorySaver
4
3
  from langgraph.types import Checkpointer
5
4
 
6
5
  from langgraph.graph import START, END, StateGraph
@@ -19,7 +18,7 @@ from IPython.display import Markdown
19
18
  from ai_data_science_team.templates import BaseAgent
20
19
  from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
21
20
  from ai_data_science_team.utils.plotly import plotly_from_dict
22
- from ai_data_science_team.tools.regex import remove_consecutive_duplicates, get_generic_summary
21
+ from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
23
22
 
24
23
 
25
24
  class SQLDataAnalyst(BaseAgent):
@@ -91,7 +90,7 @@ class SQLDataAnalyst(BaseAgent):
91
90
  self._params[k] = v
92
91
  self._compiled_graph = self._make_compiled_graph()
93
92
 
94
- def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
93
+ async def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
95
94
  """
96
95
  Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
97
96
 
@@ -144,7 +143,7 @@ class SQLDataAnalyst(BaseAgent):
144
143
  sql_data_analyst.get_plotly_graph()
145
144
  ```
146
145
  """
147
- response = self._compiled_graph.ainvoke({
146
+ response = await self._compiled_graph.ainvoke({
148
147
  "user_instructions": user_instructions,
149
148
  "max_retries": max_retries,
150
149
  "retry_count": retry_count,
File without changes
@@ -3,7 +3,6 @@
3
3
  # ***
4
4
  # Parsers
5
5
 
6
- from langchain_core.output_parsers import JsonOutputParser
7
6
  from langchain_core.output_parsers import BaseOutputParser
8
7
 
9
8
  import re
@@ -12,8 +12,8 @@ import json
12
12
 
13
13
  from typing import Any, Callable, Dict, Type, Optional, Union, List
14
14
 
15
- from ai_data_science_team.tools.parsers import PythonOutputParser
16
- from ai_data_science_team.tools.regex import (
15
+ from ai_data_science_team.parsers.parsers import PythonOutputParser
16
+ from ai_data_science_team.utils.regex import (
17
17
  relocate_imports_inside_function,
18
18
  add_comments_to_top,
19
19
  remove_consecutive_duplicates
@@ -93,7 +93,7 @@ class BaseAgent(CompiledStateGraph):
93
93
 
94
94
  return self.response
95
95
 
96
- def ainvoke(
96
+ async def ainvoke(
97
97
  self,
98
98
  input: Union[dict[str, Any], Any],
99
99
  config: Optional[RunnableConfig] = None,
@@ -110,7 +110,7 @@ class BaseAgent(CompiledStateGraph):
110
110
  Returns:
111
111
  Any: The agent's response.
112
112
  """
113
- self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
113
+ self.response = await self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
114
114
 
115
115
  if self.response.get("messages"):
116
116
  self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
@@ -148,7 +148,7 @@ class BaseAgent(CompiledStateGraph):
148
148
 
149
149
  return self.response
150
150
 
151
- def astream(
151
+ async def astream(
152
152
  self,
153
153
  input: dict[str, Any] | Any,
154
154
  config: RunnableConfig | None = None,
@@ -172,7 +172,7 @@ class BaseAgent(CompiledStateGraph):
172
172
  Returns:
173
173
  Any: The agent's response.
174
174
  """
175
- self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
175
+ self.response = await self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
176
176
 
177
177
  if self.response.get("messages"):
178
178
  self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])