ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +4 -5
- ai_data_science_team/agents/data_cleaning_agent.py +268 -116
- ai_data_science_team/agents/data_visualization_agent.py +470 -41
- ai_data_science_team/agents/data_wrangling_agent.py +471 -31
- ai_data_science_team/agents/feature_engineering_agent.py +426 -41
- ai_data_science_team/agents/sql_database_agent.py +458 -58
- ai_data_science_team/ml_agents/__init__.py +1 -0
- ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +3 -1
- ai_data_science_team/templates/agent_templates.py +319 -43
- ai_data_science_team/tools/metadata.py +94 -62
- ai_data_science_team/tools/regex.py +86 -1
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
- ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
- ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
- ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -5,24 +5,33 @@ import operator
|
|
5
5
|
|
6
6
|
from langchain.prompts import PromptTemplate
|
7
7
|
from langchain_core.messages import BaseMessage
|
8
|
+
from langchain_core.output_parsers import JsonOutputParser
|
8
9
|
|
9
10
|
from langgraph.types import Command
|
10
11
|
from langgraph.checkpoint.memory import MemorySaver
|
11
12
|
|
12
13
|
import os
|
13
|
-
import
|
14
|
+
import json
|
14
15
|
import pandas as pd
|
15
16
|
import sqlalchemy as sql
|
16
17
|
|
18
|
+
from IPython.display import Markdown
|
19
|
+
|
17
20
|
from ai_data_science_team.templates import(
|
18
21
|
node_func_execute_agent_from_sql_connection,
|
19
22
|
node_func_human_review,
|
20
23
|
node_func_fix_agent_code,
|
21
|
-
|
22
|
-
create_coding_agent_graph
|
24
|
+
node_func_report_agent_outputs,
|
25
|
+
create_coding_agent_graph,
|
26
|
+
BaseAgent,
|
27
|
+
)
|
28
|
+
from ai_data_science_team.tools.parsers import SQLOutputParser
|
29
|
+
from ai_data_science_team.tools.regex import (
|
30
|
+
add_comments_to_top,
|
31
|
+
format_agent_name,
|
32
|
+
format_recommended_steps,
|
33
|
+
get_generic_summary,
|
23
34
|
)
|
24
|
-
from ai_data_science_team.tools.parsers import PythonOutputParser, SQLOutputParser
|
25
|
-
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
|
26
35
|
from ai_data_science_team.tools.metadata import get_database_metadata
|
27
36
|
from ai_data_science_team.tools.logging import log_ai_function
|
28
37
|
|
@@ -30,16 +39,333 @@ from ai_data_science_team.tools.logging import log_ai_function
|
|
30
39
|
AGENT_NAME = "sql_database_agent"
|
31
40
|
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
32
41
|
|
42
|
+
# Class
|
43
|
+
|
44
|
+
class SQLDatabaseAgent(BaseAgent):
|
45
|
+
"""
|
46
|
+
Creates a SQL Database Agent that can recommend SQL steps and generate SQL code to query a database.
|
47
|
+
The agent can:
|
48
|
+
- Propose recommended steps to answer a user's query or instructions.
|
49
|
+
- Generate a SQL query based on the recommended steps and user instructions.
|
50
|
+
- Execute that SQL query against the provided database connection.
|
51
|
+
- Return the resulting data as a dictionary, suitable for conversion to a DataFrame or other structures.
|
52
|
+
- Log generated code and errors if enabled.
|
53
|
+
|
54
|
+
Parameters
|
55
|
+
----------
|
56
|
+
model : ChatOpenAI or langchain.llms.base.LLM
|
57
|
+
The language model used to generate the SQL code.
|
58
|
+
connection : sqlalchemy.engine.base.Engine or sqlalchemy.engine.base.Connection
|
59
|
+
The SQLAlchemy connection (or engine) to the database.
|
60
|
+
n_samples : int, optional
|
61
|
+
Number of sample rows (per column) to retrieve when summarizing database metadata. Defaults to 1.
|
62
|
+
log : bool, optional
|
63
|
+
Whether to log the generated code and errors. Defaults to False.
|
64
|
+
log_path : str, optional
|
65
|
+
Directory path for storing log files. Defaults to None.
|
66
|
+
file_name : str, optional
|
67
|
+
Name of the file for saving the generated response. Defaults to "sql_database.py".
|
68
|
+
function_name : str, optional
|
69
|
+
Name of the Python function that executes the SQL query. Defaults to "sql_database_pipeline".
|
70
|
+
overwrite : bool, optional
|
71
|
+
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
|
72
|
+
human_in_the_loop : bool, optional
|
73
|
+
Enables user review of the recommended steps before generating code. Defaults to False.
|
74
|
+
bypass_recommended_steps : bool, optional
|
75
|
+
If True, skips the step that generates recommended SQL steps. Defaults to False.
|
76
|
+
bypass_explain_code : bool, optional
|
77
|
+
If True, skips the step that provides code explanations. Defaults to False.
|
78
|
+
smart_schema_pruning : bool, optional
|
79
|
+
If True, filters the tables and columns based on the user instructions and recommended steps. Defaults to False.
|
80
|
+
|
81
|
+
Methods
|
82
|
+
-------
|
83
|
+
update_params(**kwargs)
|
84
|
+
Updates the agent's parameters and rebuilds the compiled state graph.
|
85
|
+
ainvoke_agent(user_instructions: str, max_retries=3, retry_count=0)
|
86
|
+
Asynchronously runs the agent to generate and execute a SQL query based on user instructions.
|
87
|
+
invoke_agent(user_instructions: str, max_retries=3, retry_count=0)
|
88
|
+
Synchronously runs the agent to generate and execute a SQL query based on user instructions.
|
89
|
+
get_workflow_summary()
|
90
|
+
Retrieves a summary of the agent's workflow.
|
91
|
+
get_log_summary()
|
92
|
+
Retrieves a summary of logged operations if logging is enabled.
|
93
|
+
get_data_sql()
|
94
|
+
Retrieves the resulting data from the SQL query as a dictionary.
|
95
|
+
(You can convert this to a DataFrame if desired.)
|
96
|
+
get_sql_query_code()
|
97
|
+
Retrieves the exact SQL query generated by the agent.
|
98
|
+
get_sql_database_function()
|
99
|
+
Retrieves the Python function that executes the SQL query.
|
100
|
+
get_recommended_sql_steps()
|
101
|
+
Retrieves the recommended steps for querying the SQL database.
|
102
|
+
get_response()
|
103
|
+
Returns the full response dictionary from the agent.
|
104
|
+
show()
|
105
|
+
Displays the agent's mermaid diagram for visual inspection of the compiled graph.
|
106
|
+
|
107
|
+
Examples
|
108
|
+
--------
|
109
|
+
```python
|
110
|
+
import sqlalchemy as sql
|
111
|
+
from langchain_openai import ChatOpenAI
|
112
|
+
from ai_data_science_team.agents import SQLDatabaseAgent
|
113
|
+
|
114
|
+
# Create the engine/connection
|
115
|
+
sql_engine = sql.create_engine("sqlite:///data/my_database.db")
|
116
|
+
conn = sql_engine.connect()
|
117
|
+
|
118
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
119
|
+
|
120
|
+
sql_database_agent = SQLDatabaseAgent(
|
121
|
+
model=llm,
|
122
|
+
connection=conn,
|
123
|
+
n_samples=10,
|
124
|
+
log=True,
|
125
|
+
log_path="logs",
|
126
|
+
human_in_the_loop=True
|
127
|
+
)
|
128
|
+
|
129
|
+
# Example usage
|
130
|
+
sql_database_agent.invoke_agent(
|
131
|
+
user_instructions="List all the tables in the database.",
|
132
|
+
max_retries=3,
|
133
|
+
retry_count=0
|
134
|
+
)
|
135
|
+
|
136
|
+
data_result = sql_database_agent.get_data_sql() # dictionary of rows returned
|
137
|
+
sql_code = sql_database_agent.get_sql_query_code()
|
138
|
+
response = sql_database_agent.get_response()
|
139
|
+
```
|
140
|
+
|
141
|
+
Returns
|
142
|
+
-------
|
143
|
+
SQLDatabaseAgent : langchain.graphs.CompiledStateGraph
|
144
|
+
A SQL database agent implemented as a compiled state graph.
|
145
|
+
"""
|
146
|
+
|
147
|
+
def __init__(
|
148
|
+
self,
|
149
|
+
model,
|
150
|
+
connection,
|
151
|
+
n_samples=1,
|
152
|
+
log=False,
|
153
|
+
log_path=None,
|
154
|
+
file_name="sql_database.py",
|
155
|
+
function_name="sql_database_pipeline",
|
156
|
+
overwrite=True,
|
157
|
+
human_in_the_loop=False,
|
158
|
+
bypass_recommended_steps=False,
|
159
|
+
bypass_explain_code=False,
|
160
|
+
smart_schema_pruning=False,
|
161
|
+
):
|
162
|
+
self._params = {
|
163
|
+
"model": model,
|
164
|
+
"connection": connection,
|
165
|
+
"n_samples": n_samples,
|
166
|
+
"log": log,
|
167
|
+
"log_path": log_path,
|
168
|
+
"file_name": file_name,
|
169
|
+
"function_name": function_name,
|
170
|
+
"overwrite": overwrite,
|
171
|
+
"human_in_the_loop": human_in_the_loop,
|
172
|
+
"bypass_recommended_steps": bypass_recommended_steps,
|
173
|
+
"bypass_explain_code": bypass_explain_code,
|
174
|
+
"smart_schema_pruning": smart_schema_pruning,
|
175
|
+
}
|
176
|
+
self._compiled_graph = self._make_compiled_graph()
|
177
|
+
self.response = None
|
178
|
+
|
179
|
+
def _make_compiled_graph(self):
|
180
|
+
"""
|
181
|
+
Create or rebuild the compiled graph for the SQL Database Agent.
|
182
|
+
Running this method resets the response to None.
|
183
|
+
"""
|
184
|
+
self.response = None
|
185
|
+
return make_sql_database_agent(**self._params)
|
186
|
+
|
187
|
+
def update_params(self, **kwargs):
|
188
|
+
"""
|
189
|
+
Updates the agent's parameters (e.g. connection, n_samples, log, etc.)
|
190
|
+
and rebuilds the compiled graph.
|
191
|
+
"""
|
192
|
+
for k, v in kwargs.items():
|
193
|
+
self._params[k] = v
|
194
|
+
self._compiled_graph = self._make_compiled_graph()
|
195
|
+
|
196
|
+
def ainvoke_agent(self, user_instructions: str=None, max_retries=3, retry_count=0, **kwargs):
|
197
|
+
"""
|
198
|
+
Asynchronously runs the SQL Database Agent based on user instructions.
|
199
|
+
|
200
|
+
Parameters
|
201
|
+
----------
|
202
|
+
user_instructions : str
|
203
|
+
Instructions for the SQL query or metadata request.
|
204
|
+
max_retries : int, optional
|
205
|
+
Maximum retry attempts. Defaults to 3.
|
206
|
+
retry_count : int, optional
|
207
|
+
Current retry count. Defaults to 0.
|
208
|
+
kwargs : dict
|
209
|
+
Additional keyword arguments to pass to ainvoke().
|
210
|
+
|
211
|
+
Returns
|
212
|
+
-------
|
213
|
+
None
|
214
|
+
"""
|
215
|
+
response = self._compiled_graph.ainvoke({
|
216
|
+
"user_instructions": user_instructions,
|
217
|
+
"max_retries": max_retries,
|
218
|
+
"retry_count": retry_count
|
219
|
+
}, **kwargs)
|
220
|
+
self.response = response
|
221
|
+
|
222
|
+
def invoke_agent(self, user_instructions: str=None, max_retries=3, retry_count=0, **kwargs):
|
223
|
+
"""
|
224
|
+
Synchronously runs the SQL Database Agent based on user instructions.
|
225
|
+
|
226
|
+
Parameters
|
227
|
+
----------
|
228
|
+
user_instructions : str
|
229
|
+
Instructions for the SQL query or metadata request.
|
230
|
+
max_retries : int, optional
|
231
|
+
Maximum retry attempts. Defaults to 3.
|
232
|
+
retry_count : int, optional
|
233
|
+
Current retry count. Defaults to 0.
|
234
|
+
kwargs : dict
|
235
|
+
Additional keyword arguments to pass to invoke().
|
236
|
+
|
237
|
+
Returns
|
238
|
+
-------
|
239
|
+
None
|
240
|
+
"""
|
241
|
+
response = self._compiled_graph.invoke({
|
242
|
+
"user_instructions": user_instructions,
|
243
|
+
"max_retries": max_retries,
|
244
|
+
"retry_count": retry_count
|
245
|
+
}, **kwargs)
|
246
|
+
self.response = response
|
247
|
+
|
248
|
+
def get_workflow_summary(self, markdown=False):
|
249
|
+
"""
|
250
|
+
Retrieves the agent's workflow summary, if logging is enabled.
|
251
|
+
"""
|
252
|
+
if self.response and self.response.get("messages"):
|
253
|
+
summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
|
254
|
+
if markdown:
|
255
|
+
return Markdown(summary)
|
256
|
+
else:
|
257
|
+
return summary
|
258
|
+
|
259
|
+
def get_log_summary(self, markdown=False):
|
260
|
+
"""
|
261
|
+
Logs a summary of the agent's operations, if logging is enabled.
|
262
|
+
"""
|
263
|
+
if self.response:
|
264
|
+
if self.response.get('sql_database_function_path'):
|
265
|
+
log_details = f"""
|
266
|
+
## SQL Database Agent Log Summary:
|
267
|
+
|
268
|
+
Function Path: {self.response.get('sql_database_function_path')}
|
269
|
+
|
270
|
+
Function Name: {self.response.get('sql_database_function_name')}
|
271
|
+
"""
|
272
|
+
if markdown:
|
273
|
+
return Markdown(log_details)
|
274
|
+
else:
|
275
|
+
return log_details
|
276
|
+
|
277
|
+
def get_data_sql(self):
|
278
|
+
"""
|
279
|
+
Retrieves the SQL query result from the agent's response.
|
280
|
+
|
281
|
+
Returns
|
282
|
+
-------
|
283
|
+
dict or None
|
284
|
+
The returned data as a dictionary of column -> list_of_values,
|
285
|
+
or None if no data is found.
|
286
|
+
"""
|
287
|
+
if self.response and "data_sql" in self.response:
|
288
|
+
return pd.DataFrame(self.response["data_sql"])
|
289
|
+
return None
|
290
|
+
|
291
|
+
def get_sql_query_code(self, markdown=False):
|
292
|
+
"""
|
293
|
+
Retrieves the raw SQL query code generated by the agent (if available).
|
294
|
+
|
295
|
+
Parameters
|
296
|
+
----------
|
297
|
+
markdown : bool, optional
|
298
|
+
If True, returns the code in a Markdown code block.
|
299
|
+
|
300
|
+
Returns
|
301
|
+
-------
|
302
|
+
str or None
|
303
|
+
The SQL query as a string, or None if not available.
|
304
|
+
"""
|
305
|
+
if self.response and "sql_query_code" in self.response:
|
306
|
+
if markdown:
|
307
|
+
return Markdown(f"```sql\n{self.response['sql_query_code']}\n```")
|
308
|
+
return self.response["sql_query_code"]
|
309
|
+
return None
|
310
|
+
|
311
|
+
def get_sql_database_function(self, markdown=False):
|
312
|
+
"""
|
313
|
+
Retrieves the Python function code used to execute the SQL query.
|
314
|
+
|
315
|
+
Parameters
|
316
|
+
----------
|
317
|
+
markdown : bool, optional
|
318
|
+
If True, returns the code in a Markdown code block.
|
319
|
+
|
320
|
+
Returns
|
321
|
+
-------
|
322
|
+
str or None
|
323
|
+
The function code if available, otherwise None.
|
324
|
+
"""
|
325
|
+
if self.response and "sql_database_function" in self.response:
|
326
|
+
code = self.response["sql_database_function"]
|
327
|
+
if markdown:
|
328
|
+
return Markdown(f"```python\n{code}\n```")
|
329
|
+
return code
|
330
|
+
return None
|
331
|
+
|
332
|
+
def get_recommended_sql_steps(self, markdown=False):
|
333
|
+
"""
|
334
|
+
Retrieves the recommended SQL steps from the agent's response.
|
335
|
+
|
336
|
+
Parameters
|
337
|
+
----------
|
338
|
+
markdown : bool, optional
|
339
|
+
If True, returns the steps in Markdown format.
|
340
|
+
|
341
|
+
Returns
|
342
|
+
-------
|
343
|
+
str or None
|
344
|
+
Recommended steps or None if not available.
|
345
|
+
"""
|
346
|
+
if self.response and "recommended_steps" in self.response:
|
347
|
+
if markdown:
|
348
|
+
return Markdown(self.response["recommended_steps"])
|
349
|
+
return self.response["recommended_steps"]
|
350
|
+
return None
|
351
|
+
|
352
|
+
|
353
|
+
|
354
|
+
# Function
|
33
355
|
|
34
356
|
def make_sql_database_agent(
|
35
|
-
model,
|
36
|
-
|
357
|
+
model,
|
358
|
+
connection,
|
359
|
+
n_samples=1,
|
37
360
|
log=False,
|
38
361
|
log_path=None,
|
39
362
|
file_name="sql_database.py",
|
363
|
+
function_name="sql_database_pipeline",
|
40
364
|
overwrite = True,
|
41
|
-
human_in_the_loop=False,
|
42
|
-
|
365
|
+
human_in_the_loop=False,
|
366
|
+
bypass_recommended_steps=False,
|
367
|
+
bypass_explain_code=False,
|
368
|
+
smart_schema_pruning=False,
|
43
369
|
):
|
44
370
|
"""
|
45
371
|
Creates a SQL Database Agent that can recommend SQL steps and generate SQL code to query a database.
|
@@ -51,7 +377,7 @@ def make_sql_database_agent(
|
|
51
377
|
connection : sqlalchemy.engine.base.Engine
|
52
378
|
The connection to the SQL database.
|
53
379
|
n_samples : int, optional
|
54
|
-
The number of samples to retrieve for each column, by default
|
380
|
+
The number of samples to retrieve for each column, by default 1.
|
55
381
|
If you get an error due to maximum tokens, try reducing this number.
|
56
382
|
> "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
|
57
383
|
log : bool, optional
|
@@ -68,6 +394,8 @@ def make_sql_database_agent(
|
|
68
394
|
Bypass the recommendation step, by default False
|
69
395
|
bypass_explain_code : bool, optional
|
70
396
|
Bypass the code explanation step, by default False.
|
397
|
+
smart_schema_pruning : bool, optional
|
398
|
+
If True, filters the tables and columns with an extra LLM step to reduce tokens for large databases. Increases processing time but can avoid errors due to hitting max token limits with large databases. Defaults to False.
|
71
399
|
|
72
400
|
Returns
|
73
401
|
-------
|
@@ -100,20 +428,26 @@ def make_sql_database_agent(
|
|
100
428
|
"retry_count":0
|
101
429
|
})
|
102
430
|
```
|
103
|
-
|
104
431
|
"""
|
105
432
|
|
106
|
-
is_engine = isinstance(connection, sql.engine.base.Engine)
|
107
|
-
conn = connection.connect() if is_engine else connection
|
108
|
-
|
109
433
|
llm = model
|
110
434
|
|
435
|
+
# Human in th loop requires recommended steps
|
436
|
+
if bypass_recommended_steps and human_in_the_loop:
|
437
|
+
bypass_recommended_steps = False
|
438
|
+
print("Bypass recommended steps set to False to enable human in the loop.")
|
439
|
+
|
111
440
|
# Setup Log Directory
|
112
441
|
if log:
|
113
442
|
if log_path is None:
|
114
443
|
log_path = LOG_PATH
|
115
444
|
if not os.path.exists(log_path):
|
116
445
|
os.makedirs(log_path)
|
446
|
+
|
447
|
+
# Get the database metadata
|
448
|
+
is_engine = isinstance(connection, sql.engine.base.Engine)
|
449
|
+
conn = connection.connect() if is_engine else connection
|
450
|
+
|
117
451
|
|
118
452
|
class GraphState(TypedDict):
|
119
453
|
messages: Annotated[Sequence[BaseMessage], operator.add]
|
@@ -124,6 +458,7 @@ def make_sql_database_agent(
|
|
124
458
|
sql_query_code: str
|
125
459
|
sql_database_function: str
|
126
460
|
sql_database_function_path: str
|
461
|
+
sql_database_function_file_name: str
|
127
462
|
sql_database_function_name: str
|
128
463
|
sql_database_error: str
|
129
464
|
max_retries: int
|
@@ -132,6 +467,16 @@ def make_sql_database_agent(
|
|
132
467
|
def recommend_sql_steps(state: GraphState):
|
133
468
|
|
134
469
|
print(format_agent_name(AGENT_NAME))
|
470
|
+
|
471
|
+
all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
|
472
|
+
|
473
|
+
all_sql_database_summary = smart_schema_filter(
|
474
|
+
llm,
|
475
|
+
state.get("user_instructions"),
|
476
|
+
all_sql_database_summary,
|
477
|
+
smart_filtering=smart_schema_pruning
|
478
|
+
)
|
479
|
+
|
135
480
|
print(" * RECOMMEND STEPS")
|
136
481
|
|
137
482
|
|
@@ -148,6 +493,7 @@ def make_sql_database_agent(
|
|
148
493
|
- If no user instructions are provided, just return the steps needed to understand the database.
|
149
494
|
- Take into account the database dialect and the tables and columns in the database.
|
150
495
|
- Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
496
|
+
- IMPORTANT: Pay attention to the table names and column names in the database. Make sure to use the correct table and column names in the SQL code. If a space is present in the table name or column name, make sure to account for it.
|
151
497
|
|
152
498
|
|
153
499
|
User instructions / Question:
|
@@ -159,7 +505,7 @@ def make_sql_database_agent(
|
|
159
505
|
Below are summaries of the database metadata and the SQL tables:
|
160
506
|
{all_sql_database_summary}
|
161
507
|
|
162
|
-
Return
|
508
|
+
Return steps as a numbered list. You can return short code snippets to demonstrate actions. But do not return a fully coded solution. The code will be generated separately by a Coding Agent.
|
163
509
|
|
164
510
|
Consider these:
|
165
511
|
|
@@ -178,13 +524,6 @@ def make_sql_database_agent(
|
|
178
524
|
input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
|
179
525
|
)
|
180
526
|
|
181
|
-
# Create a connection if needed
|
182
|
-
is_engine = isinstance(connection, sql.engine.base.Engine)
|
183
|
-
conn = connection.connect() if is_engine else connection
|
184
|
-
|
185
|
-
# Get the database metadata
|
186
|
-
all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
|
187
|
-
|
188
527
|
steps_agent = recommend_steps_prompt | llm
|
189
528
|
|
190
529
|
recommended_steps = steps_agent.invoke({
|
@@ -194,13 +533,22 @@ def make_sql_database_agent(
|
|
194
533
|
})
|
195
534
|
|
196
535
|
return {
|
197
|
-
"recommended_steps": "
|
536
|
+
"recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended SQL Database Steps:"),
|
198
537
|
"all_sql_database_summary": all_sql_database_summary
|
199
538
|
}
|
200
539
|
|
201
540
|
def create_sql_query_code(state: GraphState):
|
202
541
|
if bypass_recommended_steps:
|
203
542
|
print(format_agent_name(AGENT_NAME))
|
543
|
+
all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
|
544
|
+
all_sql_database_summary = smart_schema_filter(
|
545
|
+
llm,
|
546
|
+
state.get("user_instructions"),
|
547
|
+
all_sql_database_summary,
|
548
|
+
smart_filtering=smart_schema_pruning
|
549
|
+
)
|
550
|
+
else:
|
551
|
+
all_sql_database_summary = state.get("all_sql_database_summary")
|
204
552
|
print(" * CREATE SQL QUERY CODE")
|
205
553
|
|
206
554
|
# Prompt to get the SQL code from the LLM
|
@@ -216,6 +564,7 @@ def make_sql_database_agent(
|
|
216
564
|
- Only return a single query if possible.
|
217
565
|
- Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
218
566
|
- Pay attention to the SQL dialect from the database summary metadata. Write the SQL code according to the dialect specified.
|
567
|
+
- IMPORTANT: Pay attention to the table names and column names in the database. Make sure to use the correct table and column names in the SQL code. If a space is present in the table name or column name, make sure to account for it.
|
219
568
|
|
220
569
|
|
221
570
|
User instructions / Question:
|
@@ -240,13 +589,6 @@ def make_sql_database_agent(
|
|
240
589
|
input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
|
241
590
|
)
|
242
591
|
|
243
|
-
# Create a connection if needed
|
244
|
-
is_engine = isinstance(connection, sql.engine.base.Engine)
|
245
|
-
conn = connection.connect() if is_engine else connection
|
246
|
-
|
247
|
-
# Get the database metadata
|
248
|
-
all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
|
249
|
-
|
250
592
|
sql_query_code_agent = sql_query_code_prompt | llm | SQLOutputParser()
|
251
593
|
|
252
594
|
sql_query_code = sql_query_code_agent.invoke({
|
@@ -258,7 +600,7 @@ def make_sql_database_agent(
|
|
258
600
|
print(" * CREATE PYTHON FUNCTION TO RUN SQL CODE")
|
259
601
|
|
260
602
|
response = f"""
|
261
|
-
def
|
603
|
+
def {function_name}(connection):
|
262
604
|
import pandas as pd
|
263
605
|
import sqlalchemy as sql
|
264
606
|
|
@@ -288,19 +630,37 @@ def sql_database_pipeline(connection):
|
|
288
630
|
"sql_query_code": sql_query_code,
|
289
631
|
"sql_database_function": response,
|
290
632
|
"sql_database_function_path": file_path,
|
291
|
-
"
|
633
|
+
"sql_database_function_file_name": file_name_2,
|
634
|
+
"sql_database_function_name": function_name,
|
292
635
|
"all_sql_database_summary": all_sql_database_summary
|
293
636
|
}
|
294
637
|
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
638
|
+
# Human Review
|
639
|
+
|
640
|
+
prompt_text_human_review = "Are the following SQL agent instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
|
641
|
+
|
642
|
+
if not bypass_explain_code:
|
643
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_sql_steps", "explain_sql_database_code"]]:
|
644
|
+
return node_func_human_review(
|
645
|
+
state=state,
|
646
|
+
prompt_text=prompt_text_human_review,
|
647
|
+
yes_goto= 'explain_sql_database_code',
|
648
|
+
no_goto="recommend_sql_steps",
|
649
|
+
user_instructions_key="user_instructions",
|
650
|
+
recommended_steps_key="recommended_steps",
|
651
|
+
code_snippet_key="sql_database_function",
|
652
|
+
)
|
653
|
+
else:
|
654
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_sql_steps", "__end__"]]:
|
655
|
+
return node_func_human_review(
|
656
|
+
state=state,
|
657
|
+
prompt_text=prompt_text_human_review,
|
658
|
+
yes_goto= '__end__',
|
659
|
+
no_goto="recommend_sql_steps",
|
660
|
+
user_instructions_key="user_instructions",
|
661
|
+
recommended_steps_key="recommended_steps",
|
662
|
+
code_snippet_key="sql_database_function",
|
663
|
+
)
|
304
664
|
|
305
665
|
def execute_sql_database_code(state: GraphState):
|
306
666
|
|
@@ -313,18 +673,18 @@ def sql_database_pipeline(connection):
|
|
313
673
|
result_key="data_sql",
|
314
674
|
error_key="sql_database_error",
|
315
675
|
code_snippet_key="sql_database_function",
|
316
|
-
agent_function_name="
|
676
|
+
agent_function_name=state.get("sql_database_function_name"),
|
317
677
|
post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
318
678
|
error_message_prefix="An error occurred during executing the sql database pipeline: "
|
319
679
|
)
|
320
680
|
|
321
681
|
def fix_sql_database_code(state: GraphState):
|
322
682
|
prompt = """
|
323
|
-
You are a SQL Database Agent code fixer. Your job is to create a
|
683
|
+
You are a SQL Database Agent code fixer. Your job is to create a {function_name}(connection) function that can be run on a sql connection. The function is currently broken and needs to be fixed.
|
324
684
|
|
325
|
-
Make sure to only return the function definition for
|
685
|
+
Make sure to only return the function definition for {function_name}().
|
326
686
|
|
327
|
-
Return Python code in ```python``` format with a single function definition,
|
687
|
+
Return Python code in ```python``` format with a single function definition, {function_name}(connection), that includes all imports inside the function. The connection object is a SQLAlchemy connection object. Don't specify the class of the connection object, just use it as an argument to the function.
|
328
688
|
|
329
689
|
This is the broken code (please fix):
|
330
690
|
{code_snippet}
|
@@ -342,22 +702,23 @@ def sql_database_pipeline(connection):
|
|
342
702
|
agent_name=AGENT_NAME,
|
343
703
|
log=log,
|
344
704
|
file_path=state.get("sql_database_function_path", None),
|
705
|
+
function_name=state.get("sql_database_function_name"),
|
345
706
|
)
|
346
707
|
|
347
|
-
|
348
|
-
|
708
|
+
# Final reporting node
|
709
|
+
def report_agent_outputs(state: GraphState):
|
710
|
+
return node_func_report_agent_outputs(
|
349
711
|
state=state,
|
350
|
-
|
712
|
+
keys_to_include=[
|
713
|
+
"recommended_steps",
|
714
|
+
"sql_database_function",
|
715
|
+
"sql_database_function_path",
|
716
|
+
"sql_database_function_name",
|
717
|
+
"sql_database_error",
|
718
|
+
],
|
351
719
|
result_key="messages",
|
352
|
-
error_key="sql_database_error",
|
353
|
-
llm=llm,
|
354
720
|
role=AGENT_NAME,
|
355
|
-
|
356
|
-
Explain the SQL steps that the SQL Database agent performed in this function.
|
357
|
-
Keep the summary succinct and to the point.\n\n# SQL Database Agent:\n\n{code}
|
358
|
-
""",
|
359
|
-
success_prefix="# SQL Database Agent:\n\n",
|
360
|
-
error_message="The SQL Database Agent encountered an error during SQL Query Analysis. No SQL function explanation is returned."
|
721
|
+
custom_title="SQL Database Agent Outputs"
|
361
722
|
)
|
362
723
|
|
363
724
|
# Create the graph
|
@@ -367,7 +728,7 @@ def sql_database_pipeline(connection):
|
|
367
728
|
"create_sql_query_code": create_sql_query_code,
|
368
729
|
"execute_sql_database_code": execute_sql_database_code,
|
369
730
|
"fix_sql_database_code": fix_sql_database_code,
|
370
|
-
"
|
731
|
+
"report_agent_outputs": report_agent_outputs,
|
371
732
|
}
|
372
733
|
|
373
734
|
app = create_coding_agent_graph(
|
@@ -377,7 +738,7 @@ def sql_database_pipeline(connection):
|
|
377
738
|
create_code_node_name="create_sql_query_code",
|
378
739
|
execute_code_node_name="execute_sql_database_code",
|
379
740
|
fix_code_node_name="fix_sql_database_code",
|
380
|
-
explain_code_node_name="
|
741
|
+
explain_code_node_name="report_agent_outputs",
|
381
742
|
error_key="sql_database_error",
|
382
743
|
human_in_the_loop=human_in_the_loop,
|
383
744
|
human_review_node_name="human_review",
|
@@ -391,7 +752,46 @@ def sql_database_pipeline(connection):
|
|
391
752
|
|
392
753
|
|
393
754
|
|
755
|
+
def smart_schema_filter(llm, user_instructions, all_sql_database_summary, smart_filtering = True):
|
756
|
+
"""
|
757
|
+
This function filters the tables and columns based on the user instructions and the recommended steps.
|
758
|
+
"""
|
759
|
+
# Smart schema filtering
|
760
|
+
if smart_filtering:
|
761
|
+
print(" * SMART FILTER SCHEMA")
|
762
|
+
|
763
|
+
filter_schema_prompt = PromptTemplate(
|
764
|
+
template="""
|
765
|
+
You are a highly skilled data engineer. The user question is:
|
766
|
+
|
767
|
+
"{user_instructions}"
|
768
|
+
|
769
|
+
You have the full database metadata in JSON format below:
|
770
|
+
|
771
|
+
{all_sql_database_summary}
|
772
|
+
|
773
|
+
|
774
|
+
Please return ONLY the subset of this metadata that is relevant to answering the user’s question.
|
775
|
+
- Preserve the same JSON structure for "schemas" -> "tables" -> "columns".
|
776
|
+
- If any schemas/tables are irrelevant, omit them entirely.
|
777
|
+
- If some columns in a relevant table are not needed, you can still keep them if you aren't sure.
|
778
|
+
- However, try to keep only the minimum amount of data required to answer the user’s question.
|
779
|
+
|
780
|
+
Return a valid JSON object. Do not include any additional explanation or text outside of the JSON.
|
781
|
+
""",
|
782
|
+
input_variables=["user_instructions", "full_metadata_json"]
|
783
|
+
)
|
784
|
+
|
785
|
+
filter_schema_agent = filter_schema_prompt | llm | JsonOutputParser()
|
786
|
+
|
787
|
+
response = filter_schema_agent.invoke({
|
788
|
+
"user_instructions": user_instructions,
|
789
|
+
"all_sql_database_summary": all_sql_database_summary
|
790
|
+
})
|
394
791
|
|
792
|
+
return response
|
793
|
+
else:
|
794
|
+
return all_sql_database_summary
|
395
795
|
|
396
796
|
|
397
797
|
|