ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +4 -5
- ai_data_science_team/agents/data_cleaning_agent.py +268 -116
- ai_data_science_team/agents/data_visualization_agent.py +470 -41
- ai_data_science_team/agents/data_wrangling_agent.py +471 -31
- ai_data_science_team/agents/feature_engineering_agent.py +426 -41
- ai_data_science_team/agents/sql_database_agent.py +458 -58
- ai_data_science_team/ml_agents/__init__.py +1 -0
- ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +3 -1
- ai_data_science_team/templates/agent_templates.py +319 -43
- ai_data_science_team/tools/metadata.py +94 -62
- ai_data_science_team/tools/regex.py +86 -1
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
- ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
- ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
- ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -0,0 +1 @@
|
|
1
|
+
from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
|
@@ -0,0 +1,398 @@
|
|
1
|
+
|
2
|
+
from langchain_core.messages import BaseMessage
|
3
|
+
from langgraph.checkpoint.memory import MemorySaver
|
4
|
+
from langgraph.types import Checkpointer
|
5
|
+
|
6
|
+
from langgraph.graph import START, END, StateGraph
|
7
|
+
from langgraph.graph.state import CompiledStateGraph
|
8
|
+
from langgraph.types import Command
|
9
|
+
|
10
|
+
from typing import TypedDict, Annotated, Sequence, Literal
|
11
|
+
import operator
|
12
|
+
|
13
|
+
from typing_extensions import TypedDict
|
14
|
+
|
15
|
+
import pandas as pd
|
16
|
+
import json
|
17
|
+
from IPython.display import Markdown
|
18
|
+
|
19
|
+
from ai_data_science_team.templates import BaseAgent
|
20
|
+
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
21
|
+
from ai_data_science_team.utils.plotly import plotly_from_dict
|
22
|
+
from ai_data_science_team.tools.regex import remove_consecutive_duplicates, get_generic_summary
|
23
|
+
|
24
|
+
|
25
|
+
class SQLDataAnalyst(BaseAgent):
|
26
|
+
"""
|
27
|
+
SQLDataAnalyst is a multi-agent class that combines SQL database querying and data visualization capabilities.
|
28
|
+
|
29
|
+
Parameters:
|
30
|
+
-----------
|
31
|
+
model:
|
32
|
+
The language model to be used for the agents.
|
33
|
+
sql_database_agent: SQLDatabaseAgent
|
34
|
+
The SQL Database Agent.
|
35
|
+
data_visualization_agent: DataVisualizationAgent
|
36
|
+
The Data Visualization Agent.
|
37
|
+
|
38
|
+
Methods:
|
39
|
+
--------
|
40
|
+
ainvoke_agent(user_instructions, **kwargs)
|
41
|
+
Asynchronously invokes the SQL Data Analyst Multi-Agent with the given user instructions.
|
42
|
+
invoke_agent(user_instructions, **kwargs)
|
43
|
+
Invokes the SQL Data Analyst Multi-Agent with the given user instructions.
|
44
|
+
get_data_sql()
|
45
|
+
Returns the SQL data as a Pandas DataFrame.
|
46
|
+
get_plotly_graph()
|
47
|
+
Returns the Plotly graph as a Plotly object.
|
48
|
+
get_sql_query_code(markdown=False)
|
49
|
+
Returns the SQL query code as a string, optionally formatted as a Markdown code block.
|
50
|
+
get_sql_database_function(markdown=False)
|
51
|
+
Returns the SQL database function as a string, optionally formatted as a Markdown code block.
|
52
|
+
get_data_visualization_function(markdown=False)
|
53
|
+
Returns the data visualization function as a string, optionally formatted as a Markdown code block.
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
model,
|
59
|
+
sql_database_agent: SQLDatabaseAgent,
|
60
|
+
data_visualization_agent: DataVisualizationAgent,
|
61
|
+
checkpointer: Checkpointer = None,
|
62
|
+
):
|
63
|
+
self._params = {
|
64
|
+
"model": model,
|
65
|
+
"sql_database_agent": sql_database_agent,
|
66
|
+
"data_visualization_agent": data_visualization_agent,
|
67
|
+
"checkpointer": checkpointer,
|
68
|
+
}
|
69
|
+
self._compiled_graph = self._make_compiled_graph()
|
70
|
+
self.response = None
|
71
|
+
|
72
|
+
def _make_compiled_graph(self):
|
73
|
+
"""
|
74
|
+
Create or rebuild the compiled graph for the SQL Data Analyst Multi-Agent.
|
75
|
+
Running this method resets the response to None.
|
76
|
+
"""
|
77
|
+
self.response = None
|
78
|
+
return make_sql_data_analyst(
|
79
|
+
model=self._params["model"],
|
80
|
+
sql_database_agent=self._params["sql_database_agent"]._compiled_graph,
|
81
|
+
data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
|
82
|
+
checkpointer=self._params["checkpointer"],
|
83
|
+
)
|
84
|
+
|
85
|
+
def update_params(self, **kwargs):
|
86
|
+
"""
|
87
|
+
Updates the agent's parameters (e.g. model, sql_database_agent, etc.)
|
88
|
+
and rebuilds the compiled graph.
|
89
|
+
"""
|
90
|
+
for k, v in kwargs.items():
|
91
|
+
self._params[k] = v
|
92
|
+
self._compiled_graph = self._make_compiled_graph()
|
93
|
+
|
94
|
+
def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
|
95
|
+
"""
|
96
|
+
Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
|
97
|
+
|
98
|
+
Parameters:
|
99
|
+
----------
|
100
|
+
user_instructions: str
|
101
|
+
The user's instructions for the combined SQL and (optionally) Data Visualization agents.
|
102
|
+
**kwargs:
|
103
|
+
Additional keyword arguments to pass to the compiled graph's `ainvoke` method.
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
-------
|
107
|
+
None. The response is stored in the `response` attribute.
|
108
|
+
|
109
|
+
Example:
|
110
|
+
--------
|
111
|
+
``` python
|
112
|
+
from langchain_openai import ChatOpenAI
|
113
|
+
import sqlalchemy as sql
|
114
|
+
from ai_data_science_team.multiagents import SQLDataAnalyst
|
115
|
+
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
116
|
+
|
117
|
+
llm = ChatOpenAI(model = "gpt-4o-mini")
|
118
|
+
|
119
|
+
sql_engine = sql.create_engine("sqlite:///data/northwind.db")
|
120
|
+
|
121
|
+
conn = sql_engine.connect()
|
122
|
+
|
123
|
+
sql_data_analyst = SQLDataAnalyst(
|
124
|
+
model = llm,
|
125
|
+
sql_database_agent = SQLDatabaseAgent(
|
126
|
+
model = llm,
|
127
|
+
connection = conn,
|
128
|
+
n_samples = 1,
|
129
|
+
),
|
130
|
+
data_visualization_agent = DataVisualizationAgent(
|
131
|
+
model = llm,
|
132
|
+
n_samples = 10,
|
133
|
+
)
|
134
|
+
)
|
135
|
+
|
136
|
+
sql_data_analyst.ainvoke_agent(
|
137
|
+
user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
|
138
|
+
)
|
139
|
+
|
140
|
+
sql_data_analyst.get_sql_query_code()
|
141
|
+
|
142
|
+
sql_data_analyst.get_data_sql()
|
143
|
+
|
144
|
+
sql_data_analyst.get_plotly_graph()
|
145
|
+
```
|
146
|
+
"""
|
147
|
+
response = self._compiled_graph.ainvoke({
|
148
|
+
"user_instructions": user_instructions,
|
149
|
+
"max_retries": max_retries,
|
150
|
+
"retry_count": retry_count,
|
151
|
+
}, **kwargs)
|
152
|
+
|
153
|
+
if response.get("messages"):
|
154
|
+
response["messages"] = remove_consecutive_duplicates(response["messages"])
|
155
|
+
|
156
|
+
self.response = response
|
157
|
+
|
158
|
+
def invoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
|
159
|
+
"""
|
160
|
+
Invokes the SQL Data Analyst Multi-Agent.
|
161
|
+
|
162
|
+
Parameters:
|
163
|
+
----------
|
164
|
+
user_instructions: str
|
165
|
+
The user's instructions for the combined SQL and (optionally) Data Visualization agents.
|
166
|
+
max_retries (int):
|
167
|
+
Maximum retry attempts for cleaning.
|
168
|
+
retry_count (int):
|
169
|
+
Current retry attempt.
|
170
|
+
**kwargs:
|
171
|
+
Additional keyword arguments to pass to the compiled graph's `invoke` method.
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
-------
|
175
|
+
None. The response is stored in the `response` attribute.
|
176
|
+
|
177
|
+
Example:
|
178
|
+
--------
|
179
|
+
``` python
|
180
|
+
from langchain_openai import ChatOpenAI
|
181
|
+
import sqlalchemy as sql
|
182
|
+
from ai_data_science_team.multiagents import SQLDataAnalyst
|
183
|
+
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
184
|
+
|
185
|
+
llm = ChatOpenAI(model = "gpt-4o-mini")
|
186
|
+
|
187
|
+
sql_engine = sql.create_engine("sqlite:///data/northwind.db")
|
188
|
+
|
189
|
+
conn = sql_engine.connect()
|
190
|
+
|
191
|
+
sql_data_analyst = SQLDataAnalyst(
|
192
|
+
model = llm,
|
193
|
+
sql_database_agent = SQLDatabaseAgent(
|
194
|
+
model = llm,
|
195
|
+
connection = conn,
|
196
|
+
n_samples = 1,
|
197
|
+
),
|
198
|
+
data_visualization_agent = DataVisualizationAgent(
|
199
|
+
model = llm,
|
200
|
+
n_samples = 10,
|
201
|
+
)
|
202
|
+
)
|
203
|
+
|
204
|
+
sql_data_analyst.invoke_agent(
|
205
|
+
user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
|
206
|
+
)
|
207
|
+
|
208
|
+
sql_data_analyst.get_sql_query_code()
|
209
|
+
|
210
|
+
sql_data_analyst.get_data_sql()
|
211
|
+
|
212
|
+
sql_data_analyst.get_plotly_graph()
|
213
|
+
```
|
214
|
+
"""
|
215
|
+
response = self._compiled_graph.invoke({
|
216
|
+
"user_instructions": user_instructions,
|
217
|
+
"max_retries": max_retries,
|
218
|
+
"retry_count": retry_count,
|
219
|
+
}, **kwargs)
|
220
|
+
|
221
|
+
if response.get("messages"):
|
222
|
+
response["messages"] = remove_consecutive_duplicates(response["messages"])
|
223
|
+
|
224
|
+
self.response = response
|
225
|
+
|
226
|
+
|
227
|
+
def get_data_sql(self):
|
228
|
+
"""
|
229
|
+
Returns the SQL data as a Pandas DataFrame.
|
230
|
+
"""
|
231
|
+
if self.response:
|
232
|
+
if self.response.get("data_sql"):
|
233
|
+
return pd.DataFrame(self.response.get("data_sql"))
|
234
|
+
|
235
|
+
def get_plotly_graph(self):
|
236
|
+
"""
|
237
|
+
Returns the Plotly graph as a Plotly object.
|
238
|
+
"""
|
239
|
+
if self.response:
|
240
|
+
if self.response.get("plotly_graph"):
|
241
|
+
return plotly_from_dict(self.response.get("plotly_graph"))
|
242
|
+
|
243
|
+
def get_sql_query_code(self, markdown=False):
|
244
|
+
"""
|
245
|
+
Returns the SQL query code as a string.
|
246
|
+
|
247
|
+
Parameters:
|
248
|
+
----------
|
249
|
+
markdown: bool
|
250
|
+
If True, returns the code as a Markdown code block for Jupyter (IPython).
|
251
|
+
For streamlit, use `st.code()` instead.
|
252
|
+
"""
|
253
|
+
if self.response:
|
254
|
+
if self.response.get("sql_query_code"):
|
255
|
+
if markdown:
|
256
|
+
return Markdown(f"```sql\n{self.response.get('sql_query_code')}\n```")
|
257
|
+
return self.response.get("sql_query_code")
|
258
|
+
|
259
|
+
def get_sql_database_function(self, markdown=False):
|
260
|
+
"""
|
261
|
+
Returns the SQL database function as a string.
|
262
|
+
|
263
|
+
Parameters:
|
264
|
+
----------
|
265
|
+
markdown: bool
|
266
|
+
If True, returns the function as a Markdown code block for Jupyter (IPython).
|
267
|
+
For streamlit, use `st.code()` instead.
|
268
|
+
"""
|
269
|
+
if self.response:
|
270
|
+
if self.response.get("sql_database_function"):
|
271
|
+
if markdown:
|
272
|
+
return Markdown(f"```python\n{self.response.get('sql_database_function')}\n```")
|
273
|
+
return self.response.get("sql_database_function")
|
274
|
+
|
275
|
+
def get_data_visualization_function(self, markdown=False):
|
276
|
+
"""
|
277
|
+
Returns the data visualization function as a string.
|
278
|
+
|
279
|
+
Parameters:
|
280
|
+
----------
|
281
|
+
markdown: bool
|
282
|
+
If True, returns the function as a Markdown code block for Jupyter (IPython).
|
283
|
+
For streamlit, use `st.code()` instead.
|
284
|
+
"""
|
285
|
+
if self.response:
|
286
|
+
if self.response.get("data_visualization_function"):
|
287
|
+
if markdown:
|
288
|
+
return Markdown(f"```python\n{self.response.get('data_visualization_function')}\n```")
|
289
|
+
return self.response.get("data_visualization_function")
|
290
|
+
|
291
|
+
def get_workflow_summary(self, markdown=False):
|
292
|
+
"""
|
293
|
+
Returns a summary of the SQL Data Analyst workflow.
|
294
|
+
|
295
|
+
Parameters:
|
296
|
+
----------
|
297
|
+
markdown: bool
|
298
|
+
If True, returns the summary as a Markdown-formatted string.
|
299
|
+
"""
|
300
|
+
if self.response and self.get_response()['messages']:
|
301
|
+
|
302
|
+
agents = [self.get_response()['messages'][i].role for i in range(len(self.get_response()['messages']))]
|
303
|
+
|
304
|
+
agent_labels = []
|
305
|
+
for i in range(len(agents)):
|
306
|
+
agent_labels.append(f"- **Agent {i+1}:** {agents[i]}")
|
307
|
+
|
308
|
+
# Construct header
|
309
|
+
header = f"# SQL Data Analyst Workflow Summary Report\n\nThis agentic workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
|
310
|
+
|
311
|
+
reports = []
|
312
|
+
for msg in self.get_response()['messages']:
|
313
|
+
reports.append(get_generic_summary(json.loads(msg.content)))
|
314
|
+
|
315
|
+
if markdown:
|
316
|
+
return Markdown(header + "\n\n".join(reports))
|
317
|
+
return "\n\n".join(reports)
|
318
|
+
|
319
|
+
|
320
|
+
|
321
|
+
def make_sql_data_analyst(
|
322
|
+
model,
|
323
|
+
sql_database_agent: CompiledStateGraph,
|
324
|
+
data_visualization_agent: CompiledStateGraph,
|
325
|
+
checkpointer: Checkpointer = None
|
326
|
+
):
|
327
|
+
"""
|
328
|
+
Creates a multi-agent system that takes in a SQL query and returns a plot or table.
|
329
|
+
|
330
|
+
- Agent 1: SQL Database Agent made with `make_sql_database_agent()`
|
331
|
+
- Agent 2: Data Visualization Agent made with `make_data_visualization_agent()`
|
332
|
+
|
333
|
+
Parameters:
|
334
|
+
----------
|
335
|
+
model:
|
336
|
+
The language model to be used for the agents.
|
337
|
+
sql_database_agent: CompiledStateGraph
|
338
|
+
The SQL Database Agent made with `make_sql_database_agent()`.
|
339
|
+
data_visualization_agent: CompiledStateGraph
|
340
|
+
The Data Visualization Agent made with `make_data_visualization_agent()`.
|
341
|
+
checkpointer: Checkpointer (optional)
|
342
|
+
The checkpointer to save the state of the multi-agent system.
|
343
|
+
Default: None
|
344
|
+
|
345
|
+
Returns:
|
346
|
+
-------
|
347
|
+
CompiledStateGraph
|
348
|
+
The compiled multi-agent system.
|
349
|
+
"""
|
350
|
+
|
351
|
+
llm = model
|
352
|
+
|
353
|
+
class PrimaryState(TypedDict):
|
354
|
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
355
|
+
user_instructions: str
|
356
|
+
sql_query_code: str
|
357
|
+
sql_database_function: str
|
358
|
+
data_sql: dict
|
359
|
+
data_raw: dict
|
360
|
+
plot_required: bool
|
361
|
+
data_visualization_function: str
|
362
|
+
plotly_graph: dict
|
363
|
+
max_retries: int
|
364
|
+
retry_count: int
|
365
|
+
|
366
|
+
def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
|
367
|
+
|
368
|
+
response = llm.invoke(f"Respond in 1 word ('plot' or 'table'). Is the user requesting a plot? If unknown, select 'table'. \n\n User Instructions:\n{state.get('user_instructions')}")
|
369
|
+
|
370
|
+
if response.content == 'plot':
|
371
|
+
plot_required = True
|
372
|
+
goto="data_visualization_agent"
|
373
|
+
else:
|
374
|
+
plot_required = False
|
375
|
+
goto="__end__"
|
376
|
+
|
377
|
+
return Command(
|
378
|
+
update={
|
379
|
+
'data_raw': state.get("data_sql"),
|
380
|
+
'plot_required': plot_required,
|
381
|
+
},
|
382
|
+
goto=goto
|
383
|
+
)
|
384
|
+
|
385
|
+
workflow = StateGraph(PrimaryState)
|
386
|
+
|
387
|
+
workflow.add_node("sql_database_agent", sql_database_agent)
|
388
|
+
workflow.add_node("route_to_visualization", route_to_visualization)
|
389
|
+
workflow.add_node("data_visualization_agent", data_visualization_agent)
|
390
|
+
|
391
|
+
workflow.add_edge(START, "sql_database_agent")
|
392
|
+
workflow.add_edge("sql_database_agent", "route_to_visualization")
|
393
|
+
workflow.add_edge("data_visualization_agent", END)
|
394
|
+
|
395
|
+
app = workflow.compile(checkpointer=checkpointer)
|
396
|
+
|
397
|
+
return app
|
398
|
+
|
@@ -3,6 +3,8 @@ from ai_data_science_team.templates.agent_templates import(
|
|
3
3
|
node_func_human_review,
|
4
4
|
node_func_fix_agent_code,
|
5
5
|
node_func_explain_agent_code,
|
6
|
+
node_func_report_agent_outputs,
|
6
7
|
node_func_execute_agent_from_sql_connection,
|
7
|
-
create_coding_agent_graph
|
8
|
+
create_coding_agent_graph,
|
9
|
+
BaseAgent,
|
8
10
|
)
|