ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (25) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +4 -5
  3. ai_data_science_team/agents/data_cleaning_agent.py +268 -116
  4. ai_data_science_team/agents/data_visualization_agent.py +470 -41
  5. ai_data_science_team/agents/data_wrangling_agent.py +471 -31
  6. ai_data_science_team/agents/feature_engineering_agent.py +426 -41
  7. ai_data_science_team/agents/sql_database_agent.py +458 -58
  8. ai_data_science_team/ml_agents/__init__.py +1 -0
  9. ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
  10. ai_data_science_team/multiagents/__init__.py +1 -0
  11. ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
  12. ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
  13. ai_data_science_team/templates/__init__.py +3 -1
  14. ai_data_science_team/templates/agent_templates.py +319 -43
  15. ai_data_science_team/tools/metadata.py +94 -62
  16. ai_data_science_team/tools/regex.py +86 -1
  17. ai_data_science_team/utils/__init__.py +0 -0
  18. ai_data_science_team/utils/plotly.py +24 -0
  19. ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
  20. ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
  21. ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
  22. ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
  23. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
  24. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
  25. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -5,24 +5,33 @@ import operator
5
5
 
6
6
  from langchain.prompts import PromptTemplate
7
7
  from langchain_core.messages import BaseMessage
8
+ from langchain_core.output_parsers import JsonOutputParser
8
9
 
9
10
  from langgraph.types import Command
10
11
  from langgraph.checkpoint.memory import MemorySaver
11
12
 
12
13
  import os
13
- import io
14
+ import json
14
15
  import pandas as pd
15
16
  import sqlalchemy as sql
16
17
 
18
+ from IPython.display import Markdown
19
+
17
20
  from ai_data_science_team.templates import(
18
21
  node_func_execute_agent_from_sql_connection,
19
22
  node_func_human_review,
20
23
  node_func_fix_agent_code,
21
- node_func_explain_agent_code,
22
- create_coding_agent_graph
24
+ node_func_report_agent_outputs,
25
+ create_coding_agent_graph,
26
+ BaseAgent,
27
+ )
28
+ from ai_data_science_team.tools.parsers import SQLOutputParser
29
+ from ai_data_science_team.tools.regex import (
30
+ add_comments_to_top,
31
+ format_agent_name,
32
+ format_recommended_steps,
33
+ get_generic_summary,
23
34
  )
24
- from ai_data_science_team.tools.parsers import PythonOutputParser, SQLOutputParser
25
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
26
35
  from ai_data_science_team.tools.metadata import get_database_metadata
27
36
  from ai_data_science_team.tools.logging import log_ai_function
28
37
 
@@ -30,16 +39,333 @@ from ai_data_science_team.tools.logging import log_ai_function
30
39
  AGENT_NAME = "sql_database_agent"
31
40
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
32
41
 
42
+ # Class
43
+
44
+ class SQLDatabaseAgent(BaseAgent):
45
+ """
46
+ Creates a SQL Database Agent that can recommend SQL steps and generate SQL code to query a database.
47
+ The agent can:
48
+ - Propose recommended steps to answer a user's query or instructions.
49
+ - Generate a SQL query based on the recommended steps and user instructions.
50
+ - Execute that SQL query against the provided database connection.
51
+ - Return the resulting data as a dictionary, suitable for conversion to a DataFrame or other structures.
52
+ - Log generated code and errors if enabled.
53
+
54
+ Parameters
55
+ ----------
56
+ model : ChatOpenAI or langchain.llms.base.LLM
57
+ The language model used to generate the SQL code.
58
+ connection : sqlalchemy.engine.base.Engine or sqlalchemy.engine.base.Connection
59
+ The SQLAlchemy connection (or engine) to the database.
60
+ n_samples : int, optional
61
+ Number of sample rows (per column) to retrieve when summarizing database metadata. Defaults to 1.
62
+ log : bool, optional
63
+ Whether to log the generated code and errors. Defaults to False.
64
+ log_path : str, optional
65
+ Directory path for storing log files. Defaults to None.
66
+ file_name : str, optional
67
+ Name of the file for saving the generated response. Defaults to "sql_database.py".
68
+ function_name : str, optional
69
+ Name of the Python function that executes the SQL query. Defaults to "sql_database_pipeline".
70
+ overwrite : bool, optional
71
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
72
+ human_in_the_loop : bool, optional
73
+ Enables user review of the recommended steps before generating code. Defaults to False.
74
+ bypass_recommended_steps : bool, optional
75
+ If True, skips the step that generates recommended SQL steps. Defaults to False.
76
+ bypass_explain_code : bool, optional
77
+ If True, skips the step that provides code explanations. Defaults to False.
78
+ smart_schema_pruning : bool, optional
79
+ If True, filters the tables and columns based on the user instructions and recommended steps. Defaults to False.
80
+
81
+ Methods
82
+ -------
83
+ update_params(**kwargs)
84
+ Updates the agent's parameters and rebuilds the compiled state graph.
85
+ ainvoke_agent(user_instructions: str, max_retries=3, retry_count=0)
86
+ Asynchronously runs the agent to generate and execute a SQL query based on user instructions.
87
+ invoke_agent(user_instructions: str, max_retries=3, retry_count=0)
88
+ Synchronously runs the agent to generate and execute a SQL query based on user instructions.
89
+ get_workflow_summary()
90
+ Retrieves a summary of the agent's workflow.
91
+ get_log_summary()
92
+ Retrieves a summary of logged operations if logging is enabled.
93
+ get_data_sql()
94
+ Retrieves the resulting data from the SQL query as a dictionary.
95
+ (You can convert this to a DataFrame if desired.)
96
+ get_sql_query_code()
97
+ Retrieves the exact SQL query generated by the agent.
98
+ get_sql_database_function()
99
+ Retrieves the Python function that executes the SQL query.
100
+ get_recommended_sql_steps()
101
+ Retrieves the recommended steps for querying the SQL database.
102
+ get_response()
103
+ Returns the full response dictionary from the agent.
104
+ show()
105
+ Displays the agent's mermaid diagram for visual inspection of the compiled graph.
106
+
107
+ Examples
108
+ --------
109
+ ```python
110
+ import sqlalchemy as sql
111
+ from langchain_openai import ChatOpenAI
112
+ from ai_data_science_team.agents import SQLDatabaseAgent
113
+
114
+ # Create the engine/connection
115
+ sql_engine = sql.create_engine("sqlite:///data/my_database.db")
116
+ conn = sql_engine.connect()
117
+
118
+ llm = ChatOpenAI(model="gpt-4o-mini")
119
+
120
+ sql_database_agent = SQLDatabaseAgent(
121
+ model=llm,
122
+ connection=conn,
123
+ n_samples=10,
124
+ log=True,
125
+ log_path="logs",
126
+ human_in_the_loop=True
127
+ )
128
+
129
+ # Example usage
130
+ sql_database_agent.invoke_agent(
131
+ user_instructions="List all the tables in the database.",
132
+ max_retries=3,
133
+ retry_count=0
134
+ )
135
+
136
+ data_result = sql_database_agent.get_data_sql() # dictionary of rows returned
137
+ sql_code = sql_database_agent.get_sql_query_code()
138
+ response = sql_database_agent.get_response()
139
+ ```
140
+
141
+ Returns
142
+ -------
143
+ SQLDatabaseAgent : langchain.graphs.CompiledStateGraph
144
+ A SQL database agent implemented as a compiled state graph.
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ model,
150
+ connection,
151
+ n_samples=1,
152
+ log=False,
153
+ log_path=None,
154
+ file_name="sql_database.py",
155
+ function_name="sql_database_pipeline",
156
+ overwrite=True,
157
+ human_in_the_loop=False,
158
+ bypass_recommended_steps=False,
159
+ bypass_explain_code=False,
160
+ smart_schema_pruning=False,
161
+ ):
162
+ self._params = {
163
+ "model": model,
164
+ "connection": connection,
165
+ "n_samples": n_samples,
166
+ "log": log,
167
+ "log_path": log_path,
168
+ "file_name": file_name,
169
+ "function_name": function_name,
170
+ "overwrite": overwrite,
171
+ "human_in_the_loop": human_in_the_loop,
172
+ "bypass_recommended_steps": bypass_recommended_steps,
173
+ "bypass_explain_code": bypass_explain_code,
174
+ "smart_schema_pruning": smart_schema_pruning,
175
+ }
176
+ self._compiled_graph = self._make_compiled_graph()
177
+ self.response = None
178
+
179
+ def _make_compiled_graph(self):
180
+ """
181
+ Create or rebuild the compiled graph for the SQL Database Agent.
182
+ Running this method resets the response to None.
183
+ """
184
+ self.response = None
185
+ return make_sql_database_agent(**self._params)
186
+
187
+ def update_params(self, **kwargs):
188
+ """
189
+ Updates the agent's parameters (e.g. connection, n_samples, log, etc.)
190
+ and rebuilds the compiled graph.
191
+ """
192
+ for k, v in kwargs.items():
193
+ self._params[k] = v
194
+ self._compiled_graph = self._make_compiled_graph()
195
+
196
+ def ainvoke_agent(self, user_instructions: str=None, max_retries=3, retry_count=0, **kwargs):
197
+ """
198
+ Asynchronously runs the SQL Database Agent based on user instructions.
199
+
200
+ Parameters
201
+ ----------
202
+ user_instructions : str
203
+ Instructions for the SQL query or metadata request.
204
+ max_retries : int, optional
205
+ Maximum retry attempts. Defaults to 3.
206
+ retry_count : int, optional
207
+ Current retry count. Defaults to 0.
208
+ kwargs : dict
209
+ Additional keyword arguments to pass to ainvoke().
210
+
211
+ Returns
212
+ -------
213
+ None
214
+ """
215
+ response = self._compiled_graph.ainvoke({
216
+ "user_instructions": user_instructions,
217
+ "max_retries": max_retries,
218
+ "retry_count": retry_count
219
+ }, **kwargs)
220
+ self.response = response
221
+
222
+ def invoke_agent(self, user_instructions: str=None, max_retries=3, retry_count=0, **kwargs):
223
+ """
224
+ Synchronously runs the SQL Database Agent based on user instructions.
225
+
226
+ Parameters
227
+ ----------
228
+ user_instructions : str
229
+ Instructions for the SQL query or metadata request.
230
+ max_retries : int, optional
231
+ Maximum retry attempts. Defaults to 3.
232
+ retry_count : int, optional
233
+ Current retry count. Defaults to 0.
234
+ kwargs : dict
235
+ Additional keyword arguments to pass to invoke().
236
+
237
+ Returns
238
+ -------
239
+ None
240
+ """
241
+ response = self._compiled_graph.invoke({
242
+ "user_instructions": user_instructions,
243
+ "max_retries": max_retries,
244
+ "retry_count": retry_count
245
+ }, **kwargs)
246
+ self.response = response
247
+
248
+ def get_workflow_summary(self, markdown=False):
249
+ """
250
+ Retrieves the agent's workflow summary, if logging is enabled.
251
+ """
252
+ if self.response and self.response.get("messages"):
253
+ summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
254
+ if markdown:
255
+ return Markdown(summary)
256
+ else:
257
+ return summary
258
+
259
+ def get_log_summary(self, markdown=False):
260
+ """
261
+ Logs a summary of the agent's operations, if logging is enabled.
262
+ """
263
+ if self.response:
264
+ if self.response.get('sql_database_function_path'):
265
+ log_details = f"""
266
+ ## SQL Database Agent Log Summary:
267
+
268
+ Function Path: {self.response.get('sql_database_function_path')}
269
+
270
+ Function Name: {self.response.get('sql_database_function_name')}
271
+ """
272
+ if markdown:
273
+ return Markdown(log_details)
274
+ else:
275
+ return log_details
276
+
277
+ def get_data_sql(self):
278
+ """
279
+ Retrieves the SQL query result from the agent's response.
280
+
281
+ Returns
282
+ -------
283
+ dict or None
284
+ The returned data as a dictionary of column -> list_of_values,
285
+ or None if no data is found.
286
+ """
287
+ if self.response and "data_sql" in self.response:
288
+ return pd.DataFrame(self.response["data_sql"])
289
+ return None
290
+
291
+ def get_sql_query_code(self, markdown=False):
292
+ """
293
+ Retrieves the raw SQL query code generated by the agent (if available).
294
+
295
+ Parameters
296
+ ----------
297
+ markdown : bool, optional
298
+ If True, returns the code in a Markdown code block.
299
+
300
+ Returns
301
+ -------
302
+ str or None
303
+ The SQL query as a string, or None if not available.
304
+ """
305
+ if self.response and "sql_query_code" in self.response:
306
+ if markdown:
307
+ return Markdown(f"```sql\n{self.response['sql_query_code']}\n```")
308
+ return self.response["sql_query_code"]
309
+ return None
310
+
311
+ def get_sql_database_function(self, markdown=False):
312
+ """
313
+ Retrieves the Python function code used to execute the SQL query.
314
+
315
+ Parameters
316
+ ----------
317
+ markdown : bool, optional
318
+ If True, returns the code in a Markdown code block.
319
+
320
+ Returns
321
+ -------
322
+ str or None
323
+ The function code if available, otherwise None.
324
+ """
325
+ if self.response and "sql_database_function" in self.response:
326
+ code = self.response["sql_database_function"]
327
+ if markdown:
328
+ return Markdown(f"```python\n{code}\n```")
329
+ return code
330
+ return None
331
+
332
+ def get_recommended_sql_steps(self, markdown=False):
333
+ """
334
+ Retrieves the recommended SQL steps from the agent's response.
335
+
336
+ Parameters
337
+ ----------
338
+ markdown : bool, optional
339
+ If True, returns the steps in Markdown format.
340
+
341
+ Returns
342
+ -------
343
+ str or None
344
+ Recommended steps or None if not available.
345
+ """
346
+ if self.response and "recommended_steps" in self.response:
347
+ if markdown:
348
+ return Markdown(self.response["recommended_steps"])
349
+ return self.response["recommended_steps"]
350
+ return None
351
+
352
+
353
+
354
+ # Function
33
355
 
34
356
  def make_sql_database_agent(
35
- model, connection,
36
- n_samples = 10,
357
+ model,
358
+ connection,
359
+ n_samples=1,
37
360
  log=False,
38
361
  log_path=None,
39
362
  file_name="sql_database.py",
363
+ function_name="sql_database_pipeline",
40
364
  overwrite = True,
41
- human_in_the_loop=False, bypass_recommended_steps=False,
42
- bypass_explain_code=False
365
+ human_in_the_loop=False,
366
+ bypass_recommended_steps=False,
367
+ bypass_explain_code=False,
368
+ smart_schema_pruning=False,
43
369
  ):
44
370
  """
45
371
  Creates a SQL Database Agent that can recommend SQL steps and generate SQL code to query a database.
@@ -51,7 +377,7 @@ def make_sql_database_agent(
51
377
  connection : sqlalchemy.engine.base.Engine
52
378
  The connection to the SQL database.
53
379
  n_samples : int, optional
54
- The number of samples to retrieve for each column, by default 10.
380
+ The number of samples to retrieve for each column, by default 1.
55
381
  If you get an error due to maximum tokens, try reducing this number.
56
382
  > "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
57
383
  log : bool, optional
@@ -68,6 +394,8 @@ def make_sql_database_agent(
68
394
  Bypass the recommendation step, by default False
69
395
  bypass_explain_code : bool, optional
70
396
  Bypass the code explanation step, by default False.
397
+ smart_schema_pruning : bool, optional
398
+ If True, filters the tables and columns with an extra LLM step to reduce tokens for large databases. Increases processing time but can avoid errors due to hitting max token limits with large databases. Defaults to False.
71
399
 
72
400
  Returns
73
401
  -------
@@ -100,20 +428,26 @@ def make_sql_database_agent(
100
428
  "retry_count":0
101
429
  })
102
430
  ```
103
-
104
431
  """
105
432
 
106
- is_engine = isinstance(connection, sql.engine.base.Engine)
107
- conn = connection.connect() if is_engine else connection
108
-
109
433
  llm = model
110
434
 
435
+ # Human in th loop requires recommended steps
436
+ if bypass_recommended_steps and human_in_the_loop:
437
+ bypass_recommended_steps = False
438
+ print("Bypass recommended steps set to False to enable human in the loop.")
439
+
111
440
  # Setup Log Directory
112
441
  if log:
113
442
  if log_path is None:
114
443
  log_path = LOG_PATH
115
444
  if not os.path.exists(log_path):
116
445
  os.makedirs(log_path)
446
+
447
+ # Get the database metadata
448
+ is_engine = isinstance(connection, sql.engine.base.Engine)
449
+ conn = connection.connect() if is_engine else connection
450
+
117
451
 
118
452
  class GraphState(TypedDict):
119
453
  messages: Annotated[Sequence[BaseMessage], operator.add]
@@ -124,6 +458,7 @@ def make_sql_database_agent(
124
458
  sql_query_code: str
125
459
  sql_database_function: str
126
460
  sql_database_function_path: str
461
+ sql_database_function_file_name: str
127
462
  sql_database_function_name: str
128
463
  sql_database_error: str
129
464
  max_retries: int
@@ -132,6 +467,16 @@ def make_sql_database_agent(
132
467
  def recommend_sql_steps(state: GraphState):
133
468
 
134
469
  print(format_agent_name(AGENT_NAME))
470
+
471
+ all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
472
+
473
+ all_sql_database_summary = smart_schema_filter(
474
+ llm,
475
+ state.get("user_instructions"),
476
+ all_sql_database_summary,
477
+ smart_filtering=smart_schema_pruning
478
+ )
479
+
135
480
  print(" * RECOMMEND STEPS")
136
481
 
137
482
 
@@ -148,6 +493,7 @@ def make_sql_database_agent(
148
493
  - If no user instructions are provided, just return the steps needed to understand the database.
149
494
  - Take into account the database dialect and the tables and columns in the database.
150
495
  - Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
496
+ - IMPORTANT: Pay attention to the table names and column names in the database. Make sure to use the correct table and column names in the SQL code. If a space is present in the table name or column name, make sure to account for it.
151
497
 
152
498
 
153
499
  User instructions / Question:
@@ -159,7 +505,7 @@ def make_sql_database_agent(
159
505
  Below are summaries of the database metadata and the SQL tables:
160
506
  {all_sql_database_summary}
161
507
 
162
- Return the steps as a numbered point list (no code, just the steps).
508
+ Return steps as a numbered list. You can return short code snippets to demonstrate actions. But do not return a fully coded solution. The code will be generated separately by a Coding Agent.
163
509
 
164
510
  Consider these:
165
511
 
@@ -178,13 +524,6 @@ def make_sql_database_agent(
178
524
  input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
179
525
  )
180
526
 
181
- # Create a connection if needed
182
- is_engine = isinstance(connection, sql.engine.base.Engine)
183
- conn = connection.connect() if is_engine else connection
184
-
185
- # Get the database metadata
186
- all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
187
-
188
527
  steps_agent = recommend_steps_prompt | llm
189
528
 
190
529
  recommended_steps = steps_agent.invoke({
@@ -194,13 +533,22 @@ def make_sql_database_agent(
194
533
  })
195
534
 
196
535
  return {
197
- "recommended_steps": "\n\n# Recommended SQL Database Steps:\n" + recommended_steps.content.strip(),
536
+ "recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended SQL Database Steps:"),
198
537
  "all_sql_database_summary": all_sql_database_summary
199
538
  }
200
539
 
201
540
  def create_sql_query_code(state: GraphState):
202
541
  if bypass_recommended_steps:
203
542
  print(format_agent_name(AGENT_NAME))
543
+ all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
544
+ all_sql_database_summary = smart_schema_filter(
545
+ llm,
546
+ state.get("user_instructions"),
547
+ all_sql_database_summary,
548
+ smart_filtering=smart_schema_pruning
549
+ )
550
+ else:
551
+ all_sql_database_summary = state.get("all_sql_database_summary")
204
552
  print(" * CREATE SQL QUERY CODE")
205
553
 
206
554
  # Prompt to get the SQL code from the LLM
@@ -216,6 +564,7 @@ def make_sql_database_agent(
216
564
  - Only return a single query if possible.
217
565
  - Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
218
566
  - Pay attention to the SQL dialect from the database summary metadata. Write the SQL code according to the dialect specified.
567
+ - IMPORTANT: Pay attention to the table names and column names in the database. Make sure to use the correct table and column names in the SQL code. If a space is present in the table name or column name, make sure to account for it.
219
568
 
220
569
 
221
570
  User instructions / Question:
@@ -240,13 +589,6 @@ def make_sql_database_agent(
240
589
  input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
241
590
  )
242
591
 
243
- # Create a connection if needed
244
- is_engine = isinstance(connection, sql.engine.base.Engine)
245
- conn = connection.connect() if is_engine else connection
246
-
247
- # Get the database metadata
248
- all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
249
-
250
592
  sql_query_code_agent = sql_query_code_prompt | llm | SQLOutputParser()
251
593
 
252
594
  sql_query_code = sql_query_code_agent.invoke({
@@ -258,7 +600,7 @@ def make_sql_database_agent(
258
600
  print(" * CREATE PYTHON FUNCTION TO RUN SQL CODE")
259
601
 
260
602
  response = f"""
261
- def sql_database_pipeline(connection):
603
+ def {function_name}(connection):
262
604
  import pandas as pd
263
605
  import sqlalchemy as sql
264
606
 
@@ -288,19 +630,37 @@ def sql_database_pipeline(connection):
288
630
  "sql_query_code": sql_query_code,
289
631
  "sql_database_function": response,
290
632
  "sql_database_function_path": file_path,
291
- "sql_database_function_name": file_name_2,
633
+ "sql_database_function_file_name": file_name_2,
634
+ "sql_database_function_name": function_name,
292
635
  "all_sql_database_summary": all_sql_database_summary
293
636
  }
294
637
 
295
- def human_review(state: GraphState) -> Command[Literal["recommend_sql_steps", "create_sql_query_code"]]:
296
- return node_func_human_review(
297
- state=state,
298
- prompt_text="Are the following SQL database querying steps correct? (Answer 'yes' or provide modifications)\n{steps}",
299
- yes_goto="create_sql_query_code",
300
- no_goto="recommend_sql_steps",
301
- user_instructions_key="user_instructions",
302
- recommended_steps_key="recommended_steps"
303
- )
638
+ # Human Review
639
+
640
+ prompt_text_human_review = "Are the following SQL agent instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
641
+
642
+ if not bypass_explain_code:
643
+ def human_review(state: GraphState) -> Command[Literal["recommend_sql_steps", "explain_sql_database_code"]]:
644
+ return node_func_human_review(
645
+ state=state,
646
+ prompt_text=prompt_text_human_review,
647
+ yes_goto= 'explain_sql_database_code',
648
+ no_goto="recommend_sql_steps",
649
+ user_instructions_key="user_instructions",
650
+ recommended_steps_key="recommended_steps",
651
+ code_snippet_key="sql_database_function",
652
+ )
653
+ else:
654
+ def human_review(state: GraphState) -> Command[Literal["recommend_sql_steps", "__end__"]]:
655
+ return node_func_human_review(
656
+ state=state,
657
+ prompt_text=prompt_text_human_review,
658
+ yes_goto= '__end__',
659
+ no_goto="recommend_sql_steps",
660
+ user_instructions_key="user_instructions",
661
+ recommended_steps_key="recommended_steps",
662
+ code_snippet_key="sql_database_function",
663
+ )
304
664
 
305
665
  def execute_sql_database_code(state: GraphState):
306
666
 
@@ -313,18 +673,18 @@ def sql_database_pipeline(connection):
313
673
  result_key="data_sql",
314
674
  error_key="sql_database_error",
315
675
  code_snippet_key="sql_database_function",
316
- agent_function_name="sql_database_pipeline",
676
+ agent_function_name=state.get("sql_database_function_name"),
317
677
  post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
318
678
  error_message_prefix="An error occurred during executing the sql database pipeline: "
319
679
  )
320
680
 
321
681
  def fix_sql_database_code(state: GraphState):
322
682
  prompt = """
323
- You are a SQL Database Agent code fixer. Your job is to create a sql_database_pipeline(connection) function that can be run on a sql connection. The function is currently broken and needs to be fixed.
683
+ You are a SQL Database Agent code fixer. Your job is to create a {function_name}(connection) function that can be run on a sql connection. The function is currently broken and needs to be fixed.
324
684
 
325
- Make sure to only return the function definition for sql_database_pipeline().
685
+ Make sure to only return the function definition for {function_name}().
326
686
 
327
- Return Python code in ```python``` format with a single function definition, sql_database_pipeline(connection), that includes all imports inside the function. The connection object is a SQLAlchemy connection object. Don't specify the class of the connection object, just use it as an argument to the function.
687
+ Return Python code in ```python``` format with a single function definition, {function_name}(connection), that includes all imports inside the function. The connection object is a SQLAlchemy connection object. Don't specify the class of the connection object, just use it as an argument to the function.
328
688
 
329
689
  This is the broken code (please fix):
330
690
  {code_snippet}
@@ -342,22 +702,23 @@ def sql_database_pipeline(connection):
342
702
  agent_name=AGENT_NAME,
343
703
  log=log,
344
704
  file_path=state.get("sql_database_function_path", None),
705
+ function_name=state.get("sql_database_function_name"),
345
706
  )
346
707
 
347
- def explain_sql_database_code(state: GraphState):
348
- return node_func_explain_agent_code(
708
+ # Final reporting node
709
+ def report_agent_outputs(state: GraphState):
710
+ return node_func_report_agent_outputs(
349
711
  state=state,
350
- code_snippet_key="sql_database_function",
712
+ keys_to_include=[
713
+ "recommended_steps",
714
+ "sql_database_function",
715
+ "sql_database_function_path",
716
+ "sql_database_function_name",
717
+ "sql_database_error",
718
+ ],
351
719
  result_key="messages",
352
- error_key="sql_database_error",
353
- llm=llm,
354
720
  role=AGENT_NAME,
355
- explanation_prompt_template="""
356
- Explain the SQL steps that the SQL Database agent performed in this function.
357
- Keep the summary succinct and to the point.\n\n# SQL Database Agent:\n\n{code}
358
- """,
359
- success_prefix="# SQL Database Agent:\n\n",
360
- error_message="The SQL Database Agent encountered an error during SQL Query Analysis. No SQL function explanation is returned."
721
+ custom_title="SQL Database Agent Outputs"
361
722
  )
362
723
 
363
724
  # Create the graph
@@ -367,7 +728,7 @@ def sql_database_pipeline(connection):
367
728
  "create_sql_query_code": create_sql_query_code,
368
729
  "execute_sql_database_code": execute_sql_database_code,
369
730
  "fix_sql_database_code": fix_sql_database_code,
370
- "explain_sql_database_code": explain_sql_database_code
731
+ "report_agent_outputs": report_agent_outputs,
371
732
  }
372
733
 
373
734
  app = create_coding_agent_graph(
@@ -377,7 +738,7 @@ def sql_database_pipeline(connection):
377
738
  create_code_node_name="create_sql_query_code",
378
739
  execute_code_node_name="execute_sql_database_code",
379
740
  fix_code_node_name="fix_sql_database_code",
380
- explain_code_node_name="explain_sql_database_code",
741
+ explain_code_node_name="report_agent_outputs",
381
742
  error_key="sql_database_error",
382
743
  human_in_the_loop=human_in_the_loop,
383
744
  human_review_node_name="human_review",
@@ -391,7 +752,46 @@ def sql_database_pipeline(connection):
391
752
 
392
753
 
393
754
 
755
+ def smart_schema_filter(llm, user_instructions, all_sql_database_summary, smart_filtering = True):
756
+ """
757
+ This function filters the tables and columns based on the user instructions and the recommended steps.
758
+ """
759
+ # Smart schema filtering
760
+ if smart_filtering:
761
+ print(" * SMART FILTER SCHEMA")
762
+
763
+ filter_schema_prompt = PromptTemplate(
764
+ template="""
765
+ You are a highly skilled data engineer. The user question is:
766
+
767
+ "{user_instructions}"
768
+
769
+ You have the full database metadata in JSON format below:
770
+
771
+ {all_sql_database_summary}
772
+
773
+
774
+ Please return ONLY the subset of this metadata that is relevant to answering the user’s question.
775
+ - Preserve the same JSON structure for "schemas" -> "tables" -> "columns".
776
+ - If any schemas/tables are irrelevant, omit them entirely.
777
+ - If some columns in a relevant table are not needed, you can still keep them if you aren't sure.
778
+ - However, try to keep only the minimum amount of data required to answer the user’s question.
779
+
780
+ Return a valid JSON object. Do not include any additional explanation or text outside of the JSON.
781
+ """,
782
+ input_variables=["user_instructions", "full_metadata_json"]
783
+ )
784
+
785
+ filter_schema_agent = filter_schema_prompt | llm | JsonOutputParser()
786
+
787
+ response = filter_schema_agent.invoke({
788
+ "user_instructions": user_instructions,
789
+ "all_sql_database_summary": all_sql_database_summary
790
+ })
394
791
 
792
+ return response
793
+ else:
794
+ return all_sql_database_summary
395
795
 
396
796
 
397
797