ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (25) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +4 -5
  3. ai_data_science_team/agents/data_cleaning_agent.py +268 -116
  4. ai_data_science_team/agents/data_visualization_agent.py +470 -41
  5. ai_data_science_team/agents/data_wrangling_agent.py +471 -31
  6. ai_data_science_team/agents/feature_engineering_agent.py +426 -41
  7. ai_data_science_team/agents/sql_database_agent.py +458 -58
  8. ai_data_science_team/ml_agents/__init__.py +1 -0
  9. ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
  10. ai_data_science_team/multiagents/__init__.py +1 -0
  11. ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
  12. ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
  13. ai_data_science_team/templates/__init__.py +3 -1
  14. ai_data_science_team/templates/agent_templates.py +319 -43
  15. ai_data_science_team/tools/metadata.py +94 -62
  16. ai_data_science_team/tools/regex.py +86 -1
  17. ai_data_science_team/utils/__init__.py +0 -0
  18. ai_data_science_team/utils/plotly.py +24 -0
  19. ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
  20. ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
  21. ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
  22. ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
  23. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
  24. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
  25. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -0,0 +1 @@
1
+ from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
@@ -0,0 +1,398 @@
1
+
2
+ from langchain_core.messages import BaseMessage
3
+ from langgraph.checkpoint.memory import MemorySaver
4
+ from langgraph.types import Checkpointer
5
+
6
+ from langgraph.graph import START, END, StateGraph
7
+ from langgraph.graph.state import CompiledStateGraph
8
+ from langgraph.types import Command
9
+
10
+ from typing import TypedDict, Annotated, Sequence, Literal
11
+ import operator
12
+
13
+ from typing_extensions import TypedDict
14
+
15
+ import pandas as pd
16
+ import json
17
+ from IPython.display import Markdown
18
+
19
+ from ai_data_science_team.templates import BaseAgent
20
+ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
21
+ from ai_data_science_team.utils.plotly import plotly_from_dict
22
+ from ai_data_science_team.tools.regex import remove_consecutive_duplicates, get_generic_summary
23
+
24
+
25
+ class SQLDataAnalyst(BaseAgent):
26
+ """
27
+ SQLDataAnalyst is a multi-agent class that combines SQL database querying and data visualization capabilities.
28
+
29
+ Parameters:
30
+ -----------
31
+ model:
32
+ The language model to be used for the agents.
33
+ sql_database_agent: SQLDatabaseAgent
34
+ The SQL Database Agent.
35
+ data_visualization_agent: DataVisualizationAgent
36
+ The Data Visualization Agent.
37
+
38
+ Methods:
39
+ --------
40
+ ainvoke_agent(user_instructions, **kwargs)
41
+ Asynchronously invokes the SQL Data Analyst Multi-Agent with the given user instructions.
42
+ invoke_agent(user_instructions, **kwargs)
43
+ Invokes the SQL Data Analyst Multi-Agent with the given user instructions.
44
+ get_data_sql()
45
+ Returns the SQL data as a Pandas DataFrame.
46
+ get_plotly_graph()
47
+ Returns the Plotly graph as a Plotly object.
48
+ get_sql_query_code(markdown=False)
49
+ Returns the SQL query code as a string, optionally formatted as a Markdown code block.
50
+ get_sql_database_function(markdown=False)
51
+ Returns the SQL database function as a string, optionally formatted as a Markdown code block.
52
+ get_data_visualization_function(markdown=False)
53
+ Returns the data visualization function as a string, optionally formatted as a Markdown code block.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ model,
59
+ sql_database_agent: SQLDatabaseAgent,
60
+ data_visualization_agent: DataVisualizationAgent,
61
+ checkpointer: Checkpointer = None,
62
+ ):
63
+ self._params = {
64
+ "model": model,
65
+ "sql_database_agent": sql_database_agent,
66
+ "data_visualization_agent": data_visualization_agent,
67
+ "checkpointer": checkpointer,
68
+ }
69
+ self._compiled_graph = self._make_compiled_graph()
70
+ self.response = None
71
+
72
+ def _make_compiled_graph(self):
73
+ """
74
+ Create or rebuild the compiled graph for the SQL Data Analyst Multi-Agent.
75
+ Running this method resets the response to None.
76
+ """
77
+ self.response = None
78
+ return make_sql_data_analyst(
79
+ model=self._params["model"],
80
+ sql_database_agent=self._params["sql_database_agent"]._compiled_graph,
81
+ data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
82
+ checkpointer=self._params["checkpointer"],
83
+ )
84
+
85
+ def update_params(self, **kwargs):
86
+ """
87
+ Updates the agent's parameters (e.g. model, sql_database_agent, etc.)
88
+ and rebuilds the compiled graph.
89
+ """
90
+ for k, v in kwargs.items():
91
+ self._params[k] = v
92
+ self._compiled_graph = self._make_compiled_graph()
93
+
94
+ def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
95
+ """
96
+ Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
97
+
98
+ Parameters:
99
+ ----------
100
+ user_instructions: str
101
+ The user's instructions for the combined SQL and (optionally) Data Visualization agents.
102
+ **kwargs:
103
+ Additional keyword arguments to pass to the compiled graph's `ainvoke` method.
104
+
105
+ Returns:
106
+ -------
107
+ None. The response is stored in the `response` attribute.
108
+
109
+ Example:
110
+ --------
111
+ ``` python
112
+ from langchain_openai import ChatOpenAI
113
+ import sqlalchemy as sql
114
+ from ai_data_science_team.multiagents import SQLDataAnalyst
115
+ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
116
+
117
+ llm = ChatOpenAI(model = "gpt-4o-mini")
118
+
119
+ sql_engine = sql.create_engine("sqlite:///data/northwind.db")
120
+
121
+ conn = sql_engine.connect()
122
+
123
+ sql_data_analyst = SQLDataAnalyst(
124
+ model = llm,
125
+ sql_database_agent = SQLDatabaseAgent(
126
+ model = llm,
127
+ connection = conn,
128
+ n_samples = 1,
129
+ ),
130
+ data_visualization_agent = DataVisualizationAgent(
131
+ model = llm,
132
+ n_samples = 10,
133
+ )
134
+ )
135
+
136
+ sql_data_analyst.ainvoke_agent(
137
+ user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
138
+ )
139
+
140
+ sql_data_analyst.get_sql_query_code()
141
+
142
+ sql_data_analyst.get_data_sql()
143
+
144
+ sql_data_analyst.get_plotly_graph()
145
+ ```
146
+ """
147
+ response = self._compiled_graph.ainvoke({
148
+ "user_instructions": user_instructions,
149
+ "max_retries": max_retries,
150
+ "retry_count": retry_count,
151
+ }, **kwargs)
152
+
153
+ if response.get("messages"):
154
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
155
+
156
+ self.response = response
157
+
158
+ def invoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
159
+ """
160
+ Invokes the SQL Data Analyst Multi-Agent.
161
+
162
+ Parameters:
163
+ ----------
164
+ user_instructions: str
165
+ The user's instructions for the combined SQL and (optionally) Data Visualization agents.
166
+ max_retries (int):
167
+ Maximum retry attempts for cleaning.
168
+ retry_count (int):
169
+ Current retry attempt.
170
+ **kwargs:
171
+ Additional keyword arguments to pass to the compiled graph's `invoke` method.
172
+
173
+ Returns:
174
+ -------
175
+ None. The response is stored in the `response` attribute.
176
+
177
+ Example:
178
+ --------
179
+ ``` python
180
+ from langchain_openai import ChatOpenAI
181
+ import sqlalchemy as sql
182
+ from ai_data_science_team.multiagents import SQLDataAnalyst
183
+ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
184
+
185
+ llm = ChatOpenAI(model = "gpt-4o-mini")
186
+
187
+ sql_engine = sql.create_engine("sqlite:///data/northwind.db")
188
+
189
+ conn = sql_engine.connect()
190
+
191
+ sql_data_analyst = SQLDataAnalyst(
192
+ model = llm,
193
+ sql_database_agent = SQLDatabaseAgent(
194
+ model = llm,
195
+ connection = conn,
196
+ n_samples = 1,
197
+ ),
198
+ data_visualization_agent = DataVisualizationAgent(
199
+ model = llm,
200
+ n_samples = 10,
201
+ )
202
+ )
203
+
204
+ sql_data_analyst.invoke_agent(
205
+ user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
206
+ )
207
+
208
+ sql_data_analyst.get_sql_query_code()
209
+
210
+ sql_data_analyst.get_data_sql()
211
+
212
+ sql_data_analyst.get_plotly_graph()
213
+ ```
214
+ """
215
+ response = self._compiled_graph.invoke({
216
+ "user_instructions": user_instructions,
217
+ "max_retries": max_retries,
218
+ "retry_count": retry_count,
219
+ }, **kwargs)
220
+
221
+ if response.get("messages"):
222
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
223
+
224
+ self.response = response
225
+
226
+
227
+ def get_data_sql(self):
228
+ """
229
+ Returns the SQL data as a Pandas DataFrame.
230
+ """
231
+ if self.response:
232
+ if self.response.get("data_sql"):
233
+ return pd.DataFrame(self.response.get("data_sql"))
234
+
235
+ def get_plotly_graph(self):
236
+ """
237
+ Returns the Plotly graph as a Plotly object.
238
+ """
239
+ if self.response:
240
+ if self.response.get("plotly_graph"):
241
+ return plotly_from_dict(self.response.get("plotly_graph"))
242
+
243
+ def get_sql_query_code(self, markdown=False):
244
+ """
245
+ Returns the SQL query code as a string.
246
+
247
+ Parameters:
248
+ ----------
249
+ markdown: bool
250
+ If True, returns the code as a Markdown code block for Jupyter (IPython).
251
+ For streamlit, use `st.code()` instead.
252
+ """
253
+ if self.response:
254
+ if self.response.get("sql_query_code"):
255
+ if markdown:
256
+ return Markdown(f"```sql\n{self.response.get('sql_query_code')}\n```")
257
+ return self.response.get("sql_query_code")
258
+
259
+ def get_sql_database_function(self, markdown=False):
260
+ """
261
+ Returns the SQL database function as a string.
262
+
263
+ Parameters:
264
+ ----------
265
+ markdown: bool
266
+ If True, returns the function as a Markdown code block for Jupyter (IPython).
267
+ For streamlit, use `st.code()` instead.
268
+ """
269
+ if self.response:
270
+ if self.response.get("sql_database_function"):
271
+ if markdown:
272
+ return Markdown(f"```python\n{self.response.get('sql_database_function')}\n```")
273
+ return self.response.get("sql_database_function")
274
+
275
+ def get_data_visualization_function(self, markdown=False):
276
+ """
277
+ Returns the data visualization function as a string.
278
+
279
+ Parameters:
280
+ ----------
281
+ markdown: bool
282
+ If True, returns the function as a Markdown code block for Jupyter (IPython).
283
+ For streamlit, use `st.code()` instead.
284
+ """
285
+ if self.response:
286
+ if self.response.get("data_visualization_function"):
287
+ if markdown:
288
+ return Markdown(f"```python\n{self.response.get('data_visualization_function')}\n```")
289
+ return self.response.get("data_visualization_function")
290
+
291
+ def get_workflow_summary(self, markdown=False):
292
+ """
293
+ Returns a summary of the SQL Data Analyst workflow.
294
+
295
+ Parameters:
296
+ ----------
297
+ markdown: bool
298
+ If True, returns the summary as a Markdown-formatted string.
299
+ """
300
+ if self.response and self.get_response()['messages']:
301
+
302
+ agents = [self.get_response()['messages'][i].role for i in range(len(self.get_response()['messages']))]
303
+
304
+ agent_labels = []
305
+ for i in range(len(agents)):
306
+ agent_labels.append(f"- **Agent {i+1}:** {agents[i]}")
307
+
308
+ # Construct header
309
+ header = f"# SQL Data Analyst Workflow Summary Report\n\nThis agentic workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
310
+
311
+ reports = []
312
+ for msg in self.get_response()['messages']:
313
+ reports.append(get_generic_summary(json.loads(msg.content)))
314
+
315
+ if markdown:
316
+ return Markdown(header + "\n\n".join(reports))
317
+ return "\n\n".join(reports)
318
+
319
+
320
+
321
+ def make_sql_data_analyst(
322
+ model,
323
+ sql_database_agent: CompiledStateGraph,
324
+ data_visualization_agent: CompiledStateGraph,
325
+ checkpointer: Checkpointer = None
326
+ ):
327
+ """
328
+ Creates a multi-agent system that takes in a SQL query and returns a plot or table.
329
+
330
+ - Agent 1: SQL Database Agent made with `make_sql_database_agent()`
331
+ - Agent 2: Data Visualization Agent made with `make_data_visualization_agent()`
332
+
333
+ Parameters:
334
+ ----------
335
+ model:
336
+ The language model to be used for the agents.
337
+ sql_database_agent: CompiledStateGraph
338
+ The SQL Database Agent made with `make_sql_database_agent()`.
339
+ data_visualization_agent: CompiledStateGraph
340
+ The Data Visualization Agent made with `make_data_visualization_agent()`.
341
+ checkpointer: Checkpointer (optional)
342
+ The checkpointer to save the state of the multi-agent system.
343
+ Default: None
344
+
345
+ Returns:
346
+ -------
347
+ CompiledStateGraph
348
+ The compiled multi-agent system.
349
+ """
350
+
351
+ llm = model
352
+
353
+ class PrimaryState(TypedDict):
354
+ messages: Annotated[Sequence[BaseMessage], operator.add]
355
+ user_instructions: str
356
+ sql_query_code: str
357
+ sql_database_function: str
358
+ data_sql: dict
359
+ data_raw: dict
360
+ plot_required: bool
361
+ data_visualization_function: str
362
+ plotly_graph: dict
363
+ max_retries: int
364
+ retry_count: int
365
+
366
+ def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
367
+
368
+ response = llm.invoke(f"Respond in 1 word ('plot' or 'table'). Is the user requesting a plot? If unknown, select 'table'. \n\n User Instructions:\n{state.get('user_instructions')}")
369
+
370
+ if response.content == 'plot':
371
+ plot_required = True
372
+ goto="data_visualization_agent"
373
+ else:
374
+ plot_required = False
375
+ goto="__end__"
376
+
377
+ return Command(
378
+ update={
379
+ 'data_raw': state.get("data_sql"),
380
+ 'plot_required': plot_required,
381
+ },
382
+ goto=goto
383
+ )
384
+
385
+ workflow = StateGraph(PrimaryState)
386
+
387
+ workflow.add_node("sql_database_agent", sql_database_agent)
388
+ workflow.add_node("route_to_visualization", route_to_visualization)
389
+ workflow.add_node("data_visualization_agent", data_visualization_agent)
390
+
391
+ workflow.add_edge(START, "sql_database_agent")
392
+ workflow.add_edge("sql_database_agent", "route_to_visualization")
393
+ workflow.add_edge("data_visualization_agent", END)
394
+
395
+ app = workflow.compile(checkpointer=checkpointer)
396
+
397
+ return app
398
+
@@ -0,0 +1,2 @@
1
+ # TODO: Implement the supervised data analyst agent
2
+ # https://langchain-ai.github.io/langgraph/tutorials/multi_agent/agent_supervisor/#create-agent-supervisor
@@ -3,6 +3,8 @@ from ai_data_science_team.templates.agent_templates import(
3
3
  node_func_human_review,
4
4
  node_func_fix_agent_code,
5
5
  node_func_explain_agent_code,
6
+ node_func_report_agent_outputs,
6
7
  node_func_execute_agent_from_sql_connection,
7
- create_coding_agent_graph
8
+ create_coding_agent_graph,
9
+ BaseAgent,
8
10
  )