ai-data-science-team 0.0.0.9008__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -7,18 +7,19 @@ from langgraph.graph import START, END, StateGraph
7
7
  from langgraph.graph.state import CompiledStateGraph
8
8
  from langgraph.types import Command
9
9
 
10
- from typing import TypedDict, Annotated, Sequence
10
+ from typing import TypedDict, Annotated, Sequence, Literal
11
11
  import operator
12
12
 
13
- from typing_extensions import TypedDict, Literal
13
+ from typing_extensions import TypedDict
14
14
 
15
15
  import pandas as pd
16
+ import json
16
17
  from IPython.display import Markdown
17
18
 
18
19
  from ai_data_science_team.templates import BaseAgent
19
20
  from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
20
21
  from ai_data_science_team.utils.plotly import plotly_from_dict
21
-
22
+ from ai_data_science_team.tools.regex import remove_consecutive_duplicates, get_generic_summary
22
23
 
23
24
 
24
25
  class SQLDataAnalyst(BaseAgent):
@@ -90,7 +91,7 @@ class SQLDataAnalyst(BaseAgent):
90
91
  self._params[k] = v
91
92
  self._compiled_graph = self._make_compiled_graph()
92
93
 
93
- def ainvoke_agent(self, user_instructions, **kwargs):
94
+ def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
94
95
  """
95
96
  Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
96
97
 
@@ -108,15 +109,53 @@ class SQLDataAnalyst(BaseAgent):
108
109
  Example:
109
110
  --------
110
111
  ``` python
111
- # TODO
112
+ from langchain_openai import ChatOpenAI
113
+ import sqlalchemy as sql
114
+ from ai_data_science_team.multiagents import SQLDataAnalyst
115
+ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
116
+
117
+ llm = ChatOpenAI(model = "gpt-4o-mini")
118
+
119
+ sql_engine = sql.create_engine("sqlite:///data/northwind.db")
120
+
121
+ conn = sql_engine.connect()
122
+
123
+ sql_data_analyst = SQLDataAnalyst(
124
+ model = llm,
125
+ sql_database_agent = SQLDatabaseAgent(
126
+ model = llm,
127
+ connection = conn,
128
+ n_samples = 1,
129
+ ),
130
+ data_visualization_agent = DataVisualizationAgent(
131
+ model = llm,
132
+ n_samples = 10,
133
+ )
134
+ )
135
+
136
+ sql_data_analyst.ainvoke_agent(
137
+ user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
138
+ )
139
+
140
+ sql_data_analyst.get_sql_query_code()
141
+
142
+ sql_data_analyst.get_data_sql()
143
+
144
+ sql_data_analyst.get_plotly_graph()
112
145
  ```
113
146
  """
114
147
  response = self._compiled_graph.ainvoke({
115
148
  "user_instructions": user_instructions,
149
+ "max_retries": max_retries,
150
+ "retry_count": retry_count,
116
151
  }, **kwargs)
152
+
153
+ if response.get("messages"):
154
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
155
+
117
156
  self.response = response
118
157
 
119
- def invoke_agent(self, user_instructions, **kwargs):
158
+ def invoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
120
159
  """
121
160
  Invokes the SQL Data Analyst Multi-Agent.
122
161
 
@@ -124,6 +163,10 @@ class SQLDataAnalyst(BaseAgent):
124
163
  ----------
125
164
  user_instructions: str
126
165
  The user's instructions for the combined SQL and (optionally) Data Visualization agents.
166
+ max_retries (int):
167
+ Maximum retry attempts for cleaning.
168
+ retry_count (int):
169
+ Current retry attempt.
127
170
  **kwargs:
128
171
  Additional keyword arguments to pass to the compiled graph's `invoke` method.
129
172
 
@@ -134,14 +177,53 @@ class SQLDataAnalyst(BaseAgent):
134
177
  Example:
135
178
  --------
136
179
  ``` python
137
- # TODO
180
+ from langchain_openai import ChatOpenAI
181
+ import sqlalchemy as sql
182
+ from ai_data_science_team.multiagents import SQLDataAnalyst
183
+ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
184
+
185
+ llm = ChatOpenAI(model = "gpt-4o-mini")
186
+
187
+ sql_engine = sql.create_engine("sqlite:///data/northwind.db")
188
+
189
+ conn = sql_engine.connect()
190
+
191
+ sql_data_analyst = SQLDataAnalyst(
192
+ model = llm,
193
+ sql_database_agent = SQLDatabaseAgent(
194
+ model = llm,
195
+ connection = conn,
196
+ n_samples = 1,
197
+ ),
198
+ data_visualization_agent = DataVisualizationAgent(
199
+ model = llm,
200
+ n_samples = 10,
201
+ )
202
+ )
203
+
204
+ sql_data_analyst.invoke_agent(
205
+ user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
206
+ )
207
+
208
+ sql_data_analyst.get_sql_query_code()
209
+
210
+ sql_data_analyst.get_data_sql()
211
+
212
+ sql_data_analyst.get_plotly_graph()
138
213
  ```
139
214
  """
140
215
  response = self._compiled_graph.invoke({
141
216
  "user_instructions": user_instructions,
217
+ "max_retries": max_retries,
218
+ "retry_count": retry_count,
142
219
  }, **kwargs)
220
+
221
+ if response.get("messages"):
222
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
223
+
143
224
  self.response = response
144
225
 
226
+
145
227
  def get_data_sql(self):
146
228
  """
147
229
  Returns the SQL data as a Pandas DataFrame.
@@ -205,6 +287,34 @@ class SQLDataAnalyst(BaseAgent):
205
287
  if markdown:
206
288
  return Markdown(f"```python\n{self.response.get('data_visualization_function')}\n```")
207
289
  return self.response.get("data_visualization_function")
290
+
291
+ def get_workflow_summary(self, markdown=False):
292
+ """
293
+ Returns a summary of the SQL Data Analyst workflow.
294
+
295
+ Parameters:
296
+ ----------
297
+ markdown: bool
298
+ If True, returns the summary as a Markdown-formatted string.
299
+ """
300
+ if self.response and self.get_response()['messages']:
301
+
302
+ agents = [self.get_response()['messages'][i].role for i in range(len(self.get_response()['messages']))]
303
+
304
+ agent_labels = []
305
+ for i in range(len(agents)):
306
+ agent_labels.append(f"- **Agent {i+1}:** {agents[i]}")
307
+
308
+ # Construct header
309
+ header = f"# SQL Data Analyst Workflow Summary Report\n\nThis agentic workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
310
+
311
+ reports = []
312
+ for msg in self.get_response()['messages']:
313
+ reports.append(get_generic_summary(json.loads(msg.content)))
314
+
315
+ if markdown:
316
+ return Markdown(header + "\n\n".join(reports))
317
+ return "\n\n".join(reports)
208
318
 
209
319
 
210
320
 
@@ -250,6 +360,8 @@ def make_sql_data_analyst(
250
360
  plot_required: bool
251
361
  data_visualization_function: str
252
362
  plotly_graph: dict
363
+ max_retries: int
364
+ retry_count: int
253
365
 
254
366
  def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
255
367
 
@@ -3,6 +3,7 @@ from ai_data_science_team.templates.agent_templates import(
3
3
  node_func_human_review,
4
4
  node_func_fix_agent_code,
5
5
  node_func_explain_agent_code,
6
+ node_func_report_agent_outputs,
6
7
  node_func_execute_agent_from_sql_connection,
7
8
  create_coding_agent_graph,
8
9
  BaseAgent,
@@ -8,11 +8,16 @@ from langgraph.pregel.types import StreamMode
8
8
 
9
9
  import pandas as pd
10
10
  import sqlalchemy as sql
11
+ import json
11
12
 
12
- from typing import Any, Callable, Dict, Type, Optional, Union
13
+ from typing import Any, Callable, Dict, Type, Optional, Union, List
13
14
 
14
15
  from ai_data_science_team.tools.parsers import PythonOutputParser
15
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
16
+ from ai_data_science_team.tools.regex import (
17
+ relocate_imports_inside_function,
18
+ add_comments_to_top,
19
+ remove_consecutive_duplicates
20
+ )
16
21
 
17
22
  from IPython.display import Image, display
18
23
  import pandas as pd
@@ -82,6 +87,10 @@ class BaseAgent(CompiledStateGraph):
82
87
  Any: The agent's response.
83
88
  """
84
89
  self.response = self._compiled_graph.invoke(input=input, config=config,**kwargs)
90
+
91
+ if self.response.get("messages"):
92
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
93
+
85
94
  return self.response
86
95
 
87
96
  def ainvoke(
@@ -102,6 +111,10 @@ class BaseAgent(CompiledStateGraph):
102
111
  Any: The agent's response.
103
112
  """
104
113
  self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
114
+
115
+ if self.response.get("messages"):
116
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
117
+
105
118
  return self.response
106
119
 
107
120
  def stream(
@@ -129,6 +142,10 @@ class BaseAgent(CompiledStateGraph):
129
142
  Any: The agent's response.
130
143
  """
131
144
  self.response = self._compiled_graph.stream(input=input, config=config, stream_mode=stream_mode, **kwargs)
145
+
146
+ if self.response.get("messages"):
147
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
148
+
132
149
  return self.response
133
150
 
134
151
  def astream(
@@ -156,6 +173,10 @@ class BaseAgent(CompiledStateGraph):
156
173
  Any: The agent's response.
157
174
  """
158
175
  self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
176
+
177
+ if self.response.get("messages"):
178
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
179
+
159
180
  return self.response
160
181
 
161
182
  def get_state_keys(self):
@@ -183,6 +204,9 @@ class BaseAgent(CompiledStateGraph):
183
204
  Returns:
184
205
  Any: The agent's response.
185
206
  """
207
+ if self.response.get("messages"):
208
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
209
+
186
210
  return self.response
187
211
 
188
212
  def show(self, xray: int = 0):
@@ -729,3 +753,50 @@ def node_func_explain_agent_code(
729
753
  # Return an error message if there was a problem with the code
730
754
  message = AIMessage(content=error_message)
731
755
  return {result_key: [message]}
756
+
757
+
758
+
759
+ def node_func_report_agent_outputs(
760
+ state: Dict[str, Any],
761
+ keys_to_include: List[str],
762
+ result_key: str,
763
+ role: str,
764
+ custom_title: str = "Agent Output Summary"
765
+ ) -> Dict[str, Any]:
766
+ """
767
+ Gathers relevant data directly from the state (filtered by `keys_to_include`)
768
+ and returns them as a structured message in `state[result_key]`.
769
+
770
+ No LLM is used.
771
+
772
+ Parameters
773
+ ----------
774
+ state : Dict[str, Any]
775
+ The current state dictionary holding all agent variables.
776
+ keys_to_include : List[str]
777
+ The list of keys in `state` to include in the output.
778
+ result_key : str
779
+ The key in `state` under which we'll store the final structured message.
780
+ role : str
781
+ The role that will be used in the final AIMessage (e.g., "DataCleaningAgent").
782
+ custom_title : str, optional
783
+ A title or heading for your report. Defaults to "Agent Output Summary".
784
+ """
785
+ print(" * REPORT AGENT OUTPUTS")
786
+
787
+ final_report = {"report_title": custom_title}
788
+
789
+ for key in keys_to_include:
790
+ final_report[key] = state.get(key, f"<{key}_not_found_in_state>")
791
+
792
+ # Wrap it in a list of messages (like the current "messages" pattern).
793
+ # You can serialize this dictionary as JSON or just cast it to string.
794
+ return {
795
+ result_key: [
796
+ AIMessage(
797
+ content=json.dumps(final_report, indent=2),
798
+ role=role
799
+ )
800
+ ]
801
+ }
802
+
@@ -1,6 +1,7 @@
1
1
  import io
2
2
  import pandas as pd
3
3
  import sqlalchemy as sql
4
+ from sqlalchemy import inspect
4
5
  from typing import Union, List, Dict
5
6
 
6
7
  def get_dataframe_summary(
@@ -139,8 +140,7 @@ def _summarize_dataframe(df: pd.DataFrame, dataset_name: str, n_sample=30, skip_
139
140
 
140
141
 
141
142
 
142
- def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engine.base.Engine],
143
- n_samples: int = 10) -> str:
143
+ def get_database_metadata(connection, n_samples=10) -> dict:
144
144
  """
145
145
  Collects metadata and sample data from a database, with safe identifier quoting and
146
146
  basic dialect-aware row limiting. Prevents issues with spaces/reserved words in identifiers.
@@ -154,77 +154,109 @@ def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engi
154
154
 
155
155
  Returns
156
156
  -------
157
- str
158
- A formatted string with database metadata, including some sample data from each column.
157
+ dict
158
+ A dictionary with database metadata, including some sample data from each column.
159
159
  """
160
-
161
- # If a connection is passed, use it; if an engine is passed, connect to it
162
160
  is_engine = isinstance(connection, sql.engine.base.Engine)
163
161
  conn = connection.connect() if is_engine else connection
164
162
 
165
- output = []
163
+ metadata = {
164
+ "dialect": None,
165
+ "driver": None,
166
+ "connection_url": None,
167
+ "schemas": [],
168
+ }
169
+
166
170
  try:
167
- # Grab the engine off the connection
168
171
  sql_engine = conn.engine
169
172
  dialect_name = sql_engine.dialect.name.lower()
170
173
 
171
- output.append(f"Database Dialect: {sql_engine.dialect.name}")
172
- output.append(f"Driver: {sql_engine.driver}")
173
- output.append(f"Connection URL: {sql_engine.url}")
174
-
175
- # Inspect the database
176
- inspector = sql.inspect(sql_engine)
177
- tables = inspector.get_table_names()
178
- output.append(f"Tables: {tables}")
179
- output.append(f"Schemas: {inspector.get_schema_names()}")
180
-
181
- # Helper to build a dialect-specific limit clause
182
- def build_query(col_name_quoted: str, table_name_quoted: str, n: int) -> str:
183
- """
184
- Returns a SQL query string to select N rows from the given column/table
185
- across different dialects (SQLite, MySQL, Postgres, MSSQL, Oracle, etc.)
186
- """
187
- if "sqlite" in dialect_name or "mysql" in dialect_name or "postgres" in dialect_name:
188
- # Common dialects supporting LIMIT
189
- return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
190
- elif "mssql" in dialect_name:
191
- # Microsoft SQL Server syntax
192
- return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted}"
193
- elif "oracle" in dialect_name:
194
- # Oracle syntax
195
- return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
196
- else:
197
- # Fallback
198
- return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
199
-
200
- # Prepare for quoting
201
- preparer = inspector.bind.dialect.identifier_preparer
202
-
203
- # For each table, get columns and sample data
204
- for table_name in tables:
205
- output.append(f"\nTable: {table_name}")
206
- # Properly quote the table name
207
- table_name_quoted = preparer.quote_identifier(table_name)
208
-
209
- for column in inspector.get_columns(table_name):
210
- col_name = column["name"]
211
- col_type = column["type"]
212
- output.append(f" Column: {col_name} Type: {col_type}")
174
+ metadata["dialect"] = sql_engine.dialect.name
175
+ metadata["driver"] = sql_engine.driver
176
+ metadata["connection_url"] = str(sql_engine.url)
213
177
 
214
- # Properly quote the column name
215
- col_name_quoted = preparer.quote_identifier(col_name)
216
-
217
- # Build a dialect-aware query with safe quoting
218
- query = build_query(col_name_quoted, table_name_quoted, n_samples)
219
-
220
- # Read a few sample values
221
- df = pd.read_sql(sql.text(query), conn)
222
- first_values = df[col_name].tolist()
223
- output.append(f" First {n_samples} Values: {first_values}")
178
+ inspector = inspect(sql_engine)
179
+ preparer = inspector.bind.dialect.identifier_preparer
224
180
 
181
+ # For each schema
182
+ for schema_name in inspector.get_schema_names():
183
+ schema_obj = {
184
+ "schema_name": schema_name,
185
+ "tables": []
186
+ }
187
+
188
+ tables = inspector.get_table_names(schema=schema_name)
189
+ for table_name in tables:
190
+ table_info = {
191
+ "table_name": table_name,
192
+ "columns": [],
193
+ "primary_key": [],
194
+ "foreign_keys": [],
195
+ "indexes": []
196
+ }
197
+ # Get columns
198
+ columns = inspector.get_columns(table_name, schema=schema_name)
199
+ for col in columns:
200
+ col_name = col["name"]
201
+ col_type = str(col["type"])
202
+ table_name_quoted = f"{preparer.quote_identifier(schema_name)}.{preparer.quote_identifier(table_name)}"
203
+ col_name_quoted = preparer.quote_identifier(col_name)
204
+
205
+ # Build query for sample data
206
+ query = build_query(col_name_quoted, table_name_quoted, n_samples, dialect_name)
207
+
208
+ # Retrieve sample data
209
+ try:
210
+ df = pd.read_sql(query, conn)
211
+ samples = df[col_name].head(n_samples).tolist()
212
+ except Exception as e:
213
+ samples = [f"Error retrieving data: {str(e)}"]
214
+
215
+ table_info["columns"].append({
216
+ "name": col_name,
217
+ "type": col_type,
218
+ "sample_values": samples
219
+ })
220
+
221
+ # Primary keys
222
+ pk_constraint = inspector.get_pk_constraint(table_name, schema=schema_name)
223
+ table_info["primary_key"] = pk_constraint.get("constrained_columns", [])
224
+
225
+ # Foreign keys
226
+ fks = inspector.get_foreign_keys(table_name, schema=schema_name)
227
+ table_info["foreign_keys"] = [
228
+ {
229
+ "local_cols": fk["constrained_columns"],
230
+ "referred_table": fk["referred_table"],
231
+ "referred_cols": fk["referred_columns"]
232
+ }
233
+ for fk in fks
234
+ ]
235
+
236
+ # Indexes
237
+ idxs = inspector.get_indexes(table_name, schema=schema_name)
238
+ table_info["indexes"] = idxs
239
+
240
+ schema_obj["tables"].append(table_info)
241
+
242
+ metadata["schemas"].append(schema_obj)
243
+
225
244
  finally:
226
- # Close connection if created inside the function
227
245
  if is_engine:
228
246
  conn.close()
229
247
 
230
- return "\n".join(output)
248
+ return metadata
249
+
250
+ def build_query(col_name_quoted: str, table_name_quoted: str, n: int, dialect_name: str) -> str:
251
+ # Example: expand your build_query to handle random sampling if possible
252
+ if "postgres" in dialect_name:
253
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
254
+ if "mysql" in dialect_name:
255
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RAND() LIMIT {n}"
256
+ if "sqlite" in dialect_name:
257
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
258
+ if "mssql" in dialect_name:
259
+ return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted} ORDER BY NEWID()"
260
+ # Oracle or fallback
261
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
262
+
@@ -103,4 +103,62 @@ def format_recommended_steps(raw_text: str, heading: str = "# Recommended Steps:
103
103
  if not seen_heading:
104
104
  new_lines.insert(0, heading)
105
105
 
106
- return "\n".join(new_lines)
106
+ return "\n".join(new_lines)
107
+
108
+ def get_generic_summary(report_dict: dict, code_lang = "python") -> str:
109
+ """
110
+ Takes a dictionary of unknown structure (e.g., from json.loads(...))
111
+ and returns a textual summary. It assumes:
112
+ 1) 'report_title' (if present) should be displayed first.
113
+ 2) If a key includes 'code' or 'function',
114
+ the value is treated as a code block.
115
+ 3) Otherwise, key-value pairs are displayed as text.
116
+
117
+ Parameters
118
+ ----------
119
+ report_dict : dict
120
+ The dictionary holding the agent output or user report.
121
+
122
+ Returns
123
+ -------
124
+ str
125
+ A formatted summary string.
126
+ """
127
+ # 1) Grab the report title (or default)
128
+ title = report_dict.get("report_title", "Untitled Report")
129
+
130
+ lines = []
131
+ lines.append(f"# {title}")
132
+
133
+ # 2) Iterate over all other keys
134
+ for key, value in report_dict.items():
135
+ # Skip the title key, since we already displayed it
136
+ if key == "report_title":
137
+ continue
138
+
139
+ # 3) Check if it's code or function
140
+ # (You can tweak this logic if you have different rules)
141
+ key_lower = key.lower()
142
+ if "code" in key_lower or "function" in key_lower:
143
+ # Treat as code
144
+ lines.append(f"\n## {format_agent_name(key).upper()}")
145
+ lines.append(f"```{code_lang}\n" + str(value) + "\n```")
146
+ else:
147
+ # 4) Otherwise, just display the key-value as text
148
+ lines.append(f"\n## {format_agent_name(key).upper()}")
149
+ lines.append(str(value))
150
+
151
+ return "\n".join(lines)
152
+
153
+ def remove_consecutive_duplicates(messages):
154
+ unique_messages = []
155
+ prev_message = None
156
+
157
+ for msg in messages:
158
+ if msg.content != prev_message:
159
+ unique_messages.append(msg)
160
+ prev_message = msg.content # Update previous message to current
161
+
162
+ return unique_messages
163
+
164
+