ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9007__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +2 -1
- ai_data_science_team/agents/data_cleaning_agent.py +204 -19
- ai_data_science_team/agents/data_visualization_agent.py +331 -0
- ai_data_science_team/agents/data_wrangling_agent.py +56 -11
- ai_data_science_team/agents/feature_engineering_agent.py +40 -11
- ai_data_science_team/agents/sql_database_agent.py +30 -12
- ai_data_science_team/templates/__init__.py +8 -0
- ai_data_science_team/tools/metadata.py +110 -47
- ai_data_science_team/tools/regex.py +6 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/METADATA +41 -23
- ai_data_science_team-0.0.0.9007.dist-info/RECORD +21 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/WHEEL +1 -1
- ai_data_science_team-0.0.0.9006.dist-info/RECORD +0 -20
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@ from langchain_core.messages import BaseMessage
|
|
15
15
|
from langgraph.types import Command
|
16
16
|
from langgraph.checkpoint.memory import MemorySaver
|
17
17
|
|
18
|
-
from ai_data_science_team.templates
|
18
|
+
from ai_data_science_team.templates import(
|
19
19
|
node_func_execute_agent_code_on_data,
|
20
20
|
node_func_human_review,
|
21
21
|
node_func_fix_agent_code,
|
@@ -23,7 +23,7 @@ from ai_data_science_team.templates.agent_templates import(
|
|
23
23
|
create_coding_agent_graph
|
24
24
|
)
|
25
25
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
26
|
-
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
|
26
|
+
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
|
27
27
|
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
28
28
|
from ai_data_science_team.tools.logging import log_ai_function
|
29
29
|
|
@@ -31,7 +31,17 @@ from ai_data_science_team.tools.logging import log_ai_function
|
|
31
31
|
AGENT_NAME = "data_wrangling_agent"
|
32
32
|
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
33
33
|
|
34
|
-
def make_data_wrangling_agent(
|
34
|
+
def make_data_wrangling_agent(
|
35
|
+
model,
|
36
|
+
n_samples=30,
|
37
|
+
log=False,
|
38
|
+
log_path=None,
|
39
|
+
file_name="data_wrangler.py",
|
40
|
+
overwrite = True,
|
41
|
+
human_in_the_loop=False,
|
42
|
+
bypass_recommended_steps=False,
|
43
|
+
bypass_explain_code=False
|
44
|
+
):
|
35
45
|
"""
|
36
46
|
Creates a data wrangling agent that can be run on one or more datasets. The agent can be
|
37
47
|
instructed to perform common data wrangling steps such as:
|
@@ -52,11 +62,17 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
52
62
|
----------
|
53
63
|
model : langchain.llms.base.LLM
|
54
64
|
The language model to use to generate code.
|
65
|
+
n_samples : int, optional
|
66
|
+
The number of samples to show in the data summary. Defaults to 30.
|
67
|
+
If you get an error due to maximum tokens, try reducing this number.
|
68
|
+
> "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
|
55
69
|
log : bool, optional
|
56
70
|
Whether or not to log the code generated and any errors that occur.
|
57
71
|
Defaults to False.
|
58
72
|
log_path : str, optional
|
59
73
|
The path to the directory where the log files should be stored. Defaults to "logs/".
|
74
|
+
file_name : str, optional
|
75
|
+
The name of the file to save the response to. Defaults to "data_wrangler.py".
|
60
76
|
overwrite : bool, optional
|
61
77
|
Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
|
62
78
|
Defaults to True.
|
@@ -94,7 +110,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
94
110
|
|
95
111
|
Returns
|
96
112
|
-------
|
97
|
-
app : langchain.graphs.
|
113
|
+
app : langchain.graphs.CompiledStateGraph
|
98
114
|
The data wrangling agent as a state graph.
|
99
115
|
"""
|
100
116
|
llm = model
|
@@ -122,7 +138,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
122
138
|
retry_count: int
|
123
139
|
|
124
140
|
def recommend_wrangling_steps(state: GraphState):
|
125
|
-
print(
|
141
|
+
print(format_agent_name(AGENT_NAME))
|
126
142
|
print(" * RECOMMEND WRANGLING STEPS")
|
127
143
|
|
128
144
|
data_raw = state.get("data_raw")
|
@@ -143,7 +159,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
143
159
|
|
144
160
|
# Create a summary for all datasets
|
145
161
|
# We'll include a short sample and info for each dataset
|
146
|
-
all_datasets_summary = get_dataframe_summary(dataframes)
|
162
|
+
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
|
147
163
|
|
148
164
|
# Join all datasets summaries into one big text block
|
149
165
|
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
@@ -176,6 +192,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
176
192
|
|
177
193
|
Avoid these:
|
178
194
|
1. Do not include steps to save files.
|
195
|
+
2. Do not include unrelated user instructions that are not related to the data wrangling.
|
179
196
|
""",
|
180
197
|
input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
|
181
198
|
)
|
@@ -195,7 +212,34 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
195
212
|
|
196
213
|
def create_data_wrangler_code(state: GraphState):
|
197
214
|
if bypass_recommended_steps:
|
198
|
-
print(
|
215
|
+
print(format_agent_name(AGENT_NAME))
|
216
|
+
|
217
|
+
data_raw = state.get("data_raw")
|
218
|
+
|
219
|
+
if isinstance(data_raw, dict):
|
220
|
+
# Single dataset scenario
|
221
|
+
primary_dataset_name = "main"
|
222
|
+
datasets = {primary_dataset_name: data_raw}
|
223
|
+
elif isinstance(data_raw, list) and all(isinstance(item, dict) for item in data_raw):
|
224
|
+
# Multiple datasets scenario
|
225
|
+
datasets = {f"dataset_{i}": d for i, d in enumerate(data_raw, start=1)}
|
226
|
+
primary_dataset_name = "dataset_1"
|
227
|
+
else:
|
228
|
+
raise ValueError("data_raw must be a dict or a list of dicts.")
|
229
|
+
|
230
|
+
# Convert all datasets to DataFrames for inspection
|
231
|
+
dataframes = {name: pd.DataFrame.from_dict(d) for name, d in datasets.items()}
|
232
|
+
|
233
|
+
# Create a summary for all datasets
|
234
|
+
# We'll include a short sample and info for each dataset
|
235
|
+
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
|
236
|
+
|
237
|
+
# Join all datasets summaries into one big text block
|
238
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
239
|
+
|
240
|
+
else:
|
241
|
+
all_datasets_summary_str = state.get("all_datasets_summary")
|
242
|
+
|
199
243
|
print(" * CREATE DATA WRANGLER CODE")
|
200
244
|
|
201
245
|
data_wrangling_prompt = PromptTemplate(
|
@@ -242,16 +286,16 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
242
286
|
|
243
287
|
response = data_wrangling_agent.invoke({
|
244
288
|
"recommended_steps": state.get("recommended_steps"),
|
245
|
-
"all_datasets_summary":
|
289
|
+
"all_datasets_summary": all_datasets_summary_str
|
246
290
|
})
|
247
291
|
|
248
292
|
response = relocate_imports_inside_function(response)
|
249
293
|
response = add_comments_to_top(response, agent_name=AGENT_NAME)
|
250
294
|
|
251
295
|
# For logging: store the code generated
|
252
|
-
file_path,
|
296
|
+
file_path, file_name_2 = log_ai_function(
|
253
297
|
response=response,
|
254
|
-
file_name=
|
298
|
+
file_name=file_name,
|
255
299
|
log=log,
|
256
300
|
log_path=log_path,
|
257
301
|
overwrite=overwrite
|
@@ -260,7 +304,8 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
|
|
260
304
|
return {
|
261
305
|
"data_wrangler_function" : response,
|
262
306
|
"data_wrangler_function_path": file_path,
|
263
|
-
"data_wrangler_function_name":
|
307
|
+
"data_wrangler_function_name": file_name_2,
|
308
|
+
"all_datasets_summary": all_datasets_summary_str
|
264
309
|
}
|
265
310
|
|
266
311
|
|
@@ -17,7 +17,7 @@ import os
|
|
17
17
|
import io
|
18
18
|
import pandas as pd
|
19
19
|
|
20
|
-
from ai_data_science_team.templates
|
20
|
+
from ai_data_science_team.templates import(
|
21
21
|
node_func_execute_agent_code_on_data,
|
22
22
|
node_func_human_review,
|
23
23
|
node_func_fix_agent_code,
|
@@ -25,7 +25,7 @@ from ai_data_science_team.templates.agent_templates import(
|
|
25
25
|
create_coding_agent_graph
|
26
26
|
)
|
27
27
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
28
|
-
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
|
28
|
+
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
|
29
29
|
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
30
30
|
from ai_data_science_team.tools.logging import log_ai_function
|
31
31
|
|
@@ -35,7 +35,17 @@ LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
|
35
35
|
|
36
36
|
# * Feature Engineering Agent
|
37
37
|
|
38
|
-
def make_feature_engineering_agent(
|
38
|
+
def make_feature_engineering_agent(
|
39
|
+
model,
|
40
|
+
n_samples=30,
|
41
|
+
log=False,
|
42
|
+
log_path=None,
|
43
|
+
file_name="feature_engineer.py",
|
44
|
+
overwrite = True,
|
45
|
+
human_in_the_loop=False,
|
46
|
+
bypass_recommended_steps=False,
|
47
|
+
bypass_explain_code=False,
|
48
|
+
):
|
39
49
|
"""
|
40
50
|
Creates a feature engineering agent that can be run on a dataset. The agent applies various feature engineering
|
41
51
|
techniques, such as encoding categorical variables, scaling numeric variables, creating interaction terms,
|
@@ -61,11 +71,17 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
61
71
|
----------
|
62
72
|
model : langchain.llms.base.LLM
|
63
73
|
The language model to use to generate code.
|
74
|
+
n_samples : int, optional
|
75
|
+
The number of data samples to use for generating the feature engineering code. Defaults to 30.
|
76
|
+
If you get an error due to maximum tokens, try reducing this number.
|
77
|
+
> "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
|
64
78
|
log : bool, optional
|
65
79
|
Whether or not to log the code generated and any errors that occur.
|
66
80
|
Defaults to False.
|
67
81
|
log_path : str, optional
|
68
82
|
The path to the directory where the log files should be stored. Defaults to "logs/".
|
83
|
+
file_name : str, optional
|
84
|
+
The name of the file to save the log to. Defaults to "feature_engineer.py".
|
69
85
|
overwrite : bool, optional
|
70
86
|
Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
|
71
87
|
Defaults to True.
|
@@ -102,7 +118,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
102
118
|
|
103
119
|
Returns
|
104
120
|
-------
|
105
|
-
app : langchain.graphs.
|
121
|
+
app : langchain.graphs.CompiledStateGraph
|
106
122
|
The feature engineering agent as a state graph.
|
107
123
|
"""
|
108
124
|
llm = model
|
@@ -135,7 +151,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
135
151
|
Recommend a series of feature engineering steps based on the input data.
|
136
152
|
These recommended steps will be appended to the user_instructions.
|
137
153
|
"""
|
138
|
-
print(
|
154
|
+
print(format_agent_name(AGENT_NAME))
|
139
155
|
print(" * RECOMMEND FEATURE ENGINEERING STEPS")
|
140
156
|
|
141
157
|
# Prompt to get recommended steps from the LLM
|
@@ -182,6 +198,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
182
198
|
|
183
199
|
Avoid these:
|
184
200
|
1. Do not include steps to save files.
|
201
|
+
2. Do not include unrelated user instructions that are not related to the feature engineering.
|
185
202
|
""",
|
186
203
|
input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
|
187
204
|
)
|
@@ -189,7 +206,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
189
206
|
data_raw = state.get("data_raw")
|
190
207
|
df = pd.DataFrame.from_dict(data_raw)
|
191
208
|
|
192
|
-
all_datasets_summary = get_dataframe_summary([df])
|
209
|
+
all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples)
|
193
210
|
|
194
211
|
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
195
212
|
|
@@ -217,7 +234,18 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
217
234
|
|
218
235
|
def create_feature_engineering_code(state: GraphState):
|
219
236
|
if bypass_recommended_steps:
|
220
|
-
print(
|
237
|
+
print(format_agent_name(AGENT_NAME))
|
238
|
+
|
239
|
+
data_raw = state.get("data_raw")
|
240
|
+
df = pd.DataFrame.from_dict(data_raw)
|
241
|
+
|
242
|
+
all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples)
|
243
|
+
|
244
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
245
|
+
|
246
|
+
else:
|
247
|
+
all_datasets_summary_str = state.get("all_datasets_summary")
|
248
|
+
|
221
249
|
print(" * CREATE FEATURE ENGINEERING CODE")
|
222
250
|
|
223
251
|
feature_engineering_prompt = PromptTemplate(
|
@@ -272,16 +300,16 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
272
300
|
response = feature_engineering_agent.invoke({
|
273
301
|
"recommended_steps": state.get("recommended_steps"),
|
274
302
|
"target_variable": state.get("target_variable"),
|
275
|
-
"all_datasets_summary":
|
303
|
+
"all_datasets_summary": all_datasets_summary_str,
|
276
304
|
})
|
277
305
|
|
278
306
|
response = relocate_imports_inside_function(response)
|
279
307
|
response = add_comments_to_top(response, agent_name=AGENT_NAME)
|
280
308
|
|
281
309
|
# For logging: store the code generated
|
282
|
-
file_path,
|
310
|
+
file_path, file_name_2 = log_ai_function(
|
283
311
|
response=response,
|
284
|
-
file_name=
|
312
|
+
file_name=file_name,
|
285
313
|
log=log,
|
286
314
|
log_path=log_path,
|
287
315
|
overwrite=overwrite
|
@@ -290,7 +318,8 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
|
|
290
318
|
return {
|
291
319
|
"feature_engineer_function": response,
|
292
320
|
"feature_engineer_function_path": file_path,
|
293
|
-
"feature_engineer_function_name":
|
321
|
+
"feature_engineer_function_name": file_name_2,
|
322
|
+
"all_datasets_summary": all_datasets_summary_str
|
294
323
|
}
|
295
324
|
|
296
325
|
|
@@ -14,7 +14,7 @@ import io
|
|
14
14
|
import pandas as pd
|
15
15
|
import sqlalchemy as sql
|
16
16
|
|
17
|
-
from ai_data_science_team.templates
|
17
|
+
from ai_data_science_team.templates import(
|
18
18
|
node_func_execute_agent_from_sql_connection,
|
19
19
|
node_func_human_review,
|
20
20
|
node_func_fix_agent_code,
|
@@ -22,7 +22,7 @@ from ai_data_science_team.templates.agent_templates import(
|
|
22
22
|
create_coding_agent_graph
|
23
23
|
)
|
24
24
|
from ai_data_science_team.tools.parsers import PythonOutputParser, SQLOutputParser
|
25
|
-
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
|
25
|
+
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
|
26
26
|
from ai_data_science_team.tools.metadata import get_database_metadata
|
27
27
|
from ai_data_science_team.tools.logging import log_ai_function
|
28
28
|
|
@@ -31,7 +31,16 @@ AGENT_NAME = "sql_database_agent"
|
|
31
31
|
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
32
32
|
|
33
33
|
|
34
|
-
def make_sql_database_agent(
|
34
|
+
def make_sql_database_agent(
|
35
|
+
model, connection,
|
36
|
+
n_samples = 10,
|
37
|
+
log=False,
|
38
|
+
log_path=None,
|
39
|
+
file_name="sql_database.py",
|
40
|
+
overwrite = True,
|
41
|
+
human_in_the_loop=False, bypass_recommended_steps=False,
|
42
|
+
bypass_explain_code=False
|
43
|
+
):
|
35
44
|
"""
|
36
45
|
Creates a SQL Database Agent that can recommend SQL steps and generate SQL code to query a database.
|
37
46
|
|
@@ -41,10 +50,16 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
41
50
|
The language model to use for the agent.
|
42
51
|
connection : sqlalchemy.engine.base.Engine
|
43
52
|
The connection to the SQL database.
|
53
|
+
n_samples : int, optional
|
54
|
+
The number of samples to retrieve for each column, by default 10.
|
55
|
+
If you get an error due to maximum tokens, try reducing this number.
|
56
|
+
> "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
|
44
57
|
log : bool, optional
|
45
58
|
Whether to log the generated code, by default False
|
46
59
|
log_path : str, optional
|
47
60
|
The path to the log directory, by default None
|
61
|
+
file_name : str, optional
|
62
|
+
The name of the file to save the generated code, by default "sql_database.py"
|
48
63
|
overwrite : bool, optional
|
49
64
|
Whether to overwrite the existing log file, by default True
|
50
65
|
human_in_the_loop : bool, optional
|
@@ -56,7 +71,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
56
71
|
|
57
72
|
Returns
|
58
73
|
-------
|
59
|
-
app : langchain.graphs.
|
74
|
+
app : langchain.graphs.CompiledStateGraph
|
60
75
|
The data cleaning agent as a state graph.
|
61
76
|
|
62
77
|
Examples
|
@@ -116,8 +131,8 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
116
131
|
|
117
132
|
def recommend_sql_steps(state: GraphState):
|
118
133
|
|
119
|
-
print(
|
120
|
-
print(" * RECOMMEND
|
134
|
+
print(format_agent_name(AGENT_NAME))
|
135
|
+
print(" * RECOMMEND STEPS")
|
121
136
|
|
122
137
|
|
123
138
|
# Prompt to get recommended steps from the LLM
|
@@ -156,6 +171,8 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
156
171
|
2. Do not include steps to modify existing tables, create new tables or modify the database schema.
|
157
172
|
3. Do not include steps that alter the existing data in the database.
|
158
173
|
4. Make sure not to include unsafe code that could cause data loss or corruption or SQL injections.
|
174
|
+
5. Make sure to not include irrelevant steps that do not help in the SQL agent's data collection and processing. Examples include steps to create new tables, modify the schema, save files, create charts, etc.
|
175
|
+
|
159
176
|
|
160
177
|
""",
|
161
178
|
input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
|
@@ -166,7 +183,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
166
183
|
conn = connection.connect() if is_engine else connection
|
167
184
|
|
168
185
|
# Get the database metadata
|
169
|
-
all_sql_database_summary = get_database_metadata(conn,
|
186
|
+
all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
|
170
187
|
|
171
188
|
steps_agent = recommend_steps_prompt | llm
|
172
189
|
|
@@ -183,7 +200,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
183
200
|
|
184
201
|
def create_sql_query_code(state: GraphState):
|
185
202
|
if bypass_recommended_steps:
|
186
|
-
print(
|
203
|
+
print(format_agent_name(AGENT_NAME))
|
187
204
|
print(" * CREATE SQL QUERY CODE")
|
188
205
|
|
189
206
|
# Prompt to get the SQL code from the LLM
|
@@ -228,7 +245,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
|
|
228
245
|
conn = connection.connect() if is_engine else connection
|
229
246
|
|
230
247
|
# Get the database metadata
|
231
|
-
all_sql_database_summary = get_database_metadata(conn,
|
248
|
+
all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
|
232
249
|
|
233
250
|
sql_query_code_agent = sql_query_code_prompt | llm | SQLOutputParser()
|
234
251
|
|
@@ -259,9 +276,9 @@ def sql_database_pipeline(connection):
|
|
259
276
|
response = add_comments_to_top(response, AGENT_NAME)
|
260
277
|
|
261
278
|
# For logging: store the code generated
|
262
|
-
file_path,
|
279
|
+
file_path, file_name_2 = log_ai_function(
|
263
280
|
response=response,
|
264
|
-
file_name=
|
281
|
+
file_name=file_name,
|
265
282
|
log=log,
|
266
283
|
log_path=log_path,
|
267
284
|
overwrite=overwrite
|
@@ -271,7 +288,8 @@ def sql_database_pipeline(connection):
|
|
271
288
|
"sql_query_code": sql_query_code,
|
272
289
|
"sql_database_function": response,
|
273
290
|
"sql_database_function_path": file_path,
|
274
|
-
"sql_database_function_name":
|
291
|
+
"sql_database_function_name": file_name_2,
|
292
|
+
"all_sql_database_summary": all_sql_database_summary
|
275
293
|
}
|
276
294
|
|
277
295
|
def human_review(state: GraphState) -> Command[Literal["recommend_sql_steps", "create_sql_query_code"]]:
|
@@ -4,7 +4,9 @@ import sqlalchemy as sql
|
|
4
4
|
from typing import Union, List, Dict
|
5
5
|
|
6
6
|
def get_dataframe_summary(
|
7
|
-
dataframes: Union[pd.DataFrame, List[pd.DataFrame], Dict[str, pd.DataFrame]]
|
7
|
+
dataframes: Union[pd.DataFrame, List[pd.DataFrame], Dict[str, pd.DataFrame]],
|
8
|
+
n_sample: int = 30,
|
9
|
+
skip_stats: bool = False,
|
8
10
|
) -> List[str]:
|
9
11
|
"""
|
10
12
|
Generate a summary for one or more DataFrames. Accepts a single DataFrame, a list of DataFrames,
|
@@ -16,6 +18,10 @@ def get_dataframe_summary(
|
|
16
18
|
- Single DataFrame: produce a single summary (returned within a one-element list).
|
17
19
|
- List of DataFrames: produce a summary for each DataFrame, using index-based names.
|
18
20
|
- Dictionary of DataFrames: produce a summary for each DataFrame, using dictionary keys as names.
|
21
|
+
n_sample : int, default 30
|
22
|
+
Number of rows to display in the "Data (first 30 rows)" section.
|
23
|
+
skip_stats : bool, default False
|
24
|
+
If True, skip the descriptive statistics and DataFrame info sections.
|
19
25
|
|
20
26
|
Example:
|
21
27
|
--------
|
@@ -49,17 +55,17 @@ def get_dataframe_summary(
|
|
49
55
|
# --- Dictionary Case ---
|
50
56
|
if isinstance(dataframes, dict):
|
51
57
|
for dataset_name, df in dataframes.items():
|
52
|
-
summaries.append(_summarize_dataframe(df, dataset_name))
|
58
|
+
summaries.append(_summarize_dataframe(df, dataset_name, n_sample, skip_stats))
|
53
59
|
|
54
60
|
# --- Single DataFrame Case ---
|
55
61
|
elif isinstance(dataframes, pd.DataFrame):
|
56
|
-
summaries.append(_summarize_dataframe(dataframes, "Single_Dataset"))
|
62
|
+
summaries.append(_summarize_dataframe(dataframes, "Single_Dataset", n_sample, skip_stats))
|
57
63
|
|
58
64
|
# --- List of DataFrames Case ---
|
59
65
|
elif isinstance(dataframes, list):
|
60
66
|
for idx, df in enumerate(dataframes):
|
61
67
|
dataset_name = f"Dataset_{idx}"
|
62
|
-
summaries.append(_summarize_dataframe(df, dataset_name))
|
68
|
+
summaries.append(_summarize_dataframe(df, dataset_name, n_sample, skip_stats))
|
63
69
|
|
64
70
|
else:
|
65
71
|
raise TypeError(
|
@@ -69,7 +75,7 @@ def get_dataframe_summary(
|
|
69
75
|
return summaries
|
70
76
|
|
71
77
|
|
72
|
-
def _summarize_dataframe(df: pd.DataFrame, dataset_name: str) -> str:
|
78
|
+
def _summarize_dataframe(df: pd.DataFrame, dataset_name: str, n_sample=30, skip_stats=False) -> str:
|
73
79
|
"""Generate a summary string for a single DataFrame."""
|
74
80
|
# 1. Convert dictionary-type cells to strings
|
75
81
|
# This prevents unhashable dict errors during df.nunique().
|
@@ -91,77 +97,134 @@ def _summarize_dataframe(df: pd.DataFrame, dataset_name: str) -> str:
|
|
91
97
|
unique_counts = df.nunique() # Will no longer fail on unhashable dict
|
92
98
|
unique_counts_summary = "\n".join([f"{col}: {count}" for col, count in unique_counts.items()])
|
93
99
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
100
|
+
# 6. Generate the summary text
|
101
|
+
if not skip_stats:
|
102
|
+
summary_text = f"""
|
103
|
+
Dataset Name: {dataset_name}
|
104
|
+
----------------------------
|
105
|
+
Shape: {df.shape[0]} rows x {df.shape[1]} columns
|
98
106
|
|
99
|
-
|
100
|
-
|
107
|
+
Column Data Types:
|
108
|
+
{column_types}
|
101
109
|
|
102
|
-
|
103
|
-
|
110
|
+
Missing Value Percentage:
|
111
|
+
{missing_summary}
|
104
112
|
|
105
|
-
|
106
|
-
|
113
|
+
Unique Value Counts:
|
114
|
+
{unique_counts_summary}
|
107
115
|
|
108
|
-
|
109
|
-
|
116
|
+
Data (first {n_sample} rows):
|
117
|
+
{df.head(n_sample).to_string()}
|
110
118
|
|
111
|
-
|
112
|
-
|
119
|
+
Data Description:
|
120
|
+
{df.describe().to_string()}
|
113
121
|
|
114
|
-
|
115
|
-
|
116
|
-
|
122
|
+
Data Info:
|
123
|
+
{info_text}
|
124
|
+
"""
|
125
|
+
else:
|
126
|
+
summary_text = f"""
|
127
|
+
Dataset Name: {dataset_name}
|
128
|
+
----------------------------
|
129
|
+
Shape: {df.shape[0]} rows x {df.shape[1]} columns
|
130
|
+
|
131
|
+
Column Data Types:
|
132
|
+
{column_types}
|
133
|
+
|
134
|
+
Data (first {n_sample} rows):
|
135
|
+
{df.head(n_sample).to_string()}
|
136
|
+
"""
|
137
|
+
|
117
138
|
return summary_text.strip()
|
118
139
|
|
119
140
|
|
120
|
-
def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engine.base.Engine], n_values: int=10):
|
121
|
-
"""
|
122
|
-
Collects metadata and sample data from a database.
|
123
141
|
|
124
|
-
|
125
|
-
|
126
|
-
|
142
|
+
def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engine.base.Engine],
|
143
|
+
n_samples: int = 10) -> str:
|
144
|
+
"""
|
145
|
+
Collects metadata and sample data from a database, with safe identifier quoting and
|
146
|
+
basic dialect-aware row limiting. Prevents issues with spaces/reserved words in identifiers.
|
147
|
+
|
148
|
+
Parameters
|
149
|
+
----------
|
150
|
+
connection : Union[sql.engine.base.Connection, sql.engine.base.Engine]
|
127
151
|
An active SQLAlchemy connection or engine.
|
128
|
-
|
152
|
+
n_samples : int
|
129
153
|
Number of sample values to retrieve for each column.
|
130
154
|
|
131
|
-
Returns
|
132
|
-
|
133
|
-
str
|
155
|
+
Returns
|
156
|
+
-------
|
157
|
+
str
|
158
|
+
A formatted string with database metadata, including some sample data from each column.
|
134
159
|
"""
|
160
|
+
|
135
161
|
# If a connection is passed, use it; if an engine is passed, connect to it
|
136
162
|
is_engine = isinstance(connection, sql.engine.base.Engine)
|
137
163
|
conn = connection.connect() if is_engine else connection
|
138
|
-
output = []
|
139
164
|
|
165
|
+
output = []
|
140
166
|
try:
|
141
|
-
#
|
167
|
+
# Grab the engine off the connection
|
142
168
|
sql_engine = conn.engine
|
169
|
+
dialect_name = sql_engine.dialect.name.lower()
|
170
|
+
|
143
171
|
output.append(f"Database Dialect: {sql_engine.dialect.name}")
|
144
172
|
output.append(f"Driver: {sql_engine.driver}")
|
145
173
|
output.append(f"Connection URL: {sql_engine.url}")
|
146
|
-
|
174
|
+
|
147
175
|
# Inspect the database
|
148
176
|
inspector = sql.inspect(sql_engine)
|
149
|
-
|
177
|
+
tables = inspector.get_table_names()
|
178
|
+
output.append(f"Tables: {tables}")
|
150
179
|
output.append(f"Schemas: {inspector.get_schema_names()}")
|
151
|
-
|
152
|
-
#
|
153
|
-
|
180
|
+
|
181
|
+
# Helper to build a dialect-specific limit clause
|
182
|
+
def build_query(col_name_quoted: str, table_name_quoted: str, n: int) -> str:
|
183
|
+
"""
|
184
|
+
Returns a SQL query string to select N rows from the given column/table
|
185
|
+
across different dialects (SQLite, MySQL, Postgres, MSSQL, Oracle, etc.)
|
186
|
+
"""
|
187
|
+
if "sqlite" in dialect_name or "mysql" in dialect_name or "postgres" in dialect_name:
|
188
|
+
# Common dialects supporting LIMIT
|
189
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
|
190
|
+
elif "mssql" in dialect_name:
|
191
|
+
# Microsoft SQL Server syntax
|
192
|
+
return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted}"
|
193
|
+
elif "oracle" in dialect_name:
|
194
|
+
# Oracle syntax
|
195
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
|
196
|
+
else:
|
197
|
+
# Fallback
|
198
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
|
199
|
+
|
200
|
+
# Prepare for quoting
|
201
|
+
preparer = inspector.bind.dialect.identifier_preparer
|
202
|
+
|
203
|
+
# For each table, get columns and sample data
|
204
|
+
for table_name in tables:
|
154
205
|
output.append(f"\nTable: {table_name}")
|
206
|
+
# Properly quote the table name
|
207
|
+
table_name_quoted = preparer.quote_identifier(table_name)
|
208
|
+
|
155
209
|
for column in inspector.get_columns(table_name):
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
210
|
+
col_name = column["name"]
|
211
|
+
col_type = column["type"]
|
212
|
+
output.append(f" Column: {col_name} Type: {col_type}")
|
213
|
+
|
214
|
+
# Properly quote the column name
|
215
|
+
col_name_quoted = preparer.quote_identifier(col_name)
|
216
|
+
|
217
|
+
# Build a dialect-aware query with safe quoting
|
218
|
+
query = build_query(col_name_quoted, table_name_quoted, n_samples)
|
219
|
+
|
220
|
+
# Read a few sample values
|
221
|
+
df = pd.read_sql(sql.text(query), conn)
|
222
|
+
first_values = df[col_name].tolist()
|
223
|
+
output.append(f" First {n_samples} Values: {first_values}")
|
224
|
+
|
161
225
|
finally:
|
162
|
-
# Close connection if
|
226
|
+
# Close connection if created inside the function
|
163
227
|
if is_engine:
|
164
228
|
conn.close()
|
165
|
-
|
166
|
-
# Join all collected information into a single string
|
229
|
+
|
167
230
|
return "\n".join(output)
|
@@ -71,3 +71,9 @@ def add_comments_to_top(code_text, agent_name="data_wrangler"):
|
|
71
71
|
# Join the header with newlines, then prepend to the existing code_text
|
72
72
|
header_block = "\n".join(header_comments)
|
73
73
|
return header_block + code_text
|
74
|
+
|
75
|
+
def format_agent_name(agent_name: str) -> str:
|
76
|
+
|
77
|
+
formatted_name = agent_name.strip().replace("_", " ").upper()
|
78
|
+
|
79
|
+
return f"---{formatted_name}----"
|