ai-data-science-team 0.0.0.9008__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,18 +7,19 @@ from langgraph.graph import START, END, StateGraph
7
7
  from langgraph.graph.state import CompiledStateGraph
8
8
  from langgraph.types import Command
9
9
 
10
- from typing import TypedDict, Annotated, Sequence
10
+ from typing import TypedDict, Annotated, Sequence, Literal
11
11
  import operator
12
12
 
13
- from typing_extensions import TypedDict, Literal
13
+ from typing_extensions import TypedDict
14
14
 
15
15
  import pandas as pd
16
+ import json
16
17
  from IPython.display import Markdown
17
18
 
18
19
  from ai_data_science_team.templates import BaseAgent
19
20
  from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
20
21
  from ai_data_science_team.utils.plotly import plotly_from_dict
21
-
22
+ from ai_data_science_team.tools.regex import remove_consecutive_duplicates, get_generic_summary
22
23
 
23
24
 
24
25
  class SQLDataAnalyst(BaseAgent):
@@ -90,7 +91,7 @@ class SQLDataAnalyst(BaseAgent):
90
91
  self._params[k] = v
91
92
  self._compiled_graph = self._make_compiled_graph()
92
93
 
93
- def ainvoke_agent(self, user_instructions, **kwargs):
94
+ def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
94
95
  """
95
96
  Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
96
97
 
@@ -108,15 +109,53 @@ class SQLDataAnalyst(BaseAgent):
108
109
  Example:
109
110
  --------
110
111
  ``` python
111
- # TODO
112
+ from langchain_openai import ChatOpenAI
113
+ import sqlalchemy as sql
114
+ from ai_data_science_team.multiagents import SQLDataAnalyst
115
+ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
116
+
117
+ llm = ChatOpenAI(model = "gpt-4o-mini")
118
+
119
+ sql_engine = sql.create_engine("sqlite:///data/northwind.db")
120
+
121
+ conn = sql_engine.connect()
122
+
123
+ sql_data_analyst = SQLDataAnalyst(
124
+ model = llm,
125
+ sql_database_agent = SQLDatabaseAgent(
126
+ model = llm,
127
+ connection = conn,
128
+ n_samples = 1,
129
+ ),
130
+ data_visualization_agent = DataVisualizationAgent(
131
+ model = llm,
132
+ n_samples = 10,
133
+ )
134
+ )
135
+
136
+ sql_data_analyst.ainvoke_agent(
137
+ user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
138
+ )
139
+
140
+ sql_data_analyst.get_sql_query_code()
141
+
142
+ sql_data_analyst.get_data_sql()
143
+
144
+ sql_data_analyst.get_plotly_graph()
112
145
  ```
113
146
  """
114
147
  response = self._compiled_graph.ainvoke({
115
148
  "user_instructions": user_instructions,
149
+ "max_retries": max_retries,
150
+ "retry_count": retry_count,
116
151
  }, **kwargs)
152
+
153
+ if response.get("messages"):
154
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
155
+
117
156
  self.response = response
118
157
 
119
- def invoke_agent(self, user_instructions, **kwargs):
158
+ def invoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
120
159
  """
121
160
  Invokes the SQL Data Analyst Multi-Agent.
122
161
 
@@ -124,6 +163,10 @@ class SQLDataAnalyst(BaseAgent):
124
163
  ----------
125
164
  user_instructions: str
126
165
  The user's instructions for the combined SQL and (optionally) Data Visualization agents.
166
+ max_retries (int):
167
+ Maximum retry attempts for cleaning.
168
+ retry_count (int):
169
+ Current retry attempt.
127
170
  **kwargs:
128
171
  Additional keyword arguments to pass to the compiled graph's `invoke` method.
129
172
 
@@ -134,14 +177,53 @@ class SQLDataAnalyst(BaseAgent):
134
177
  Example:
135
178
  --------
136
179
  ``` python
137
- # TODO
180
+ from langchain_openai import ChatOpenAI
181
+ import sqlalchemy as sql
182
+ from ai_data_science_team.multiagents import SQLDataAnalyst
183
+ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
184
+
185
+ llm = ChatOpenAI(model = "gpt-4o-mini")
186
+
187
+ sql_engine = sql.create_engine("sqlite:///data/northwind.db")
188
+
189
+ conn = sql_engine.connect()
190
+
191
+ sql_data_analyst = SQLDataAnalyst(
192
+ model = llm,
193
+ sql_database_agent = SQLDatabaseAgent(
194
+ model = llm,
195
+ connection = conn,
196
+ n_samples = 1,
197
+ ),
198
+ data_visualization_agent = DataVisualizationAgent(
199
+ model = llm,
200
+ n_samples = 10,
201
+ )
202
+ )
203
+
204
+ sql_data_analyst.invoke_agent(
205
+ user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
206
+ )
207
+
208
+ sql_data_analyst.get_sql_query_code()
209
+
210
+ sql_data_analyst.get_data_sql()
211
+
212
+ sql_data_analyst.get_plotly_graph()
138
213
  ```
139
214
  """
140
215
  response = self._compiled_graph.invoke({
141
216
  "user_instructions": user_instructions,
217
+ "max_retries": max_retries,
218
+ "retry_count": retry_count,
142
219
  }, **kwargs)
220
+
221
+ if response.get("messages"):
222
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
223
+
143
224
  self.response = response
144
225
 
226
+
145
227
  def get_data_sql(self):
146
228
  """
147
229
  Returns the SQL data as a Pandas DataFrame.
@@ -205,6 +287,34 @@ class SQLDataAnalyst(BaseAgent):
205
287
  if markdown:
206
288
  return Markdown(f"```python\n{self.response.get('data_visualization_function')}\n```")
207
289
  return self.response.get("data_visualization_function")
290
+
291
+ def get_workflow_summary(self, markdown=False):
292
+ """
293
+ Returns a summary of the SQL Data Analyst workflow.
294
+
295
+ Parameters:
296
+ ----------
297
+ markdown: bool
298
+ If True, returns the summary as a Markdown-formatted string.
299
+ """
300
+ if self.response and self.get_response()['messages']:
301
+
302
+ agents = [self.get_response()['messages'][i].role for i in range(len(self.get_response()['messages']))]
303
+
304
+ agent_labels = []
305
+ for i in range(len(agents)):
306
+ agent_labels.append(f"- **Agent {i+1}:** {agents[i]}")
307
+
308
+ # Construct header
309
+ header = f"# SQL Data Analyst Workflow Summary Report\n\nThis agentic workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
310
+
311
+ reports = []
312
+ for msg in self.get_response()['messages']:
313
+ reports.append(get_generic_summary(json.loads(msg.content)))
314
+
315
+ if markdown:
316
+ return Markdown(header + "\n\n".join(reports))
317
+ return "\n\n".join(reports)
208
318
 
209
319
 
210
320
 
@@ -250,6 +360,8 @@ def make_sql_data_analyst(
250
360
  plot_required: bool
251
361
  data_visualization_function: str
252
362
  plotly_graph: dict
363
+ max_retries: int
364
+ retry_count: int
253
365
 
254
366
  def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
255
367
 
@@ -3,6 +3,7 @@ from ai_data_science_team.templates.agent_templates import(
3
3
  node_func_human_review,
4
4
  node_func_fix_agent_code,
5
5
  node_func_explain_agent_code,
6
+ node_func_report_agent_outputs,
6
7
  node_func_execute_agent_from_sql_connection,
7
8
  create_coding_agent_graph,
8
9
  BaseAgent,
@@ -8,11 +8,16 @@ from langgraph.pregel.types import StreamMode
8
8
 
9
9
  import pandas as pd
10
10
  import sqlalchemy as sql
11
+ import json
11
12
 
12
- from typing import Any, Callable, Dict, Type, Optional, Union
13
+ from typing import Any, Callable, Dict, Type, Optional, Union, List
13
14
 
14
15
  from ai_data_science_team.tools.parsers import PythonOutputParser
15
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
16
+ from ai_data_science_team.tools.regex import (
17
+ relocate_imports_inside_function,
18
+ add_comments_to_top,
19
+ remove_consecutive_duplicates
20
+ )
16
21
 
17
22
  from IPython.display import Image, display
18
23
  import pandas as pd
@@ -82,6 +87,10 @@ class BaseAgent(CompiledStateGraph):
82
87
  Any: The agent's response.
83
88
  """
84
89
  self.response = self._compiled_graph.invoke(input=input, config=config,**kwargs)
90
+
91
+ if self.response.get("messages"):
92
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
93
+
85
94
  return self.response
86
95
 
87
96
  def ainvoke(
@@ -102,6 +111,10 @@ class BaseAgent(CompiledStateGraph):
102
111
  Any: The agent's response.
103
112
  """
104
113
  self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
114
+
115
+ if self.response.get("messages"):
116
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
117
+
105
118
  return self.response
106
119
 
107
120
  def stream(
@@ -129,6 +142,10 @@ class BaseAgent(CompiledStateGraph):
129
142
  Any: The agent's response.
130
143
  """
131
144
  self.response = self._compiled_graph.stream(input=input, config=config, stream_mode=stream_mode, **kwargs)
145
+
146
+ if self.response.get("messages"):
147
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
148
+
132
149
  return self.response
133
150
 
134
151
  def astream(
@@ -156,6 +173,10 @@ class BaseAgent(CompiledStateGraph):
156
173
  Any: The agent's response.
157
174
  """
158
175
  self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
176
+
177
+ if self.response.get("messages"):
178
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
179
+
159
180
  return self.response
160
181
 
161
182
  def get_state_keys(self):
@@ -183,6 +204,9 @@ class BaseAgent(CompiledStateGraph):
183
204
  Returns:
184
205
  Any: The agent's response.
185
206
  """
207
+ if self.response.get("messages"):
208
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
209
+
186
210
  return self.response
187
211
 
188
212
  def show(self, xray: int = 0):
@@ -729,3 +753,50 @@ def node_func_explain_agent_code(
729
753
  # Return an error message if there was a problem with the code
730
754
  message = AIMessage(content=error_message)
731
755
  return {result_key: [message]}
756
+
757
+
758
+
759
+ def node_func_report_agent_outputs(
760
+ state: Dict[str, Any],
761
+ keys_to_include: List[str],
762
+ result_key: str,
763
+ role: str,
764
+ custom_title: str = "Agent Output Summary"
765
+ ) -> Dict[str, Any]:
766
+ """
767
+ Gathers relevant data directly from the state (filtered by `keys_to_include`)
768
+ and returns them as a structured message in `state[result_key]`.
769
+
770
+ No LLM is used.
771
+
772
+ Parameters
773
+ ----------
774
+ state : Dict[str, Any]
775
+ The current state dictionary holding all agent variables.
776
+ keys_to_include : List[str]
777
+ The list of keys in `state` to include in the output.
778
+ result_key : str
779
+ The key in `state` under which we'll store the final structured message.
780
+ role : str
781
+ The role that will be used in the final AIMessage (e.g., "DataCleaningAgent").
782
+ custom_title : str, optional
783
+ A title or heading for your report. Defaults to "Agent Output Summary".
784
+ """
785
+ print(" * REPORT AGENT OUTPUTS")
786
+
787
+ final_report = {"report_title": custom_title}
788
+
789
+ for key in keys_to_include:
790
+ final_report[key] = state.get(key, f"<{key}_not_found_in_state>")
791
+
792
+ # Wrap it in a list of messages (like the current "messages" pattern).
793
+ # You can serialize this dictionary as JSON or just cast it to string.
794
+ return {
795
+ result_key: [
796
+ AIMessage(
797
+ content=json.dumps(final_report, indent=2),
798
+ role=role
799
+ )
800
+ ]
801
+ }
802
+
@@ -1,6 +1,7 @@
1
1
  import io
2
2
  import pandas as pd
3
3
  import sqlalchemy as sql
4
+ from sqlalchemy import inspect
4
5
  from typing import Union, List, Dict
5
6
 
6
7
  def get_dataframe_summary(
@@ -139,8 +140,7 @@ def _summarize_dataframe(df: pd.DataFrame, dataset_name: str, n_sample=30, skip_
139
140
 
140
141
 
141
142
 
142
- def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engine.base.Engine],
143
- n_samples: int = 10) -> str:
143
+ def get_database_metadata(connection, n_samples=10) -> dict:
144
144
  """
145
145
  Collects metadata and sample data from a database, with safe identifier quoting and
146
146
  basic dialect-aware row limiting. Prevents issues with spaces/reserved words in identifiers.
@@ -154,77 +154,109 @@ def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engi
154
154
 
155
155
  Returns
156
156
  -------
157
- str
158
- A formatted string with database metadata, including some sample data from each column.
157
+ dict
158
+ A dictionary with database metadata, including some sample data from each column.
159
159
  """
160
-
161
- # If a connection is passed, use it; if an engine is passed, connect to it
162
160
  is_engine = isinstance(connection, sql.engine.base.Engine)
163
161
  conn = connection.connect() if is_engine else connection
164
162
 
165
- output = []
163
+ metadata = {
164
+ "dialect": None,
165
+ "driver": None,
166
+ "connection_url": None,
167
+ "schemas": [],
168
+ }
169
+
166
170
  try:
167
- # Grab the engine off the connection
168
171
  sql_engine = conn.engine
169
172
  dialect_name = sql_engine.dialect.name.lower()
170
173
 
171
- output.append(f"Database Dialect: {sql_engine.dialect.name}")
172
- output.append(f"Driver: {sql_engine.driver}")
173
- output.append(f"Connection URL: {sql_engine.url}")
174
-
175
- # Inspect the database
176
- inspector = sql.inspect(sql_engine)
177
- tables = inspector.get_table_names()
178
- output.append(f"Tables: {tables}")
179
- output.append(f"Schemas: {inspector.get_schema_names()}")
180
-
181
- # Helper to build a dialect-specific limit clause
182
- def build_query(col_name_quoted: str, table_name_quoted: str, n: int) -> str:
183
- """
184
- Returns a SQL query string to select N rows from the given column/table
185
- across different dialects (SQLite, MySQL, Postgres, MSSQL, Oracle, etc.)
186
- """
187
- if "sqlite" in dialect_name or "mysql" in dialect_name or "postgres" in dialect_name:
188
- # Common dialects supporting LIMIT
189
- return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
190
- elif "mssql" in dialect_name:
191
- # Microsoft SQL Server syntax
192
- return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted}"
193
- elif "oracle" in dialect_name:
194
- # Oracle syntax
195
- return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
196
- else:
197
- # Fallback
198
- return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
199
-
200
- # Prepare for quoting
201
- preparer = inspector.bind.dialect.identifier_preparer
202
-
203
- # For each table, get columns and sample data
204
- for table_name in tables:
205
- output.append(f"\nTable: {table_name}")
206
- # Properly quote the table name
207
- table_name_quoted = preparer.quote_identifier(table_name)
208
-
209
- for column in inspector.get_columns(table_name):
210
- col_name = column["name"]
211
- col_type = column["type"]
212
- output.append(f" Column: {col_name} Type: {col_type}")
174
+ metadata["dialect"] = sql_engine.dialect.name
175
+ metadata["driver"] = sql_engine.driver
176
+ metadata["connection_url"] = str(sql_engine.url)
213
177
 
214
- # Properly quote the column name
215
- col_name_quoted = preparer.quote_identifier(col_name)
216
-
217
- # Build a dialect-aware query with safe quoting
218
- query = build_query(col_name_quoted, table_name_quoted, n_samples)
219
-
220
- # Read a few sample values
221
- df = pd.read_sql(sql.text(query), conn)
222
- first_values = df[col_name].tolist()
223
- output.append(f" First {n_samples} Values: {first_values}")
178
+ inspector = inspect(sql_engine)
179
+ preparer = inspector.bind.dialect.identifier_preparer
224
180
 
181
+ # For each schema
182
+ for schema_name in inspector.get_schema_names():
183
+ schema_obj = {
184
+ "schema_name": schema_name,
185
+ "tables": []
186
+ }
187
+
188
+ tables = inspector.get_table_names(schema=schema_name)
189
+ for table_name in tables:
190
+ table_info = {
191
+ "table_name": table_name,
192
+ "columns": [],
193
+ "primary_key": [],
194
+ "foreign_keys": [],
195
+ "indexes": []
196
+ }
197
+ # Get columns
198
+ columns = inspector.get_columns(table_name, schema=schema_name)
199
+ for col in columns:
200
+ col_name = col["name"]
201
+ col_type = str(col["type"])
202
+ table_name_quoted = f"{preparer.quote_identifier(schema_name)}.{preparer.quote_identifier(table_name)}"
203
+ col_name_quoted = preparer.quote_identifier(col_name)
204
+
205
+ # Build query for sample data
206
+ query = build_query(col_name_quoted, table_name_quoted, n_samples, dialect_name)
207
+
208
+ # Retrieve sample data
209
+ try:
210
+ df = pd.read_sql(query, conn)
211
+ samples = df[col_name].head(n_samples).tolist()
212
+ except Exception as e:
213
+ samples = [f"Error retrieving data: {str(e)}"]
214
+
215
+ table_info["columns"].append({
216
+ "name": col_name,
217
+ "type": col_type,
218
+ "sample_values": samples
219
+ })
220
+
221
+ # Primary keys
222
+ pk_constraint = inspector.get_pk_constraint(table_name, schema=schema_name)
223
+ table_info["primary_key"] = pk_constraint.get("constrained_columns", [])
224
+
225
+ # Foreign keys
226
+ fks = inspector.get_foreign_keys(table_name, schema=schema_name)
227
+ table_info["foreign_keys"] = [
228
+ {
229
+ "local_cols": fk["constrained_columns"],
230
+ "referred_table": fk["referred_table"],
231
+ "referred_cols": fk["referred_columns"]
232
+ }
233
+ for fk in fks
234
+ ]
235
+
236
+ # Indexes
237
+ idxs = inspector.get_indexes(table_name, schema=schema_name)
238
+ table_info["indexes"] = idxs
239
+
240
+ schema_obj["tables"].append(table_info)
241
+
242
+ metadata["schemas"].append(schema_obj)
243
+
225
244
  finally:
226
- # Close connection if created inside the function
227
245
  if is_engine:
228
246
  conn.close()
229
247
 
230
- return "\n".join(output)
248
+ return metadata
249
+
250
+ def build_query(col_name_quoted: str, table_name_quoted: str, n: int, dialect_name: str) -> str:
251
+ # Example: expand your build_query to handle random sampling if possible
252
+ if "postgres" in dialect_name:
253
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
254
+ if "mysql" in dialect_name:
255
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RAND() LIMIT {n}"
256
+ if "sqlite" in dialect_name:
257
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
258
+ if "mssql" in dialect_name:
259
+ return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted} ORDER BY NEWID()"
260
+ # Oracle or fallback
261
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
262
+
@@ -103,4 +103,62 @@ def format_recommended_steps(raw_text: str, heading: str = "# Recommended Steps:
103
103
  if not seen_heading:
104
104
  new_lines.insert(0, heading)
105
105
 
106
- return "\n".join(new_lines)
106
+ return "\n".join(new_lines)
107
+
108
+ def get_generic_summary(report_dict: dict, code_lang = "python") -> str:
109
+ """
110
+ Takes a dictionary of unknown structure (e.g., from json.loads(...))
111
+ and returns a textual summary. It assumes:
112
+ 1) 'report_title' (if present) should be displayed first.
113
+ 2) If a key includes 'code' or 'function',
114
+ the value is treated as a code block.
115
+ 3) Otherwise, key-value pairs are displayed as text.
116
+
117
+ Parameters
118
+ ----------
119
+ report_dict : dict
120
+ The dictionary holding the agent output or user report.
121
+
122
+ Returns
123
+ -------
124
+ str
125
+ A formatted summary string.
126
+ """
127
+ # 1) Grab the report title (or default)
128
+ title = report_dict.get("report_title", "Untitled Report")
129
+
130
+ lines = []
131
+ lines.append(f"# {title}")
132
+
133
+ # 2) Iterate over all other keys
134
+ for key, value in report_dict.items():
135
+ # Skip the title key, since we already displayed it
136
+ if key == "report_title":
137
+ continue
138
+
139
+ # 3) Check if it's code or function
140
+ # (You can tweak this logic if you have different rules)
141
+ key_lower = key.lower()
142
+ if "code" in key_lower or "function" in key_lower:
143
+ # Treat as code
144
+ lines.append(f"\n## {format_agent_name(key).upper()}")
145
+ lines.append(f"```{code_lang}\n" + str(value) + "\n```")
146
+ else:
147
+ # 4) Otherwise, just display the key-value as text
148
+ lines.append(f"\n## {format_agent_name(key).upper()}")
149
+ lines.append(str(value))
150
+
151
+ return "\n".join(lines)
152
+
153
+ def remove_consecutive_duplicates(messages):
154
+ unique_messages = []
155
+ prev_message = None
156
+
157
+ for msg in messages:
158
+ if msg.content != prev_message:
159
+ unique_messages.append(msg)
160
+ prev_message = msg.content # Update previous message to current
161
+
162
+ return unique_messages
163
+
164
+