ai-data-science-team 0.0.0.9008__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +0 -1
- ai_data_science_team/agents/data_cleaning_agent.py +45 -34
- ai_data_science_team/agents/data_visualization_agent.py +39 -43
- ai_data_science_team/agents/data_wrangling_agent.py +45 -44
- ai_data_science_team/agents/feature_engineering_agent.py +42 -61
- ai_data_science_team/agents/sql_database_agent.py +125 -71
- ai_data_science_team/ml_agents/__init__.py +1 -0
- ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +119 -7
- ai_data_science_team/templates/__init__.py +1 -0
- ai_data_science_team/templates/agent_templates.py +73 -2
- ai_data_science_team/tools/metadata.py +94 -62
- ai_data_science_team/tools/regex.py +59 -1
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/METADATA +28 -14
- ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +0 -26
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -7,18 +7,19 @@ from langgraph.graph import START, END, StateGraph
|
|
7
7
|
from langgraph.graph.state import CompiledStateGraph
|
8
8
|
from langgraph.types import Command
|
9
9
|
|
10
|
-
from typing import TypedDict, Annotated, Sequence
|
10
|
+
from typing import TypedDict, Annotated, Sequence, Literal
|
11
11
|
import operator
|
12
12
|
|
13
|
-
from typing_extensions import TypedDict
|
13
|
+
from typing_extensions import TypedDict
|
14
14
|
|
15
15
|
import pandas as pd
|
16
|
+
import json
|
16
17
|
from IPython.display import Markdown
|
17
18
|
|
18
19
|
from ai_data_science_team.templates import BaseAgent
|
19
20
|
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
20
21
|
from ai_data_science_team.utils.plotly import plotly_from_dict
|
21
|
-
|
22
|
+
from ai_data_science_team.tools.regex import remove_consecutive_duplicates, get_generic_summary
|
22
23
|
|
23
24
|
|
24
25
|
class SQLDataAnalyst(BaseAgent):
|
@@ -90,7 +91,7 @@ class SQLDataAnalyst(BaseAgent):
|
|
90
91
|
self._params[k] = v
|
91
92
|
self._compiled_graph = self._make_compiled_graph()
|
92
93
|
|
93
|
-
def ainvoke_agent(self, user_instructions, **kwargs):
|
94
|
+
def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
|
94
95
|
"""
|
95
96
|
Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
|
96
97
|
|
@@ -108,15 +109,53 @@ class SQLDataAnalyst(BaseAgent):
|
|
108
109
|
Example:
|
109
110
|
--------
|
110
111
|
``` python
|
111
|
-
|
112
|
+
from langchain_openai import ChatOpenAI
|
113
|
+
import sqlalchemy as sql
|
114
|
+
from ai_data_science_team.multiagents import SQLDataAnalyst
|
115
|
+
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
116
|
+
|
117
|
+
llm = ChatOpenAI(model = "gpt-4o-mini")
|
118
|
+
|
119
|
+
sql_engine = sql.create_engine("sqlite:///data/northwind.db")
|
120
|
+
|
121
|
+
conn = sql_engine.connect()
|
122
|
+
|
123
|
+
sql_data_analyst = SQLDataAnalyst(
|
124
|
+
model = llm,
|
125
|
+
sql_database_agent = SQLDatabaseAgent(
|
126
|
+
model = llm,
|
127
|
+
connection = conn,
|
128
|
+
n_samples = 1,
|
129
|
+
),
|
130
|
+
data_visualization_agent = DataVisualizationAgent(
|
131
|
+
model = llm,
|
132
|
+
n_samples = 10,
|
133
|
+
)
|
134
|
+
)
|
135
|
+
|
136
|
+
sql_data_analyst.ainvoke_agent(
|
137
|
+
user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
|
138
|
+
)
|
139
|
+
|
140
|
+
sql_data_analyst.get_sql_query_code()
|
141
|
+
|
142
|
+
sql_data_analyst.get_data_sql()
|
143
|
+
|
144
|
+
sql_data_analyst.get_plotly_graph()
|
112
145
|
```
|
113
146
|
"""
|
114
147
|
response = self._compiled_graph.ainvoke({
|
115
148
|
"user_instructions": user_instructions,
|
149
|
+
"max_retries": max_retries,
|
150
|
+
"retry_count": retry_count,
|
116
151
|
}, **kwargs)
|
152
|
+
|
153
|
+
if response.get("messages"):
|
154
|
+
response["messages"] = remove_consecutive_duplicates(response["messages"])
|
155
|
+
|
117
156
|
self.response = response
|
118
157
|
|
119
|
-
def invoke_agent(self, user_instructions, **kwargs):
|
158
|
+
def invoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
|
120
159
|
"""
|
121
160
|
Invokes the SQL Data Analyst Multi-Agent.
|
122
161
|
|
@@ -124,6 +163,10 @@ class SQLDataAnalyst(BaseAgent):
|
|
124
163
|
----------
|
125
164
|
user_instructions: str
|
126
165
|
The user's instructions for the combined SQL and (optionally) Data Visualization agents.
|
166
|
+
max_retries (int):
|
167
|
+
Maximum retry attempts for cleaning.
|
168
|
+
retry_count (int):
|
169
|
+
Current retry attempt.
|
127
170
|
**kwargs:
|
128
171
|
Additional keyword arguments to pass to the compiled graph's `invoke` method.
|
129
172
|
|
@@ -134,14 +177,53 @@ class SQLDataAnalyst(BaseAgent):
|
|
134
177
|
Example:
|
135
178
|
--------
|
136
179
|
``` python
|
137
|
-
|
180
|
+
from langchain_openai import ChatOpenAI
|
181
|
+
import sqlalchemy as sql
|
182
|
+
from ai_data_science_team.multiagents import SQLDataAnalyst
|
183
|
+
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
184
|
+
|
185
|
+
llm = ChatOpenAI(model = "gpt-4o-mini")
|
186
|
+
|
187
|
+
sql_engine = sql.create_engine("sqlite:///data/northwind.db")
|
188
|
+
|
189
|
+
conn = sql_engine.connect()
|
190
|
+
|
191
|
+
sql_data_analyst = SQLDataAnalyst(
|
192
|
+
model = llm,
|
193
|
+
sql_database_agent = SQLDatabaseAgent(
|
194
|
+
model = llm,
|
195
|
+
connection = conn,
|
196
|
+
n_samples = 1,
|
197
|
+
),
|
198
|
+
data_visualization_agent = DataVisualizationAgent(
|
199
|
+
model = llm,
|
200
|
+
n_samples = 10,
|
201
|
+
)
|
202
|
+
)
|
203
|
+
|
204
|
+
sql_data_analyst.invoke_agent(
|
205
|
+
user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
|
206
|
+
)
|
207
|
+
|
208
|
+
sql_data_analyst.get_sql_query_code()
|
209
|
+
|
210
|
+
sql_data_analyst.get_data_sql()
|
211
|
+
|
212
|
+
sql_data_analyst.get_plotly_graph()
|
138
213
|
```
|
139
214
|
"""
|
140
215
|
response = self._compiled_graph.invoke({
|
141
216
|
"user_instructions": user_instructions,
|
217
|
+
"max_retries": max_retries,
|
218
|
+
"retry_count": retry_count,
|
142
219
|
}, **kwargs)
|
220
|
+
|
221
|
+
if response.get("messages"):
|
222
|
+
response["messages"] = remove_consecutive_duplicates(response["messages"])
|
223
|
+
|
143
224
|
self.response = response
|
144
225
|
|
226
|
+
|
145
227
|
def get_data_sql(self):
|
146
228
|
"""
|
147
229
|
Returns the SQL data as a Pandas DataFrame.
|
@@ -205,6 +287,34 @@ class SQLDataAnalyst(BaseAgent):
|
|
205
287
|
if markdown:
|
206
288
|
return Markdown(f"```python\n{self.response.get('data_visualization_function')}\n```")
|
207
289
|
return self.response.get("data_visualization_function")
|
290
|
+
|
291
|
+
def get_workflow_summary(self, markdown=False):
|
292
|
+
"""
|
293
|
+
Returns a summary of the SQL Data Analyst workflow.
|
294
|
+
|
295
|
+
Parameters:
|
296
|
+
----------
|
297
|
+
markdown: bool
|
298
|
+
If True, returns the summary as a Markdown-formatted string.
|
299
|
+
"""
|
300
|
+
if self.response and self.get_response()['messages']:
|
301
|
+
|
302
|
+
agents = [self.get_response()['messages'][i].role for i in range(len(self.get_response()['messages']))]
|
303
|
+
|
304
|
+
agent_labels = []
|
305
|
+
for i in range(len(agents)):
|
306
|
+
agent_labels.append(f"- **Agent {i+1}:** {agents[i]}")
|
307
|
+
|
308
|
+
# Construct header
|
309
|
+
header = f"# SQL Data Analyst Workflow Summary Report\n\nThis agentic workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
|
310
|
+
|
311
|
+
reports = []
|
312
|
+
for msg in self.get_response()['messages']:
|
313
|
+
reports.append(get_generic_summary(json.loads(msg.content)))
|
314
|
+
|
315
|
+
if markdown:
|
316
|
+
return Markdown(header + "\n\n".join(reports))
|
317
|
+
return "\n\n".join(reports)
|
208
318
|
|
209
319
|
|
210
320
|
|
@@ -250,6 +360,8 @@ def make_sql_data_analyst(
|
|
250
360
|
plot_required: bool
|
251
361
|
data_visualization_function: str
|
252
362
|
plotly_graph: dict
|
363
|
+
max_retries: int
|
364
|
+
retry_count: int
|
253
365
|
|
254
366
|
def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
|
255
367
|
|
@@ -8,11 +8,16 @@ from langgraph.pregel.types import StreamMode
|
|
8
8
|
|
9
9
|
import pandas as pd
|
10
10
|
import sqlalchemy as sql
|
11
|
+
import json
|
11
12
|
|
12
|
-
from typing import Any, Callable, Dict, Type, Optional, Union
|
13
|
+
from typing import Any, Callable, Dict, Type, Optional, Union, List
|
13
14
|
|
14
15
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
15
|
-
from ai_data_science_team.tools.regex import
|
16
|
+
from ai_data_science_team.tools.regex import (
|
17
|
+
relocate_imports_inside_function,
|
18
|
+
add_comments_to_top,
|
19
|
+
remove_consecutive_duplicates
|
20
|
+
)
|
16
21
|
|
17
22
|
from IPython.display import Image, display
|
18
23
|
import pandas as pd
|
@@ -82,6 +87,10 @@ class BaseAgent(CompiledStateGraph):
|
|
82
87
|
Any: The agent's response.
|
83
88
|
"""
|
84
89
|
self.response = self._compiled_graph.invoke(input=input, config=config,**kwargs)
|
90
|
+
|
91
|
+
if self.response.get("messages"):
|
92
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
93
|
+
|
85
94
|
return self.response
|
86
95
|
|
87
96
|
def ainvoke(
|
@@ -102,6 +111,10 @@ class BaseAgent(CompiledStateGraph):
|
|
102
111
|
Any: The agent's response.
|
103
112
|
"""
|
104
113
|
self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
|
114
|
+
|
115
|
+
if self.response.get("messages"):
|
116
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
117
|
+
|
105
118
|
return self.response
|
106
119
|
|
107
120
|
def stream(
|
@@ -129,6 +142,10 @@ class BaseAgent(CompiledStateGraph):
|
|
129
142
|
Any: The agent's response.
|
130
143
|
"""
|
131
144
|
self.response = self._compiled_graph.stream(input=input, config=config, stream_mode=stream_mode, **kwargs)
|
145
|
+
|
146
|
+
if self.response.get("messages"):
|
147
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
148
|
+
|
132
149
|
return self.response
|
133
150
|
|
134
151
|
def astream(
|
@@ -156,6 +173,10 @@ class BaseAgent(CompiledStateGraph):
|
|
156
173
|
Any: The agent's response.
|
157
174
|
"""
|
158
175
|
self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
|
176
|
+
|
177
|
+
if self.response.get("messages"):
|
178
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
179
|
+
|
159
180
|
return self.response
|
160
181
|
|
161
182
|
def get_state_keys(self):
|
@@ -183,6 +204,9 @@ class BaseAgent(CompiledStateGraph):
|
|
183
204
|
Returns:
|
184
205
|
Any: The agent's response.
|
185
206
|
"""
|
207
|
+
if self.response.get("messages"):
|
208
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
209
|
+
|
186
210
|
return self.response
|
187
211
|
|
188
212
|
def show(self, xray: int = 0):
|
@@ -729,3 +753,50 @@ def node_func_explain_agent_code(
|
|
729
753
|
# Return an error message if there was a problem with the code
|
730
754
|
message = AIMessage(content=error_message)
|
731
755
|
return {result_key: [message]}
|
756
|
+
|
757
|
+
|
758
|
+
|
759
|
+
def node_func_report_agent_outputs(
|
760
|
+
state: Dict[str, Any],
|
761
|
+
keys_to_include: List[str],
|
762
|
+
result_key: str,
|
763
|
+
role: str,
|
764
|
+
custom_title: str = "Agent Output Summary"
|
765
|
+
) -> Dict[str, Any]:
|
766
|
+
"""
|
767
|
+
Gathers relevant data directly from the state (filtered by `keys_to_include`)
|
768
|
+
and returns them as a structured message in `state[result_key]`.
|
769
|
+
|
770
|
+
No LLM is used.
|
771
|
+
|
772
|
+
Parameters
|
773
|
+
----------
|
774
|
+
state : Dict[str, Any]
|
775
|
+
The current state dictionary holding all agent variables.
|
776
|
+
keys_to_include : List[str]
|
777
|
+
The list of keys in `state` to include in the output.
|
778
|
+
result_key : str
|
779
|
+
The key in `state` under which we'll store the final structured message.
|
780
|
+
role : str
|
781
|
+
The role that will be used in the final AIMessage (e.g., "DataCleaningAgent").
|
782
|
+
custom_title : str, optional
|
783
|
+
A title or heading for your report. Defaults to "Agent Output Summary".
|
784
|
+
"""
|
785
|
+
print(" * REPORT AGENT OUTPUTS")
|
786
|
+
|
787
|
+
final_report = {"report_title": custom_title}
|
788
|
+
|
789
|
+
for key in keys_to_include:
|
790
|
+
final_report[key] = state.get(key, f"<{key}_not_found_in_state>")
|
791
|
+
|
792
|
+
# Wrap it in a list of messages (like the current "messages" pattern).
|
793
|
+
# You can serialize this dictionary as JSON or just cast it to string.
|
794
|
+
return {
|
795
|
+
result_key: [
|
796
|
+
AIMessage(
|
797
|
+
content=json.dumps(final_report, indent=2),
|
798
|
+
role=role
|
799
|
+
)
|
800
|
+
]
|
801
|
+
}
|
802
|
+
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import io
|
2
2
|
import pandas as pd
|
3
3
|
import sqlalchemy as sql
|
4
|
+
from sqlalchemy import inspect
|
4
5
|
from typing import Union, List, Dict
|
5
6
|
|
6
7
|
def get_dataframe_summary(
|
@@ -139,8 +140,7 @@ def _summarize_dataframe(df: pd.DataFrame, dataset_name: str, n_sample=30, skip_
|
|
139
140
|
|
140
141
|
|
141
142
|
|
142
|
-
def get_database_metadata(connection
|
143
|
-
n_samples: int = 10) -> str:
|
143
|
+
def get_database_metadata(connection, n_samples=10) -> dict:
|
144
144
|
"""
|
145
145
|
Collects metadata and sample data from a database, with safe identifier quoting and
|
146
146
|
basic dialect-aware row limiting. Prevents issues with spaces/reserved words in identifiers.
|
@@ -154,77 +154,109 @@ def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engi
|
|
154
154
|
|
155
155
|
Returns
|
156
156
|
-------
|
157
|
-
|
158
|
-
A
|
157
|
+
dict
|
158
|
+
A dictionary with database metadata, including some sample data from each column.
|
159
159
|
"""
|
160
|
-
|
161
|
-
# If a connection is passed, use it; if an engine is passed, connect to it
|
162
160
|
is_engine = isinstance(connection, sql.engine.base.Engine)
|
163
161
|
conn = connection.connect() if is_engine else connection
|
164
162
|
|
165
|
-
|
163
|
+
metadata = {
|
164
|
+
"dialect": None,
|
165
|
+
"driver": None,
|
166
|
+
"connection_url": None,
|
167
|
+
"schemas": [],
|
168
|
+
}
|
169
|
+
|
166
170
|
try:
|
167
|
-
# Grab the engine off the connection
|
168
171
|
sql_engine = conn.engine
|
169
172
|
dialect_name = sql_engine.dialect.name.lower()
|
170
173
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
# Inspect the database
|
176
|
-
inspector = sql.inspect(sql_engine)
|
177
|
-
tables = inspector.get_table_names()
|
178
|
-
output.append(f"Tables: {tables}")
|
179
|
-
output.append(f"Schemas: {inspector.get_schema_names()}")
|
180
|
-
|
181
|
-
# Helper to build a dialect-specific limit clause
|
182
|
-
def build_query(col_name_quoted: str, table_name_quoted: str, n: int) -> str:
|
183
|
-
"""
|
184
|
-
Returns a SQL query string to select N rows from the given column/table
|
185
|
-
across different dialects (SQLite, MySQL, Postgres, MSSQL, Oracle, etc.)
|
186
|
-
"""
|
187
|
-
if "sqlite" in dialect_name or "mysql" in dialect_name or "postgres" in dialect_name:
|
188
|
-
# Common dialects supporting LIMIT
|
189
|
-
return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
|
190
|
-
elif "mssql" in dialect_name:
|
191
|
-
# Microsoft SQL Server syntax
|
192
|
-
return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted}"
|
193
|
-
elif "oracle" in dialect_name:
|
194
|
-
# Oracle syntax
|
195
|
-
return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
|
196
|
-
else:
|
197
|
-
# Fallback
|
198
|
-
return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
|
199
|
-
|
200
|
-
# Prepare for quoting
|
201
|
-
preparer = inspector.bind.dialect.identifier_preparer
|
202
|
-
|
203
|
-
# For each table, get columns and sample data
|
204
|
-
for table_name in tables:
|
205
|
-
output.append(f"\nTable: {table_name}")
|
206
|
-
# Properly quote the table name
|
207
|
-
table_name_quoted = preparer.quote_identifier(table_name)
|
208
|
-
|
209
|
-
for column in inspector.get_columns(table_name):
|
210
|
-
col_name = column["name"]
|
211
|
-
col_type = column["type"]
|
212
|
-
output.append(f" Column: {col_name} Type: {col_type}")
|
174
|
+
metadata["dialect"] = sql_engine.dialect.name
|
175
|
+
metadata["driver"] = sql_engine.driver
|
176
|
+
metadata["connection_url"] = str(sql_engine.url)
|
213
177
|
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
# Build a dialect-aware query with safe quoting
|
218
|
-
query = build_query(col_name_quoted, table_name_quoted, n_samples)
|
219
|
-
|
220
|
-
# Read a few sample values
|
221
|
-
df = pd.read_sql(sql.text(query), conn)
|
222
|
-
first_values = df[col_name].tolist()
|
223
|
-
output.append(f" First {n_samples} Values: {first_values}")
|
178
|
+
inspector = inspect(sql_engine)
|
179
|
+
preparer = inspector.bind.dialect.identifier_preparer
|
224
180
|
|
181
|
+
# For each schema
|
182
|
+
for schema_name in inspector.get_schema_names():
|
183
|
+
schema_obj = {
|
184
|
+
"schema_name": schema_name,
|
185
|
+
"tables": []
|
186
|
+
}
|
187
|
+
|
188
|
+
tables = inspector.get_table_names(schema=schema_name)
|
189
|
+
for table_name in tables:
|
190
|
+
table_info = {
|
191
|
+
"table_name": table_name,
|
192
|
+
"columns": [],
|
193
|
+
"primary_key": [],
|
194
|
+
"foreign_keys": [],
|
195
|
+
"indexes": []
|
196
|
+
}
|
197
|
+
# Get columns
|
198
|
+
columns = inspector.get_columns(table_name, schema=schema_name)
|
199
|
+
for col in columns:
|
200
|
+
col_name = col["name"]
|
201
|
+
col_type = str(col["type"])
|
202
|
+
table_name_quoted = f"{preparer.quote_identifier(schema_name)}.{preparer.quote_identifier(table_name)}"
|
203
|
+
col_name_quoted = preparer.quote_identifier(col_name)
|
204
|
+
|
205
|
+
# Build query for sample data
|
206
|
+
query = build_query(col_name_quoted, table_name_quoted, n_samples, dialect_name)
|
207
|
+
|
208
|
+
# Retrieve sample data
|
209
|
+
try:
|
210
|
+
df = pd.read_sql(query, conn)
|
211
|
+
samples = df[col_name].head(n_samples).tolist()
|
212
|
+
except Exception as e:
|
213
|
+
samples = [f"Error retrieving data: {str(e)}"]
|
214
|
+
|
215
|
+
table_info["columns"].append({
|
216
|
+
"name": col_name,
|
217
|
+
"type": col_type,
|
218
|
+
"sample_values": samples
|
219
|
+
})
|
220
|
+
|
221
|
+
# Primary keys
|
222
|
+
pk_constraint = inspector.get_pk_constraint(table_name, schema=schema_name)
|
223
|
+
table_info["primary_key"] = pk_constraint.get("constrained_columns", [])
|
224
|
+
|
225
|
+
# Foreign keys
|
226
|
+
fks = inspector.get_foreign_keys(table_name, schema=schema_name)
|
227
|
+
table_info["foreign_keys"] = [
|
228
|
+
{
|
229
|
+
"local_cols": fk["constrained_columns"],
|
230
|
+
"referred_table": fk["referred_table"],
|
231
|
+
"referred_cols": fk["referred_columns"]
|
232
|
+
}
|
233
|
+
for fk in fks
|
234
|
+
]
|
235
|
+
|
236
|
+
# Indexes
|
237
|
+
idxs = inspector.get_indexes(table_name, schema=schema_name)
|
238
|
+
table_info["indexes"] = idxs
|
239
|
+
|
240
|
+
schema_obj["tables"].append(table_info)
|
241
|
+
|
242
|
+
metadata["schemas"].append(schema_obj)
|
243
|
+
|
225
244
|
finally:
|
226
|
-
# Close connection if created inside the function
|
227
245
|
if is_engine:
|
228
246
|
conn.close()
|
229
247
|
|
230
|
-
return
|
248
|
+
return metadata
|
249
|
+
|
250
|
+
def build_query(col_name_quoted: str, table_name_quoted: str, n: int, dialect_name: str) -> str:
|
251
|
+
# Example: expand your build_query to handle random sampling if possible
|
252
|
+
if "postgres" in dialect_name:
|
253
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
|
254
|
+
if "mysql" in dialect_name:
|
255
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RAND() LIMIT {n}"
|
256
|
+
if "sqlite" in dialect_name:
|
257
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
|
258
|
+
if "mssql" in dialect_name:
|
259
|
+
return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted} ORDER BY NEWID()"
|
260
|
+
# Oracle or fallback
|
261
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
|
262
|
+
|
@@ -103,4 +103,62 @@ def format_recommended_steps(raw_text: str, heading: str = "# Recommended Steps:
|
|
103
103
|
if not seen_heading:
|
104
104
|
new_lines.insert(0, heading)
|
105
105
|
|
106
|
-
return "\n".join(new_lines)
|
106
|
+
return "\n".join(new_lines)
|
107
|
+
|
108
|
+
def get_generic_summary(report_dict: dict, code_lang = "python") -> str:
|
109
|
+
"""
|
110
|
+
Takes a dictionary of unknown structure (e.g., from json.loads(...))
|
111
|
+
and returns a textual summary. It assumes:
|
112
|
+
1) 'report_title' (if present) should be displayed first.
|
113
|
+
2) If a key includes 'code' or 'function',
|
114
|
+
the value is treated as a code block.
|
115
|
+
3) Otherwise, key-value pairs are displayed as text.
|
116
|
+
|
117
|
+
Parameters
|
118
|
+
----------
|
119
|
+
report_dict : dict
|
120
|
+
The dictionary holding the agent output or user report.
|
121
|
+
|
122
|
+
Returns
|
123
|
+
-------
|
124
|
+
str
|
125
|
+
A formatted summary string.
|
126
|
+
"""
|
127
|
+
# 1) Grab the report title (or default)
|
128
|
+
title = report_dict.get("report_title", "Untitled Report")
|
129
|
+
|
130
|
+
lines = []
|
131
|
+
lines.append(f"# {title}")
|
132
|
+
|
133
|
+
# 2) Iterate over all other keys
|
134
|
+
for key, value in report_dict.items():
|
135
|
+
# Skip the title key, since we already displayed it
|
136
|
+
if key == "report_title":
|
137
|
+
continue
|
138
|
+
|
139
|
+
# 3) Check if it's code or function
|
140
|
+
# (You can tweak this logic if you have different rules)
|
141
|
+
key_lower = key.lower()
|
142
|
+
if "code" in key_lower or "function" in key_lower:
|
143
|
+
# Treat as code
|
144
|
+
lines.append(f"\n## {format_agent_name(key).upper()}")
|
145
|
+
lines.append(f"```{code_lang}\n" + str(value) + "\n```")
|
146
|
+
else:
|
147
|
+
# 4) Otherwise, just display the key-value as text
|
148
|
+
lines.append(f"\n## {format_agent_name(key).upper()}")
|
149
|
+
lines.append(str(value))
|
150
|
+
|
151
|
+
return "\n".join(lines)
|
152
|
+
|
153
|
+
def remove_consecutive_duplicates(messages):
|
154
|
+
unique_messages = []
|
155
|
+
prev_message = None
|
156
|
+
|
157
|
+
for msg in messages:
|
158
|
+
if msg.content != prev_message:
|
159
|
+
unique_messages.append(msg)
|
160
|
+
prev_message = msg.content # Update previous message to current
|
161
|
+
|
162
|
+
return unique_messages
|
163
|
+
|
164
|
+
|