ai-data-science-team 0.0.0.9008__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +0 -1
- ai_data_science_team/agents/data_cleaning_agent.py +45 -34
- ai_data_science_team/agents/data_visualization_agent.py +39 -43
- ai_data_science_team/agents/data_wrangling_agent.py +45 -44
- ai_data_science_team/agents/feature_engineering_agent.py +42 -61
- ai_data_science_team/agents/sql_database_agent.py +125 -71
- ai_data_science_team/ml_agents/__init__.py +1 -0
- ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +119 -7
- ai_data_science_team/templates/__init__.py +1 -0
- ai_data_science_team/templates/agent_templates.py +73 -2
- ai_data_science_team/tools/metadata.py +94 -62
- ai_data_science_team/tools/regex.py +59 -1
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/METADATA +28 -14
- ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +0 -26
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@ from langgraph.types import Command
|
|
14
14
|
from langgraph.checkpoint.memory import MemorySaver
|
15
15
|
|
16
16
|
import os
|
17
|
+
import json
|
17
18
|
import pandas as pd
|
18
19
|
|
19
20
|
from IPython.display import Markdown
|
@@ -22,7 +23,7 @@ from ai_data_science_team.templates import(
|
|
22
23
|
node_func_execute_agent_code_on_data,
|
23
24
|
node_func_human_review,
|
24
25
|
node_func_fix_agent_code,
|
25
|
-
|
26
|
+
node_func_report_agent_outputs,
|
26
27
|
create_coding_agent_graph,
|
27
28
|
BaseAgent,
|
28
29
|
)
|
@@ -31,7 +32,8 @@ from ai_data_science_team.tools.regex import (
|
|
31
32
|
relocate_imports_inside_function,
|
32
33
|
add_comments_to_top,
|
33
34
|
format_agent_name,
|
34
|
-
format_recommended_steps
|
35
|
+
format_recommended_steps,
|
36
|
+
get_generic_summary,
|
35
37
|
)
|
36
38
|
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
37
39
|
from ai_data_science_team.tools.logging import log_ai_function
|
@@ -103,8 +105,8 @@ class FeatureEngineeringAgent(BaseAgent):
|
|
103
105
|
retry_count=0
|
104
106
|
)
|
105
107
|
Engineers features from the provided dataset synchronously based on user instructions.
|
106
|
-
|
107
|
-
|
108
|
+
get_workflow_summary()
|
109
|
+
Retrieves a summary of the agent's workflow.
|
108
110
|
get_log_summary()
|
109
111
|
Retrieves a summary of logged operations if logging is enabled.
|
110
112
|
get_data_engineered()
|
@@ -285,40 +287,34 @@ class FeatureEngineeringAgent(BaseAgent):
|
|
285
287
|
self.response = response
|
286
288
|
return None
|
287
289
|
|
288
|
-
def
|
290
|
+
def get_workflow_summary(self, markdown=False):
|
289
291
|
"""
|
290
|
-
|
291
|
-
|
292
|
-
Returns
|
293
|
-
-------
|
294
|
-
str or list
|
295
|
-
Explanation of the feature engineering steps.
|
292
|
+
Retrieves the agent's workflow summary, if logging is enabled.
|
296
293
|
"""
|
297
|
-
if self.response:
|
298
|
-
|
299
|
-
|
294
|
+
if self.response and self.response.get("messages"):
|
295
|
+
summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
|
296
|
+
if markdown:
|
297
|
+
return Markdown(summary)
|
298
|
+
else:
|
299
|
+
return summary
|
300
300
|
|
301
301
|
def get_log_summary(self, markdown=False):
|
302
302
|
"""
|
303
303
|
Logs a summary of the agent's operations, if logging is enabled.
|
304
|
+
"""
|
305
|
+
if self.response:
|
306
|
+
if self.response.get('feature_engineer_function_path'):
|
307
|
+
log_details = f"""
|
308
|
+
## Featuring Engineering Agent Log Summary:
|
304
309
|
|
305
|
-
|
306
|
-
----------
|
307
|
-
markdown : bool, optional
|
308
|
-
If True, returns Markdown-formatted output.
|
310
|
+
Function Path: {self.response.get('feature_engineer_function_path')}
|
309
311
|
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
log_details = f"Log Path: {self.response.get('feature_engineer_function_path')}"
|
317
|
-
if markdown:
|
318
|
-
return Markdown(log_details)
|
319
|
-
else:
|
320
|
-
return log_details
|
321
|
-
return None
|
312
|
+
Function Name: {self.response.get('feature_engineer_function_name')}
|
313
|
+
"""
|
314
|
+
if markdown:
|
315
|
+
return Markdown(log_details)
|
316
|
+
else:
|
317
|
+
return log_details
|
322
318
|
|
323
319
|
def get_data_engineered(self):
|
324
320
|
"""
|
@@ -388,22 +384,7 @@ class FeatureEngineeringAgent(BaseAgent):
|
|
388
384
|
return steps
|
389
385
|
return None
|
390
386
|
|
391
|
-
|
392
|
-
"""
|
393
|
-
Returns the agent's full response dictionary.
|
394
|
-
|
395
|
-
Returns
|
396
|
-
-------
|
397
|
-
dict or None
|
398
|
-
The response dictionary if available, otherwise None.
|
399
|
-
"""
|
400
|
-
return self.response
|
401
|
-
|
402
|
-
def show(self):
|
403
|
-
"""
|
404
|
-
Displays the agent's mermaid diagram for visual inspection of the compiled graph.
|
405
|
-
"""
|
406
|
-
return self._compiled_graph.show()
|
387
|
+
|
407
388
|
|
408
389
|
|
409
390
|
# * Feature Engineering Agent
|
@@ -576,7 +557,7 @@ def make_feature_engineering_agent(
|
|
576
557
|
Below are summaries of all datasets provided:
|
577
558
|
{all_datasets_summary}
|
578
559
|
|
579
|
-
Return
|
560
|
+
Return steps as a numbered list. You can return short code snippets to demonstrate actions. But do not return a fully coded solution. The code will be generated separately by a Coding Agent.
|
580
561
|
|
581
562
|
Avoid these:
|
582
563
|
1. Do not include steps to save files.
|
@@ -649,7 +630,6 @@ def make_feature_engineering_agent(
|
|
649
630
|
|
650
631
|
feature_engineering_prompt = PromptTemplate(
|
651
632
|
template="""
|
652
|
-
|
653
633
|
You are a Feature Engineering Agent. Your job is to create a {function_name}() function that can be run on the data provided using the following recommended steps.
|
654
634
|
|
655
635
|
Recommended Steps:
|
@@ -763,21 +743,22 @@ def make_feature_engineering_agent(
|
|
763
743
|
function_name=state.get("feature_engineer_function_name"),
|
764
744
|
)
|
765
745
|
|
766
|
-
|
767
|
-
|
746
|
+
# Final reporting node
|
747
|
+
def report_agent_outputs(state: GraphState):
|
748
|
+
return node_func_report_agent_outputs(
|
768
749
|
state=state,
|
769
|
-
|
750
|
+
keys_to_include=[
|
751
|
+
"recommended_steps",
|
752
|
+
"feature_engineer_function",
|
753
|
+
"feature_engineer_function_path",
|
754
|
+
"feature_engineer_function_name",
|
755
|
+
"feature_engineer_error",
|
756
|
+
],
|
770
757
|
result_key="messages",
|
771
|
-
error_key="feature_engineer_error",
|
772
|
-
llm=llm,
|
773
758
|
role=AGENT_NAME,
|
774
|
-
|
775
|
-
Explain the feature engineering steps performed by this function. Keep the explanation clear and concise.\n\n# Feature Engineering Agent:\n\n{code}
|
776
|
-
""",
|
777
|
-
success_prefix="# Feature Engineering Agent:\n\n ",
|
778
|
-
error_message="The Feature Engineering Agent encountered an error during feature engineering. Data could not be explained."
|
759
|
+
custom_title="Feature Engineering Agent Outputs"
|
779
760
|
)
|
780
|
-
|
761
|
+
|
781
762
|
# Create the graph
|
782
763
|
node_functions = {
|
783
764
|
"recommend_feature_engineering_steps": recommend_feature_engineering_steps,
|
@@ -785,7 +766,7 @@ def make_feature_engineering_agent(
|
|
785
766
|
"create_feature_engineering_code": create_feature_engineering_code,
|
786
767
|
"execute_feature_engineering_code": execute_feature_engineering_code,
|
787
768
|
"fix_feature_engineering_code": fix_feature_engineering_code,
|
788
|
-
"
|
769
|
+
"report_agent_outputs": report_agent_outputs,
|
789
770
|
}
|
790
771
|
|
791
772
|
app = create_coding_agent_graph(
|
@@ -795,7 +776,7 @@ def make_feature_engineering_agent(
|
|
795
776
|
create_code_node_name="create_feature_engineering_code",
|
796
777
|
execute_code_node_name="execute_feature_engineering_code",
|
797
778
|
fix_code_node_name="fix_feature_engineering_code",
|
798
|
-
explain_code_node_name="
|
779
|
+
explain_code_node_name="report_agent_outputs",
|
799
780
|
error_key="feature_engineer_error",
|
800
781
|
max_retries_key = "max_retries",
|
801
782
|
retry_count_key = "retry_count",
|
@@ -5,11 +5,13 @@ import operator
|
|
5
5
|
|
6
6
|
from langchain.prompts import PromptTemplate
|
7
7
|
from langchain_core.messages import BaseMessage
|
8
|
+
from langchain_core.output_parsers import JsonOutputParser
|
8
9
|
|
9
10
|
from langgraph.types import Command
|
10
11
|
from langgraph.checkpoint.memory import MemorySaver
|
11
12
|
|
12
13
|
import os
|
14
|
+
import json
|
13
15
|
import pandas as pd
|
14
16
|
import sqlalchemy as sql
|
15
17
|
|
@@ -19,12 +21,17 @@ from ai_data_science_team.templates import(
|
|
19
21
|
node_func_execute_agent_from_sql_connection,
|
20
22
|
node_func_human_review,
|
21
23
|
node_func_fix_agent_code,
|
22
|
-
|
24
|
+
node_func_report_agent_outputs,
|
23
25
|
create_coding_agent_graph,
|
24
26
|
BaseAgent,
|
25
27
|
)
|
26
28
|
from ai_data_science_team.tools.parsers import SQLOutputParser
|
27
|
-
from ai_data_science_team.tools.regex import
|
29
|
+
from ai_data_science_team.tools.regex import (
|
30
|
+
add_comments_to_top,
|
31
|
+
format_agent_name,
|
32
|
+
format_recommended_steps,
|
33
|
+
get_generic_summary,
|
34
|
+
)
|
28
35
|
from ai_data_science_team.tools.metadata import get_database_metadata
|
29
36
|
from ai_data_science_team.tools.logging import log_ai_function
|
30
37
|
|
@@ -51,7 +58,7 @@ class SQLDatabaseAgent(BaseAgent):
|
|
51
58
|
connection : sqlalchemy.engine.base.Engine or sqlalchemy.engine.base.Connection
|
52
59
|
The SQLAlchemy connection (or engine) to the database.
|
53
60
|
n_samples : int, optional
|
54
|
-
Number of sample rows (per column) to retrieve when summarizing database metadata. Defaults to
|
61
|
+
Number of sample rows (per column) to retrieve when summarizing database metadata. Defaults to 1.
|
55
62
|
log : bool, optional
|
56
63
|
Whether to log the generated code and errors. Defaults to False.
|
57
64
|
log_path : str, optional
|
@@ -68,6 +75,8 @@ class SQLDatabaseAgent(BaseAgent):
|
|
68
75
|
If True, skips the step that generates recommended SQL steps. Defaults to False.
|
69
76
|
bypass_explain_code : bool, optional
|
70
77
|
If True, skips the step that provides code explanations. Defaults to False.
|
78
|
+
smart_schema_pruning : bool, optional
|
79
|
+
If True, filters the tables and columns based on the user instructions and recommended steps. Defaults to False.
|
71
80
|
|
72
81
|
Methods
|
73
82
|
-------
|
@@ -77,8 +86,8 @@ class SQLDatabaseAgent(BaseAgent):
|
|
77
86
|
Asynchronously runs the agent to generate and execute a SQL query based on user instructions.
|
78
87
|
invoke_agent(user_instructions: str, max_retries=3, retry_count=0)
|
79
88
|
Synchronously runs the agent to generate and execute a SQL query based on user instructions.
|
80
|
-
|
81
|
-
|
89
|
+
get_workflow_summary()
|
90
|
+
Retrieves a summary of the agent's workflow.
|
82
91
|
get_log_summary()
|
83
92
|
Retrieves a summary of logged operations if logging is enabled.
|
84
93
|
get_data_sql()
|
@@ -139,7 +148,7 @@ class SQLDatabaseAgent(BaseAgent):
|
|
139
148
|
self,
|
140
149
|
model,
|
141
150
|
connection,
|
142
|
-
n_samples=
|
151
|
+
n_samples=1,
|
143
152
|
log=False,
|
144
153
|
log_path=None,
|
145
154
|
file_name="sql_database.py",
|
@@ -147,7 +156,8 @@ class SQLDatabaseAgent(BaseAgent):
|
|
147
156
|
overwrite=True,
|
148
157
|
human_in_the_loop=False,
|
149
158
|
bypass_recommended_steps=False,
|
150
|
-
bypass_explain_code=False
|
159
|
+
bypass_explain_code=False,
|
160
|
+
smart_schema_pruning=False,
|
151
161
|
):
|
152
162
|
self._params = {
|
153
163
|
"model": model,
|
@@ -160,7 +170,8 @@ class SQLDatabaseAgent(BaseAgent):
|
|
160
170
|
"overwrite": overwrite,
|
161
171
|
"human_in_the_loop": human_in_the_loop,
|
162
172
|
"bypass_recommended_steps": bypass_recommended_steps,
|
163
|
-
"bypass_explain_code": bypass_explain_code
|
173
|
+
"bypass_explain_code": bypass_explain_code,
|
174
|
+
"smart_schema_pruning": smart_schema_pruning,
|
164
175
|
}
|
165
176
|
self._compiled_graph = self._make_compiled_graph()
|
166
177
|
self.response = None
|
@@ -234,40 +245,34 @@ class SQLDatabaseAgent(BaseAgent):
|
|
234
245
|
}, **kwargs)
|
235
246
|
self.response = response
|
236
247
|
|
237
|
-
def
|
248
|
+
def get_workflow_summary(self, markdown=False):
|
238
249
|
"""
|
239
|
-
|
240
|
-
if the explain step is not bypassed.
|
241
|
-
|
242
|
-
Returns
|
243
|
-
-------
|
244
|
-
str or list
|
245
|
-
An explanation of the SQL steps.
|
250
|
+
Retrieves the agent's workflow summary, if logging is enabled.
|
246
251
|
"""
|
247
|
-
if self.response:
|
248
|
-
|
249
|
-
|
252
|
+
if self.response and self.response.get("messages"):
|
253
|
+
summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
|
254
|
+
if markdown:
|
255
|
+
return Markdown(summary)
|
256
|
+
else:
|
257
|
+
return summary
|
250
258
|
|
251
259
|
def get_log_summary(self, markdown=False):
|
252
260
|
"""
|
253
|
-
|
261
|
+
Logs a summary of the agent's operations, if logging is enabled.
|
262
|
+
"""
|
263
|
+
if self.response:
|
264
|
+
if self.response.get('sql_database_function_path'):
|
265
|
+
log_details = f"""
|
266
|
+
## SQL Database Agent Log Summary:
|
254
267
|
|
255
|
-
|
256
|
-
----------
|
257
|
-
markdown : bool, optional
|
258
|
-
If True, returns the summary in Markdown format.
|
268
|
+
Function Path: {self.response.get('sql_database_function_path')}
|
259
269
|
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
log_details = f"Log Path: {self.response['sql_database_function_path']}"
|
267
|
-
if markdown:
|
268
|
-
return Markdown(log_details)
|
269
|
-
return log_details
|
270
|
-
return None
|
270
|
+
Function Name: {self.response.get('sql_database_function_name')}
|
271
|
+
"""
|
272
|
+
if markdown:
|
273
|
+
return Markdown(log_details)
|
274
|
+
else:
|
275
|
+
return log_details
|
271
276
|
|
272
277
|
def get_data_sql(self):
|
273
278
|
"""
|
@@ -351,14 +356,16 @@ class SQLDatabaseAgent(BaseAgent):
|
|
351
356
|
def make_sql_database_agent(
|
352
357
|
model,
|
353
358
|
connection,
|
354
|
-
n_samples
|
359
|
+
n_samples=1,
|
355
360
|
log=False,
|
356
361
|
log_path=None,
|
357
362
|
file_name="sql_database.py",
|
358
363
|
function_name="sql_database_pipeline",
|
359
364
|
overwrite = True,
|
360
|
-
human_in_the_loop=False,
|
361
|
-
|
365
|
+
human_in_the_loop=False,
|
366
|
+
bypass_recommended_steps=False,
|
367
|
+
bypass_explain_code=False,
|
368
|
+
smart_schema_pruning=False,
|
362
369
|
):
|
363
370
|
"""
|
364
371
|
Creates a SQL Database Agent that can recommend SQL steps and generate SQL code to query a database.
|
@@ -370,7 +377,7 @@ def make_sql_database_agent(
|
|
370
377
|
connection : sqlalchemy.engine.base.Engine
|
371
378
|
The connection to the SQL database.
|
372
379
|
n_samples : int, optional
|
373
|
-
The number of samples to retrieve for each column, by default
|
380
|
+
The number of samples to retrieve for each column, by default 1.
|
374
381
|
If you get an error due to maximum tokens, try reducing this number.
|
375
382
|
> "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
|
376
383
|
log : bool, optional
|
@@ -387,6 +394,8 @@ def make_sql_database_agent(
|
|
387
394
|
Bypass the recommendation step, by default False
|
388
395
|
bypass_explain_code : bool, optional
|
389
396
|
Bypass the code explanation step, by default False.
|
397
|
+
smart_schema_pruning : bool, optional
|
398
|
+
If True, filters the tables and columns with an extra LLM step to reduce tokens for large databases. Increases processing time but can avoid errors due to hitting max token limits with large databases. Defaults to False.
|
390
399
|
|
391
400
|
Returns
|
392
401
|
-------
|
@@ -419,12 +428,8 @@ def make_sql_database_agent(
|
|
419
428
|
"retry_count":0
|
420
429
|
})
|
421
430
|
```
|
422
|
-
|
423
431
|
"""
|
424
432
|
|
425
|
-
is_engine = isinstance(connection, sql.engine.base.Engine)
|
426
|
-
conn = connection.connect() if is_engine else connection
|
427
|
-
|
428
433
|
llm = model
|
429
434
|
|
430
435
|
# Human in th loop requires recommended steps
|
@@ -438,6 +443,11 @@ def make_sql_database_agent(
|
|
438
443
|
log_path = LOG_PATH
|
439
444
|
if not os.path.exists(log_path):
|
440
445
|
os.makedirs(log_path)
|
446
|
+
|
447
|
+
# Get the database metadata
|
448
|
+
is_engine = isinstance(connection, sql.engine.base.Engine)
|
449
|
+
conn = connection.connect() if is_engine else connection
|
450
|
+
|
441
451
|
|
442
452
|
class GraphState(TypedDict):
|
443
453
|
messages: Annotated[Sequence[BaseMessage], operator.add]
|
@@ -457,6 +467,16 @@ def make_sql_database_agent(
|
|
457
467
|
def recommend_sql_steps(state: GraphState):
|
458
468
|
|
459
469
|
print(format_agent_name(AGENT_NAME))
|
470
|
+
|
471
|
+
all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
|
472
|
+
|
473
|
+
all_sql_database_summary = smart_schema_filter(
|
474
|
+
llm,
|
475
|
+
state.get("user_instructions"),
|
476
|
+
all_sql_database_summary,
|
477
|
+
smart_filtering=smart_schema_pruning
|
478
|
+
)
|
479
|
+
|
460
480
|
print(" * RECOMMEND STEPS")
|
461
481
|
|
462
482
|
|
@@ -485,7 +505,7 @@ def make_sql_database_agent(
|
|
485
505
|
Below are summaries of the database metadata and the SQL tables:
|
486
506
|
{all_sql_database_summary}
|
487
507
|
|
488
|
-
Return
|
508
|
+
Return steps as a numbered list. You can return short code snippets to demonstrate actions. But do not return a fully coded solution. The code will be generated separately by a Coding Agent.
|
489
509
|
|
490
510
|
Consider these:
|
491
511
|
|
@@ -504,13 +524,6 @@ def make_sql_database_agent(
|
|
504
524
|
input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
|
505
525
|
)
|
506
526
|
|
507
|
-
# Create a connection if needed
|
508
|
-
is_engine = isinstance(connection, sql.engine.base.Engine)
|
509
|
-
conn = connection.connect() if is_engine else connection
|
510
|
-
|
511
|
-
# Get the database metadata
|
512
|
-
all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
|
513
|
-
|
514
527
|
steps_agent = recommend_steps_prompt | llm
|
515
528
|
|
516
529
|
recommended_steps = steps_agent.invoke({
|
@@ -527,6 +540,15 @@ def make_sql_database_agent(
|
|
527
540
|
def create_sql_query_code(state: GraphState):
|
528
541
|
if bypass_recommended_steps:
|
529
542
|
print(format_agent_name(AGENT_NAME))
|
543
|
+
all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
|
544
|
+
all_sql_database_summary = smart_schema_filter(
|
545
|
+
llm,
|
546
|
+
state.get("user_instructions"),
|
547
|
+
all_sql_database_summary,
|
548
|
+
smart_filtering=smart_schema_pruning
|
549
|
+
)
|
550
|
+
else:
|
551
|
+
all_sql_database_summary = state.get("all_sql_database_summary")
|
530
552
|
print(" * CREATE SQL QUERY CODE")
|
531
553
|
|
532
554
|
# Prompt to get the SQL code from the LLM
|
@@ -567,13 +589,6 @@ def make_sql_database_agent(
|
|
567
589
|
input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
|
568
590
|
)
|
569
591
|
|
570
|
-
# Create a connection if needed
|
571
|
-
is_engine = isinstance(connection, sql.engine.base.Engine)
|
572
|
-
conn = connection.connect() if is_engine else connection
|
573
|
-
|
574
|
-
# Get the database metadata
|
575
|
-
all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
|
576
|
-
|
577
592
|
sql_query_code_agent = sql_query_code_prompt | llm | SQLOutputParser()
|
578
593
|
|
579
594
|
sql_query_code = sql_query_code_agent.invoke({
|
@@ -690,20 +705,20 @@ def {function_name}(connection):
|
|
690
705
|
function_name=state.get("sql_database_function_name"),
|
691
706
|
)
|
692
707
|
|
693
|
-
|
694
|
-
|
708
|
+
# Final reporting node
|
709
|
+
def report_agent_outputs(state: GraphState):
|
710
|
+
return node_func_report_agent_outputs(
|
695
711
|
state=state,
|
696
|
-
|
712
|
+
keys_to_include=[
|
713
|
+
"recommended_steps",
|
714
|
+
"sql_database_function",
|
715
|
+
"sql_database_function_path",
|
716
|
+
"sql_database_function_name",
|
717
|
+
"sql_database_error",
|
718
|
+
],
|
697
719
|
result_key="messages",
|
698
|
-
error_key="sql_database_error",
|
699
|
-
llm=llm,
|
700
720
|
role=AGENT_NAME,
|
701
|
-
|
702
|
-
Explain the SQL steps that the SQL Database agent performed in this function.
|
703
|
-
Keep the summary succinct and to the point.\n\n# SQL Database Agent:\n\n{code}
|
704
|
-
""",
|
705
|
-
success_prefix="# SQL Database Agent:\n\n",
|
706
|
-
error_message="The SQL Database Agent encountered an error during SQL Query Analysis. No SQL function explanation is returned."
|
721
|
+
custom_title="SQL Database Agent Outputs"
|
707
722
|
)
|
708
723
|
|
709
724
|
# Create the graph
|
@@ -713,7 +728,7 @@ def {function_name}(connection):
|
|
713
728
|
"create_sql_query_code": create_sql_query_code,
|
714
729
|
"execute_sql_database_code": execute_sql_database_code,
|
715
730
|
"fix_sql_database_code": fix_sql_database_code,
|
716
|
-
"
|
731
|
+
"report_agent_outputs": report_agent_outputs,
|
717
732
|
}
|
718
733
|
|
719
734
|
app = create_coding_agent_graph(
|
@@ -723,7 +738,7 @@ def {function_name}(connection):
|
|
723
738
|
create_code_node_name="create_sql_query_code",
|
724
739
|
execute_code_node_name="execute_sql_database_code",
|
725
740
|
fix_code_node_name="fix_sql_database_code",
|
726
|
-
explain_code_node_name="
|
741
|
+
explain_code_node_name="report_agent_outputs",
|
727
742
|
error_key="sql_database_error",
|
728
743
|
human_in_the_loop=human_in_the_loop,
|
729
744
|
human_review_node_name="human_review",
|
@@ -737,7 +752,46 @@ def {function_name}(connection):
|
|
737
752
|
|
738
753
|
|
739
754
|
|
755
|
+
def smart_schema_filter(llm, user_instructions, all_sql_database_summary, smart_filtering = True):
|
756
|
+
"""
|
757
|
+
This function filters the tables and columns based on the user instructions and the recommended steps.
|
758
|
+
"""
|
759
|
+
# Smart schema filtering
|
760
|
+
if smart_filtering:
|
761
|
+
print(" * SMART FILTER SCHEMA")
|
740
762
|
|
763
|
+
filter_schema_prompt = PromptTemplate(
|
764
|
+
template="""
|
765
|
+
You are a highly skilled data engineer. The user question is:
|
766
|
+
|
767
|
+
"{user_instructions}"
|
768
|
+
|
769
|
+
You have the full database metadata in JSON format below:
|
770
|
+
|
771
|
+
{all_sql_database_summary}
|
772
|
+
|
773
|
+
|
774
|
+
Please return ONLY the subset of this metadata that is relevant to answering the user’s question.
|
775
|
+
- Preserve the same JSON structure for "schemas" -> "tables" -> "columns".
|
776
|
+
- If any schemas/tables are irrelevant, omit them entirely.
|
777
|
+
- If some columns in a relevant table are not needed, you can still keep them if you aren't sure.
|
778
|
+
- However, try to keep only the minimum amount of data required to answer the user’s question.
|
779
|
+
|
780
|
+
Return a valid JSON object. Do not include any additional explanation or text outside of the JSON.
|
781
|
+
""",
|
782
|
+
input_variables=["user_instructions", "full_metadata_json"]
|
783
|
+
)
|
784
|
+
|
785
|
+
filter_schema_agent = filter_schema_prompt | llm | JsonOutputParser()
|
786
|
+
|
787
|
+
response = filter_schema_agent.invoke({
|
788
|
+
"user_instructions": user_instructions,
|
789
|
+
"all_sql_database_summary": all_sql_database_summary
|
790
|
+
})
|
791
|
+
|
792
|
+
return response
|
793
|
+
else:
|
794
|
+
return all_sql_database_summary
|
741
795
|
|
742
796
|
|
743
797
|
|
@@ -0,0 +1 @@
|
|
1
|
+
from ai_data_science_team.ml_agents.h2o_ml_agent import make_h2o_ml_agent, H2OMLAgent
|