ai-data-science-team 0.0.0.9008__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +0 -1
- ai_data_science_team/agents/data_cleaning_agent.py +45 -34
- ai_data_science_team/agents/data_visualization_agent.py +39 -43
- ai_data_science_team/agents/data_wrangling_agent.py +45 -44
- ai_data_science_team/agents/feature_engineering_agent.py +42 -61
- ai_data_science_team/agents/sql_database_agent.py +125 -71
- ai_data_science_team/ml_agents/__init__.py +1 -0
- ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +119 -7
- ai_data_science_team/templates/__init__.py +1 -0
- ai_data_science_team/templates/agent_templates.py +73 -2
- ai_data_science_team/tools/metadata.py +94 -62
- ai_data_science_team/tools/regex.py +59 -1
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/METADATA +28 -14
- ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +0 -26
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -7,18 +7,19 @@ from langgraph.graph import START, END, StateGraph
|
|
7
7
|
from langgraph.graph.state import CompiledStateGraph
|
8
8
|
from langgraph.types import Command
|
9
9
|
|
10
|
-
from typing import TypedDict, Annotated, Sequence
|
10
|
+
from typing import TypedDict, Annotated, Sequence, Literal
|
11
11
|
import operator
|
12
12
|
|
13
|
-
from typing_extensions import TypedDict
|
13
|
+
from typing_extensions import TypedDict
|
14
14
|
|
15
15
|
import pandas as pd
|
16
|
+
import json
|
16
17
|
from IPython.display import Markdown
|
17
18
|
|
18
19
|
from ai_data_science_team.templates import BaseAgent
|
19
20
|
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
20
21
|
from ai_data_science_team.utils.plotly import plotly_from_dict
|
21
|
-
|
22
|
+
from ai_data_science_team.tools.regex import remove_consecutive_duplicates, get_generic_summary
|
22
23
|
|
23
24
|
|
24
25
|
class SQLDataAnalyst(BaseAgent):
|
@@ -90,7 +91,7 @@ class SQLDataAnalyst(BaseAgent):
|
|
90
91
|
self._params[k] = v
|
91
92
|
self._compiled_graph = self._make_compiled_graph()
|
92
93
|
|
93
|
-
def ainvoke_agent(self, user_instructions, **kwargs):
|
94
|
+
def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
|
94
95
|
"""
|
95
96
|
Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
|
96
97
|
|
@@ -108,15 +109,53 @@ class SQLDataAnalyst(BaseAgent):
|
|
108
109
|
Example:
|
109
110
|
--------
|
110
111
|
``` python
|
111
|
-
|
112
|
+
from langchain_openai import ChatOpenAI
|
113
|
+
import sqlalchemy as sql
|
114
|
+
from ai_data_science_team.multiagents import SQLDataAnalyst
|
115
|
+
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
116
|
+
|
117
|
+
llm = ChatOpenAI(model = "gpt-4o-mini")
|
118
|
+
|
119
|
+
sql_engine = sql.create_engine("sqlite:///data/northwind.db")
|
120
|
+
|
121
|
+
conn = sql_engine.connect()
|
122
|
+
|
123
|
+
sql_data_analyst = SQLDataAnalyst(
|
124
|
+
model = llm,
|
125
|
+
sql_database_agent = SQLDatabaseAgent(
|
126
|
+
model = llm,
|
127
|
+
connection = conn,
|
128
|
+
n_samples = 1,
|
129
|
+
),
|
130
|
+
data_visualization_agent = DataVisualizationAgent(
|
131
|
+
model = llm,
|
132
|
+
n_samples = 10,
|
133
|
+
)
|
134
|
+
)
|
135
|
+
|
136
|
+
sql_data_analyst.ainvoke_agent(
|
137
|
+
user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
|
138
|
+
)
|
139
|
+
|
140
|
+
sql_data_analyst.get_sql_query_code()
|
141
|
+
|
142
|
+
sql_data_analyst.get_data_sql()
|
143
|
+
|
144
|
+
sql_data_analyst.get_plotly_graph()
|
112
145
|
```
|
113
146
|
"""
|
114
147
|
response = self._compiled_graph.ainvoke({
|
115
148
|
"user_instructions": user_instructions,
|
149
|
+
"max_retries": max_retries,
|
150
|
+
"retry_count": retry_count,
|
116
151
|
}, **kwargs)
|
152
|
+
|
153
|
+
if response.get("messages"):
|
154
|
+
response["messages"] = remove_consecutive_duplicates(response["messages"])
|
155
|
+
|
117
156
|
self.response = response
|
118
157
|
|
119
|
-
def invoke_agent(self, user_instructions, **kwargs):
|
158
|
+
def invoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
|
120
159
|
"""
|
121
160
|
Invokes the SQL Data Analyst Multi-Agent.
|
122
161
|
|
@@ -124,6 +163,10 @@ class SQLDataAnalyst(BaseAgent):
|
|
124
163
|
----------
|
125
164
|
user_instructions: str
|
126
165
|
The user's instructions for the combined SQL and (optionally) Data Visualization agents.
|
166
|
+
max_retries (int):
|
167
|
+
Maximum retry attempts for cleaning.
|
168
|
+
retry_count (int):
|
169
|
+
Current retry attempt.
|
127
170
|
**kwargs:
|
128
171
|
Additional keyword arguments to pass to the compiled graph's `invoke` method.
|
129
172
|
|
@@ -134,14 +177,53 @@ class SQLDataAnalyst(BaseAgent):
|
|
134
177
|
Example:
|
135
178
|
--------
|
136
179
|
``` python
|
137
|
-
|
180
|
+
from langchain_openai import ChatOpenAI
|
181
|
+
import sqlalchemy as sql
|
182
|
+
from ai_data_science_team.multiagents import SQLDataAnalyst
|
183
|
+
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
184
|
+
|
185
|
+
llm = ChatOpenAI(model = "gpt-4o-mini")
|
186
|
+
|
187
|
+
sql_engine = sql.create_engine("sqlite:///data/northwind.db")
|
188
|
+
|
189
|
+
conn = sql_engine.connect()
|
190
|
+
|
191
|
+
sql_data_analyst = SQLDataAnalyst(
|
192
|
+
model = llm,
|
193
|
+
sql_database_agent = SQLDatabaseAgent(
|
194
|
+
model = llm,
|
195
|
+
connection = conn,
|
196
|
+
n_samples = 1,
|
197
|
+
),
|
198
|
+
data_visualization_agent = DataVisualizationAgent(
|
199
|
+
model = llm,
|
200
|
+
n_samples = 10,
|
201
|
+
)
|
202
|
+
)
|
203
|
+
|
204
|
+
sql_data_analyst.invoke_agent(
|
205
|
+
user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
|
206
|
+
)
|
207
|
+
|
208
|
+
sql_data_analyst.get_sql_query_code()
|
209
|
+
|
210
|
+
sql_data_analyst.get_data_sql()
|
211
|
+
|
212
|
+
sql_data_analyst.get_plotly_graph()
|
138
213
|
```
|
139
214
|
"""
|
140
215
|
response = self._compiled_graph.invoke({
|
141
216
|
"user_instructions": user_instructions,
|
217
|
+
"max_retries": max_retries,
|
218
|
+
"retry_count": retry_count,
|
142
219
|
}, **kwargs)
|
220
|
+
|
221
|
+
if response.get("messages"):
|
222
|
+
response["messages"] = remove_consecutive_duplicates(response["messages"])
|
223
|
+
|
143
224
|
self.response = response
|
144
225
|
|
226
|
+
|
145
227
|
def get_data_sql(self):
|
146
228
|
"""
|
147
229
|
Returns the SQL data as a Pandas DataFrame.
|
@@ -205,6 +287,34 @@ class SQLDataAnalyst(BaseAgent):
|
|
205
287
|
if markdown:
|
206
288
|
return Markdown(f"```python\n{self.response.get('data_visualization_function')}\n```")
|
207
289
|
return self.response.get("data_visualization_function")
|
290
|
+
|
291
|
+
def get_workflow_summary(self, markdown=False):
|
292
|
+
"""
|
293
|
+
Returns a summary of the SQL Data Analyst workflow.
|
294
|
+
|
295
|
+
Parameters:
|
296
|
+
----------
|
297
|
+
markdown: bool
|
298
|
+
If True, returns the summary as a Markdown-formatted string.
|
299
|
+
"""
|
300
|
+
if self.response and self.get_response()['messages']:
|
301
|
+
|
302
|
+
agents = [self.get_response()['messages'][i].role for i in range(len(self.get_response()['messages']))]
|
303
|
+
|
304
|
+
agent_labels = []
|
305
|
+
for i in range(len(agents)):
|
306
|
+
agent_labels.append(f"- **Agent {i+1}:** {agents[i]}")
|
307
|
+
|
308
|
+
# Construct header
|
309
|
+
header = f"# SQL Data Analyst Workflow Summary Report\n\nThis agentic workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
|
310
|
+
|
311
|
+
reports = []
|
312
|
+
for msg in self.get_response()['messages']:
|
313
|
+
reports.append(get_generic_summary(json.loads(msg.content)))
|
314
|
+
|
315
|
+
if markdown:
|
316
|
+
return Markdown(header + "\n\n".join(reports))
|
317
|
+
return "\n\n".join(reports)
|
208
318
|
|
209
319
|
|
210
320
|
|
@@ -250,6 +360,8 @@ def make_sql_data_analyst(
|
|
250
360
|
plot_required: bool
|
251
361
|
data_visualization_function: str
|
252
362
|
plotly_graph: dict
|
363
|
+
max_retries: int
|
364
|
+
retry_count: int
|
253
365
|
|
254
366
|
def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
|
255
367
|
|
@@ -8,11 +8,16 @@ from langgraph.pregel.types import StreamMode
|
|
8
8
|
|
9
9
|
import pandas as pd
|
10
10
|
import sqlalchemy as sql
|
11
|
+
import json
|
11
12
|
|
12
|
-
from typing import Any, Callable, Dict, Type, Optional, Union
|
13
|
+
from typing import Any, Callable, Dict, Type, Optional, Union, List
|
13
14
|
|
14
15
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
15
|
-
from ai_data_science_team.tools.regex import
|
16
|
+
from ai_data_science_team.tools.regex import (
|
17
|
+
relocate_imports_inside_function,
|
18
|
+
add_comments_to_top,
|
19
|
+
remove_consecutive_duplicates
|
20
|
+
)
|
16
21
|
|
17
22
|
from IPython.display import Image, display
|
18
23
|
import pandas as pd
|
@@ -82,6 +87,10 @@ class BaseAgent(CompiledStateGraph):
|
|
82
87
|
Any: The agent's response.
|
83
88
|
"""
|
84
89
|
self.response = self._compiled_graph.invoke(input=input, config=config,**kwargs)
|
90
|
+
|
91
|
+
if self.response.get("messages"):
|
92
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
93
|
+
|
85
94
|
return self.response
|
86
95
|
|
87
96
|
def ainvoke(
|
@@ -102,6 +111,10 @@ class BaseAgent(CompiledStateGraph):
|
|
102
111
|
Any: The agent's response.
|
103
112
|
"""
|
104
113
|
self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
|
114
|
+
|
115
|
+
if self.response.get("messages"):
|
116
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
117
|
+
|
105
118
|
return self.response
|
106
119
|
|
107
120
|
def stream(
|
@@ -129,6 +142,10 @@ class BaseAgent(CompiledStateGraph):
|
|
129
142
|
Any: The agent's response.
|
130
143
|
"""
|
131
144
|
self.response = self._compiled_graph.stream(input=input, config=config, stream_mode=stream_mode, **kwargs)
|
145
|
+
|
146
|
+
if self.response.get("messages"):
|
147
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
148
|
+
|
132
149
|
return self.response
|
133
150
|
|
134
151
|
def astream(
|
@@ -156,6 +173,10 @@ class BaseAgent(CompiledStateGraph):
|
|
156
173
|
Any: The agent's response.
|
157
174
|
"""
|
158
175
|
self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
|
176
|
+
|
177
|
+
if self.response.get("messages"):
|
178
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
179
|
+
|
159
180
|
return self.response
|
160
181
|
|
161
182
|
def get_state_keys(self):
|
@@ -183,6 +204,9 @@ class BaseAgent(CompiledStateGraph):
|
|
183
204
|
Returns:
|
184
205
|
Any: The agent's response.
|
185
206
|
"""
|
207
|
+
if self.response.get("messages"):
|
208
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
209
|
+
|
186
210
|
return self.response
|
187
211
|
|
188
212
|
def show(self, xray: int = 0):
|
@@ -729,3 +753,50 @@ def node_func_explain_agent_code(
|
|
729
753
|
# Return an error message if there was a problem with the code
|
730
754
|
message = AIMessage(content=error_message)
|
731
755
|
return {result_key: [message]}
|
756
|
+
|
757
|
+
|
758
|
+
|
759
|
+
def node_func_report_agent_outputs(
|
760
|
+
state: Dict[str, Any],
|
761
|
+
keys_to_include: List[str],
|
762
|
+
result_key: str,
|
763
|
+
role: str,
|
764
|
+
custom_title: str = "Agent Output Summary"
|
765
|
+
) -> Dict[str, Any]:
|
766
|
+
"""
|
767
|
+
Gathers relevant data directly from the state (filtered by `keys_to_include`)
|
768
|
+
and returns them as a structured message in `state[result_key]`.
|
769
|
+
|
770
|
+
No LLM is used.
|
771
|
+
|
772
|
+
Parameters
|
773
|
+
----------
|
774
|
+
state : Dict[str, Any]
|
775
|
+
The current state dictionary holding all agent variables.
|
776
|
+
keys_to_include : List[str]
|
777
|
+
The list of keys in `state` to include in the output.
|
778
|
+
result_key : str
|
779
|
+
The key in `state` under which we'll store the final structured message.
|
780
|
+
role : str
|
781
|
+
The role that will be used in the final AIMessage (e.g., "DataCleaningAgent").
|
782
|
+
custom_title : str, optional
|
783
|
+
A title or heading for your report. Defaults to "Agent Output Summary".
|
784
|
+
"""
|
785
|
+
print(" * REPORT AGENT OUTPUTS")
|
786
|
+
|
787
|
+
final_report = {"report_title": custom_title}
|
788
|
+
|
789
|
+
for key in keys_to_include:
|
790
|
+
final_report[key] = state.get(key, f"<{key}_not_found_in_state>")
|
791
|
+
|
792
|
+
# Wrap it in a list of messages (like the current "messages" pattern).
|
793
|
+
# You can serialize this dictionary as JSON or just cast it to string.
|
794
|
+
return {
|
795
|
+
result_key: [
|
796
|
+
AIMessage(
|
797
|
+
content=json.dumps(final_report, indent=2),
|
798
|
+
role=role
|
799
|
+
)
|
800
|
+
]
|
801
|
+
}
|
802
|
+
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import io
|
2
2
|
import pandas as pd
|
3
3
|
import sqlalchemy as sql
|
4
|
+
from sqlalchemy import inspect
|
4
5
|
from typing import Union, List, Dict
|
5
6
|
|
6
7
|
def get_dataframe_summary(
|
@@ -139,8 +140,7 @@ def _summarize_dataframe(df: pd.DataFrame, dataset_name: str, n_sample=30, skip_
|
|
139
140
|
|
140
141
|
|
141
142
|
|
142
|
-
def get_database_metadata(connection
|
143
|
-
n_samples: int = 10) -> str:
|
143
|
+
def get_database_metadata(connection, n_samples=10) -> dict:
|
144
144
|
"""
|
145
145
|
Collects metadata and sample data from a database, with safe identifier quoting and
|
146
146
|
basic dialect-aware row limiting. Prevents issues with spaces/reserved words in identifiers.
|
@@ -154,77 +154,109 @@ def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engi
|
|
154
154
|
|
155
155
|
Returns
|
156
156
|
-------
|
157
|
-
|
158
|
-
A
|
157
|
+
dict
|
158
|
+
A dictionary with database metadata, including some sample data from each column.
|
159
159
|
"""
|
160
|
-
|
161
|
-
# If a connection is passed, use it; if an engine is passed, connect to it
|
162
160
|
is_engine = isinstance(connection, sql.engine.base.Engine)
|
163
161
|
conn = connection.connect() if is_engine else connection
|
164
162
|
|
165
|
-
|
163
|
+
metadata = {
|
164
|
+
"dialect": None,
|
165
|
+
"driver": None,
|
166
|
+
"connection_url": None,
|
167
|
+
"schemas": [],
|
168
|
+
}
|
169
|
+
|
166
170
|
try:
|
167
|
-
# Grab the engine off the connection
|
168
171
|
sql_engine = conn.engine
|
169
172
|
dialect_name = sql_engine.dialect.name.lower()
|
170
173
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
# Inspect the database
|
176
|
-
inspector = sql.inspect(sql_engine)
|
177
|
-
tables = inspector.get_table_names()
|
178
|
-
output.append(f"Tables: {tables}")
|
179
|
-
output.append(f"Schemas: {inspector.get_schema_names()}")
|
180
|
-
|
181
|
-
# Helper to build a dialect-specific limit clause
|
182
|
-
def build_query(col_name_quoted: str, table_name_quoted: str, n: int) -> str:
|
183
|
-
"""
|
184
|
-
Returns a SQL query string to select N rows from the given column/table
|
185
|
-
across different dialects (SQLite, MySQL, Postgres, MSSQL, Oracle, etc.)
|
186
|
-
"""
|
187
|
-
if "sqlite" in dialect_name or "mysql" in dialect_name or "postgres" in dialect_name:
|
188
|
-
# Common dialects supporting LIMIT
|
189
|
-
return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
|
190
|
-
elif "mssql" in dialect_name:
|
191
|
-
# Microsoft SQL Server syntax
|
192
|
-
return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted}"
|
193
|
-
elif "oracle" in dialect_name:
|
194
|
-
# Oracle syntax
|
195
|
-
return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
|
196
|
-
else:
|
197
|
-
# Fallback
|
198
|
-
return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
|
199
|
-
|
200
|
-
# Prepare for quoting
|
201
|
-
preparer = inspector.bind.dialect.identifier_preparer
|
202
|
-
|
203
|
-
# For each table, get columns and sample data
|
204
|
-
for table_name in tables:
|
205
|
-
output.append(f"\nTable: {table_name}")
|
206
|
-
# Properly quote the table name
|
207
|
-
table_name_quoted = preparer.quote_identifier(table_name)
|
208
|
-
|
209
|
-
for column in inspector.get_columns(table_name):
|
210
|
-
col_name = column["name"]
|
211
|
-
col_type = column["type"]
|
212
|
-
output.append(f" Column: {col_name} Type: {col_type}")
|
174
|
+
metadata["dialect"] = sql_engine.dialect.name
|
175
|
+
metadata["driver"] = sql_engine.driver
|
176
|
+
metadata["connection_url"] = str(sql_engine.url)
|
213
177
|
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
# Build a dialect-aware query with safe quoting
|
218
|
-
query = build_query(col_name_quoted, table_name_quoted, n_samples)
|
219
|
-
|
220
|
-
# Read a few sample values
|
221
|
-
df = pd.read_sql(sql.text(query), conn)
|
222
|
-
first_values = df[col_name].tolist()
|
223
|
-
output.append(f" First {n_samples} Values: {first_values}")
|
178
|
+
inspector = inspect(sql_engine)
|
179
|
+
preparer = inspector.bind.dialect.identifier_preparer
|
224
180
|
|
181
|
+
# For each schema
|
182
|
+
for schema_name in inspector.get_schema_names():
|
183
|
+
schema_obj = {
|
184
|
+
"schema_name": schema_name,
|
185
|
+
"tables": []
|
186
|
+
}
|
187
|
+
|
188
|
+
tables = inspector.get_table_names(schema=schema_name)
|
189
|
+
for table_name in tables:
|
190
|
+
table_info = {
|
191
|
+
"table_name": table_name,
|
192
|
+
"columns": [],
|
193
|
+
"primary_key": [],
|
194
|
+
"foreign_keys": [],
|
195
|
+
"indexes": []
|
196
|
+
}
|
197
|
+
# Get columns
|
198
|
+
columns = inspector.get_columns(table_name, schema=schema_name)
|
199
|
+
for col in columns:
|
200
|
+
col_name = col["name"]
|
201
|
+
col_type = str(col["type"])
|
202
|
+
table_name_quoted = f"{preparer.quote_identifier(schema_name)}.{preparer.quote_identifier(table_name)}"
|
203
|
+
col_name_quoted = preparer.quote_identifier(col_name)
|
204
|
+
|
205
|
+
# Build query for sample data
|
206
|
+
query = build_query(col_name_quoted, table_name_quoted, n_samples, dialect_name)
|
207
|
+
|
208
|
+
# Retrieve sample data
|
209
|
+
try:
|
210
|
+
df = pd.read_sql(query, conn)
|
211
|
+
samples = df[col_name].head(n_samples).tolist()
|
212
|
+
except Exception as e:
|
213
|
+
samples = [f"Error retrieving data: {str(e)}"]
|
214
|
+
|
215
|
+
table_info["columns"].append({
|
216
|
+
"name": col_name,
|
217
|
+
"type": col_type,
|
218
|
+
"sample_values": samples
|
219
|
+
})
|
220
|
+
|
221
|
+
# Primary keys
|
222
|
+
pk_constraint = inspector.get_pk_constraint(table_name, schema=schema_name)
|
223
|
+
table_info["primary_key"] = pk_constraint.get("constrained_columns", [])
|
224
|
+
|
225
|
+
# Foreign keys
|
226
|
+
fks = inspector.get_foreign_keys(table_name, schema=schema_name)
|
227
|
+
table_info["foreign_keys"] = [
|
228
|
+
{
|
229
|
+
"local_cols": fk["constrained_columns"],
|
230
|
+
"referred_table": fk["referred_table"],
|
231
|
+
"referred_cols": fk["referred_columns"]
|
232
|
+
}
|
233
|
+
for fk in fks
|
234
|
+
]
|
235
|
+
|
236
|
+
# Indexes
|
237
|
+
idxs = inspector.get_indexes(table_name, schema=schema_name)
|
238
|
+
table_info["indexes"] = idxs
|
239
|
+
|
240
|
+
schema_obj["tables"].append(table_info)
|
241
|
+
|
242
|
+
metadata["schemas"].append(schema_obj)
|
243
|
+
|
225
244
|
finally:
|
226
|
-
# Close connection if created inside the function
|
227
245
|
if is_engine:
|
228
246
|
conn.close()
|
229
247
|
|
230
|
-
return
|
248
|
+
return metadata
|
249
|
+
|
250
|
+
def build_query(col_name_quoted: str, table_name_quoted: str, n: int, dialect_name: str) -> str:
|
251
|
+
# Example: expand your build_query to handle random sampling if possible
|
252
|
+
if "postgres" in dialect_name:
|
253
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
|
254
|
+
if "mysql" in dialect_name:
|
255
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RAND() LIMIT {n}"
|
256
|
+
if "sqlite" in dialect_name:
|
257
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
|
258
|
+
if "mssql" in dialect_name:
|
259
|
+
return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted} ORDER BY NEWID()"
|
260
|
+
# Oracle or fallback
|
261
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
|
262
|
+
|
@@ -103,4 +103,62 @@ def format_recommended_steps(raw_text: str, heading: str = "# Recommended Steps:
|
|
103
103
|
if not seen_heading:
|
104
104
|
new_lines.insert(0, heading)
|
105
105
|
|
106
|
-
return "\n".join(new_lines)
|
106
|
+
return "\n".join(new_lines)
|
107
|
+
|
108
|
+
def get_generic_summary(report_dict: dict, code_lang = "python") -> str:
|
109
|
+
"""
|
110
|
+
Takes a dictionary of unknown structure (e.g., from json.loads(...))
|
111
|
+
and returns a textual summary. It assumes:
|
112
|
+
1) 'report_title' (if present) should be displayed first.
|
113
|
+
2) If a key includes 'code' or 'function',
|
114
|
+
the value is treated as a code block.
|
115
|
+
3) Otherwise, key-value pairs are displayed as text.
|
116
|
+
|
117
|
+
Parameters
|
118
|
+
----------
|
119
|
+
report_dict : dict
|
120
|
+
The dictionary holding the agent output or user report.
|
121
|
+
|
122
|
+
Returns
|
123
|
+
-------
|
124
|
+
str
|
125
|
+
A formatted summary string.
|
126
|
+
"""
|
127
|
+
# 1) Grab the report title (or default)
|
128
|
+
title = report_dict.get("report_title", "Untitled Report")
|
129
|
+
|
130
|
+
lines = []
|
131
|
+
lines.append(f"# {title}")
|
132
|
+
|
133
|
+
# 2) Iterate over all other keys
|
134
|
+
for key, value in report_dict.items():
|
135
|
+
# Skip the title key, since we already displayed it
|
136
|
+
if key == "report_title":
|
137
|
+
continue
|
138
|
+
|
139
|
+
# 3) Check if it's code or function
|
140
|
+
# (You can tweak this logic if you have different rules)
|
141
|
+
key_lower = key.lower()
|
142
|
+
if "code" in key_lower or "function" in key_lower:
|
143
|
+
# Treat as code
|
144
|
+
lines.append(f"\n## {format_agent_name(key).upper()}")
|
145
|
+
lines.append(f"```{code_lang}\n" + str(value) + "\n```")
|
146
|
+
else:
|
147
|
+
# 4) Otherwise, just display the key-value as text
|
148
|
+
lines.append(f"\n## {format_agent_name(key).upper()}")
|
149
|
+
lines.append(str(value))
|
150
|
+
|
151
|
+
return "\n".join(lines)
|
152
|
+
|
153
|
+
def remove_consecutive_duplicates(messages):
|
154
|
+
unique_messages = []
|
155
|
+
prev_message = None
|
156
|
+
|
157
|
+
for msg in messages:
|
158
|
+
if msg.content != prev_message:
|
159
|
+
unique_messages.append(msg)
|
160
|
+
prev_message = msg.content # Update previous message to current
|
161
|
+
|
162
|
+
return unique_messages
|
163
|
+
|
164
|
+
|