ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +5 -4
- ai_data_science_team/agents/data_cleaning_agent.py +371 -45
- ai_data_science_team/agents/data_visualization_agent.py +764 -0
- ai_data_science_team/agents/data_wrangling_agent.py +507 -23
- ai_data_science_team/agents/feature_engineering_agent.py +467 -34
- ai_data_science_team/agents/sql_database_agent.py +394 -30
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +286 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +9 -0
- ai_data_science_team/templates/agent_templates.py +247 -42
- ai_data_science_team/tools/metadata.py +110 -47
- ai_data_science_team/tools/regex.py +33 -0
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- ai_data_science_team-0.0.0.9008.dist-info/METADATA +231 -0
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +26 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/WHEEL +1 -1
- ai_data_science_team-0.0.0.9006.dist-info/METADATA +0 -165
- ai_data_science_team-0.0.0.9006.dist-info/RECORD +0 -20
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/top_level.txt +0 -0
@@ -10,19 +10,21 @@ from langgraph.types import Command
|
|
10
10
|
from langgraph.checkpoint.memory import MemorySaver
|
11
11
|
|
12
12
|
import os
|
13
|
-
import io
|
14
13
|
import pandas as pd
|
15
14
|
import sqlalchemy as sql
|
16
15
|
|
17
|
-
from
|
16
|
+
from IPython.display import Markdown
|
17
|
+
|
18
|
+
from ai_data_science_team.templates import(
|
18
19
|
node_func_execute_agent_from_sql_connection,
|
19
20
|
node_func_human_review,
|
20
21
|
node_func_fix_agent_code,
|
21
22
|
node_func_explain_agent_code,
|
22
|
-
create_coding_agent_graph
|
23
|
+
create_coding_agent_graph,
|
24
|
+
BaseAgent,
|
23
25
|
)
|
24
|
-
from ai_data_science_team.tools.parsers import
|
25
|
-
from ai_data_science_team.tools.regex import
|
26
|
+
from ai_data_science_team.tools.parsers import SQLOutputParser
|
27
|
+
from ai_data_science_team.tools.regex import add_comments_to_top, format_agent_name, format_recommended_steps
|
26
28
|
from ai_data_science_team.tools.metadata import get_database_metadata
|
27
29
|
from ai_data_science_team.tools.logging import log_ai_function
|
28
30
|
|
@@ -30,8 +32,334 @@ from ai_data_science_team.tools.logging import log_ai_function
|
|
30
32
|
AGENT_NAME = "sql_database_agent"
|
31
33
|
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
32
34
|
|
35
|
+
# Class
|
36
|
+
|
37
|
+
class SQLDatabaseAgent(BaseAgent):
|
38
|
+
"""
|
39
|
+
Creates a SQL Database Agent that can recommend SQL steps and generate SQL code to query a database.
|
40
|
+
The agent can:
|
41
|
+
- Propose recommended steps to answer a user's query or instructions.
|
42
|
+
- Generate a SQL query based on the recommended steps and user instructions.
|
43
|
+
- Execute that SQL query against the provided database connection.
|
44
|
+
- Return the resulting data as a dictionary, suitable for conversion to a DataFrame or other structures.
|
45
|
+
- Log generated code and errors if enabled.
|
46
|
+
|
47
|
+
Parameters
|
48
|
+
----------
|
49
|
+
model : ChatOpenAI or langchain.llms.base.LLM
|
50
|
+
The language model used to generate the SQL code.
|
51
|
+
connection : sqlalchemy.engine.base.Engine or sqlalchemy.engine.base.Connection
|
52
|
+
The SQLAlchemy connection (or engine) to the database.
|
53
|
+
n_samples : int, optional
|
54
|
+
Number of sample rows (per column) to retrieve when summarizing database metadata. Defaults to 10.
|
55
|
+
log : bool, optional
|
56
|
+
Whether to log the generated code and errors. Defaults to False.
|
57
|
+
log_path : str, optional
|
58
|
+
Directory path for storing log files. Defaults to None.
|
59
|
+
file_name : str, optional
|
60
|
+
Name of the file for saving the generated response. Defaults to "sql_database.py".
|
61
|
+
function_name : str, optional
|
62
|
+
Name of the Python function that executes the SQL query. Defaults to "sql_database_pipeline".
|
63
|
+
overwrite : bool, optional
|
64
|
+
Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
|
65
|
+
human_in_the_loop : bool, optional
|
66
|
+
Enables user review of the recommended steps before generating code. Defaults to False.
|
67
|
+
bypass_recommended_steps : bool, optional
|
68
|
+
If True, skips the step that generates recommended SQL steps. Defaults to False.
|
69
|
+
bypass_explain_code : bool, optional
|
70
|
+
If True, skips the step that provides code explanations. Defaults to False.
|
71
|
+
|
72
|
+
Methods
|
73
|
+
-------
|
74
|
+
update_params(**kwargs)
|
75
|
+
Updates the agent's parameters and rebuilds the compiled state graph.
|
76
|
+
ainvoke_agent(user_instructions: str, max_retries=3, retry_count=0)
|
77
|
+
Asynchronously runs the agent to generate and execute a SQL query based on user instructions.
|
78
|
+
invoke_agent(user_instructions: str, max_retries=3, retry_count=0)
|
79
|
+
Synchronously runs the agent to generate and execute a SQL query based on user instructions.
|
80
|
+
explain_sql_steps()
|
81
|
+
Returns an explanation of the SQL steps performed by the agent.
|
82
|
+
get_log_summary()
|
83
|
+
Retrieves a summary of logged operations if logging is enabled.
|
84
|
+
get_data_sql()
|
85
|
+
Retrieves the resulting data from the SQL query as a dictionary.
|
86
|
+
(You can convert this to a DataFrame if desired.)
|
87
|
+
get_sql_query_code()
|
88
|
+
Retrieves the exact SQL query generated by the agent.
|
89
|
+
get_sql_database_function()
|
90
|
+
Retrieves the Python function that executes the SQL query.
|
91
|
+
get_recommended_sql_steps()
|
92
|
+
Retrieves the recommended steps for querying the SQL database.
|
93
|
+
get_response()
|
94
|
+
Returns the full response dictionary from the agent.
|
95
|
+
show()
|
96
|
+
Displays the agent's mermaid diagram for visual inspection of the compiled graph.
|
97
|
+
|
98
|
+
Examples
|
99
|
+
--------
|
100
|
+
```python
|
101
|
+
import sqlalchemy as sql
|
102
|
+
from langchain_openai import ChatOpenAI
|
103
|
+
from ai_data_science_team.agents import SQLDatabaseAgent
|
104
|
+
|
105
|
+
# Create the engine/connection
|
106
|
+
sql_engine = sql.create_engine("sqlite:///data/my_database.db")
|
107
|
+
conn = sql_engine.connect()
|
108
|
+
|
109
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
110
|
+
|
111
|
+
sql_database_agent = SQLDatabaseAgent(
|
112
|
+
model=llm,
|
113
|
+
connection=conn,
|
114
|
+
n_samples=10,
|
115
|
+
log=True,
|
116
|
+
log_path="logs",
|
117
|
+
human_in_the_loop=True
|
118
|
+
)
|
119
|
+
|
120
|
+
# Example usage
|
121
|
+
sql_database_agent.invoke_agent(
|
122
|
+
user_instructions="List all the tables in the database.",
|
123
|
+
max_retries=3,
|
124
|
+
retry_count=0
|
125
|
+
)
|
126
|
+
|
127
|
+
data_result = sql_database_agent.get_data_sql() # dictionary of rows returned
|
128
|
+
sql_code = sql_database_agent.get_sql_query_code()
|
129
|
+
response = sql_database_agent.get_response()
|
130
|
+
```
|
131
|
+
|
132
|
+
Returns
|
133
|
+
-------
|
134
|
+
SQLDatabaseAgent : langchain.graphs.CompiledStateGraph
|
135
|
+
A SQL database agent implemented as a compiled state graph.
|
136
|
+
"""
|
137
|
+
|
138
|
+
def __init__(
|
139
|
+
self,
|
140
|
+
model,
|
141
|
+
connection,
|
142
|
+
n_samples=10,
|
143
|
+
log=False,
|
144
|
+
log_path=None,
|
145
|
+
file_name="sql_database.py",
|
146
|
+
function_name="sql_database_pipeline",
|
147
|
+
overwrite=True,
|
148
|
+
human_in_the_loop=False,
|
149
|
+
bypass_recommended_steps=False,
|
150
|
+
bypass_explain_code=False
|
151
|
+
):
|
152
|
+
self._params = {
|
153
|
+
"model": model,
|
154
|
+
"connection": connection,
|
155
|
+
"n_samples": n_samples,
|
156
|
+
"log": log,
|
157
|
+
"log_path": log_path,
|
158
|
+
"file_name": file_name,
|
159
|
+
"function_name": function_name,
|
160
|
+
"overwrite": overwrite,
|
161
|
+
"human_in_the_loop": human_in_the_loop,
|
162
|
+
"bypass_recommended_steps": bypass_recommended_steps,
|
163
|
+
"bypass_explain_code": bypass_explain_code
|
164
|
+
}
|
165
|
+
self._compiled_graph = self._make_compiled_graph()
|
166
|
+
self.response = None
|
167
|
+
|
168
|
+
def _make_compiled_graph(self):
|
169
|
+
"""
|
170
|
+
Create or rebuild the compiled graph for the SQL Database Agent.
|
171
|
+
Running this method resets the response to None.
|
172
|
+
"""
|
173
|
+
self.response = None
|
174
|
+
return make_sql_database_agent(**self._params)
|
175
|
+
|
176
|
+
def update_params(self, **kwargs):
|
177
|
+
"""
|
178
|
+
Updates the agent's parameters (e.g. connection, n_samples, log, etc.)
|
179
|
+
and rebuilds the compiled graph.
|
180
|
+
"""
|
181
|
+
for k, v in kwargs.items():
|
182
|
+
self._params[k] = v
|
183
|
+
self._compiled_graph = self._make_compiled_graph()
|
184
|
+
|
185
|
+
def ainvoke_agent(self, user_instructions: str=None, max_retries=3, retry_count=0, **kwargs):
|
186
|
+
"""
|
187
|
+
Asynchronously runs the SQL Database Agent based on user instructions.
|
188
|
+
|
189
|
+
Parameters
|
190
|
+
----------
|
191
|
+
user_instructions : str
|
192
|
+
Instructions for the SQL query or metadata request.
|
193
|
+
max_retries : int, optional
|
194
|
+
Maximum retry attempts. Defaults to 3.
|
195
|
+
retry_count : int, optional
|
196
|
+
Current retry count. Defaults to 0.
|
197
|
+
kwargs : dict
|
198
|
+
Additional keyword arguments to pass to ainvoke().
|
199
|
+
|
200
|
+
Returns
|
201
|
+
-------
|
202
|
+
None
|
203
|
+
"""
|
204
|
+
response = self._compiled_graph.ainvoke({
|
205
|
+
"user_instructions": user_instructions,
|
206
|
+
"max_retries": max_retries,
|
207
|
+
"retry_count": retry_count
|
208
|
+
}, **kwargs)
|
209
|
+
self.response = response
|
210
|
+
|
211
|
+
def invoke_agent(self, user_instructions: str=None, max_retries=3, retry_count=0, **kwargs):
|
212
|
+
"""
|
213
|
+
Synchronously runs the SQL Database Agent based on user instructions.
|
214
|
+
|
215
|
+
Parameters
|
216
|
+
----------
|
217
|
+
user_instructions : str
|
218
|
+
Instructions for the SQL query or metadata request.
|
219
|
+
max_retries : int, optional
|
220
|
+
Maximum retry attempts. Defaults to 3.
|
221
|
+
retry_count : int, optional
|
222
|
+
Current retry count. Defaults to 0.
|
223
|
+
kwargs : dict
|
224
|
+
Additional keyword arguments to pass to invoke().
|
225
|
+
|
226
|
+
Returns
|
227
|
+
-------
|
228
|
+
None
|
229
|
+
"""
|
230
|
+
response = self._compiled_graph.invoke({
|
231
|
+
"user_instructions": user_instructions,
|
232
|
+
"max_retries": max_retries,
|
233
|
+
"retry_count": retry_count
|
234
|
+
}, **kwargs)
|
235
|
+
self.response = response
|
236
|
+
|
237
|
+
def explain_sql_steps(self):
|
238
|
+
"""
|
239
|
+
Provides an explanation of the SQL steps performed by the agent
|
240
|
+
if the explain step is not bypassed.
|
241
|
+
|
242
|
+
Returns
|
243
|
+
-------
|
244
|
+
str or list
|
245
|
+
An explanation of the SQL steps.
|
246
|
+
"""
|
247
|
+
if self.response:
|
248
|
+
return self.response.get("messages", [])
|
249
|
+
return []
|
33
250
|
|
34
|
-
def
|
251
|
+
def get_log_summary(self, markdown=False):
|
252
|
+
"""
|
253
|
+
Retrieves a summary of the logging details if logging is enabled.
|
254
|
+
|
255
|
+
Parameters
|
256
|
+
----------
|
257
|
+
markdown : bool, optional
|
258
|
+
If True, returns the summary in Markdown format.
|
259
|
+
|
260
|
+
Returns
|
261
|
+
-------
|
262
|
+
str or None
|
263
|
+
Log details or None if logging is not used or data is unavailable.
|
264
|
+
"""
|
265
|
+
if self.response and self.response.get("sql_database_function_path"):
|
266
|
+
log_details = f"Log Path: {self.response['sql_database_function_path']}"
|
267
|
+
if markdown:
|
268
|
+
return Markdown(log_details)
|
269
|
+
return log_details
|
270
|
+
return None
|
271
|
+
|
272
|
+
def get_data_sql(self):
|
273
|
+
"""
|
274
|
+
Retrieves the SQL query result from the agent's response.
|
275
|
+
|
276
|
+
Returns
|
277
|
+
-------
|
278
|
+
dict or None
|
279
|
+
The returned data as a dictionary of column -> list_of_values,
|
280
|
+
or None if no data is found.
|
281
|
+
"""
|
282
|
+
if self.response and "data_sql" in self.response:
|
283
|
+
return pd.DataFrame(self.response["data_sql"])
|
284
|
+
return None
|
285
|
+
|
286
|
+
def get_sql_query_code(self, markdown=False):
|
287
|
+
"""
|
288
|
+
Retrieves the raw SQL query code generated by the agent (if available).
|
289
|
+
|
290
|
+
Parameters
|
291
|
+
----------
|
292
|
+
markdown : bool, optional
|
293
|
+
If True, returns the code in a Markdown code block.
|
294
|
+
|
295
|
+
Returns
|
296
|
+
-------
|
297
|
+
str or None
|
298
|
+
The SQL query as a string, or None if not available.
|
299
|
+
"""
|
300
|
+
if self.response and "sql_query_code" in self.response:
|
301
|
+
if markdown:
|
302
|
+
return Markdown(f"```sql\n{self.response['sql_query_code']}\n```")
|
303
|
+
return self.response["sql_query_code"]
|
304
|
+
return None
|
305
|
+
|
306
|
+
def get_sql_database_function(self, markdown=False):
|
307
|
+
"""
|
308
|
+
Retrieves the Python function code used to execute the SQL query.
|
309
|
+
|
310
|
+
Parameters
|
311
|
+
----------
|
312
|
+
markdown : bool, optional
|
313
|
+
If True, returns the code in a Markdown code block.
|
314
|
+
|
315
|
+
Returns
|
316
|
+
-------
|
317
|
+
str or None
|
318
|
+
The function code if available, otherwise None.
|
319
|
+
"""
|
320
|
+
if self.response and "sql_database_function" in self.response:
|
321
|
+
code = self.response["sql_database_function"]
|
322
|
+
if markdown:
|
323
|
+
return Markdown(f"```python\n{code}\n```")
|
324
|
+
return code
|
325
|
+
return None
|
326
|
+
|
327
|
+
def get_recommended_sql_steps(self, markdown=False):
|
328
|
+
"""
|
329
|
+
Retrieves the recommended SQL steps from the agent's response.
|
330
|
+
|
331
|
+
Parameters
|
332
|
+
----------
|
333
|
+
markdown : bool, optional
|
334
|
+
If True, returns the steps in Markdown format.
|
335
|
+
|
336
|
+
Returns
|
337
|
+
-------
|
338
|
+
str or None
|
339
|
+
Recommended steps or None if not available.
|
340
|
+
"""
|
341
|
+
if self.response and "recommended_steps" in self.response:
|
342
|
+
if markdown:
|
343
|
+
return Markdown(self.response["recommended_steps"])
|
344
|
+
return self.response["recommended_steps"]
|
345
|
+
return None
|
346
|
+
|
347
|
+
|
348
|
+
|
349
|
+
# Function
|
350
|
+
|
351
|
+
def make_sql_database_agent(
|
352
|
+
model,
|
353
|
+
connection,
|
354
|
+
n_samples = 10,
|
355
|
+
log=False,
|
356
|
+
log_path=None,
|
357
|
+
file_name="sql_database.py",
|
358
|
+
function_name="sql_database_pipeline",
|
359
|
+
overwrite = True,
|
360
|
+
human_in_the_loop=False, bypass_recommended_steps=False,
|
361
|
+
bypass_explain_code=False
|
362
|
+
):
|
35
363
|
"""
|
36
364
|
Creates a SQL Database Agent that can recommend SQL steps and generate SQL code to query a database.
|
37
365
|
|
@@ -41,10 +369,16 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
41
369
|
The language model to use for the agent.
|
42
370
|
connection : sqlalchemy.engine.base.Engine
|
43
371
|
The connection to the SQL database.
|
372
|
+
n_samples : int, optional
|
373
|
+
The number of samples to retrieve for each column, by default 10.
|
374
|
+
If you get an error due to maximum tokens, try reducing this number.
|
375
|
+
> "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
|
44
376
|
log : bool, optional
|
45
377
|
Whether to log the generated code, by default False
|
46
378
|
log_path : str, optional
|
47
379
|
The path to the log directory, by default None
|
380
|
+
file_name : str, optional
|
381
|
+
The name of the file to save the generated code, by default "sql_database.py"
|
48
382
|
overwrite : bool, optional
|
49
383
|
Whether to overwrite the existing log file, by default True
|
50
384
|
human_in_the_loop : bool, optional
|
@@ -56,7 +390,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
56
390
|
|
57
391
|
Returns
|
58
392
|
-------
|
59
|
-
app : langchain.graphs.
|
393
|
+
app : langchain.graphs.CompiledStateGraph
|
60
394
|
The data cleaning agent as a state graph.
|
61
395
|
|
62
396
|
Examples
|
@@ -93,6 +427,11 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
93
427
|
|
94
428
|
llm = model
|
95
429
|
|
430
|
+
# Human in th loop requires recommended steps
|
431
|
+
if bypass_recommended_steps and human_in_the_loop:
|
432
|
+
bypass_recommended_steps = False
|
433
|
+
print("Bypass recommended steps set to False to enable human in the loop.")
|
434
|
+
|
96
435
|
# Setup Log Directory
|
97
436
|
if log:
|
98
437
|
if log_path is None:
|
@@ -109,6 +448,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
109
448
|
sql_query_code: str
|
110
449
|
sql_database_function: str
|
111
450
|
sql_database_function_path: str
|
451
|
+
sql_database_function_file_name: str
|
112
452
|
sql_database_function_name: str
|
113
453
|
sql_database_error: str
|
114
454
|
max_retries: int
|
@@ -116,8 +456,8 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
116
456
|
|
117
457
|
def recommend_sql_steps(state: GraphState):
|
118
458
|
|
119
|
-
print(
|
120
|
-
print(" * RECOMMEND
|
459
|
+
print(format_agent_name(AGENT_NAME))
|
460
|
+
print(" * RECOMMEND STEPS")
|
121
461
|
|
122
462
|
|
123
463
|
# Prompt to get recommended steps from the LLM
|
@@ -133,6 +473,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
133
473
|
- If no user instructions are provided, just return the steps needed to understand the database.
|
134
474
|
- Take into account the database dialect and the tables and columns in the database.
|
135
475
|
- Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
476
|
+
- IMPORTANT: Pay attention to the table names and column names in the database. Make sure to use the correct table and column names in the SQL code. If a space is present in the table name or column name, make sure to account for it.
|
136
477
|
|
137
478
|
|
138
479
|
User instructions / Question:
|
@@ -156,6 +497,8 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
156
497
|
2. Do not include steps to modify existing tables, create new tables or modify the database schema.
|
157
498
|
3. Do not include steps that alter the existing data in the database.
|
158
499
|
4. Make sure not to include unsafe code that could cause data loss or corruption or SQL injections.
|
500
|
+
5. Make sure to not include irrelevant steps that do not help in the SQL agent's data collection and processing. Examples include steps to create new tables, modify the schema, save files, create charts, etc.
|
501
|
+
|
159
502
|
|
160
503
|
""",
|
161
504
|
input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
|
@@ -166,7 +509,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
166
509
|
conn = connection.connect() if is_engine else connection
|
167
510
|
|
168
511
|
# Get the database metadata
|
169
|
-
all_sql_database_summary = get_database_metadata(conn,
|
512
|
+
all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
|
170
513
|
|
171
514
|
steps_agent = recommend_steps_prompt | llm
|
172
515
|
|
@@ -177,13 +520,13 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
177
520
|
})
|
178
521
|
|
179
522
|
return {
|
180
|
-
"recommended_steps": "
|
523
|
+
"recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended SQL Database Steps:"),
|
181
524
|
"all_sql_database_summary": all_sql_database_summary
|
182
525
|
}
|
183
526
|
|
184
527
|
def create_sql_query_code(state: GraphState):
|
185
528
|
if bypass_recommended_steps:
|
186
|
-
print(
|
529
|
+
print(format_agent_name(AGENT_NAME))
|
187
530
|
print(" * CREATE SQL QUERY CODE")
|
188
531
|
|
189
532
|
# Prompt to get the SQL code from the LLM
|
@@ -199,6 +542,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
199
542
|
- Only return a single query if possible.
|
200
543
|
- Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
201
544
|
- Pay attention to the SQL dialect from the database summary metadata. Write the SQL code according to the dialect specified.
|
545
|
+
- IMPORTANT: Pay attention to the table names and column names in the database. Make sure to use the correct table and column names in the SQL code. If a space is present in the table name or column name, make sure to account for it.
|
202
546
|
|
203
547
|
|
204
548
|
User instructions / Question:
|
@@ -228,7 +572,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
228
572
|
conn = connection.connect() if is_engine else connection
|
229
573
|
|
230
574
|
# Get the database metadata
|
231
|
-
all_sql_database_summary = get_database_metadata(conn,
|
575
|
+
all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
|
232
576
|
|
233
577
|
sql_query_code_agent = sql_query_code_prompt | llm | SQLOutputParser()
|
234
578
|
|
@@ -241,7 +585,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
241
585
|
print(" * CREATE PYTHON FUNCTION TO RUN SQL CODE")
|
242
586
|
|
243
587
|
response = f"""
|
244
|
-
def
|
588
|
+
def {function_name}(connection):
|
245
589
|
import pandas as pd
|
246
590
|
import sqlalchemy as sql
|
247
591
|
|
@@ -259,9 +603,9 @@ def sql_database_pipeline(connection):
|
|
259
603
|
response = add_comments_to_top(response, AGENT_NAME)
|
260
604
|
|
261
605
|
# For logging: store the code generated
|
262
|
-
file_path,
|
606
|
+
file_path, file_name_2 = log_ai_function(
|
263
607
|
response=response,
|
264
|
-
file_name=
|
608
|
+
file_name=file_name,
|
265
609
|
log=log,
|
266
610
|
log_path=log_path,
|
267
611
|
overwrite=overwrite
|
@@ -271,18 +615,37 @@ def sql_database_pipeline(connection):
|
|
271
615
|
"sql_query_code": sql_query_code,
|
272
616
|
"sql_database_function": response,
|
273
617
|
"sql_database_function_path": file_path,
|
274
|
-
"
|
618
|
+
"sql_database_function_file_name": file_name_2,
|
619
|
+
"sql_database_function_name": function_name,
|
620
|
+
"all_sql_database_summary": all_sql_database_summary
|
275
621
|
}
|
276
622
|
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
623
|
+
# Human Review
|
624
|
+
|
625
|
+
prompt_text_human_review = "Are the following SQL agent instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
|
626
|
+
|
627
|
+
if not bypass_explain_code:
|
628
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_sql_steps", "explain_sql_database_code"]]:
|
629
|
+
return node_func_human_review(
|
630
|
+
state=state,
|
631
|
+
prompt_text=prompt_text_human_review,
|
632
|
+
yes_goto= 'explain_sql_database_code',
|
633
|
+
no_goto="recommend_sql_steps",
|
634
|
+
user_instructions_key="user_instructions",
|
635
|
+
recommended_steps_key="recommended_steps",
|
636
|
+
code_snippet_key="sql_database_function",
|
637
|
+
)
|
638
|
+
else:
|
639
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_sql_steps", "__end__"]]:
|
640
|
+
return node_func_human_review(
|
641
|
+
state=state,
|
642
|
+
prompt_text=prompt_text_human_review,
|
643
|
+
yes_goto= '__end__',
|
644
|
+
no_goto="recommend_sql_steps",
|
645
|
+
user_instructions_key="user_instructions",
|
646
|
+
recommended_steps_key="recommended_steps",
|
647
|
+
code_snippet_key="sql_database_function",
|
648
|
+
)
|
286
649
|
|
287
650
|
def execute_sql_database_code(state: GraphState):
|
288
651
|
|
@@ -295,18 +658,18 @@ def sql_database_pipeline(connection):
|
|
295
658
|
result_key="data_sql",
|
296
659
|
error_key="sql_database_error",
|
297
660
|
code_snippet_key="sql_database_function",
|
298
|
-
agent_function_name="
|
661
|
+
agent_function_name=state.get("sql_database_function_name"),
|
299
662
|
post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
300
663
|
error_message_prefix="An error occurred during executing the sql database pipeline: "
|
301
664
|
)
|
302
665
|
|
303
666
|
def fix_sql_database_code(state: GraphState):
|
304
667
|
prompt = """
|
305
|
-
You are a SQL Database Agent code fixer. Your job is to create a
|
668
|
+
You are a SQL Database Agent code fixer. Your job is to create a {function_name}(connection) function that can be run on a sql connection. The function is currently broken and needs to be fixed.
|
306
669
|
|
307
|
-
Make sure to only return the function definition for
|
670
|
+
Make sure to only return the function definition for {function_name}().
|
308
671
|
|
309
|
-
Return Python code in ```python``` format with a single function definition,
|
672
|
+
Return Python code in ```python``` format with a single function definition, {function_name}(connection), that includes all imports inside the function. The connection object is a SQLAlchemy connection object. Don't specify the class of the connection object, just use it as an argument to the function.
|
310
673
|
|
311
674
|
This is the broken code (please fix):
|
312
675
|
{code_snippet}
|
@@ -324,6 +687,7 @@ def sql_database_pipeline(connection):
|
|
324
687
|
agent_name=AGENT_NAME,
|
325
688
|
log=log,
|
326
689
|
file_path=state.get("sql_database_function_path", None),
|
690
|
+
function_name=state.get("sql_database_function_name"),
|
327
691
|
)
|
328
692
|
|
329
693
|
def explain_sql_database_code(state: GraphState):
|
@@ -0,0 +1 @@
|
|
1
|
+
from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
|