ai-data-science-team 0.0.0.9008__py3-none-any.whl → 0.0.0.9010__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (29) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +0 -1
  3. ai_data_science_team/agents/data_cleaning_agent.py +50 -39
  4. ai_data_science_team/agents/data_loader_tools_agent.py +69 -0
  5. ai_data_science_team/agents/data_visualization_agent.py +45 -50
  6. ai_data_science_team/agents/data_wrangling_agent.py +50 -49
  7. ai_data_science_team/agents/feature_engineering_agent.py +48 -67
  8. ai_data_science_team/agents/sql_database_agent.py +130 -76
  9. ai_data_science_team/ml_agents/__init__.py +2 -0
  10. ai_data_science_team/ml_agents/h2o_ml_agent.py +852 -0
  11. ai_data_science_team/ml_agents/mlflow_tools_agent.py +327 -0
  12. ai_data_science_team/multiagents/sql_data_analyst.py +120 -9
  13. ai_data_science_team/parsers/__init__.py +0 -0
  14. ai_data_science_team/{tools → parsers}/parsers.py +0 -1
  15. ai_data_science_team/templates/__init__.py +1 -0
  16. ai_data_science_team/templates/agent_templates.py +78 -7
  17. ai_data_science_team/tools/data_loader.py +378 -0
  18. ai_data_science_team/tools/{metadata.py → dataframe.py} +0 -91
  19. ai_data_science_team/tools/h2o.py +643 -0
  20. ai_data_science_team/tools/mlflow.py +961 -0
  21. ai_data_science_team/tools/sql.py +126 -0
  22. ai_data_science_team/{tools → utils}/regex.py +59 -1
  23. {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/METADATA +56 -24
  24. ai_data_science_team-0.0.0.9010.dist-info/RECORD +35 -0
  25. ai_data_science_team-0.0.0.9008.dist-info/RECORD +0 -26
  26. /ai_data_science_team/{tools → utils}/logging.py +0 -0
  27. {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/LICENSE +0 -0
  28. {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/WHEEL +0 -0
  29. {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@ from langgraph.types import Command
14
14
  from langgraph.checkpoint.memory import MemorySaver
15
15
 
16
16
  import os
17
+ import json
17
18
  import pandas as pd
18
19
 
19
20
  from IPython.display import Markdown
@@ -22,19 +23,20 @@ from ai_data_science_team.templates import(
22
23
  node_func_execute_agent_code_on_data,
23
24
  node_func_human_review,
24
25
  node_func_fix_agent_code,
25
- node_func_explain_agent_code,
26
+ node_func_report_agent_outputs,
26
27
  create_coding_agent_graph,
27
28
  BaseAgent,
28
29
  )
29
- from ai_data_science_team.tools.parsers import PythonOutputParser
30
- from ai_data_science_team.tools.regex import (
30
+ from ai_data_science_team.parsers.parsers import PythonOutputParser
31
+ from ai_data_science_team.utils.regex import (
31
32
  relocate_imports_inside_function,
32
33
  add_comments_to_top,
33
34
  format_agent_name,
34
- format_recommended_steps
35
+ format_recommended_steps,
36
+ get_generic_summary,
35
37
  )
36
- from ai_data_science_team.tools.metadata import get_dataframe_summary
37
- from ai_data_science_team.tools.logging import log_ai_function
38
+ from ai_data_science_team.tools.dataframe import get_dataframe_summary
39
+ from ai_data_science_team.utils.logging import log_ai_function
38
40
 
39
41
  # Setup
40
42
  AGENT_NAME = "feature_engineering_agent"
@@ -103,8 +105,8 @@ class FeatureEngineeringAgent(BaseAgent):
103
105
  retry_count=0
104
106
  )
105
107
  Engineers features from the provided dataset synchronously based on user instructions.
106
- explain_feature_engineering_steps()
107
- Returns an explanation of the feature engineering steps performed by the agent.
108
+ get_workflow_summary()
109
+ Retrieves a summary of the agent's workflow.
108
110
  get_log_summary()
109
111
  Retrieves a summary of logged operations if logging is enabled.
110
112
  get_data_engineered()
@@ -201,7 +203,7 @@ class FeatureEngineeringAgent(BaseAgent):
201
203
  self._params[k] = v
202
204
  self._compiled_graph = self._make_compiled_graph()
203
205
 
204
- def ainvoke_agent(
206
+ async def ainvoke_agent(
205
207
  self,
206
208
  data_raw: pd.DataFrame,
207
209
  user_instructions: str=None,
@@ -233,7 +235,7 @@ class FeatureEngineeringAgent(BaseAgent):
233
235
  -------
234
236
  None
235
237
  """
236
- response = self._compiled_graph.ainvoke({
238
+ response = await self._compiled_graph.ainvoke({
237
239
  "user_instructions": user_instructions,
238
240
  "data_raw": data_raw.to_dict(),
239
241
  "target_variable": target_variable,
@@ -285,40 +287,34 @@ class FeatureEngineeringAgent(BaseAgent):
285
287
  self.response = response
286
288
  return None
287
289
 
288
- def explain_feature_engineering_steps(self):
290
+ def get_workflow_summary(self, markdown=False):
289
291
  """
290
- Provides an explanation of the feature engineering steps performed by the agent.
291
-
292
- Returns
293
- -------
294
- str or list
295
- Explanation of the feature engineering steps.
292
+ Retrieves the agent's workflow summary, if logging is enabled.
296
293
  """
297
- if self.response:
298
- return self.response.get("messages", [])
299
- return []
294
+ if self.response and self.response.get("messages"):
295
+ summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
296
+ if markdown:
297
+ return Markdown(summary)
298
+ else:
299
+ return summary
300
300
 
301
301
  def get_log_summary(self, markdown=False):
302
302
  """
303
303
  Logs a summary of the agent's operations, if logging is enabled.
304
+ """
305
+ if self.response:
306
+ if self.response.get('feature_engineer_function_path'):
307
+ log_details = f"""
308
+ ## Featuring Engineering Agent Log Summary:
304
309
 
305
- Parameters
306
- ----------
307
- markdown : bool, optional
308
- If True, returns Markdown-formatted output.
310
+ Function Path: {self.response.get('feature_engineer_function_path')}
309
311
 
310
- Returns
311
- -------
312
- str or None
313
- Summary of logs, or None if not available.
314
- """
315
- if self.response and self.response.get("feature_engineer_function_path"):
316
- log_details = f"Log Path: {self.response.get('feature_engineer_function_path')}"
317
- if markdown:
318
- return Markdown(log_details)
319
- else:
320
- return log_details
321
- return None
312
+ Function Name: {self.response.get('feature_engineer_function_name')}
313
+ """
314
+ if markdown:
315
+ return Markdown(log_details)
316
+ else:
317
+ return log_details
322
318
 
323
319
  def get_data_engineered(self):
324
320
  """
@@ -388,22 +384,7 @@ class FeatureEngineeringAgent(BaseAgent):
388
384
  return steps
389
385
  return None
390
386
 
391
- def get_response(self):
392
- """
393
- Returns the agent's full response dictionary.
394
-
395
- Returns
396
- -------
397
- dict or None
398
- The response dictionary if available, otherwise None.
399
- """
400
- return self.response
401
-
402
- def show(self):
403
- """
404
- Displays the agent's mermaid diagram for visual inspection of the compiled graph.
405
- """
406
- return self._compiled_graph.show()
387
+
407
388
 
408
389
 
409
390
  # * Feature Engineering Agent
@@ -576,7 +557,7 @@ def make_feature_engineering_agent(
576
557
  Below are summaries of all datasets provided:
577
558
  {all_datasets_summary}
578
559
 
579
- Return the steps as a numbered list (no code, just the steps).
560
+ Return steps as a numbered list. You can return short code snippets to demonstrate actions. But do not return a fully coded solution. The code will be generated separately by a Coding Agent.
580
561
 
581
562
  Avoid these:
582
563
  1. Do not include steps to save files.
@@ -649,7 +630,6 @@ def make_feature_engineering_agent(
649
630
 
650
631
  feature_engineering_prompt = PromptTemplate(
651
632
  template="""
652
-
653
633
  You are a Feature Engineering Agent. Your job is to create a {function_name}() function that can be run on the data provided using the following recommended steps.
654
634
 
655
635
  Recommended Steps:
@@ -763,21 +743,22 @@ def make_feature_engineering_agent(
763
743
  function_name=state.get("feature_engineer_function_name"),
764
744
  )
765
745
 
766
- def explain_feature_engineering_code(state: GraphState):
767
- return node_func_explain_agent_code(
746
+ # Final reporting node
747
+ def report_agent_outputs(state: GraphState):
748
+ return node_func_report_agent_outputs(
768
749
  state=state,
769
- code_snippet_key="feature_engineer_function",
750
+ keys_to_include=[
751
+ "recommended_steps",
752
+ "feature_engineer_function",
753
+ "feature_engineer_function_path",
754
+ "feature_engineer_function_name",
755
+ "feature_engineer_error",
756
+ ],
770
757
  result_key="messages",
771
- error_key="feature_engineer_error",
772
- llm=llm,
773
758
  role=AGENT_NAME,
774
- explanation_prompt_template="""
775
- Explain the feature engineering steps performed by this function. Keep the explanation clear and concise.\n\n# Feature Engineering Agent:\n\n{code}
776
- """,
777
- success_prefix="# Feature Engineering Agent:\n\n ",
778
- error_message="The Feature Engineering Agent encountered an error during feature engineering. Data could not be explained."
759
+ custom_title="Feature Engineering Agent Outputs"
779
760
  )
780
-
761
+
781
762
  # Create the graph
782
763
  node_functions = {
783
764
  "recommend_feature_engineering_steps": recommend_feature_engineering_steps,
@@ -785,7 +766,7 @@ def make_feature_engineering_agent(
785
766
  "create_feature_engineering_code": create_feature_engineering_code,
786
767
  "execute_feature_engineering_code": execute_feature_engineering_code,
787
768
  "fix_feature_engineering_code": fix_feature_engineering_code,
788
- "explain_feature_engineering_code": explain_feature_engineering_code
769
+ "report_agent_outputs": report_agent_outputs,
789
770
  }
790
771
 
791
772
  app = create_coding_agent_graph(
@@ -795,7 +776,7 @@ def make_feature_engineering_agent(
795
776
  create_code_node_name="create_feature_engineering_code",
796
777
  execute_code_node_name="execute_feature_engineering_code",
797
778
  fix_code_node_name="fix_feature_engineering_code",
798
- explain_code_node_name="explain_feature_engineering_code",
779
+ explain_code_node_name="report_agent_outputs",
799
780
  error_key="feature_engineer_error",
800
781
  max_retries_key = "max_retries",
801
782
  retry_count_key = "retry_count",
@@ -5,11 +5,13 @@ import operator
5
5
 
6
6
  from langchain.prompts import PromptTemplate
7
7
  from langchain_core.messages import BaseMessage
8
+ from langchain_core.output_parsers import JsonOutputParser
8
9
 
9
10
  from langgraph.types import Command
10
11
  from langgraph.checkpoint.memory import MemorySaver
11
12
 
12
13
  import os
14
+ import json
13
15
  import pandas as pd
14
16
  import sqlalchemy as sql
15
17
 
@@ -19,14 +21,19 @@ from ai_data_science_team.templates import(
19
21
  node_func_execute_agent_from_sql_connection,
20
22
  node_func_human_review,
21
23
  node_func_fix_agent_code,
22
- node_func_explain_agent_code,
24
+ node_func_report_agent_outputs,
23
25
  create_coding_agent_graph,
24
26
  BaseAgent,
25
27
  )
26
- from ai_data_science_team.tools.parsers import SQLOutputParser
27
- from ai_data_science_team.tools.regex import add_comments_to_top, format_agent_name, format_recommended_steps
28
- from ai_data_science_team.tools.metadata import get_database_metadata
29
- from ai_data_science_team.tools.logging import log_ai_function
28
+ from ai_data_science_team.parsers.parsers import SQLOutputParser
29
+ from ai_data_science_team.utils.regex import (
30
+ add_comments_to_top,
31
+ format_agent_name,
32
+ format_recommended_steps,
33
+ get_generic_summary,
34
+ )
35
+ from ai_data_science_team.tools.sql import get_database_metadata
36
+ from ai_data_science_team.utils.logging import log_ai_function
30
37
 
31
38
  # Setup
32
39
  AGENT_NAME = "sql_database_agent"
@@ -51,7 +58,7 @@ class SQLDatabaseAgent(BaseAgent):
51
58
  connection : sqlalchemy.engine.base.Engine or sqlalchemy.engine.base.Connection
52
59
  The SQLAlchemy connection (or engine) to the database.
53
60
  n_samples : int, optional
54
- Number of sample rows (per column) to retrieve when summarizing database metadata. Defaults to 10.
61
+ Number of sample rows (per column) to retrieve when summarizing database metadata. Defaults to 1.
55
62
  log : bool, optional
56
63
  Whether to log the generated code and errors. Defaults to False.
57
64
  log_path : str, optional
@@ -68,6 +75,8 @@ class SQLDatabaseAgent(BaseAgent):
68
75
  If True, skips the step that generates recommended SQL steps. Defaults to False.
69
76
  bypass_explain_code : bool, optional
70
77
  If True, skips the step that provides code explanations. Defaults to False.
78
+ smart_schema_pruning : bool, optional
79
+ If True, filters the tables and columns based on the user instructions and recommended steps. Defaults to False.
71
80
 
72
81
  Methods
73
82
  -------
@@ -77,8 +86,8 @@ class SQLDatabaseAgent(BaseAgent):
77
86
  Asynchronously runs the agent to generate and execute a SQL query based on user instructions.
78
87
  invoke_agent(user_instructions: str, max_retries=3, retry_count=0)
79
88
  Synchronously runs the agent to generate and execute a SQL query based on user instructions.
80
- explain_sql_steps()
81
- Returns an explanation of the SQL steps performed by the agent.
89
+ get_workflow_summary()
90
+ Retrieves a summary of the agent's workflow.
82
91
  get_log_summary()
83
92
  Retrieves a summary of logged operations if logging is enabled.
84
93
  get_data_sql()
@@ -139,7 +148,7 @@ class SQLDatabaseAgent(BaseAgent):
139
148
  self,
140
149
  model,
141
150
  connection,
142
- n_samples=10,
151
+ n_samples=1,
143
152
  log=False,
144
153
  log_path=None,
145
154
  file_name="sql_database.py",
@@ -147,7 +156,8 @@ class SQLDatabaseAgent(BaseAgent):
147
156
  overwrite=True,
148
157
  human_in_the_loop=False,
149
158
  bypass_recommended_steps=False,
150
- bypass_explain_code=False
159
+ bypass_explain_code=False,
160
+ smart_schema_pruning=False,
151
161
  ):
152
162
  self._params = {
153
163
  "model": model,
@@ -160,7 +170,8 @@ class SQLDatabaseAgent(BaseAgent):
160
170
  "overwrite": overwrite,
161
171
  "human_in_the_loop": human_in_the_loop,
162
172
  "bypass_recommended_steps": bypass_recommended_steps,
163
- "bypass_explain_code": bypass_explain_code
173
+ "bypass_explain_code": bypass_explain_code,
174
+ "smart_schema_pruning": smart_schema_pruning,
164
175
  }
165
176
  self._compiled_graph = self._make_compiled_graph()
166
177
  self.response = None
@@ -182,7 +193,7 @@ class SQLDatabaseAgent(BaseAgent):
182
193
  self._params[k] = v
183
194
  self._compiled_graph = self._make_compiled_graph()
184
195
 
185
- def ainvoke_agent(self, user_instructions: str=None, max_retries=3, retry_count=0, **kwargs):
196
+ async def ainvoke_agent(self, user_instructions: str=None, max_retries=3, retry_count=0, **kwargs):
186
197
  """
187
198
  Asynchronously runs the SQL Database Agent based on user instructions.
188
199
 
@@ -201,7 +212,7 @@ class SQLDatabaseAgent(BaseAgent):
201
212
  -------
202
213
  None
203
214
  """
204
- response = self._compiled_graph.ainvoke({
215
+ response = await self._compiled_graph.ainvoke({
205
216
  "user_instructions": user_instructions,
206
217
  "max_retries": max_retries,
207
218
  "retry_count": retry_count
@@ -234,40 +245,34 @@ class SQLDatabaseAgent(BaseAgent):
234
245
  }, **kwargs)
235
246
  self.response = response
236
247
 
237
- def explain_sql_steps(self):
248
+ def get_workflow_summary(self, markdown=False):
238
249
  """
239
- Provides an explanation of the SQL steps performed by the agent
240
- if the explain step is not bypassed.
241
-
242
- Returns
243
- -------
244
- str or list
245
- An explanation of the SQL steps.
250
+ Retrieves the agent's workflow summary, if logging is enabled.
246
251
  """
247
- if self.response:
248
- return self.response.get("messages", [])
249
- return []
252
+ if self.response and self.response.get("messages"):
253
+ summary = get_generic_summary(json.loads(self.response.get("messages")[-1].content))
254
+ if markdown:
255
+ return Markdown(summary)
256
+ else:
257
+ return summary
250
258
 
251
259
  def get_log_summary(self, markdown=False):
252
260
  """
253
- Retrieves a summary of the logging details if logging is enabled.
261
+ Logs a summary of the agent's operations, if logging is enabled.
262
+ """
263
+ if self.response:
264
+ if self.response.get('sql_database_function_path'):
265
+ log_details = f"""
266
+ ## SQL Database Agent Log Summary:
254
267
 
255
- Parameters
256
- ----------
257
- markdown : bool, optional
258
- If True, returns the summary in Markdown format.
268
+ Function Path: {self.response.get('sql_database_function_path')}
259
269
 
260
- Returns
261
- -------
262
- str or None
263
- Log details or None if logging is not used or data is unavailable.
264
- """
265
- if self.response and self.response.get("sql_database_function_path"):
266
- log_details = f"Log Path: {self.response['sql_database_function_path']}"
267
- if markdown:
268
- return Markdown(log_details)
269
- return log_details
270
- return None
270
+ Function Name: {self.response.get('sql_database_function_name')}
271
+ """
272
+ if markdown:
273
+ return Markdown(log_details)
274
+ else:
275
+ return log_details
271
276
 
272
277
  def get_data_sql(self):
273
278
  """
@@ -351,14 +356,16 @@ class SQLDatabaseAgent(BaseAgent):
351
356
  def make_sql_database_agent(
352
357
  model,
353
358
  connection,
354
- n_samples = 10,
359
+ n_samples=1,
355
360
  log=False,
356
361
  log_path=None,
357
362
  file_name="sql_database.py",
358
363
  function_name="sql_database_pipeline",
359
364
  overwrite = True,
360
- human_in_the_loop=False, bypass_recommended_steps=False,
361
- bypass_explain_code=False
365
+ human_in_the_loop=False,
366
+ bypass_recommended_steps=False,
367
+ bypass_explain_code=False,
368
+ smart_schema_pruning=False,
362
369
  ):
363
370
  """
364
371
  Creates a SQL Database Agent that can recommend SQL steps and generate SQL code to query a database.
@@ -370,7 +377,7 @@ def make_sql_database_agent(
370
377
  connection : sqlalchemy.engine.base.Engine
371
378
  The connection to the SQL database.
372
379
  n_samples : int, optional
373
- The number of samples to retrieve for each column, by default 10.
380
+ The number of samples to retrieve for each column, by default 1.
374
381
  If you get an error due to maximum tokens, try reducing this number.
375
382
  > "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
376
383
  log : bool, optional
@@ -387,6 +394,8 @@ def make_sql_database_agent(
387
394
  Bypass the recommendation step, by default False
388
395
  bypass_explain_code : bool, optional
389
396
  Bypass the code explanation step, by default False.
397
+ smart_schema_pruning : bool, optional
398
+ If True, filters the tables and columns with an extra LLM step to reduce tokens for large databases. Increases processing time but can avoid errors due to hitting max token limits with large databases. Defaults to False.
390
399
 
391
400
  Returns
392
401
  -------
@@ -419,12 +428,8 @@ def make_sql_database_agent(
419
428
  "retry_count":0
420
429
  })
421
430
  ```
422
-
423
431
  """
424
432
 
425
- is_engine = isinstance(connection, sql.engine.base.Engine)
426
- conn = connection.connect() if is_engine else connection
427
-
428
433
  llm = model
429
434
 
430
435
  # Human in th loop requires recommended steps
@@ -438,6 +443,11 @@ def make_sql_database_agent(
438
443
  log_path = LOG_PATH
439
444
  if not os.path.exists(log_path):
440
445
  os.makedirs(log_path)
446
+
447
+ # Get the database metadata
448
+ is_engine = isinstance(connection, sql.engine.base.Engine)
449
+ conn = connection.connect() if is_engine else connection
450
+
441
451
 
442
452
  class GraphState(TypedDict):
443
453
  messages: Annotated[Sequence[BaseMessage], operator.add]
@@ -457,6 +467,16 @@ def make_sql_database_agent(
457
467
  def recommend_sql_steps(state: GraphState):
458
468
 
459
469
  print(format_agent_name(AGENT_NAME))
470
+
471
+ all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
472
+
473
+ all_sql_database_summary = smart_schema_filter(
474
+ llm,
475
+ state.get("user_instructions"),
476
+ all_sql_database_summary,
477
+ smart_filtering=smart_schema_pruning
478
+ )
479
+
460
480
  print(" * RECOMMEND STEPS")
461
481
 
462
482
 
@@ -485,7 +505,7 @@ def make_sql_database_agent(
485
505
  Below are summaries of the database metadata and the SQL tables:
486
506
  {all_sql_database_summary}
487
507
 
488
- Return the steps as a numbered point list (no code, just the steps).
508
+ Return steps as a numbered list. You can return short code snippets to demonstrate actions. But do not return a fully coded solution. The code will be generated separately by a Coding Agent.
489
509
 
490
510
  Consider these:
491
511
 
@@ -504,13 +524,6 @@ def make_sql_database_agent(
504
524
  input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
505
525
  )
506
526
 
507
- # Create a connection if needed
508
- is_engine = isinstance(connection, sql.engine.base.Engine)
509
- conn = connection.connect() if is_engine else connection
510
-
511
- # Get the database metadata
512
- all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
513
-
514
527
  steps_agent = recommend_steps_prompt | llm
515
528
 
516
529
  recommended_steps = steps_agent.invoke({
@@ -527,6 +540,15 @@ def make_sql_database_agent(
527
540
  def create_sql_query_code(state: GraphState):
528
541
  if bypass_recommended_steps:
529
542
  print(format_agent_name(AGENT_NAME))
543
+ all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
544
+ all_sql_database_summary = smart_schema_filter(
545
+ llm,
546
+ state.get("user_instructions"),
547
+ all_sql_database_summary,
548
+ smart_filtering=smart_schema_pruning
549
+ )
550
+ else:
551
+ all_sql_database_summary = state.get("all_sql_database_summary")
530
552
  print(" * CREATE SQL QUERY CODE")
531
553
 
532
554
  # Prompt to get the SQL code from the LLM
@@ -567,13 +589,6 @@ def make_sql_database_agent(
567
589
  input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
568
590
  )
569
591
 
570
- # Create a connection if needed
571
- is_engine = isinstance(connection, sql.engine.base.Engine)
572
- conn = connection.connect() if is_engine else connection
573
-
574
- # Get the database metadata
575
- all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
576
-
577
592
  sql_query_code_agent = sql_query_code_prompt | llm | SQLOutputParser()
578
593
 
579
594
  sql_query_code = sql_query_code_agent.invoke({
@@ -690,20 +705,20 @@ def {function_name}(connection):
690
705
  function_name=state.get("sql_database_function_name"),
691
706
  )
692
707
 
693
- def explain_sql_database_code(state: GraphState):
694
- return node_func_explain_agent_code(
708
+ # Final reporting node
709
+ def report_agent_outputs(state: GraphState):
710
+ return node_func_report_agent_outputs(
695
711
  state=state,
696
- code_snippet_key="sql_database_function",
712
+ keys_to_include=[
713
+ "recommended_steps",
714
+ "sql_database_function",
715
+ "sql_database_function_path",
716
+ "sql_database_function_name",
717
+ "sql_database_error",
718
+ ],
697
719
  result_key="messages",
698
- error_key="sql_database_error",
699
- llm=llm,
700
720
  role=AGENT_NAME,
701
- explanation_prompt_template="""
702
- Explain the SQL steps that the SQL Database agent performed in this function.
703
- Keep the summary succinct and to the point.\n\n# SQL Database Agent:\n\n{code}
704
- """,
705
- success_prefix="# SQL Database Agent:\n\n",
706
- error_message="The SQL Database Agent encountered an error during SQL Query Analysis. No SQL function explanation is returned."
721
+ custom_title="SQL Database Agent Outputs"
707
722
  )
708
723
 
709
724
  # Create the graph
@@ -713,7 +728,7 @@ def {function_name}(connection):
713
728
  "create_sql_query_code": create_sql_query_code,
714
729
  "execute_sql_database_code": execute_sql_database_code,
715
730
  "fix_sql_database_code": fix_sql_database_code,
716
- "explain_sql_database_code": explain_sql_database_code
731
+ "report_agent_outputs": report_agent_outputs,
717
732
  }
718
733
 
719
734
  app = create_coding_agent_graph(
@@ -723,7 +738,7 @@ def {function_name}(connection):
723
738
  create_code_node_name="create_sql_query_code",
724
739
  execute_code_node_name="execute_sql_database_code",
725
740
  fix_code_node_name="fix_sql_database_code",
726
- explain_code_node_name="explain_sql_database_code",
741
+ explain_code_node_name="report_agent_outputs",
727
742
  error_key="sql_database_error",
728
743
  human_in_the_loop=human_in_the_loop,
729
744
  human_review_node_name="human_review",
@@ -737,7 +752,46 @@ def {function_name}(connection):
737
752
 
738
753
 
739
754
 
755
+ def smart_schema_filter(llm, user_instructions, all_sql_database_summary, smart_filtering = True):
756
+ """
757
+ This function filters the tables and columns based on the user instructions and the recommended steps.
758
+ """
759
+ # Smart schema filtering
760
+ if smart_filtering:
761
+ print(" * SMART FILTER SCHEMA")
740
762
 
763
+ filter_schema_prompt = PromptTemplate(
764
+ template="""
765
+ You are a highly skilled data engineer. The user question is:
766
+
767
+ "{user_instructions}"
768
+
769
+ You have the full database metadata in JSON format below:
770
+
771
+ {all_sql_database_summary}
772
+
773
+
774
+ Please return ONLY the subset of this metadata that is relevant to answering the user’s question.
775
+ - Preserve the same JSON structure for "schemas" -> "tables" -> "columns".
776
+ - If any schemas/tables are irrelevant, omit them entirely.
777
+ - If some columns in a relevant table are not needed, you can still keep them if you aren't sure.
778
+ - However, try to keep only the minimum amount of data required to answer the user’s question.
779
+
780
+ Return a valid JSON object. Do not include any additional explanation or text outside of the JSON.
781
+ """,
782
+ input_variables=["user_instructions", "full_metadata_json"]
783
+ )
784
+
785
+ filter_schema_agent = filter_schema_prompt | llm | JsonOutputParser()
786
+
787
+ response = filter_schema_agent.invoke({
788
+ "user_instructions": user_instructions,
789
+ "all_sql_database_summary": all_sql_database_summary
790
+ })
791
+
792
+ return response
793
+ else:
794
+ return all_sql_database_summary
741
795
 
742
796
 
743
797
 
@@ -0,0 +1,2 @@
1
+ from ai_data_science_team.ml_agents.h2o_ml_agent import make_h2o_ml_agent, H2OMLAgent
2
+ from ai_data_science_team.ml_agents.mlflow_tools_agent import make_mlflow_tools_agent, MLflowToolsAgent