ai-data-science-team 0.0.0.9009__py3-none-any.whl → 0.0.0.9010__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/data_cleaning_agent.py +6 -6
  3. ai_data_science_team/agents/data_loader_tools_agent.py +69 -0
  4. ai_data_science_team/agents/data_visualization_agent.py +6 -7
  5. ai_data_science_team/agents/data_wrangling_agent.py +6 -6
  6. ai_data_science_team/agents/feature_engineering_agent.py +6 -6
  7. ai_data_science_team/agents/sql_database_agent.py +6 -6
  8. ai_data_science_team/ml_agents/__init__.py +1 -0
  9. ai_data_science_team/ml_agents/h2o_ml_agent.py +205 -385
  10. ai_data_science_team/ml_agents/mlflow_tools_agent.py +327 -0
  11. ai_data_science_team/multiagents/sql_data_analyst.py +3 -4
  12. ai_data_science_team/parsers/__init__.py +0 -0
  13. ai_data_science_team/{tools → parsers}/parsers.py +0 -1
  14. ai_data_science_team/templates/agent_templates.py +6 -6
  15. ai_data_science_team/tools/data_loader.py +378 -0
  16. ai_data_science_team/tools/dataframe.py +139 -0
  17. ai_data_science_team/tools/h2o.py +643 -0
  18. ai_data_science_team/tools/mlflow.py +961 -0
  19. ai_data_science_team/tools/{metadata.py → sql.py} +1 -137
  20. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/METADATA +34 -16
  21. ai_data_science_team-0.0.0.9010.dist-info/RECORD +35 -0
  22. ai_data_science_team-0.0.0.9009.dist-info/RECORD +0 -28
  23. /ai_data_science_team/{tools → utils}/logging.py +0 -0
  24. /ai_data_science_team/{tools → utils}/regex.py +0 -0
  25. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/LICENSE +0 -0
  26. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/WHEEL +0 -0
  27. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,327 @@
1
+
2
+ from typing import Any, Optional, Annotated, Sequence
3
+ import operator
4
+
5
+ import pandas as pd
6
+
7
+ from IPython.display import Markdown
8
+
9
+ from langchain_core.messages import BaseMessage, AIMessage
10
+
11
+ from langgraph.prebuilt import create_react_agent, ToolNode
12
+ from langgraph.prebuilt.chat_agent_executor import AgentState
13
+ from langgraph.graph import START, END, StateGraph
14
+
15
+ from ai_data_science_team.templates import BaseAgent
16
+ from ai_data_science_team.utils.regex import format_agent_name
17
+ from ai_data_science_team.tools.mlflow import (
18
+ mlflow_search_experiments,
19
+ mlflow_search_runs,
20
+ mlflow_create_experiment,
21
+ mlflow_predict_from_run_id,
22
+ mlflow_launch_ui,
23
+ mlflow_stop_ui,
24
+ mlflow_list_artifacts,
25
+ mlflow_download_artifacts,
26
+ mlflow_list_registered_models,
27
+ mlflow_search_registered_models,
28
+ mlflow_get_model_version_details,
29
+ )
30
+
31
+ AGENT_NAME = "mlflow_tools_agent"
32
+
33
+ # TOOL SETUP
34
+ tools = [
35
+ mlflow_search_experiments,
36
+ mlflow_search_runs,
37
+ mlflow_create_experiment,
38
+ mlflow_predict_from_run_id,
39
+ mlflow_launch_ui,
40
+ mlflow_stop_ui,
41
+ mlflow_list_artifacts,
42
+ mlflow_download_artifacts,
43
+ mlflow_list_registered_models,
44
+ mlflow_search_registered_models,
45
+ mlflow_get_model_version_details,
46
+ ]
47
+
48
+ class MLflowToolsAgent(BaseAgent):
49
+ """
50
+ An agent that can interact with MLflow by calling tools.
51
+
52
+ Current tools include:
53
+ - List Experiments
54
+ - Search Runs
55
+ - Create Experiment
56
+ - Predict (from a Run ID)
57
+
58
+ Parameters:
59
+ ----------
60
+ model : langchain.llms.base.LLM
61
+ The language model used to generate the tool calling agent.
62
+ mlfow_tracking_uri : str, optional
63
+ The tracking URI for MLflow. Defaults to None.
64
+ mlflow_registry_uri : str, optional
65
+ The registry URI for MLflow. Defaults to None.
66
+ **react_agent_kwargs : dict, optional
67
+ Additional keyword arguments to pass to the agent's react agent.
68
+
69
+ Methods:
70
+ --------
71
+ update_params(**kwargs):
72
+ Updates the agent's parameters and rebuilds the compiled graph.
73
+ ainvoke_agent(user_instructions: str=None, data_raw: pd.DataFrame=None, **kwargs):
74
+ Asynchronously runs the agent with the given user instructions.
75
+ invoke_agent(user_instructions: str=None, data_raw: pd.DataFrame=None, **kwargs):
76
+ Runs the agent with the given user instructions.
77
+ get_internal_messages(markdown: bool=False):
78
+ Returns the internal messages from the agent's response.
79
+ get_mlflow_artifacts(as_dataframe: bool=False):
80
+ Returns the MLflow artifacts from the agent's response.
81
+ get_ai_message(markdown: bool=False):
82
+ Returns the AI message from the agent's response
83
+
84
+
85
+
86
+ Examples:
87
+ --------
88
+ ```python
89
+ from ai_data_science_team.ml_agents import MLflowToolsAgent
90
+
91
+ mlflow_agent = MLflowToolsAgent(llm)
92
+
93
+ mlflow_agent.invoke_agent(user_instructions="List the MLflow experiments")
94
+
95
+ mlflow_agent.get_response()
96
+
97
+ mlflow_agent.get_internal_messages(markdown=True)
98
+
99
+ mlflow_agent.get_ai_message(markdown=True)
100
+
101
+ mlflow_agent.get_mlflow_artifacts(as_dataframe=True)
102
+
103
+ ```
104
+
105
+ Returns
106
+ -------
107
+ MLflowToolsAgent : langchain.graphs.CompiledStateGraph
108
+ An instance of the MLflow Tools Agent.
109
+
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ model: Any,
115
+ mlflow_tracking_uri: Optional[str]=None,
116
+ mlflow_registry_uri: Optional[str]=None,
117
+ **react_agent_kwargs,
118
+ ):
119
+ self._params = {
120
+ "model": model,
121
+ "mlflow_tracking_uri": mlflow_tracking_uri,
122
+ "mlflow_registry_uri": mlflow_registry_uri,
123
+ **react_agent_kwargs,
124
+ }
125
+ self._compiled_graph = self._make_compiled_graph()
126
+ self.response = None
127
+
128
+ def _make_compiled_graph(self):
129
+ """
130
+ Creates the compiled graph for the agent.
131
+ """
132
+ self.response = None
133
+ return make_mlflow_tools_agent(**self._params)
134
+
135
+
136
+ def update_params(self, **kwargs):
137
+ """
138
+ Updates the agent's parameters and rebuilds the compiled graph.
139
+ """
140
+ for k, v in kwargs.items():
141
+ self._params[k] = v
142
+ self._compiled_graph = self._make_compiled_graph()
143
+
144
+ async def ainvoke_agent(
145
+ self,
146
+ user_instructions: str=None,
147
+ data_raw: pd.DataFrame=None,
148
+ **kwargs
149
+ ):
150
+ """
151
+ Runs the agent with the given user instructions.
152
+
153
+ Parameters:
154
+ ----------
155
+ user_instructions : str, optional
156
+ The user instructions to pass to the agent.
157
+ data_raw : pd.DataFrame, optional
158
+ The data to pass to the agent. Used for prediction and tool calls where data is required.
159
+ kwargs : dict, optional
160
+ Additional keyword arguments to pass to the agents ainvoke method.
161
+
162
+ """
163
+ response = await self._compiled_graph.ainvoke(
164
+ {
165
+ "user_instructions": user_instructions,
166
+ "data_raw": data_raw.to_dict() if data_raw is not None else None,
167
+ },
168
+ **kwargs
169
+ )
170
+ self.response = response
171
+ return None
172
+
173
+ def invoke_agent(
174
+ self,
175
+ user_instructions: str=None,
176
+ data_raw: pd.DataFrame=None,
177
+ **kwargs
178
+ ):
179
+ """
180
+ Runs the agent with the given user instructions.
181
+
182
+ Parameters:
183
+ ----------
184
+ user_instructions : str, optional
185
+ The user instructions to pass to the agent.
186
+ data_raw : pd.DataFrame, optional
187
+ The raw data to pass to the agent. Used for prediction and tool calls where data is required.
188
+ kwargs : dict, optional
189
+ Additional keyword arguments to pass to the agents invoke method.
190
+
191
+ """
192
+ response = self._compiled_graph.invoke(
193
+ {
194
+ "user_instructions": user_instructions,
195
+ "data_raw": data_raw.to_dict() if data_raw is not None else None,
196
+ },
197
+ **kwargs
198
+ )
199
+ self.response = response
200
+ return None
201
+
202
+ def get_internal_messages(self, markdown: bool=False):
203
+ """
204
+ Returns the internal messages from the agent's response.
205
+ """
206
+ pretty_print = "\n\n".join([f"### {msg.type.upper()}\n\nID: {msg.id}\n\nContent:\n\n{msg.content}" for msg in self.response["internal_messages"]])
207
+ if markdown:
208
+ return Markdown(pretty_print)
209
+ else:
210
+ return self.response["internal_messages"]
211
+
212
+ def get_mlflow_artifacts(self, as_dataframe: bool=False):
213
+ """
214
+ Returns the MLflow artifacts from the agent's response.
215
+ """
216
+ if as_dataframe:
217
+ return pd.DataFrame(self.response["mlflow_artifacts"])
218
+ else:
219
+ return self.response["mlflow_artifacts"]
220
+
221
+ def get_ai_message(self, markdown: bool=False):
222
+ """
223
+ Returns the AI message from the agent's response.
224
+ """
225
+ if markdown:
226
+ return Markdown(self.response["messages"][0].content)
227
+ else:
228
+ return self.response["messages"][0].content
229
+
230
+
231
+
232
+
233
+ def make_mlflow_tools_agent(
234
+ model: Any,
235
+ mlflow_tracking_uri: str=None,
236
+ mlflow_registry_uri: str=None,
237
+ **react_agent_kwargs,
238
+ ):
239
+ """
240
+ MLflow Tool Calling Agent
241
+ """
242
+
243
+ try:
244
+ import mlflow
245
+ except ImportError:
246
+ return "MLflow is not installed. Please install it by running: !pip install mlflow"
247
+
248
+ if mlflow_tracking_uri is not None:
249
+ mlflow.set_tracking_uri(mlflow_tracking_uri)
250
+
251
+ if mlflow_registry_uri is not None:
252
+ mlflow.set_registry_uri(mlflow_registry_uri)
253
+
254
+ class GraphState(AgentState):
255
+ internal_messages: Annotated[Sequence[BaseMessage], operator.add]
256
+ user_instructions: str
257
+ data_raw: dict
258
+ mlflow_artifacts: dict
259
+
260
+
261
+ def mflfow_tools_agent(state):
262
+ """
263
+ Postprocesses the MLflow state, keeping only the last message
264
+ and extracting the last tool artifact.
265
+ """
266
+ print(format_agent_name(AGENT_NAME))
267
+ print(" * RUN REACT TOOL-CALLING AGENT")
268
+
269
+ tool_node = ToolNode(
270
+ tools=tools
271
+ )
272
+
273
+ mlflow_agent = create_react_agent(
274
+ model,
275
+ tools=tool_node,
276
+ state_schema=GraphState,
277
+ **react_agent_kwargs,
278
+ )
279
+
280
+ response = mlflow_agent.invoke(
281
+ {
282
+ "messages": [("user", state["user_instructions"])],
283
+ "data_raw": state["data_raw"],
284
+ },
285
+ )
286
+
287
+ print(" * POST-PROCESS RESULTS")
288
+
289
+ internal_messages = response['messages']
290
+
291
+ # Ensure there is at least one AI message
292
+ if not internal_messages:
293
+ return {
294
+ "internal_messages": [],
295
+ "mlflow_artifacts": None,
296
+ }
297
+
298
+ # Get the last AI message
299
+ last_ai_message = AIMessage(internal_messages[-1].content, role = AGENT_NAME)
300
+
301
+ # Get the last tool artifact safely
302
+ last_tool_artifact = None
303
+ if len(internal_messages) > 1:
304
+ last_message = internal_messages[-2] # Get second-to-last message
305
+ if hasattr(last_message, "artifact"): # Check if it has an "artifact"
306
+ last_tool_artifact = last_message.artifact
307
+ elif isinstance(last_message, dict) and "artifact" in last_message:
308
+ last_tool_artifact = last_message["artifact"]
309
+
310
+ return {
311
+ "messages": [last_ai_message],
312
+ "internal_messages": internal_messages,
313
+ "mlflow_artifacts": last_tool_artifact,
314
+ }
315
+
316
+
317
+ workflow = StateGraph(GraphState)
318
+
319
+ workflow.add_node("mlflow_tools_agent", mflfow_tools_agent)
320
+
321
+ workflow.add_edge(START, "mlflow_tools_agent")
322
+ workflow.add_edge("mlflow_tools_agent", END)
323
+
324
+ app = workflow.compile()
325
+
326
+ return app
327
+
@@ -1,6 +1,5 @@
1
1
 
2
2
  from langchain_core.messages import BaseMessage
3
- from langgraph.checkpoint.memory import MemorySaver
4
3
  from langgraph.types import Checkpointer
5
4
 
6
5
  from langgraph.graph import START, END, StateGraph
@@ -19,7 +18,7 @@ from IPython.display import Markdown
19
18
  from ai_data_science_team.templates import BaseAgent
20
19
  from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
21
20
  from ai_data_science_team.utils.plotly import plotly_from_dict
22
- from ai_data_science_team.tools.regex import remove_consecutive_duplicates, get_generic_summary
21
+ from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
23
22
 
24
23
 
25
24
  class SQLDataAnalyst(BaseAgent):
@@ -91,7 +90,7 @@ class SQLDataAnalyst(BaseAgent):
91
90
  self._params[k] = v
92
91
  self._compiled_graph = self._make_compiled_graph()
93
92
 
94
- def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
93
+ async def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
95
94
  """
96
95
  Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
97
96
 
@@ -144,7 +143,7 @@ class SQLDataAnalyst(BaseAgent):
144
143
  sql_data_analyst.get_plotly_graph()
145
144
  ```
146
145
  """
147
- response = self._compiled_graph.ainvoke({
146
+ response = await self._compiled_graph.ainvoke({
148
147
  "user_instructions": user_instructions,
149
148
  "max_retries": max_retries,
150
149
  "retry_count": retry_count,
File without changes
@@ -3,7 +3,6 @@
3
3
  # ***
4
4
  # Parsers
5
5
 
6
- from langchain_core.output_parsers import JsonOutputParser
7
6
  from langchain_core.output_parsers import BaseOutputParser
8
7
 
9
8
  import re
@@ -12,8 +12,8 @@ import json
12
12
 
13
13
  from typing import Any, Callable, Dict, Type, Optional, Union, List
14
14
 
15
- from ai_data_science_team.tools.parsers import PythonOutputParser
16
- from ai_data_science_team.tools.regex import (
15
+ from ai_data_science_team.parsers.parsers import PythonOutputParser
16
+ from ai_data_science_team.utils.regex import (
17
17
  relocate_imports_inside_function,
18
18
  add_comments_to_top,
19
19
  remove_consecutive_duplicates
@@ -93,7 +93,7 @@ class BaseAgent(CompiledStateGraph):
93
93
 
94
94
  return self.response
95
95
 
96
- def ainvoke(
96
+ async def ainvoke(
97
97
  self,
98
98
  input: Union[dict[str, Any], Any],
99
99
  config: Optional[RunnableConfig] = None,
@@ -110,7 +110,7 @@ class BaseAgent(CompiledStateGraph):
110
110
  Returns:
111
111
  Any: The agent's response.
112
112
  """
113
- self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
113
+ self.response = await self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
114
114
 
115
115
  if self.response.get("messages"):
116
116
  self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
@@ -148,7 +148,7 @@ class BaseAgent(CompiledStateGraph):
148
148
 
149
149
  return self.response
150
150
 
151
- def astream(
151
+ async def astream(
152
152
  self,
153
153
  input: dict[str, Any] | Any,
154
154
  config: RunnableConfig | None = None,
@@ -172,7 +172,7 @@ class BaseAgent(CompiledStateGraph):
172
172
  Returns:
173
173
  Any: The agent's response.
174
174
  """
175
- self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
175
+ self.response = await self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
176
176
 
177
177
  if self.response.get("messages"):
178
178
  self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])