ai-data-science-team 0.0.0.9005__py3-none-any.whl → 0.0.0.9007__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +3 -1
- ai_data_science_team/agents/data_cleaning_agent.py +213 -20
- ai_data_science_team/agents/data_visualization_agent.py +331 -0
- ai_data_science_team/agents/data_wrangling_agent.py +66 -24
- ai_data_science_team/agents/feature_engineering_agent.py +50 -13
- ai_data_science_team/agents/sql_database_agent.py +397 -0
- ai_data_science_team/templates/__init__.py +8 -0
- ai_data_science_team/templates/agent_templates.py +154 -37
- ai_data_science_team/tools/logging.py +1 -1
- ai_data_science_team/tools/metadata.py +230 -0
- ai_data_science_team/tools/regex.py +7 -1
- {ai_data_science_team-0.0.0.9005.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/METADATA +43 -22
- ai_data_science_team-0.0.0.9007.dist-info/RECORD +21 -0
- {ai_data_science_team-0.0.0.9005.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/WHEEL +1 -1
- ai_data_science_team/tools/data_analysis.py +0 -116
- ai_data_science_team-0.0.0.9005.dist-info/RECORD +0 -19
- {ai_data_science_team-0.0.0.9005.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9005.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ import os
|
|
17
17
|
import io
|
18
18
|
import pandas as pd
|
19
19
|
|
20
|
-
from ai_data_science_team.templates
|
20
|
+
from ai_data_science_team.templates import(
|
21
21
|
node_func_execute_agent_code_on_data,
|
22
22
|
node_func_human_review,
|
23
23
|
node_func_fix_agent_code,
|
@@ -25,8 +25,8 @@ from ai_data_science_team.templates.agent_templates import(
|
|
25
25
|
create_coding_agent_graph
|
26
26
|
)
|
27
27
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
28
|
-
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
|
29
|
-
from ai_data_science_team.tools.
|
28
|
+
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
|
29
|
+
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
30
30
|
from ai_data_science_team.tools.logging import log_ai_function
|
31
31
|
|
32
32
|
# Setup
|
@@ -35,7 +35,17 @@ LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
|
35
35
|
|
36
36
|
# * Feature Engineering Agent
|
37
37
|
|
38
|
-
def make_feature_engineering_agent(
|
38
|
+
def make_feature_engineering_agent(
|
39
|
+
model,
|
40
|
+
n_samples=30,
|
41
|
+
log=False,
|
42
|
+
log_path=None,
|
43
|
+
file_name="feature_engineer.py",
|
44
|
+
overwrite = True,
|
45
|
+
human_in_the_loop=False,
|
46
|
+
bypass_recommended_steps=False,
|
47
|
+
bypass_explain_code=False,
|
48
|
+
):
|
39
49
|
"""
|
40
50
|
Creates a feature engineering agent that can be run on a dataset. The agent applies various feature engineering
|
41
51
|
techniques, such as encoding categorical variables, scaling numeric variables, creating interaction terms,
|
@@ -61,16 +71,26 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
61
71
|
----------
|
62
72
|
model : langchain.llms.base.LLM
|
63
73
|
The language model to use to generate code.
|
74
|
+
n_samples : int, optional
|
75
|
+
The number of data samples to use for generating the feature engineering code. Defaults to 30.
|
76
|
+
If you get an error due to maximum tokens, try reducing this number.
|
77
|
+
> "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
|
64
78
|
log : bool, optional
|
65
79
|
Whether or not to log the code generated and any errors that occur.
|
66
80
|
Defaults to False.
|
67
81
|
log_path : str, optional
|
68
82
|
The path to the directory where the log files should be stored. Defaults to "logs/".
|
83
|
+
file_name : str, optional
|
84
|
+
The name of the file to save the log to. Defaults to "feature_engineer.py".
|
69
85
|
overwrite : bool, optional
|
70
86
|
Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
|
71
87
|
Defaults to True.
|
72
88
|
human_in_the_loop : bool, optional
|
73
89
|
Whether or not to use human in the loop. If True, adds an interput and human in the loop step that asks the user to review the feature engineering instructions. Defaults to False.
|
90
|
+
bypass_recommended_steps : bool, optional
|
91
|
+
Bypass the recommendation step, by default False
|
92
|
+
bypass_explain_code : bool, optional
|
93
|
+
Bypass the code explanation step, by default False.
|
74
94
|
|
75
95
|
Examples
|
76
96
|
-------
|
@@ -98,7 +118,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
98
118
|
|
99
119
|
Returns
|
100
120
|
-------
|
101
|
-
app : langchain.graphs.
|
121
|
+
app : langchain.graphs.CompiledStateGraph
|
102
122
|
The feature engineering agent as a state graph.
|
103
123
|
"""
|
104
124
|
llm = model
|
@@ -131,7 +151,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
131
151
|
Recommend a series of feature engineering steps based on the input data.
|
132
152
|
These recommended steps will be appended to the user_instructions.
|
133
153
|
"""
|
134
|
-
print(
|
154
|
+
print(format_agent_name(AGENT_NAME))
|
135
155
|
print(" * RECOMMEND FEATURE ENGINEERING STEPS")
|
136
156
|
|
137
157
|
# Prompt to get recommended steps from the LLM
|
@@ -178,6 +198,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
178
198
|
|
179
199
|
Avoid these:
|
180
200
|
1. Do not include steps to save files.
|
201
|
+
2. Do not include unrelated user instructions that are not related to the feature engineering.
|
181
202
|
""",
|
182
203
|
input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
|
183
204
|
)
|
@@ -185,7 +206,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
185
206
|
data_raw = state.get("data_raw")
|
186
207
|
df = pd.DataFrame.from_dict(data_raw)
|
187
208
|
|
188
|
-
all_datasets_summary =
|
209
|
+
all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples)
|
189
210
|
|
190
211
|
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
191
212
|
|
@@ -212,6 +233,19 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
212
233
|
)
|
213
234
|
|
214
235
|
def create_feature_engineering_code(state: GraphState):
|
236
|
+
if bypass_recommended_steps:
|
237
|
+
print(format_agent_name(AGENT_NAME))
|
238
|
+
|
239
|
+
data_raw = state.get("data_raw")
|
240
|
+
df = pd.DataFrame.from_dict(data_raw)
|
241
|
+
|
242
|
+
all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples)
|
243
|
+
|
244
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
245
|
+
|
246
|
+
else:
|
247
|
+
all_datasets_summary_str = state.get("all_datasets_summary")
|
248
|
+
|
215
249
|
print(" * CREATE FEATURE ENGINEERING CODE")
|
216
250
|
|
217
251
|
feature_engineering_prompt = PromptTemplate(
|
@@ -266,16 +300,16 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
266
300
|
response = feature_engineering_agent.invoke({
|
267
301
|
"recommended_steps": state.get("recommended_steps"),
|
268
302
|
"target_variable": state.get("target_variable"),
|
269
|
-
"all_datasets_summary":
|
303
|
+
"all_datasets_summary": all_datasets_summary_str,
|
270
304
|
})
|
271
305
|
|
272
306
|
response = relocate_imports_inside_function(response)
|
273
307
|
response = add_comments_to_top(response, agent_name=AGENT_NAME)
|
274
308
|
|
275
309
|
# For logging: store the code generated
|
276
|
-
file_path,
|
310
|
+
file_path, file_name_2 = log_ai_function(
|
277
311
|
response=response,
|
278
|
-
file_name=
|
312
|
+
file_name=file_name,
|
279
313
|
log=log,
|
280
314
|
log_path=log_path,
|
281
315
|
overwrite=overwrite
|
@@ -284,7 +318,8 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
284
318
|
return {
|
285
319
|
"feature_engineer_function": response,
|
286
320
|
"feature_engineer_function_path": file_path,
|
287
|
-
"feature_engineer_function_name":
|
321
|
+
"feature_engineer_function_name": file_name_2,
|
322
|
+
"all_datasets_summary": all_datasets_summary_str
|
288
323
|
}
|
289
324
|
|
290
325
|
|
@@ -298,7 +333,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
298
333
|
code_snippet_key="feature_engineer_function",
|
299
334
|
agent_function_name="feature_engineer",
|
300
335
|
pre_processing=lambda data: pd.DataFrame.from_dict(data),
|
301
|
-
post_processing=lambda df: df.to_dict(),
|
336
|
+
post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
302
337
|
error_message_prefix="An error occurred during feature engineering: "
|
303
338
|
)
|
304
339
|
|
@@ -362,7 +397,9 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
362
397
|
error_key="feature_engineer_error",
|
363
398
|
human_in_the_loop=human_in_the_loop,
|
364
399
|
human_review_node_name="human_review",
|
365
|
-
checkpointer=MemorySaver() if human_in_the_loop else None
|
400
|
+
checkpointer=MemorySaver() if human_in_the_loop else None,
|
401
|
+
bypass_recommended_steps=bypass_recommended_steps,
|
402
|
+
bypass_explain_code=bypass_explain_code,
|
366
403
|
)
|
367
404
|
|
368
405
|
return app
|
@@ -0,0 +1,397 @@
|
|
1
|
+
|
2
|
+
|
3
|
+
from typing import TypedDict, Annotated, Sequence, Literal
|
4
|
+
import operator
|
5
|
+
|
6
|
+
from langchain.prompts import PromptTemplate
|
7
|
+
from langchain_core.messages import BaseMessage
|
8
|
+
|
9
|
+
from langgraph.types import Command
|
10
|
+
from langgraph.checkpoint.memory import MemorySaver
|
11
|
+
|
12
|
+
import os
|
13
|
+
import io
|
14
|
+
import pandas as pd
|
15
|
+
import sqlalchemy as sql
|
16
|
+
|
17
|
+
from ai_data_science_team.templates import(
|
18
|
+
node_func_execute_agent_from_sql_connection,
|
19
|
+
node_func_human_review,
|
20
|
+
node_func_fix_agent_code,
|
21
|
+
node_func_explain_agent_code,
|
22
|
+
create_coding_agent_graph
|
23
|
+
)
|
24
|
+
from ai_data_science_team.tools.parsers import PythonOutputParser, SQLOutputParser
|
25
|
+
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
|
26
|
+
from ai_data_science_team.tools.metadata import get_database_metadata
|
27
|
+
from ai_data_science_team.tools.logging import log_ai_function
|
28
|
+
|
29
|
+
# Setup
|
30
|
+
AGENT_NAME = "sql_database_agent"
|
31
|
+
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
32
|
+
|
33
|
+
|
34
|
+
def make_sql_database_agent(
|
35
|
+
model, connection,
|
36
|
+
n_samples = 10,
|
37
|
+
log=False,
|
38
|
+
log_path=None,
|
39
|
+
file_name="sql_database.py",
|
40
|
+
overwrite = True,
|
41
|
+
human_in_the_loop=False, bypass_recommended_steps=False,
|
42
|
+
bypass_explain_code=False
|
43
|
+
):
|
44
|
+
"""
|
45
|
+
Creates a SQL Database Agent that can recommend SQL steps and generate SQL code to query a database.
|
46
|
+
|
47
|
+
Parameters
|
48
|
+
----------
|
49
|
+
model : ChatOpenAI
|
50
|
+
The language model to use for the agent.
|
51
|
+
connection : sqlalchemy.engine.base.Engine
|
52
|
+
The connection to the SQL database.
|
53
|
+
n_samples : int, optional
|
54
|
+
The number of samples to retrieve for each column, by default 10.
|
55
|
+
If you get an error due to maximum tokens, try reducing this number.
|
56
|
+
> "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
|
57
|
+
log : bool, optional
|
58
|
+
Whether to log the generated code, by default False
|
59
|
+
log_path : str, optional
|
60
|
+
The path to the log directory, by default None
|
61
|
+
file_name : str, optional
|
62
|
+
The name of the file to save the generated code, by default "sql_database.py"
|
63
|
+
overwrite : bool, optional
|
64
|
+
Whether to overwrite the existing log file, by default True
|
65
|
+
human_in_the_loop : bool, optional
|
66
|
+
Whether or not to use human in the loop. If True, adds an interput and human in the loop step that asks the user to review the feature engineering instructions. Defaults to False.
|
67
|
+
bypass_recommended_steps : bool, optional
|
68
|
+
Bypass the recommendation step, by default False
|
69
|
+
bypass_explain_code : bool, optional
|
70
|
+
Bypass the code explanation step, by default False.
|
71
|
+
|
72
|
+
Returns
|
73
|
+
-------
|
74
|
+
app : langchain.graphs.CompiledStateGraph
|
75
|
+
The data cleaning agent as a state graph.
|
76
|
+
|
77
|
+
Examples
|
78
|
+
--------
|
79
|
+
```python
|
80
|
+
from ai_data_science_team.agents import make_sql_database_agent
|
81
|
+
import sqlalchemy as sql
|
82
|
+
from langchain_openai import ChatOpenAI
|
83
|
+
|
84
|
+
sql_engine = sql.create_engine("sqlite:///data/leads_scored.db")
|
85
|
+
|
86
|
+
conn = sql_engine.connect()
|
87
|
+
|
88
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
89
|
+
|
90
|
+
sql_agent = make_sql_database_agent(
|
91
|
+
model=llm,
|
92
|
+
connection=conn
|
93
|
+
)
|
94
|
+
|
95
|
+
sql_agent
|
96
|
+
|
97
|
+
response = sql_agent.invoke({
|
98
|
+
"user_instructions": "List the tables in the database",
|
99
|
+
"max_retries":3,
|
100
|
+
"retry_count":0
|
101
|
+
})
|
102
|
+
```
|
103
|
+
|
104
|
+
"""
|
105
|
+
|
106
|
+
is_engine = isinstance(connection, sql.engine.base.Engine)
|
107
|
+
conn = connection.connect() if is_engine else connection
|
108
|
+
|
109
|
+
llm = model
|
110
|
+
|
111
|
+
# Setup Log Directory
|
112
|
+
if log:
|
113
|
+
if log_path is None:
|
114
|
+
log_path = LOG_PATH
|
115
|
+
if not os.path.exists(log_path):
|
116
|
+
os.makedirs(log_path)
|
117
|
+
|
118
|
+
class GraphState(TypedDict):
|
119
|
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
120
|
+
user_instructions: str
|
121
|
+
recommended_steps: str
|
122
|
+
data_sql: dict
|
123
|
+
all_sql_database_summary: str
|
124
|
+
sql_query_code: str
|
125
|
+
sql_database_function: str
|
126
|
+
sql_database_function_path: str
|
127
|
+
sql_database_function_name: str
|
128
|
+
sql_database_error: str
|
129
|
+
max_retries: int
|
130
|
+
retry_count: int
|
131
|
+
|
132
|
+
def recommend_sql_steps(state: GraphState):
|
133
|
+
|
134
|
+
print(format_agent_name(AGENT_NAME))
|
135
|
+
print(" * RECOMMEND STEPS")
|
136
|
+
|
137
|
+
|
138
|
+
# Prompt to get recommended steps from the LLM
|
139
|
+
recommend_steps_prompt = PromptTemplate(
|
140
|
+
template="""
|
141
|
+
You are a SQL Database Instructions Expert. Given the following information about the SQL database,
|
142
|
+
recommend a series of numbered steps to take to collect the data and process it according to user instructions.
|
143
|
+
The steps should be tailored to the SQL database characteristics and should be helpful
|
144
|
+
for a sql database coding agent that will write the SQL code.
|
145
|
+
|
146
|
+
IMPORTANT INSTRUCTIONS:
|
147
|
+
- Take into account the user instructions and the previously recommended steps.
|
148
|
+
- If no user instructions are provided, just return the steps needed to understand the database.
|
149
|
+
- Take into account the database dialect and the tables and columns in the database.
|
150
|
+
- Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
151
|
+
|
152
|
+
|
153
|
+
User instructions / Question:
|
154
|
+
{user_instructions}
|
155
|
+
|
156
|
+
Previously Recommended Steps (if any):
|
157
|
+
{recommended_steps}
|
158
|
+
|
159
|
+
Below are summaries of the database metadata and the SQL tables:
|
160
|
+
{all_sql_database_summary}
|
161
|
+
|
162
|
+
Return the steps as a numbered point list (no code, just the steps).
|
163
|
+
|
164
|
+
Consider these:
|
165
|
+
|
166
|
+
1. Consider the database dialect and the tables and columns in the database.
|
167
|
+
|
168
|
+
|
169
|
+
Avoid these:
|
170
|
+
1. Do not include steps to save files.
|
171
|
+
2. Do not include steps to modify existing tables, create new tables or modify the database schema.
|
172
|
+
3. Do not include steps that alter the existing data in the database.
|
173
|
+
4. Make sure not to include unsafe code that could cause data loss or corruption or SQL injections.
|
174
|
+
5. Make sure to not include irrelevant steps that do not help in the SQL agent's data collection and processing. Examples include steps to create new tables, modify the schema, save files, create charts, etc.
|
175
|
+
|
176
|
+
|
177
|
+
""",
|
178
|
+
input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
|
179
|
+
)
|
180
|
+
|
181
|
+
# Create a connection if needed
|
182
|
+
is_engine = isinstance(connection, sql.engine.base.Engine)
|
183
|
+
conn = connection.connect() if is_engine else connection
|
184
|
+
|
185
|
+
# Get the database metadata
|
186
|
+
all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
|
187
|
+
|
188
|
+
steps_agent = recommend_steps_prompt | llm
|
189
|
+
|
190
|
+
recommended_steps = steps_agent.invoke({
|
191
|
+
"user_instructions": state.get("user_instructions"),
|
192
|
+
"recommended_steps": state.get("recommended_steps"),
|
193
|
+
"all_sql_database_summary": all_sql_database_summary
|
194
|
+
})
|
195
|
+
|
196
|
+
return {
|
197
|
+
"recommended_steps": "\n\n# Recommended SQL Database Steps:\n" + recommended_steps.content.strip(),
|
198
|
+
"all_sql_database_summary": all_sql_database_summary
|
199
|
+
}
|
200
|
+
|
201
|
+
def create_sql_query_code(state: GraphState):
|
202
|
+
if bypass_recommended_steps:
|
203
|
+
print(format_agent_name(AGENT_NAME))
|
204
|
+
print(" * CREATE SQL QUERY CODE")
|
205
|
+
|
206
|
+
# Prompt to get the SQL code from the LLM
|
207
|
+
sql_query_code_prompt = PromptTemplate(
|
208
|
+
template="""
|
209
|
+
You are a SQL Database Coding Expert. Given the following information about the SQL database,
|
210
|
+
write the SQL code to collect the data and process it according to user instructions.
|
211
|
+
The code should be tailored to the SQL database characteristics and should take into account user instructions, recommended steps, database and table characteristics.
|
212
|
+
|
213
|
+
IMPORTANT INSTRUCTIONS:
|
214
|
+
- Do not use a LIMIT clause unless a user specifies a limit to be returned.
|
215
|
+
- Return SQL in ```sql ``` format.
|
216
|
+
- Only return a single query if possible.
|
217
|
+
- Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
218
|
+
- Pay attention to the SQL dialect from the database summary metadata. Write the SQL code according to the dialect specified.
|
219
|
+
|
220
|
+
|
221
|
+
User instructions / Question:
|
222
|
+
{user_instructions}
|
223
|
+
|
224
|
+
Recommended Steps:
|
225
|
+
{recommended_steps}
|
226
|
+
|
227
|
+
Below are summaries of the database metadata and the SQL tables:
|
228
|
+
{all_sql_database_summary}
|
229
|
+
|
230
|
+
Return:
|
231
|
+
- The SQL code in ```sql ``` format to collect the data and process it according to the user instructions.
|
232
|
+
|
233
|
+
Avoid these:
|
234
|
+
- Do not include steps to save files.
|
235
|
+
- Do not include steps to modify existing tables, create new tables or modify the database schema.
|
236
|
+
- Make sure not to alter the existing data in the database.
|
237
|
+
- Make sure not to include unsafe code that could cause data loss or corruption.
|
238
|
+
|
239
|
+
""",
|
240
|
+
input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
|
241
|
+
)
|
242
|
+
|
243
|
+
# Create a connection if needed
|
244
|
+
is_engine = isinstance(connection, sql.engine.base.Engine)
|
245
|
+
conn = connection.connect() if is_engine else connection
|
246
|
+
|
247
|
+
# Get the database metadata
|
248
|
+
all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
|
249
|
+
|
250
|
+
sql_query_code_agent = sql_query_code_prompt | llm | SQLOutputParser()
|
251
|
+
|
252
|
+
sql_query_code = sql_query_code_agent.invoke({
|
253
|
+
"user_instructions": state.get("user_instructions"),
|
254
|
+
"recommended_steps": state.get("recommended_steps"),
|
255
|
+
"all_sql_database_summary": all_sql_database_summary
|
256
|
+
})
|
257
|
+
|
258
|
+
print(" * CREATE PYTHON FUNCTION TO RUN SQL CODE")
|
259
|
+
|
260
|
+
response = f"""
|
261
|
+
def sql_database_pipeline(connection):
|
262
|
+
import pandas as pd
|
263
|
+
import sqlalchemy as sql
|
264
|
+
|
265
|
+
# Create a connection if needed
|
266
|
+
is_engine = isinstance(connection, sql.engine.base.Engine)
|
267
|
+
conn = connection.connect() if is_engine else connection
|
268
|
+
|
269
|
+
sql_query = '''
|
270
|
+
{sql_query_code}
|
271
|
+
'''
|
272
|
+
|
273
|
+
return pd.read_sql(sql_query, connection)
|
274
|
+
"""
|
275
|
+
|
276
|
+
response = add_comments_to_top(response, AGENT_NAME)
|
277
|
+
|
278
|
+
# For logging: store the code generated
|
279
|
+
file_path, file_name_2 = log_ai_function(
|
280
|
+
response=response,
|
281
|
+
file_name=file_name,
|
282
|
+
log=log,
|
283
|
+
log_path=log_path,
|
284
|
+
overwrite=overwrite
|
285
|
+
)
|
286
|
+
|
287
|
+
return {
|
288
|
+
"sql_query_code": sql_query_code,
|
289
|
+
"sql_database_function": response,
|
290
|
+
"sql_database_function_path": file_path,
|
291
|
+
"sql_database_function_name": file_name_2,
|
292
|
+
"all_sql_database_summary": all_sql_database_summary
|
293
|
+
}
|
294
|
+
|
295
|
+
def human_review(state: GraphState) -> Command[Literal["recommend_sql_steps", "create_sql_query_code"]]:
|
296
|
+
return node_func_human_review(
|
297
|
+
state=state,
|
298
|
+
prompt_text="Are the following SQL database querying steps correct? (Answer 'yes' or provide modifications)\n{steps}",
|
299
|
+
yes_goto="create_sql_query_code",
|
300
|
+
no_goto="recommend_sql_steps",
|
301
|
+
user_instructions_key="user_instructions",
|
302
|
+
recommended_steps_key="recommended_steps"
|
303
|
+
)
|
304
|
+
|
305
|
+
def execute_sql_database_code(state: GraphState):
|
306
|
+
|
307
|
+
is_engine = isinstance(connection, sql.engine.base.Engine)
|
308
|
+
conn = connection.connect() if is_engine else connection
|
309
|
+
|
310
|
+
return node_func_execute_agent_from_sql_connection(
|
311
|
+
state=state,
|
312
|
+
connection=conn,
|
313
|
+
result_key="data_sql",
|
314
|
+
error_key="sql_database_error",
|
315
|
+
code_snippet_key="sql_database_function",
|
316
|
+
agent_function_name="sql_database_pipeline",
|
317
|
+
post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
318
|
+
error_message_prefix="An error occurred during executing the sql database pipeline: "
|
319
|
+
)
|
320
|
+
|
321
|
+
def fix_sql_database_code(state: GraphState):
|
322
|
+
prompt = """
|
323
|
+
You are a SQL Database Agent code fixer. Your job is to create a sql_database_pipeline(connection) function that can be run on a sql connection. The function is currently broken and needs to be fixed.
|
324
|
+
|
325
|
+
Make sure to only return the function definition for sql_database_pipeline().
|
326
|
+
|
327
|
+
Return Python code in ```python``` format with a single function definition, sql_database_pipeline(connection), that includes all imports inside the function. The connection object is a SQLAlchemy connection object. Don't specify the class of the connection object, just use it as an argument to the function.
|
328
|
+
|
329
|
+
This is the broken code (please fix):
|
330
|
+
{code_snippet}
|
331
|
+
|
332
|
+
Last Known Error:
|
333
|
+
{error}
|
334
|
+
"""
|
335
|
+
|
336
|
+
return node_func_fix_agent_code(
|
337
|
+
state=state,
|
338
|
+
code_snippet_key="sql_database_function",
|
339
|
+
error_key="sql_database_error",
|
340
|
+
llm=llm,
|
341
|
+
prompt_template=prompt,
|
342
|
+
agent_name=AGENT_NAME,
|
343
|
+
log=log,
|
344
|
+
file_path=state.get("sql_database_function_path", None),
|
345
|
+
)
|
346
|
+
|
347
|
+
def explain_sql_database_code(state: GraphState):
|
348
|
+
return node_func_explain_agent_code(
|
349
|
+
state=state,
|
350
|
+
code_snippet_key="sql_database_function",
|
351
|
+
result_key="messages",
|
352
|
+
error_key="sql_database_error",
|
353
|
+
llm=llm,
|
354
|
+
role=AGENT_NAME,
|
355
|
+
explanation_prompt_template="""
|
356
|
+
Explain the SQL steps that the SQL Database agent performed in this function.
|
357
|
+
Keep the summary succinct and to the point.\n\n# SQL Database Agent:\n\n{code}
|
358
|
+
""",
|
359
|
+
success_prefix="# SQL Database Agent:\n\n",
|
360
|
+
error_message="The SQL Database Agent encountered an error during SQL Query Analysis. No SQL function explanation is returned."
|
361
|
+
)
|
362
|
+
|
363
|
+
# Create the graph
|
364
|
+
node_functions = {
|
365
|
+
"recommend_sql_steps": recommend_sql_steps,
|
366
|
+
"human_review": human_review,
|
367
|
+
"create_sql_query_code": create_sql_query_code,
|
368
|
+
"execute_sql_database_code": execute_sql_database_code,
|
369
|
+
"fix_sql_database_code": fix_sql_database_code,
|
370
|
+
"explain_sql_database_code": explain_sql_database_code
|
371
|
+
}
|
372
|
+
|
373
|
+
app = create_coding_agent_graph(
|
374
|
+
GraphState=GraphState,
|
375
|
+
node_functions=node_functions,
|
376
|
+
recommended_steps_node_name="recommend_sql_steps",
|
377
|
+
create_code_node_name="create_sql_query_code",
|
378
|
+
execute_code_node_name="execute_sql_database_code",
|
379
|
+
fix_code_node_name="fix_sql_database_code",
|
380
|
+
explain_code_node_name="explain_sql_database_code",
|
381
|
+
error_key="sql_database_error",
|
382
|
+
human_in_the_loop=human_in_the_loop,
|
383
|
+
human_review_node_name="human_review",
|
384
|
+
checkpointer=MemorySaver() if human_in_the_loop else None,
|
385
|
+
bypass_recommended_steps=bypass_recommended_steps,
|
386
|
+
bypass_explain_code=bypass_explain_code,
|
387
|
+
)
|
388
|
+
|
389
|
+
return app
|
390
|
+
|
391
|
+
|
392
|
+
|
393
|
+
|
394
|
+
|
395
|
+
|
396
|
+
|
397
|
+
|