ai-data-science-team 0.0.0.9009__py3-none-any.whl → 0.0.0.9011__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +1 -0
  3. ai_data_science_team/agents/data_cleaning_agent.py +6 -6
  4. ai_data_science_team/agents/data_loader_tools_agent.py +272 -0
  5. ai_data_science_team/agents/data_visualization_agent.py +6 -7
  6. ai_data_science_team/agents/data_wrangling_agent.py +6 -6
  7. ai_data_science_team/agents/feature_engineering_agent.py +6 -6
  8. ai_data_science_team/agents/sql_database_agent.py +6 -6
  9. ai_data_science_team/ml_agents/__init__.py +1 -0
  10. ai_data_science_team/ml_agents/h2o_ml_agent.py +206 -385
  11. ai_data_science_team/ml_agents/h2o_ml_tools_agent.py +0 -0
  12. ai_data_science_team/ml_agents/mlflow_tools_agent.py +350 -0
  13. ai_data_science_team/multiagents/sql_data_analyst.py +3 -4
  14. ai_data_science_team/parsers/__init__.py +0 -0
  15. ai_data_science_team/{tools → parsers}/parsers.py +0 -1
  16. ai_data_science_team/templates/agent_templates.py +6 -6
  17. ai_data_science_team/tools/data_loader.py +448 -0
  18. ai_data_science_team/tools/dataframe.py +139 -0
  19. ai_data_science_team/tools/h2o.py +643 -0
  20. ai_data_science_team/tools/mlflow.py +961 -0
  21. ai_data_science_team/tools/{metadata.py → sql.py} +1 -137
  22. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9011.dist-info}/METADATA +40 -19
  23. ai_data_science_team-0.0.0.9011.dist-info/RECORD +36 -0
  24. ai_data_science_team-0.0.0.9009.dist-info/RECORD +0 -28
  25. /ai_data_science_team/{tools → utils}/logging.py +0 -0
  26. /ai_data_science_team/{tools → utils}/regex.py +0 -0
  27. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9011.dist-info}/LICENSE +0 -0
  28. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9011.dist-info}/WHEEL +0 -0
  29. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9011.dist-info}/top_level.txt +0 -0
File without changes
@@ -0,0 +1,350 @@
1
+
2
+ from typing import Any, Optional, Annotated, Sequence, Dict
3
+ import operator
4
+
5
+ import pandas as pd
6
+
7
+ from IPython.display import Markdown
8
+
9
+ from langchain_core.messages import BaseMessage, AIMessage
10
+
11
+ from langgraph.prebuilt import create_react_agent, ToolNode
12
+ from langgraph.prebuilt.chat_agent_executor import AgentState
13
+ from langgraph.graph import START, END, StateGraph
14
+
15
+ from ai_data_science_team.templates import BaseAgent
16
+ from ai_data_science_team.utils.regex import format_agent_name
17
+ from ai_data_science_team.tools.mlflow import (
18
+ mlflow_search_experiments,
19
+ mlflow_search_runs,
20
+ mlflow_create_experiment,
21
+ mlflow_predict_from_run_id,
22
+ mlflow_launch_ui,
23
+ mlflow_stop_ui,
24
+ mlflow_list_artifacts,
25
+ mlflow_download_artifacts,
26
+ mlflow_list_registered_models,
27
+ mlflow_search_registered_models,
28
+ mlflow_get_model_version_details,
29
+ )
30
+
31
+ AGENT_NAME = "mlflow_tools_agent"
32
+
33
+ # TOOL SETUP
34
+ tools = [
35
+ mlflow_search_experiments,
36
+ mlflow_search_runs,
37
+ mlflow_create_experiment,
38
+ mlflow_predict_from_run_id,
39
+ mlflow_launch_ui,
40
+ mlflow_stop_ui,
41
+ mlflow_list_artifacts,
42
+ mlflow_download_artifacts,
43
+ mlflow_list_registered_models,
44
+ mlflow_search_registered_models,
45
+ mlflow_get_model_version_details,
46
+ ]
47
+
48
+ class MLflowToolsAgent(BaseAgent):
49
+ """
50
+ An agent that can interact with MLflow by calling tools.
51
+
52
+ Current tools include:
53
+ - List Experiments
54
+ - Search Runs
55
+ - Create Experiment
56
+ - Predict (from a Run ID)
57
+
58
+ Parameters:
59
+ ----------
60
+ model : langchain.llms.base.LLM
61
+ The language model used to generate the tool calling agent.
62
+ mlfow_tracking_uri : str, optional
63
+ The tracking URI for MLflow. Defaults to None.
64
+ mlflow_registry_uri : str, optional
65
+ The registry URI for MLflow. Defaults to None.
66
+ react_agent_kwargs : dict
67
+ Additional keyword arguments to pass to the create_react_agent function.
68
+ invoke_react_agent_kwargs : dict
69
+ Additional keyword arguments to pass to the invoke method of the react agent.
70
+
71
+ Methods:
72
+ --------
73
+ update_params(**kwargs):
74
+ Updates the agent's parameters and rebuilds the compiled graph.
75
+ ainvoke_agent(user_instructions: str=None, data_raw: pd.DataFrame=None, **kwargs):
76
+ Asynchronously runs the agent with the given user instructions.
77
+ invoke_agent(user_instructions: str=None, data_raw: pd.DataFrame=None, **kwargs):
78
+ Runs the agent with the given user instructions.
79
+ get_internal_messages(markdown: bool=False):
80
+ Returns the internal messages from the agent's response.
81
+ get_mlflow_artifacts(as_dataframe: bool=False):
82
+ Returns the MLflow artifacts from the agent's response.
83
+ get_ai_message(markdown: bool=False):
84
+ Returns the AI message from the agent's response
85
+
86
+
87
+
88
+ Examples:
89
+ --------
90
+ ```python
91
+ from ai_data_science_team.ml_agents import MLflowToolsAgent
92
+
93
+ mlflow_agent = MLflowToolsAgent(llm)
94
+
95
+ mlflow_agent.invoke_agent(user_instructions="List the MLflow experiments")
96
+
97
+ mlflow_agent.get_response()
98
+
99
+ mlflow_agent.get_internal_messages(markdown=True)
100
+
101
+ mlflow_agent.get_ai_message(markdown=True)
102
+
103
+ mlflow_agent.get_mlflow_artifacts(as_dataframe=True)
104
+
105
+ ```
106
+
107
+ Returns
108
+ -------
109
+ MLflowToolsAgent : langchain.graphs.CompiledStateGraph
110
+ An instance of the MLflow Tools Agent.
111
+
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ model: Any,
117
+ mlflow_tracking_uri: Optional[str]=None,
118
+ mlflow_registry_uri: Optional[str]=None,
119
+ create_react_agent_kwargs: Optional[Dict]={},
120
+ invoke_react_agent_kwargs: Optional[Dict]={},
121
+ ):
122
+ self._params = {
123
+ "model": model,
124
+ "mlflow_tracking_uri": mlflow_tracking_uri,
125
+ "mlflow_registry_uri": mlflow_registry_uri,
126
+ "create_react_agent_kwargs": create_react_agent_kwargs,
127
+ "invoke_react_agent_kwargs": invoke_react_agent_kwargs,
128
+ }
129
+ self._compiled_graph = self._make_compiled_graph()
130
+ self.response = None
131
+
132
+ def _make_compiled_graph(self):
133
+ """
134
+ Creates the compiled graph for the agent.
135
+ """
136
+ self.response = None
137
+ return make_mlflow_tools_agent(**self._params)
138
+
139
+
140
+ def update_params(self, **kwargs):
141
+ """
142
+ Updates the agent's parameters and rebuilds the compiled graph.
143
+ """
144
+ for k, v in kwargs.items():
145
+ self._params[k] = v
146
+ self._compiled_graph = self._make_compiled_graph()
147
+
148
+ async def ainvoke_agent(
149
+ self,
150
+ user_instructions: str=None,
151
+ data_raw: pd.DataFrame=None,
152
+ **kwargs
153
+ ):
154
+ """
155
+ Runs the agent with the given user instructions.
156
+
157
+ Parameters:
158
+ ----------
159
+ user_instructions : str, optional
160
+ The user instructions to pass to the agent.
161
+ data_raw : pd.DataFrame, optional
162
+ The data to pass to the agent. Used for prediction and tool calls where data is required.
163
+ kwargs : dict, optional
164
+ Additional keyword arguments to pass to the agents ainvoke method.
165
+
166
+ """
167
+ response = await self._compiled_graph.ainvoke(
168
+ {
169
+ "user_instructions": user_instructions,
170
+ "data_raw": data_raw.to_dict() if data_raw is not None else None,
171
+ },
172
+ **kwargs
173
+ )
174
+ self.response = response
175
+ return None
176
+
177
+ def invoke_agent(
178
+ self,
179
+ user_instructions: str=None,
180
+ data_raw: pd.DataFrame=None,
181
+ **kwargs
182
+ ):
183
+ """
184
+ Runs the agent with the given user instructions.
185
+
186
+ Parameters:
187
+ ----------
188
+ user_instructions : str, optional
189
+ The user instructions to pass to the agent.
190
+ data_raw : pd.DataFrame, optional
191
+ The raw data to pass to the agent. Used for prediction and tool calls where data is required.
192
+
193
+ """
194
+ response = self._compiled_graph.invoke(
195
+ {
196
+ "user_instructions": user_instructions,
197
+ "data_raw": data_raw.to_dict() if data_raw is not None else None,
198
+ },
199
+ **kwargs
200
+ )
201
+ self.response = response
202
+ return None
203
+
204
+ def get_internal_messages(self, markdown: bool=False):
205
+ """
206
+ Returns the internal messages from the agent's response.
207
+ """
208
+ pretty_print = "\n\n".join([f"### {msg.type.upper()}\n\nID: {msg.id}\n\nContent:\n\n{msg.content}" for msg in self.response["internal_messages"]])
209
+ if markdown:
210
+ return Markdown(pretty_print)
211
+ else:
212
+ return self.response["internal_messages"]
213
+
214
+ def get_mlflow_artifacts(self, as_dataframe: bool=False):
215
+ """
216
+ Returns the MLflow artifacts from the agent's response.
217
+ """
218
+ if as_dataframe:
219
+ return pd.DataFrame(self.response["mlflow_artifacts"])
220
+ else:
221
+ return self.response["mlflow_artifacts"]
222
+
223
+ def get_ai_message(self, markdown: bool=False):
224
+ """
225
+ Returns the AI message from the agent's response.
226
+ """
227
+ if markdown:
228
+ return Markdown(self.response["messages"][0].content)
229
+ else:
230
+ return self.response["messages"][0].content
231
+
232
+
233
+
234
+
235
+ def make_mlflow_tools_agent(
236
+ model: Any,
237
+ mlflow_tracking_uri: str=None,
238
+ mlflow_registry_uri: str=None,
239
+ create_react_agent_kwargs: Optional[Dict]={},
240
+ invoke_react_agent_kwargs: Optional[Dict]={},
241
+ ):
242
+ """
243
+ MLflow Tool Calling Agent
244
+
245
+ Parameters:
246
+ ----------
247
+ model : Any
248
+ The language model used to generate the agent.
249
+ mlflow_tracking_uri : str, optional
250
+ The tracking URI for MLflow. Defaults to None.
251
+ mlflow_registry_uri : str, optional
252
+ The registry URI for MLflow. Defaults to None.
253
+ create_react_agent_kwargs : dict, optional
254
+ Additional keyword arguments to pass to the agent's create_react_agent method.
255
+ invoke_react_agent_kwargs : dict, optional
256
+ Additional keyword arguments to pass to the agent's invoke method.
257
+
258
+ Returns
259
+ -------
260
+ app : langchain.graphs.CompiledStateGraph
261
+ A compiled state graph for the MLflow Tool Calling Agent.
262
+
263
+ """
264
+
265
+ try:
266
+ import mlflow
267
+ except ImportError:
268
+ return "MLflow is not installed. Please install it by running: !pip install mlflow"
269
+
270
+ if mlflow_tracking_uri is not None:
271
+ mlflow.set_tracking_uri(mlflow_tracking_uri)
272
+
273
+ if mlflow_registry_uri is not None:
274
+ mlflow.set_registry_uri(mlflow_registry_uri)
275
+
276
+ class GraphState(AgentState):
277
+ internal_messages: Annotated[Sequence[BaseMessage], operator.add]
278
+ user_instructions: str
279
+ data_raw: dict
280
+ mlflow_artifacts: dict
281
+
282
+
283
+ def mflfow_tools_agent(state):
284
+ """
285
+ Postprocesses the MLflow state, keeping only the last message
286
+ and extracting the last tool artifact.
287
+ """
288
+ print(format_agent_name(AGENT_NAME))
289
+ print(" * RUN REACT TOOL-CALLING AGENT")
290
+
291
+ tool_node = ToolNode(
292
+ tools=tools
293
+ )
294
+
295
+ mlflow_agent = create_react_agent(
296
+ model,
297
+ tools=tool_node,
298
+ state_schema=GraphState,
299
+ **create_react_agent_kwargs,
300
+ )
301
+
302
+ response = mlflow_agent.invoke(
303
+ {
304
+ "messages": [("user", state["user_instructions"])],
305
+ "data_raw": state["data_raw"],
306
+ },
307
+ invoke_react_agent_kwargs,
308
+ )
309
+
310
+ print(" * POST-PROCESS RESULTS")
311
+
312
+ internal_messages = response['messages']
313
+
314
+ # Ensure there is at least one AI message
315
+ if not internal_messages:
316
+ return {
317
+ "internal_messages": [],
318
+ "mlflow_artifacts": None,
319
+ }
320
+
321
+ # Get the last AI message
322
+ last_ai_message = AIMessage(internal_messages[-1].content, role = AGENT_NAME)
323
+
324
+ # Get the last tool artifact safely
325
+ last_tool_artifact = None
326
+ if len(internal_messages) > 1:
327
+ last_message = internal_messages[-2] # Get second-to-last message
328
+ if hasattr(last_message, "artifact"): # Check if it has an "artifact"
329
+ last_tool_artifact = last_message.artifact
330
+ elif isinstance(last_message, dict) and "artifact" in last_message:
331
+ last_tool_artifact = last_message["artifact"]
332
+
333
+ return {
334
+ "messages": [last_ai_message],
335
+ "internal_messages": internal_messages,
336
+ "mlflow_artifacts": last_tool_artifact,
337
+ }
338
+
339
+
340
+ workflow = StateGraph(GraphState)
341
+
342
+ workflow.add_node("mlflow_tools_agent", mflfow_tools_agent)
343
+
344
+ workflow.add_edge(START, "mlflow_tools_agent")
345
+ workflow.add_edge("mlflow_tools_agent", END)
346
+
347
+ app = workflow.compile()
348
+
349
+ return app
350
+
@@ -1,6 +1,5 @@
1
1
 
2
2
  from langchain_core.messages import BaseMessage
3
- from langgraph.checkpoint.memory import MemorySaver
4
3
  from langgraph.types import Checkpointer
5
4
 
6
5
  from langgraph.graph import START, END, StateGraph
@@ -19,7 +18,7 @@ from IPython.display import Markdown
19
18
  from ai_data_science_team.templates import BaseAgent
20
19
  from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
21
20
  from ai_data_science_team.utils.plotly import plotly_from_dict
22
- from ai_data_science_team.tools.regex import remove_consecutive_duplicates, get_generic_summary
21
+ from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
23
22
 
24
23
 
25
24
  class SQLDataAnalyst(BaseAgent):
@@ -91,7 +90,7 @@ class SQLDataAnalyst(BaseAgent):
91
90
  self._params[k] = v
92
91
  self._compiled_graph = self._make_compiled_graph()
93
92
 
94
- def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
93
+ async def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
95
94
  """
96
95
  Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
97
96
 
@@ -144,7 +143,7 @@ class SQLDataAnalyst(BaseAgent):
144
143
  sql_data_analyst.get_plotly_graph()
145
144
  ```
146
145
  """
147
- response = self._compiled_graph.ainvoke({
146
+ response = await self._compiled_graph.ainvoke({
148
147
  "user_instructions": user_instructions,
149
148
  "max_retries": max_retries,
150
149
  "retry_count": retry_count,
File without changes
@@ -3,7 +3,6 @@
3
3
  # ***
4
4
  # Parsers
5
5
 
6
- from langchain_core.output_parsers import JsonOutputParser
7
6
  from langchain_core.output_parsers import BaseOutputParser
8
7
 
9
8
  import re
@@ -12,8 +12,8 @@ import json
12
12
 
13
13
  from typing import Any, Callable, Dict, Type, Optional, Union, List
14
14
 
15
- from ai_data_science_team.tools.parsers import PythonOutputParser
16
- from ai_data_science_team.tools.regex import (
15
+ from ai_data_science_team.parsers.parsers import PythonOutputParser
16
+ from ai_data_science_team.utils.regex import (
17
17
  relocate_imports_inside_function,
18
18
  add_comments_to_top,
19
19
  remove_consecutive_duplicates
@@ -93,7 +93,7 @@ class BaseAgent(CompiledStateGraph):
93
93
 
94
94
  return self.response
95
95
 
96
- def ainvoke(
96
+ async def ainvoke(
97
97
  self,
98
98
  input: Union[dict[str, Any], Any],
99
99
  config: Optional[RunnableConfig] = None,
@@ -110,7 +110,7 @@ class BaseAgent(CompiledStateGraph):
110
110
  Returns:
111
111
  Any: The agent's response.
112
112
  """
113
- self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
113
+ self.response = await self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
114
114
 
115
115
  if self.response.get("messages"):
116
116
  self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
@@ -148,7 +148,7 @@ class BaseAgent(CompiledStateGraph):
148
148
 
149
149
  return self.response
150
150
 
151
- def astream(
151
+ async def astream(
152
152
  self,
153
153
  input: dict[str, Any] | Any,
154
154
  config: RunnableConfig | None = None,
@@ -172,7 +172,7 @@ class BaseAgent(CompiledStateGraph):
172
172
  Returns:
173
173
  Any: The agent's response.
174
174
  """
175
- self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
175
+ self.response = await self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
176
176
 
177
177
  if self.response.get("messages"):
178
178
  self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])