ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +4 -5
- ai_data_science_team/agents/data_cleaning_agent.py +268 -116
- ai_data_science_team/agents/data_visualization_agent.py +470 -41
- ai_data_science_team/agents/data_wrangling_agent.py +471 -31
- ai_data_science_team/agents/feature_engineering_agent.py +426 -41
- ai_data_science_team/agents/sql_database_agent.py +458 -58
- ai_data_science_team/ml_agents/__init__.py +1 -0
- ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +3 -1
- ai_data_science_team/templates/agent_templates.py +319 -43
- ai_data_science_team/tools/metadata.py +94 -62
- ai_data_science_team/tools/regex.py +86 -1
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
- ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
- ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
- ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -0,0 +1 @@
|
|
1
|
+
from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
|
@@ -0,0 +1,398 @@
|
|
1
|
+
|
2
|
+
from langchain_core.messages import BaseMessage
|
3
|
+
from langgraph.checkpoint.memory import MemorySaver
|
4
|
+
from langgraph.types import Checkpointer
|
5
|
+
|
6
|
+
from langgraph.graph import START, END, StateGraph
|
7
|
+
from langgraph.graph.state import CompiledStateGraph
|
8
|
+
from langgraph.types import Command
|
9
|
+
|
10
|
+
from typing import TypedDict, Annotated, Sequence, Literal
|
11
|
+
import operator
|
12
|
+
|
13
|
+
from typing_extensions import TypedDict
|
14
|
+
|
15
|
+
import pandas as pd
|
16
|
+
import json
|
17
|
+
from IPython.display import Markdown
|
18
|
+
|
19
|
+
from ai_data_science_team.templates import BaseAgent
|
20
|
+
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
21
|
+
from ai_data_science_team.utils.plotly import plotly_from_dict
|
22
|
+
from ai_data_science_team.tools.regex import remove_consecutive_duplicates, get_generic_summary
|
23
|
+
|
24
|
+
|
25
|
+
class SQLDataAnalyst(BaseAgent):
|
26
|
+
"""
|
27
|
+
SQLDataAnalyst is a multi-agent class that combines SQL database querying and data visualization capabilities.
|
28
|
+
|
29
|
+
Parameters:
|
30
|
+
-----------
|
31
|
+
model:
|
32
|
+
The language model to be used for the agents.
|
33
|
+
sql_database_agent: SQLDatabaseAgent
|
34
|
+
The SQL Database Agent.
|
35
|
+
data_visualization_agent: DataVisualizationAgent
|
36
|
+
The Data Visualization Agent.
|
37
|
+
|
38
|
+
Methods:
|
39
|
+
--------
|
40
|
+
ainvoke_agent(user_instructions, **kwargs)
|
41
|
+
Asynchronously invokes the SQL Data Analyst Multi-Agent with the given user instructions.
|
42
|
+
invoke_agent(user_instructions, **kwargs)
|
43
|
+
Invokes the SQL Data Analyst Multi-Agent with the given user instructions.
|
44
|
+
get_data_sql()
|
45
|
+
Returns the SQL data as a Pandas DataFrame.
|
46
|
+
get_plotly_graph()
|
47
|
+
Returns the Plotly graph as a Plotly object.
|
48
|
+
get_sql_query_code(markdown=False)
|
49
|
+
Returns the SQL query code as a string, optionally formatted as a Markdown code block.
|
50
|
+
get_sql_database_function(markdown=False)
|
51
|
+
Returns the SQL database function as a string, optionally formatted as a Markdown code block.
|
52
|
+
get_data_visualization_function(markdown=False)
|
53
|
+
Returns the data visualization function as a string, optionally formatted as a Markdown code block.
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
model,
|
59
|
+
sql_database_agent: SQLDatabaseAgent,
|
60
|
+
data_visualization_agent: DataVisualizationAgent,
|
61
|
+
checkpointer: Checkpointer = None,
|
62
|
+
):
|
63
|
+
self._params = {
|
64
|
+
"model": model,
|
65
|
+
"sql_database_agent": sql_database_agent,
|
66
|
+
"data_visualization_agent": data_visualization_agent,
|
67
|
+
"checkpointer": checkpointer,
|
68
|
+
}
|
69
|
+
self._compiled_graph = self._make_compiled_graph()
|
70
|
+
self.response = None
|
71
|
+
|
72
|
+
def _make_compiled_graph(self):
|
73
|
+
"""
|
74
|
+
Create or rebuild the compiled graph for the SQL Data Analyst Multi-Agent.
|
75
|
+
Running this method resets the response to None.
|
76
|
+
"""
|
77
|
+
self.response = None
|
78
|
+
return make_sql_data_analyst(
|
79
|
+
model=self._params["model"],
|
80
|
+
sql_database_agent=self._params["sql_database_agent"]._compiled_graph,
|
81
|
+
data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
|
82
|
+
checkpointer=self._params["checkpointer"],
|
83
|
+
)
|
84
|
+
|
85
|
+
def update_params(self, **kwargs):
|
86
|
+
"""
|
87
|
+
Updates the agent's parameters (e.g. model, sql_database_agent, etc.)
|
88
|
+
and rebuilds the compiled graph.
|
89
|
+
"""
|
90
|
+
for k, v in kwargs.items():
|
91
|
+
self._params[k] = v
|
92
|
+
self._compiled_graph = self._make_compiled_graph()
|
93
|
+
|
94
|
+
def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
|
95
|
+
"""
|
96
|
+
Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
|
97
|
+
|
98
|
+
Parameters:
|
99
|
+
----------
|
100
|
+
user_instructions: str
|
101
|
+
The user's instructions for the combined SQL and (optionally) Data Visualization agents.
|
102
|
+
**kwargs:
|
103
|
+
Additional keyword arguments to pass to the compiled graph's `ainvoke` method.
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
-------
|
107
|
+
None. The response is stored in the `response` attribute.
|
108
|
+
|
109
|
+
Example:
|
110
|
+
--------
|
111
|
+
``` python
|
112
|
+
from langchain_openai import ChatOpenAI
|
113
|
+
import sqlalchemy as sql
|
114
|
+
from ai_data_science_team.multiagents import SQLDataAnalyst
|
115
|
+
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
116
|
+
|
117
|
+
llm = ChatOpenAI(model = "gpt-4o-mini")
|
118
|
+
|
119
|
+
sql_engine = sql.create_engine("sqlite:///data/northwind.db")
|
120
|
+
|
121
|
+
conn = sql_engine.connect()
|
122
|
+
|
123
|
+
sql_data_analyst = SQLDataAnalyst(
|
124
|
+
model = llm,
|
125
|
+
sql_database_agent = SQLDatabaseAgent(
|
126
|
+
model = llm,
|
127
|
+
connection = conn,
|
128
|
+
n_samples = 1,
|
129
|
+
),
|
130
|
+
data_visualization_agent = DataVisualizationAgent(
|
131
|
+
model = llm,
|
132
|
+
n_samples = 10,
|
133
|
+
)
|
134
|
+
)
|
135
|
+
|
136
|
+
sql_data_analyst.ainvoke_agent(
|
137
|
+
user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
|
138
|
+
)
|
139
|
+
|
140
|
+
sql_data_analyst.get_sql_query_code()
|
141
|
+
|
142
|
+
sql_data_analyst.get_data_sql()
|
143
|
+
|
144
|
+
sql_data_analyst.get_plotly_graph()
|
145
|
+
```
|
146
|
+
"""
|
147
|
+
response = self._compiled_graph.ainvoke({
|
148
|
+
"user_instructions": user_instructions,
|
149
|
+
"max_retries": max_retries,
|
150
|
+
"retry_count": retry_count,
|
151
|
+
}, **kwargs)
|
152
|
+
|
153
|
+
if response.get("messages"):
|
154
|
+
response["messages"] = remove_consecutive_duplicates(response["messages"])
|
155
|
+
|
156
|
+
self.response = response
|
157
|
+
|
158
|
+
def invoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
|
159
|
+
"""
|
160
|
+
Invokes the SQL Data Analyst Multi-Agent.
|
161
|
+
|
162
|
+
Parameters:
|
163
|
+
----------
|
164
|
+
user_instructions: str
|
165
|
+
The user's instructions for the combined SQL and (optionally) Data Visualization agents.
|
166
|
+
max_retries (int):
|
167
|
+
Maximum retry attempts for cleaning.
|
168
|
+
retry_count (int):
|
169
|
+
Current retry attempt.
|
170
|
+
**kwargs:
|
171
|
+
Additional keyword arguments to pass to the compiled graph's `invoke` method.
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
-------
|
175
|
+
None. The response is stored in the `response` attribute.
|
176
|
+
|
177
|
+
Example:
|
178
|
+
--------
|
179
|
+
``` python
|
180
|
+
from langchain_openai import ChatOpenAI
|
181
|
+
import sqlalchemy as sql
|
182
|
+
from ai_data_science_team.multiagents import SQLDataAnalyst
|
183
|
+
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
184
|
+
|
185
|
+
llm = ChatOpenAI(model = "gpt-4o-mini")
|
186
|
+
|
187
|
+
sql_engine = sql.create_engine("sqlite:///data/northwind.db")
|
188
|
+
|
189
|
+
conn = sql_engine.connect()
|
190
|
+
|
191
|
+
sql_data_analyst = SQLDataAnalyst(
|
192
|
+
model = llm,
|
193
|
+
sql_database_agent = SQLDatabaseAgent(
|
194
|
+
model = llm,
|
195
|
+
connection = conn,
|
196
|
+
n_samples = 1,
|
197
|
+
),
|
198
|
+
data_visualization_agent = DataVisualizationAgent(
|
199
|
+
model = llm,
|
200
|
+
n_samples = 10,
|
201
|
+
)
|
202
|
+
)
|
203
|
+
|
204
|
+
sql_data_analyst.invoke_agent(
|
205
|
+
user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
|
206
|
+
)
|
207
|
+
|
208
|
+
sql_data_analyst.get_sql_query_code()
|
209
|
+
|
210
|
+
sql_data_analyst.get_data_sql()
|
211
|
+
|
212
|
+
sql_data_analyst.get_plotly_graph()
|
213
|
+
```
|
214
|
+
"""
|
215
|
+
response = self._compiled_graph.invoke({
|
216
|
+
"user_instructions": user_instructions,
|
217
|
+
"max_retries": max_retries,
|
218
|
+
"retry_count": retry_count,
|
219
|
+
}, **kwargs)
|
220
|
+
|
221
|
+
if response.get("messages"):
|
222
|
+
response["messages"] = remove_consecutive_duplicates(response["messages"])
|
223
|
+
|
224
|
+
self.response = response
|
225
|
+
|
226
|
+
|
227
|
+
def get_data_sql(self):
|
228
|
+
"""
|
229
|
+
Returns the SQL data as a Pandas DataFrame.
|
230
|
+
"""
|
231
|
+
if self.response:
|
232
|
+
if self.response.get("data_sql"):
|
233
|
+
return pd.DataFrame(self.response.get("data_sql"))
|
234
|
+
|
235
|
+
def get_plotly_graph(self):
|
236
|
+
"""
|
237
|
+
Returns the Plotly graph as a Plotly object.
|
238
|
+
"""
|
239
|
+
if self.response:
|
240
|
+
if self.response.get("plotly_graph"):
|
241
|
+
return plotly_from_dict(self.response.get("plotly_graph"))
|
242
|
+
|
243
|
+
def get_sql_query_code(self, markdown=False):
|
244
|
+
"""
|
245
|
+
Returns the SQL query code as a string.
|
246
|
+
|
247
|
+
Parameters:
|
248
|
+
----------
|
249
|
+
markdown: bool
|
250
|
+
If True, returns the code as a Markdown code block for Jupyter (IPython).
|
251
|
+
For streamlit, use `st.code()` instead.
|
252
|
+
"""
|
253
|
+
if self.response:
|
254
|
+
if self.response.get("sql_query_code"):
|
255
|
+
if markdown:
|
256
|
+
return Markdown(f"```sql\n{self.response.get('sql_query_code')}\n```")
|
257
|
+
return self.response.get("sql_query_code")
|
258
|
+
|
259
|
+
def get_sql_database_function(self, markdown=False):
|
260
|
+
"""
|
261
|
+
Returns the SQL database function as a string.
|
262
|
+
|
263
|
+
Parameters:
|
264
|
+
----------
|
265
|
+
markdown: bool
|
266
|
+
If True, returns the function as a Markdown code block for Jupyter (IPython).
|
267
|
+
For streamlit, use `st.code()` instead.
|
268
|
+
"""
|
269
|
+
if self.response:
|
270
|
+
if self.response.get("sql_database_function"):
|
271
|
+
if markdown:
|
272
|
+
return Markdown(f"```python\n{self.response.get('sql_database_function')}\n```")
|
273
|
+
return self.response.get("sql_database_function")
|
274
|
+
|
275
|
+
def get_data_visualization_function(self, markdown=False):
|
276
|
+
"""
|
277
|
+
Returns the data visualization function as a string.
|
278
|
+
|
279
|
+
Parameters:
|
280
|
+
----------
|
281
|
+
markdown: bool
|
282
|
+
If True, returns the function as a Markdown code block for Jupyter (IPython).
|
283
|
+
For streamlit, use `st.code()` instead.
|
284
|
+
"""
|
285
|
+
if self.response:
|
286
|
+
if self.response.get("data_visualization_function"):
|
287
|
+
if markdown:
|
288
|
+
return Markdown(f"```python\n{self.response.get('data_visualization_function')}\n```")
|
289
|
+
return self.response.get("data_visualization_function")
|
290
|
+
|
291
|
+
def get_workflow_summary(self, markdown=False):
|
292
|
+
"""
|
293
|
+
Returns a summary of the SQL Data Analyst workflow.
|
294
|
+
|
295
|
+
Parameters:
|
296
|
+
----------
|
297
|
+
markdown: bool
|
298
|
+
If True, returns the summary as a Markdown-formatted string.
|
299
|
+
"""
|
300
|
+
if self.response and self.get_response()['messages']:
|
301
|
+
|
302
|
+
agents = [self.get_response()['messages'][i].role for i in range(len(self.get_response()['messages']))]
|
303
|
+
|
304
|
+
agent_labels = []
|
305
|
+
for i in range(len(agents)):
|
306
|
+
agent_labels.append(f"- **Agent {i+1}:** {agents[i]}")
|
307
|
+
|
308
|
+
# Construct header
|
309
|
+
header = f"# SQL Data Analyst Workflow Summary Report\n\nThis agentic workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
|
310
|
+
|
311
|
+
reports = []
|
312
|
+
for msg in self.get_response()['messages']:
|
313
|
+
reports.append(get_generic_summary(json.loads(msg.content)))
|
314
|
+
|
315
|
+
if markdown:
|
316
|
+
return Markdown(header + "\n\n".join(reports))
|
317
|
+
return "\n\n".join(reports)
|
318
|
+
|
319
|
+
|
320
|
+
|
321
|
+
def make_sql_data_analyst(
|
322
|
+
model,
|
323
|
+
sql_database_agent: CompiledStateGraph,
|
324
|
+
data_visualization_agent: CompiledStateGraph,
|
325
|
+
checkpointer: Checkpointer = None
|
326
|
+
):
|
327
|
+
"""
|
328
|
+
Creates a multi-agent system that takes in a SQL query and returns a plot or table.
|
329
|
+
|
330
|
+
- Agent 1: SQL Database Agent made with `make_sql_database_agent()`
|
331
|
+
- Agent 2: Data Visualization Agent made with `make_data_visualization_agent()`
|
332
|
+
|
333
|
+
Parameters:
|
334
|
+
----------
|
335
|
+
model:
|
336
|
+
The language model to be used for the agents.
|
337
|
+
sql_database_agent: CompiledStateGraph
|
338
|
+
The SQL Database Agent made with `make_sql_database_agent()`.
|
339
|
+
data_visualization_agent: CompiledStateGraph
|
340
|
+
The Data Visualization Agent made with `make_data_visualization_agent()`.
|
341
|
+
checkpointer: Checkpointer (optional)
|
342
|
+
The checkpointer to save the state of the multi-agent system.
|
343
|
+
Default: None
|
344
|
+
|
345
|
+
Returns:
|
346
|
+
-------
|
347
|
+
CompiledStateGraph
|
348
|
+
The compiled multi-agent system.
|
349
|
+
"""
|
350
|
+
|
351
|
+
llm = model
|
352
|
+
|
353
|
+
class PrimaryState(TypedDict):
|
354
|
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
355
|
+
user_instructions: str
|
356
|
+
sql_query_code: str
|
357
|
+
sql_database_function: str
|
358
|
+
data_sql: dict
|
359
|
+
data_raw: dict
|
360
|
+
plot_required: bool
|
361
|
+
data_visualization_function: str
|
362
|
+
plotly_graph: dict
|
363
|
+
max_retries: int
|
364
|
+
retry_count: int
|
365
|
+
|
366
|
+
def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
|
367
|
+
|
368
|
+
response = llm.invoke(f"Respond in 1 word ('plot' or 'table'). Is the user requesting a plot? If unknown, select 'table'. \n\n User Instructions:\n{state.get('user_instructions')}")
|
369
|
+
|
370
|
+
if response.content == 'plot':
|
371
|
+
plot_required = True
|
372
|
+
goto="data_visualization_agent"
|
373
|
+
else:
|
374
|
+
plot_required = False
|
375
|
+
goto="__end__"
|
376
|
+
|
377
|
+
return Command(
|
378
|
+
update={
|
379
|
+
'data_raw': state.get("data_sql"),
|
380
|
+
'plot_required': plot_required,
|
381
|
+
},
|
382
|
+
goto=goto
|
383
|
+
)
|
384
|
+
|
385
|
+
workflow = StateGraph(PrimaryState)
|
386
|
+
|
387
|
+
workflow.add_node("sql_database_agent", sql_database_agent)
|
388
|
+
workflow.add_node("route_to_visualization", route_to_visualization)
|
389
|
+
workflow.add_node("data_visualization_agent", data_visualization_agent)
|
390
|
+
|
391
|
+
workflow.add_edge(START, "sql_database_agent")
|
392
|
+
workflow.add_edge("sql_database_agent", "route_to_visualization")
|
393
|
+
workflow.add_edge("data_visualization_agent", END)
|
394
|
+
|
395
|
+
app = workflow.compile(checkpointer=checkpointer)
|
396
|
+
|
397
|
+
return app
|
398
|
+
|
@@ -3,6 +3,8 @@ from ai_data_science_team.templates.agent_templates import(
|
|
3
3
|
node_func_human_review,
|
4
4
|
node_func_fix_agent_code,
|
5
5
|
node_func_explain_agent_code,
|
6
|
+
node_func_report_agent_outputs,
|
6
7
|
node_func_execute_agent_from_sql_connection,
|
7
|
-
create_coding_agent_graph
|
8
|
+
create_coding_agent_graph,
|
9
|
+
BaseAgent,
|
8
10
|
)
|