ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +4 -5
  3. ai_data_science_team/agents/data_cleaning_agent.py +268 -116
  4. ai_data_science_team/agents/data_visualization_agent.py +470 -41
  5. ai_data_science_team/agents/data_wrangling_agent.py +471 -31
  6. ai_data_science_team/agents/feature_engineering_agent.py +426 -41
  7. ai_data_science_team/agents/sql_database_agent.py +458 -58
  8. ai_data_science_team/ml_agents/__init__.py +1 -0
  9. ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
  10. ai_data_science_team/multiagents/__init__.py +1 -0
  11. ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
  12. ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
  13. ai_data_science_team/templates/__init__.py +3 -1
  14. ai_data_science_team/templates/agent_templates.py +319 -43
  15. ai_data_science_team/tools/metadata.py +94 -62
  16. ai_data_science_team/tools/regex.py +86 -1
  17. ai_data_science_team/utils/__init__.py +0 -0
  18. ai_data_science_team/utils/plotly.py +24 -0
  19. ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
  20. ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
  21. ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
  22. ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
  23. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
  24. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
  25. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -0,0 +1 @@
1
+ from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
@@ -0,0 +1,398 @@
1
+
2
+ from langchain_core.messages import BaseMessage
3
+ from langgraph.checkpoint.memory import MemorySaver
4
+ from langgraph.types import Checkpointer
5
+
6
+ from langgraph.graph import START, END, StateGraph
7
+ from langgraph.graph.state import CompiledStateGraph
8
+ from langgraph.types import Command
9
+
10
+ from typing import TypedDict, Annotated, Sequence, Literal
11
+ import operator
12
+
13
+ from typing_extensions import TypedDict
14
+
15
+ import pandas as pd
16
+ import json
17
+ from IPython.display import Markdown
18
+
19
+ from ai_data_science_team.templates import BaseAgent
20
+ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
21
+ from ai_data_science_team.utils.plotly import plotly_from_dict
22
+ from ai_data_science_team.tools.regex import remove_consecutive_duplicates, get_generic_summary
23
+
24
+
25
+ class SQLDataAnalyst(BaseAgent):
26
+ """
27
+ SQLDataAnalyst is a multi-agent class that combines SQL database querying and data visualization capabilities.
28
+
29
+ Parameters:
30
+ -----------
31
+ model:
32
+ The language model to be used for the agents.
33
+ sql_database_agent: SQLDatabaseAgent
34
+ The SQL Database Agent.
35
+ data_visualization_agent: DataVisualizationAgent
36
+ The Data Visualization Agent.
37
+
38
+ Methods:
39
+ --------
40
+ ainvoke_agent(user_instructions, **kwargs)
41
+ Asynchronously invokes the SQL Data Analyst Multi-Agent with the given user instructions.
42
+ invoke_agent(user_instructions, **kwargs)
43
+ Invokes the SQL Data Analyst Multi-Agent with the given user instructions.
44
+ get_data_sql()
45
+ Returns the SQL data as a Pandas DataFrame.
46
+ get_plotly_graph()
47
+ Returns the Plotly graph as a Plotly object.
48
+ get_sql_query_code(markdown=False)
49
+ Returns the SQL query code as a string, optionally formatted as a Markdown code block.
50
+ get_sql_database_function(markdown=False)
51
+ Returns the SQL database function as a string, optionally formatted as a Markdown code block.
52
+ get_data_visualization_function(markdown=False)
53
+ Returns the data visualization function as a string, optionally formatted as a Markdown code block.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ model,
59
+ sql_database_agent: SQLDatabaseAgent,
60
+ data_visualization_agent: DataVisualizationAgent,
61
+ checkpointer: Checkpointer = None,
62
+ ):
63
+ self._params = {
64
+ "model": model,
65
+ "sql_database_agent": sql_database_agent,
66
+ "data_visualization_agent": data_visualization_agent,
67
+ "checkpointer": checkpointer,
68
+ }
69
+ self._compiled_graph = self._make_compiled_graph()
70
+ self.response = None
71
+
72
+ def _make_compiled_graph(self):
73
+ """
74
+ Create or rebuild the compiled graph for the SQL Data Analyst Multi-Agent.
75
+ Running this method resets the response to None.
76
+ """
77
+ self.response = None
78
+ return make_sql_data_analyst(
79
+ model=self._params["model"],
80
+ sql_database_agent=self._params["sql_database_agent"]._compiled_graph,
81
+ data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
82
+ checkpointer=self._params["checkpointer"],
83
+ )
84
+
85
+ def update_params(self, **kwargs):
86
+ """
87
+ Updates the agent's parameters (e.g. model, sql_database_agent, etc.)
88
+ and rebuilds the compiled graph.
89
+ """
90
+ for k, v in kwargs.items():
91
+ self._params[k] = v
92
+ self._compiled_graph = self._make_compiled_graph()
93
+
94
+ def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
95
+ """
96
+ Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
97
+
98
+ Parameters:
99
+ ----------
100
+ user_instructions: str
101
+ The user's instructions for the combined SQL and (optionally) Data Visualization agents.
102
+ **kwargs:
103
+ Additional keyword arguments to pass to the compiled graph's `ainvoke` method.
104
+
105
+ Returns:
106
+ -------
107
+ None. The response is stored in the `response` attribute.
108
+
109
+ Example:
110
+ --------
111
+ ``` python
112
+ from langchain_openai import ChatOpenAI
113
+ import sqlalchemy as sql
114
+ from ai_data_science_team.multiagents import SQLDataAnalyst
115
+ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
116
+
117
+ llm = ChatOpenAI(model = "gpt-4o-mini")
118
+
119
+ sql_engine = sql.create_engine("sqlite:///data/northwind.db")
120
+
121
+ conn = sql_engine.connect()
122
+
123
+ sql_data_analyst = SQLDataAnalyst(
124
+ model = llm,
125
+ sql_database_agent = SQLDatabaseAgent(
126
+ model = llm,
127
+ connection = conn,
128
+ n_samples = 1,
129
+ ),
130
+ data_visualization_agent = DataVisualizationAgent(
131
+ model = llm,
132
+ n_samples = 10,
133
+ )
134
+ )
135
+
136
+ sql_data_analyst.ainvoke_agent(
137
+ user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
138
+ )
139
+
140
+ sql_data_analyst.get_sql_query_code()
141
+
142
+ sql_data_analyst.get_data_sql()
143
+
144
+ sql_data_analyst.get_plotly_graph()
145
+ ```
146
+ """
147
+ response = self._compiled_graph.ainvoke({
148
+ "user_instructions": user_instructions,
149
+ "max_retries": max_retries,
150
+ "retry_count": retry_count,
151
+ }, **kwargs)
152
+
153
+ if response.get("messages"):
154
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
155
+
156
+ self.response = response
157
+
158
+ def invoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
159
+ """
160
+ Invokes the SQL Data Analyst Multi-Agent.
161
+
162
+ Parameters:
163
+ ----------
164
+ user_instructions: str
165
+ The user's instructions for the combined SQL and (optionally) Data Visualization agents.
166
+ max_retries (int):
167
+ Maximum retry attempts for cleaning.
168
+ retry_count (int):
169
+ Current retry attempt.
170
+ **kwargs:
171
+ Additional keyword arguments to pass to the compiled graph's `invoke` method.
172
+
173
+ Returns:
174
+ -------
175
+ None. The response is stored in the `response` attribute.
176
+
177
+ Example:
178
+ --------
179
+ ``` python
180
+ from langchain_openai import ChatOpenAI
181
+ import sqlalchemy as sql
182
+ from ai_data_science_team.multiagents import SQLDataAnalyst
183
+ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
184
+
185
+ llm = ChatOpenAI(model = "gpt-4o-mini")
186
+
187
+ sql_engine = sql.create_engine("sqlite:///data/northwind.db")
188
+
189
+ conn = sql_engine.connect()
190
+
191
+ sql_data_analyst = SQLDataAnalyst(
192
+ model = llm,
193
+ sql_database_agent = SQLDatabaseAgent(
194
+ model = llm,
195
+ connection = conn,
196
+ n_samples = 1,
197
+ ),
198
+ data_visualization_agent = DataVisualizationAgent(
199
+ model = llm,
200
+ n_samples = 10,
201
+ )
202
+ )
203
+
204
+ sql_data_analyst.invoke_agent(
205
+ user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
206
+ )
207
+
208
+ sql_data_analyst.get_sql_query_code()
209
+
210
+ sql_data_analyst.get_data_sql()
211
+
212
+ sql_data_analyst.get_plotly_graph()
213
+ ```
214
+ """
215
+ response = self._compiled_graph.invoke({
216
+ "user_instructions": user_instructions,
217
+ "max_retries": max_retries,
218
+ "retry_count": retry_count,
219
+ }, **kwargs)
220
+
221
+ if response.get("messages"):
222
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
223
+
224
+ self.response = response
225
+
226
+
227
+ def get_data_sql(self):
228
+ """
229
+ Returns the SQL data as a Pandas DataFrame.
230
+ """
231
+ if self.response:
232
+ if self.response.get("data_sql"):
233
+ return pd.DataFrame(self.response.get("data_sql"))
234
+
235
+ def get_plotly_graph(self):
236
+ """
237
+ Returns the Plotly graph as a Plotly object.
238
+ """
239
+ if self.response:
240
+ if self.response.get("plotly_graph"):
241
+ return plotly_from_dict(self.response.get("plotly_graph"))
242
+
243
+ def get_sql_query_code(self, markdown=False):
244
+ """
245
+ Returns the SQL query code as a string.
246
+
247
+ Parameters:
248
+ ----------
249
+ markdown: bool
250
+ If True, returns the code as a Markdown code block for Jupyter (IPython).
251
+ For streamlit, use `st.code()` instead.
252
+ """
253
+ if self.response:
254
+ if self.response.get("sql_query_code"):
255
+ if markdown:
256
+ return Markdown(f"```sql\n{self.response.get('sql_query_code')}\n```")
257
+ return self.response.get("sql_query_code")
258
+
259
+ def get_sql_database_function(self, markdown=False):
260
+ """
261
+ Returns the SQL database function as a string.
262
+
263
+ Parameters:
264
+ ----------
265
+ markdown: bool
266
+ If True, returns the function as a Markdown code block for Jupyter (IPython).
267
+ For streamlit, use `st.code()` instead.
268
+ """
269
+ if self.response:
270
+ if self.response.get("sql_database_function"):
271
+ if markdown:
272
+ return Markdown(f"```python\n{self.response.get('sql_database_function')}\n```")
273
+ return self.response.get("sql_database_function")
274
+
275
+ def get_data_visualization_function(self, markdown=False):
276
+ """
277
+ Returns the data visualization function as a string.
278
+
279
+ Parameters:
280
+ ----------
281
+ markdown: bool
282
+ If True, returns the function as a Markdown code block for Jupyter (IPython).
283
+ For streamlit, use `st.code()` instead.
284
+ """
285
+ if self.response:
286
+ if self.response.get("data_visualization_function"):
287
+ if markdown:
288
+ return Markdown(f"```python\n{self.response.get('data_visualization_function')}\n```")
289
+ return self.response.get("data_visualization_function")
290
+
291
+ def get_workflow_summary(self, markdown=False):
292
+ """
293
+ Returns a summary of the SQL Data Analyst workflow.
294
+
295
+ Parameters:
296
+ ----------
297
+ markdown: bool
298
+ If True, returns the summary as a Markdown-formatted string.
299
+ """
300
+ if self.response and self.get_response()['messages']:
301
+
302
+ agents = [self.get_response()['messages'][i].role for i in range(len(self.get_response()['messages']))]
303
+
304
+ agent_labels = []
305
+ for i in range(len(agents)):
306
+ agent_labels.append(f"- **Agent {i+1}:** {agents[i]}")
307
+
308
+ # Construct header
309
+ header = f"# SQL Data Analyst Workflow Summary Report\n\nThis agentic workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
310
+
311
+ reports = []
312
+ for msg in self.get_response()['messages']:
313
+ reports.append(get_generic_summary(json.loads(msg.content)))
314
+
315
+ if markdown:
316
+ return Markdown(header + "\n\n".join(reports))
317
+ return "\n\n".join(reports)
318
+
319
+
320
+
321
+ def make_sql_data_analyst(
322
+ model,
323
+ sql_database_agent: CompiledStateGraph,
324
+ data_visualization_agent: CompiledStateGraph,
325
+ checkpointer: Checkpointer = None
326
+ ):
327
+ """
328
+ Creates a multi-agent system that takes in a SQL query and returns a plot or table.
329
+
330
+ - Agent 1: SQL Database Agent made with `make_sql_database_agent()`
331
+ - Agent 2: Data Visualization Agent made with `make_data_visualization_agent()`
332
+
333
+ Parameters:
334
+ ----------
335
+ model:
336
+ The language model to be used for the agents.
337
+ sql_database_agent: CompiledStateGraph
338
+ The SQL Database Agent made with `make_sql_database_agent()`.
339
+ data_visualization_agent: CompiledStateGraph
340
+ The Data Visualization Agent made with `make_data_visualization_agent()`.
341
+ checkpointer: Checkpointer (optional)
342
+ The checkpointer to save the state of the multi-agent system.
343
+ Default: None
344
+
345
+ Returns:
346
+ -------
347
+ CompiledStateGraph
348
+ The compiled multi-agent system.
349
+ """
350
+
351
+ llm = model
352
+
353
+ class PrimaryState(TypedDict):
354
+ messages: Annotated[Sequence[BaseMessage], operator.add]
355
+ user_instructions: str
356
+ sql_query_code: str
357
+ sql_database_function: str
358
+ data_sql: dict
359
+ data_raw: dict
360
+ plot_required: bool
361
+ data_visualization_function: str
362
+ plotly_graph: dict
363
+ max_retries: int
364
+ retry_count: int
365
+
366
+ def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
367
+
368
+ response = llm.invoke(f"Respond in 1 word ('plot' or 'table'). Is the user requesting a plot? If unknown, select 'table'. \n\n User Instructions:\n{state.get('user_instructions')}")
369
+
370
+ if response.content == 'plot':
371
+ plot_required = True
372
+ goto="data_visualization_agent"
373
+ else:
374
+ plot_required = False
375
+ goto="__end__"
376
+
377
+ return Command(
378
+ update={
379
+ 'data_raw': state.get("data_sql"),
380
+ 'plot_required': plot_required,
381
+ },
382
+ goto=goto
383
+ )
384
+
385
+ workflow = StateGraph(PrimaryState)
386
+
387
+ workflow.add_node("sql_database_agent", sql_database_agent)
388
+ workflow.add_node("route_to_visualization", route_to_visualization)
389
+ workflow.add_node("data_visualization_agent", data_visualization_agent)
390
+
391
+ workflow.add_edge(START, "sql_database_agent")
392
+ workflow.add_edge("sql_database_agent", "route_to_visualization")
393
+ workflow.add_edge("data_visualization_agent", END)
394
+
395
+ app = workflow.compile(checkpointer=checkpointer)
396
+
397
+ return app
398
+
@@ -0,0 +1,2 @@
1
+ # TODO: Implement the supervised data analyst agent
2
+ # https://langchain-ai.github.io/langgraph/tutorials/multi_agent/agent_supervisor/#create-agent-supervisor
@@ -3,6 +3,8 @@ from ai_data_science_team.templates.agent_templates import(
3
3
  node_func_human_review,
4
4
  node_func_fix_agent_code,
5
5
  node_func_explain_agent_code,
6
+ node_func_report_agent_outputs,
6
7
  node_func_execute_agent_from_sql_connection,
7
- create_coding_agent_graph
8
+ create_coding_agent_graph,
9
+ BaseAgent,
8
10
  )