ai-data-science-team 0.0.0.9009__py3-none-any.whl → 0.0.0.9011__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (29) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +1 -0
  3. ai_data_science_team/agents/data_cleaning_agent.py +6 -6
  4. ai_data_science_team/agents/data_loader_tools_agent.py +272 -0
  5. ai_data_science_team/agents/data_visualization_agent.py +6 -7
  6. ai_data_science_team/agents/data_wrangling_agent.py +6 -6
  7. ai_data_science_team/agents/feature_engineering_agent.py +6 -6
  8. ai_data_science_team/agents/sql_database_agent.py +6 -6
  9. ai_data_science_team/ml_agents/__init__.py +1 -0
  10. ai_data_science_team/ml_agents/h2o_ml_agent.py +206 -385
  11. ai_data_science_team/ml_agents/h2o_ml_tools_agent.py +0 -0
  12. ai_data_science_team/ml_agents/mlflow_tools_agent.py +350 -0
  13. ai_data_science_team/multiagents/sql_data_analyst.py +3 -4
  14. ai_data_science_team/parsers/__init__.py +0 -0
  15. ai_data_science_team/{tools → parsers}/parsers.py +0 -1
  16. ai_data_science_team/templates/agent_templates.py +6 -6
  17. ai_data_science_team/tools/data_loader.py +448 -0
  18. ai_data_science_team/tools/dataframe.py +139 -0
  19. ai_data_science_team/tools/h2o.py +643 -0
  20. ai_data_science_team/tools/mlflow.py +961 -0
  21. ai_data_science_team/tools/{metadata.py → sql.py} +1 -137
  22. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9011.dist-info}/METADATA +40 -19
  23. ai_data_science_team-0.0.0.9011.dist-info/RECORD +36 -0
  24. ai_data_science_team-0.0.0.9009.dist-info/RECORD +0 -28
  25. /ai_data_science_team/{tools → utils}/logging.py +0 -0
  26. /ai_data_science_team/{tools → utils}/regex.py +0 -0
  27. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9011.dist-info}/LICENSE +0 -0
  28. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9011.dist-info}/WHEEL +0 -0
  29. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9011.dist-info}/top_level.txt +0 -0
File without changes
@@ -0,0 +1,350 @@
1
+
2
+ from typing import Any, Optional, Annotated, Sequence, Dict
3
+ import operator
4
+
5
+ import pandas as pd
6
+
7
+ from IPython.display import Markdown
8
+
9
+ from langchain_core.messages import BaseMessage, AIMessage
10
+
11
+ from langgraph.prebuilt import create_react_agent, ToolNode
12
+ from langgraph.prebuilt.chat_agent_executor import AgentState
13
+ from langgraph.graph import START, END, StateGraph
14
+
15
+ from ai_data_science_team.templates import BaseAgent
16
+ from ai_data_science_team.utils.regex import format_agent_name
17
+ from ai_data_science_team.tools.mlflow import (
18
+ mlflow_search_experiments,
19
+ mlflow_search_runs,
20
+ mlflow_create_experiment,
21
+ mlflow_predict_from_run_id,
22
+ mlflow_launch_ui,
23
+ mlflow_stop_ui,
24
+ mlflow_list_artifacts,
25
+ mlflow_download_artifacts,
26
+ mlflow_list_registered_models,
27
+ mlflow_search_registered_models,
28
+ mlflow_get_model_version_details,
29
+ )
30
+
31
+ AGENT_NAME = "mlflow_tools_agent"
32
+
33
+ # TOOL SETUP
34
+ tools = [
35
+ mlflow_search_experiments,
36
+ mlflow_search_runs,
37
+ mlflow_create_experiment,
38
+ mlflow_predict_from_run_id,
39
+ mlflow_launch_ui,
40
+ mlflow_stop_ui,
41
+ mlflow_list_artifacts,
42
+ mlflow_download_artifacts,
43
+ mlflow_list_registered_models,
44
+ mlflow_search_registered_models,
45
+ mlflow_get_model_version_details,
46
+ ]
47
+
48
+ class MLflowToolsAgent(BaseAgent):
49
+ """
50
+ An agent that can interact with MLflow by calling tools.
51
+
52
+ Current tools include:
53
+ - List Experiments
54
+ - Search Runs
55
+ - Create Experiment
56
+ - Predict (from a Run ID)
57
+
58
+ Parameters:
59
+ ----------
60
+ model : langchain.llms.base.LLM
61
+ The language model used to generate the tool calling agent.
62
+ mlfow_tracking_uri : str, optional
63
+ The tracking URI for MLflow. Defaults to None.
64
+ mlflow_registry_uri : str, optional
65
+ The registry URI for MLflow. Defaults to None.
66
+ react_agent_kwargs : dict
67
+ Additional keyword arguments to pass to the create_react_agent function.
68
+ invoke_react_agent_kwargs : dict
69
+ Additional keyword arguments to pass to the invoke method of the react agent.
70
+
71
+ Methods:
72
+ --------
73
+ update_params(**kwargs):
74
+ Updates the agent's parameters and rebuilds the compiled graph.
75
+ ainvoke_agent(user_instructions: str=None, data_raw: pd.DataFrame=None, **kwargs):
76
+ Asynchronously runs the agent with the given user instructions.
77
+ invoke_agent(user_instructions: str=None, data_raw: pd.DataFrame=None, **kwargs):
78
+ Runs the agent with the given user instructions.
79
+ get_internal_messages(markdown: bool=False):
80
+ Returns the internal messages from the agent's response.
81
+ get_mlflow_artifacts(as_dataframe: bool=False):
82
+ Returns the MLflow artifacts from the agent's response.
83
+ get_ai_message(markdown: bool=False):
84
+ Returns the AI message from the agent's response
85
+
86
+
87
+
88
+ Examples:
89
+ --------
90
+ ```python
91
+ from ai_data_science_team.ml_agents import MLflowToolsAgent
92
+
93
+ mlflow_agent = MLflowToolsAgent(llm)
94
+
95
+ mlflow_agent.invoke_agent(user_instructions="List the MLflow experiments")
96
+
97
+ mlflow_agent.get_response()
98
+
99
+ mlflow_agent.get_internal_messages(markdown=True)
100
+
101
+ mlflow_agent.get_ai_message(markdown=True)
102
+
103
+ mlflow_agent.get_mlflow_artifacts(as_dataframe=True)
104
+
105
+ ```
106
+
107
+ Returns
108
+ -------
109
+ MLflowToolsAgent : langchain.graphs.CompiledStateGraph
110
+ An instance of the MLflow Tools Agent.
111
+
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ model: Any,
117
+ mlflow_tracking_uri: Optional[str]=None,
118
+ mlflow_registry_uri: Optional[str]=None,
119
+ create_react_agent_kwargs: Optional[Dict]={},
120
+ invoke_react_agent_kwargs: Optional[Dict]={},
121
+ ):
122
+ self._params = {
123
+ "model": model,
124
+ "mlflow_tracking_uri": mlflow_tracking_uri,
125
+ "mlflow_registry_uri": mlflow_registry_uri,
126
+ "create_react_agent_kwargs": create_react_agent_kwargs,
127
+ "invoke_react_agent_kwargs": invoke_react_agent_kwargs,
128
+ }
129
+ self._compiled_graph = self._make_compiled_graph()
130
+ self.response = None
131
+
132
+ def _make_compiled_graph(self):
133
+ """
134
+ Creates the compiled graph for the agent.
135
+ """
136
+ self.response = None
137
+ return make_mlflow_tools_agent(**self._params)
138
+
139
+
140
+ def update_params(self, **kwargs):
141
+ """
142
+ Updates the agent's parameters and rebuilds the compiled graph.
143
+ """
144
+ for k, v in kwargs.items():
145
+ self._params[k] = v
146
+ self._compiled_graph = self._make_compiled_graph()
147
+
148
+ async def ainvoke_agent(
149
+ self,
150
+ user_instructions: str=None,
151
+ data_raw: pd.DataFrame=None,
152
+ **kwargs
153
+ ):
154
+ """
155
+ Runs the agent with the given user instructions.
156
+
157
+ Parameters:
158
+ ----------
159
+ user_instructions : str, optional
160
+ The user instructions to pass to the agent.
161
+ data_raw : pd.DataFrame, optional
162
+ The data to pass to the agent. Used for prediction and tool calls where data is required.
163
+ kwargs : dict, optional
164
+ Additional keyword arguments to pass to the agents ainvoke method.
165
+
166
+ """
167
+ response = await self._compiled_graph.ainvoke(
168
+ {
169
+ "user_instructions": user_instructions,
170
+ "data_raw": data_raw.to_dict() if data_raw is not None else None,
171
+ },
172
+ **kwargs
173
+ )
174
+ self.response = response
175
+ return None
176
+
177
+ def invoke_agent(
178
+ self,
179
+ user_instructions: str=None,
180
+ data_raw: pd.DataFrame=None,
181
+ **kwargs
182
+ ):
183
+ """
184
+ Runs the agent with the given user instructions.
185
+
186
+ Parameters:
187
+ ----------
188
+ user_instructions : str, optional
189
+ The user instructions to pass to the agent.
190
+ data_raw : pd.DataFrame, optional
191
+ The raw data to pass to the agent. Used for prediction and tool calls where data is required.
192
+
193
+ """
194
+ response = self._compiled_graph.invoke(
195
+ {
196
+ "user_instructions": user_instructions,
197
+ "data_raw": data_raw.to_dict() if data_raw is not None else None,
198
+ },
199
+ **kwargs
200
+ )
201
+ self.response = response
202
+ return None
203
+
204
+ def get_internal_messages(self, markdown: bool=False):
205
+ """
206
+ Returns the internal messages from the agent's response.
207
+ """
208
+ pretty_print = "\n\n".join([f"### {msg.type.upper()}\n\nID: {msg.id}\n\nContent:\n\n{msg.content}" for msg in self.response["internal_messages"]])
209
+ if markdown:
210
+ return Markdown(pretty_print)
211
+ else:
212
+ return self.response["internal_messages"]
213
+
214
+ def get_mlflow_artifacts(self, as_dataframe: bool=False):
215
+ """
216
+ Returns the MLflow artifacts from the agent's response.
217
+ """
218
+ if as_dataframe:
219
+ return pd.DataFrame(self.response["mlflow_artifacts"])
220
+ else:
221
+ return self.response["mlflow_artifacts"]
222
+
223
+ def get_ai_message(self, markdown: bool=False):
224
+ """
225
+ Returns the AI message from the agent's response.
226
+ """
227
+ if markdown:
228
+ return Markdown(self.response["messages"][0].content)
229
+ else:
230
+ return self.response["messages"][0].content
231
+
232
+
233
+
234
+
235
+ def make_mlflow_tools_agent(
236
+ model: Any,
237
+ mlflow_tracking_uri: str=None,
238
+ mlflow_registry_uri: str=None,
239
+ create_react_agent_kwargs: Optional[Dict]={},
240
+ invoke_react_agent_kwargs: Optional[Dict]={},
241
+ ):
242
+ """
243
+ MLflow Tool Calling Agent
244
+
245
+ Parameters:
246
+ ----------
247
+ model : Any
248
+ The language model used to generate the agent.
249
+ mlflow_tracking_uri : str, optional
250
+ The tracking URI for MLflow. Defaults to None.
251
+ mlflow_registry_uri : str, optional
252
+ The registry URI for MLflow. Defaults to None.
253
+ create_react_agent_kwargs : dict, optional
254
+ Additional keyword arguments to pass to the agent's create_react_agent method.
255
+ invoke_react_agent_kwargs : dict, optional
256
+ Additional keyword arguments to pass to the agent's invoke method.
257
+
258
+ Returns
259
+ -------
260
+ app : langchain.graphs.CompiledStateGraph
261
+ A compiled state graph for the MLflow Tool Calling Agent.
262
+
263
+ """
264
+
265
+ try:
266
+ import mlflow
267
+ except ImportError:
268
+ return "MLflow is not installed. Please install it by running: !pip install mlflow"
269
+
270
+ if mlflow_tracking_uri is not None:
271
+ mlflow.set_tracking_uri(mlflow_tracking_uri)
272
+
273
+ if mlflow_registry_uri is not None:
274
+ mlflow.set_registry_uri(mlflow_registry_uri)
275
+
276
+ class GraphState(AgentState):
277
+ internal_messages: Annotated[Sequence[BaseMessage], operator.add]
278
+ user_instructions: str
279
+ data_raw: dict
280
+ mlflow_artifacts: dict
281
+
282
+
283
+ def mflfow_tools_agent(state):
284
+ """
285
+ Postprocesses the MLflow state, keeping only the last message
286
+ and extracting the last tool artifact.
287
+ """
288
+ print(format_agent_name(AGENT_NAME))
289
+ print(" * RUN REACT TOOL-CALLING AGENT")
290
+
291
+ tool_node = ToolNode(
292
+ tools=tools
293
+ )
294
+
295
+ mlflow_agent = create_react_agent(
296
+ model,
297
+ tools=tool_node,
298
+ state_schema=GraphState,
299
+ **create_react_agent_kwargs,
300
+ )
301
+
302
+ response = mlflow_agent.invoke(
303
+ {
304
+ "messages": [("user", state["user_instructions"])],
305
+ "data_raw": state["data_raw"],
306
+ },
307
+ invoke_react_agent_kwargs,
308
+ )
309
+
310
+ print(" * POST-PROCESS RESULTS")
311
+
312
+ internal_messages = response['messages']
313
+
314
+ # Ensure there is at least one AI message
315
+ if not internal_messages:
316
+ return {
317
+ "internal_messages": [],
318
+ "mlflow_artifacts": None,
319
+ }
320
+
321
+ # Get the last AI message
322
+ last_ai_message = AIMessage(internal_messages[-1].content, role = AGENT_NAME)
323
+
324
+ # Get the last tool artifact safely
325
+ last_tool_artifact = None
326
+ if len(internal_messages) > 1:
327
+ last_message = internal_messages[-2] # Get second-to-last message
328
+ if hasattr(last_message, "artifact"): # Check if it has an "artifact"
329
+ last_tool_artifact = last_message.artifact
330
+ elif isinstance(last_message, dict) and "artifact" in last_message:
331
+ last_tool_artifact = last_message["artifact"]
332
+
333
+ return {
334
+ "messages": [last_ai_message],
335
+ "internal_messages": internal_messages,
336
+ "mlflow_artifacts": last_tool_artifact,
337
+ }
338
+
339
+
340
+ workflow = StateGraph(GraphState)
341
+
342
+ workflow.add_node("mlflow_tools_agent", mflfow_tools_agent)
343
+
344
+ workflow.add_edge(START, "mlflow_tools_agent")
345
+ workflow.add_edge("mlflow_tools_agent", END)
346
+
347
+ app = workflow.compile()
348
+
349
+ return app
350
+
@@ -1,6 +1,5 @@
1
1
 
2
2
  from langchain_core.messages import BaseMessage
3
- from langgraph.checkpoint.memory import MemorySaver
4
3
  from langgraph.types import Checkpointer
5
4
 
6
5
  from langgraph.graph import START, END, StateGraph
@@ -19,7 +18,7 @@ from IPython.display import Markdown
19
18
  from ai_data_science_team.templates import BaseAgent
20
19
  from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
21
20
  from ai_data_science_team.utils.plotly import plotly_from_dict
22
- from ai_data_science_team.tools.regex import remove_consecutive_duplicates, get_generic_summary
21
+ from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
23
22
 
24
23
 
25
24
  class SQLDataAnalyst(BaseAgent):
@@ -91,7 +90,7 @@ class SQLDataAnalyst(BaseAgent):
91
90
  self._params[k] = v
92
91
  self._compiled_graph = self._make_compiled_graph()
93
92
 
94
- def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
93
+ async def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
95
94
  """
96
95
  Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
97
96
 
@@ -144,7 +143,7 @@ class SQLDataAnalyst(BaseAgent):
144
143
  sql_data_analyst.get_plotly_graph()
145
144
  ```
146
145
  """
147
- response = self._compiled_graph.ainvoke({
146
+ response = await self._compiled_graph.ainvoke({
148
147
  "user_instructions": user_instructions,
149
148
  "max_retries": max_retries,
150
149
  "retry_count": retry_count,
File without changes
@@ -3,7 +3,6 @@
3
3
  # ***
4
4
  # Parsers
5
5
 
6
- from langchain_core.output_parsers import JsonOutputParser
7
6
  from langchain_core.output_parsers import BaseOutputParser
8
7
 
9
8
  import re
@@ -12,8 +12,8 @@ import json
12
12
 
13
13
  from typing import Any, Callable, Dict, Type, Optional, Union, List
14
14
 
15
- from ai_data_science_team.tools.parsers import PythonOutputParser
16
- from ai_data_science_team.tools.regex import (
15
+ from ai_data_science_team.parsers.parsers import PythonOutputParser
16
+ from ai_data_science_team.utils.regex import (
17
17
  relocate_imports_inside_function,
18
18
  add_comments_to_top,
19
19
  remove_consecutive_duplicates
@@ -93,7 +93,7 @@ class BaseAgent(CompiledStateGraph):
93
93
 
94
94
  return self.response
95
95
 
96
- def ainvoke(
96
+ async def ainvoke(
97
97
  self,
98
98
  input: Union[dict[str, Any], Any],
99
99
  config: Optional[RunnableConfig] = None,
@@ -110,7 +110,7 @@ class BaseAgent(CompiledStateGraph):
110
110
  Returns:
111
111
  Any: The agent's response.
112
112
  """
113
- self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
113
+ self.response = await self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
114
114
 
115
115
  if self.response.get("messages"):
116
116
  self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
@@ -148,7 +148,7 @@ class BaseAgent(CompiledStateGraph):
148
148
 
149
149
  return self.response
150
150
 
151
- def astream(
151
+ async def astream(
152
152
  self,
153
153
  input: dict[str, Any] | Any,
154
154
  config: RunnableConfig | None = None,
@@ -172,7 +172,7 @@ class BaseAgent(CompiledStateGraph):
172
172
  Returns:
173
173
  Any: The agent's response.
174
174
  """
175
- self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
175
+ self.response = await self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
176
176
 
177
177
  if self.response.get("messages"):
178
178
  self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])