nvidia-nat-vanna 1.5.0a20260115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nvidia-nat-vanna might be problematic. Click here for more details.

nat/meta/pypi.md ADDED
@@ -0,0 +1,129 @@
1
+ <!--
2
+ SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ SPDX-License-Identifier: Apache-2.0
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ -->
17
+
18
+ # NVIDIA NeMo Agent Toolkit Vanna
19
+
20
+ Vanna-based Text-to-SQL integration for NeMo Agent toolkit.
21
+
22
+ ## Overview
23
+
24
+ This package provides production-ready text-to-SQL capabilities using the Vanna framework with Databricks support.
25
+
26
+ ## Features
27
+
28
+ - **AI-Powered SQL Generation**: Convert natural language to SQL using LLMs
29
+ - **Databricks Support**: Optimized for Databricks SQL warehouses
30
+ - **Vector-Based Similarity Search**: Milvus integration for few-shot learning
31
+ - **Streaming Support**: Real-time progress updates
32
+ - **Query Execution**: Optional database execution with formatted results
33
+ - **Highly Configurable**: Customizable prompts, examples, and connections
34
+
35
+ ## Quick Start
36
+
37
+ Install the package:
38
+
39
+ ```bash
40
+ pip install nvidia-nat-vanna
41
+ ```
42
+
43
+ Create a workflow configuration:
44
+
45
+ ```yaml
46
+ functions:
47
+ text2sql:
48
+ _type: text2sql
49
+ llm_name: my_llm
50
+ embedder_name: my_embedder
51
+ milvus_retriever: my_retriever
52
+ database_type: databricks
53
+ connection_url: "${CONNECTION_URL}"
54
+ execute_sql: false
55
+
56
+ execute_db_query:
57
+ _type: execute_db_query
58
+ database_type: databricks
59
+ connection_url: "${CONNECTION_URL}"
60
+ max_rows: 100
61
+
62
+ llms:
63
+ my_llm:
64
+ _type: nim
65
+ model_name: meta/llama-3.1-70b-instruct
66
+ api_key: "${NVIDIA_API_KEY}"
67
+
68
+ embedders:
69
+ my_embedder:
70
+ _type: nim
71
+ model_name: nvidia/llama-3.2-nv-embedqa-1b-v2
72
+ api_key: "${NVIDIA_API_KEY}"
73
+
74
+ retrievers:
75
+ my_retriever:
76
+ _type: milvus_retriever
77
+ uri: "${MILVUS_URI}"
78
+ connection_args:
79
+ user: "developer"
80
+ password: "${MILVUS_PASSWORD}"
81
+ db_name: "default"
82
+ embedding_model: my_embedder
83
+ content_field: text
84
+ use_async_client: true
85
+
86
+ workflow:
87
+ _type: rewoo_agent
88
+ tool_names: [text2sql, execute_db_query]
89
+ llm_name: my_llm
90
+ ```
91
+
92
+ Run the workflow:
93
+
94
+ ```bash
95
+ nat run --config config.yml --input "How many customers do we have?"
96
+ ```
97
+
98
+ ## Components
99
+
100
+ ### `text2sql` Function
101
+
102
+ Generates SQL queries from natural language using:
103
+ - Few-shot learning with similar examples
104
+ - DDL (schema) information
105
+ - Custom documentation
106
+ - LLM-powered query generation
107
+
108
+ ### `execute_db_query` Function
109
+
110
+ Executes SQL queries and returns formatted results:
111
+ - Databricks SQL execution
112
+ - Result limiting and pagination
113
+ - Structured output format
114
+ - SQLAlchemy Object Relational Mapper (ORM)-based connection
115
+
116
+ ## Use Cases
117
+
118
+ - **Business Intelligence**: Enable non-technical users to query data
119
+ - **Data Exploration**: Rapid prototyping and analysis
120
+ - **Conversational Analytics**: Multi-turn Q&A about your data
121
+ - **SQL Assistance**: Help analysts write complex queries
122
+
123
+ ## Documentation
124
+
125
+ Full documentation: <https://docs.nvidia.com/nemo/agent-toolkit/latest/>
126
+
127
+ ## License
128
+
129
+ Part of NVIDIA NeMo Agent toolkit. See repository for license details.
@@ -0,0 +1,14 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,296 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import json
18
+ import logging
19
+ import re
20
+ import typing
21
+ from enum import Enum
22
+ from typing import Any
23
+
24
+ from pydantic import BaseModel
25
+ from pydantic import Field
26
+ from pydantic import PlainSerializer
27
+ from pydantic import SecretStr
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def _serialize_secret(v: SecretStr) -> str:
33
+ """Serialize SecretStr to plain string for required secret fields."""
34
+ return v.get_secret_value()
35
+
36
+
37
+ # Required SecretStr that follows OptionalSecretStr pattern
38
+ RequiredSecretStr = typing.Annotated[SecretStr, PlainSerializer(_serialize_secret)]
39
+
40
+
41
+ class SupportedDatabase(str, Enum):
42
+ """Supported database types for Vanna text-to-SQL."""
43
+
44
+ DATABRICKS = "databricks"
45
+
46
+
47
+ class QueryResult(BaseModel):
48
+ """Result from executing a database query."""
49
+
50
+ results: list[tuple[Any, ...]] = Field(description="List of tuples representing rows returned from the query")
51
+ column_names: list[str] = Field(description="List of column names for the result set")
52
+
53
+ def to_dataframe(self) -> Any:
54
+ """Convert query results to a pandas DataFrame."""
55
+ import pandas as pd
56
+
57
+ return pd.DataFrame(self.results, columns=self.column_names)
58
+
59
+ def to_records(self) -> list[dict[str, Any]]:
60
+ """Convert query results to a list of dictionaries."""
61
+ return [dict(zip(self.column_names, row, strict=False)) for row in self.results]
62
+
63
+ @property
64
+ def row_count(self) -> int:
65
+ """Get the number of rows in the result set.
66
+
67
+ Returns:
68
+ Number of rows
69
+ """
70
+ return len(self.results)
71
+
72
+
73
+ def extract_sql_from_message(sql_query: str | Any) -> str:
74
+ """Extract clean SQL query from various input formats.
75
+
76
+ Handles:
77
+ 1. Direct SQL strings (passes through)
78
+ 2. BaseModel objects with 'sql' field (Text2SQLOutput)
79
+ 3. Dictionaries with 'sql' key
80
+ 4. Tool message format with content attribute
81
+ 5. String representations of tool messages
82
+
83
+ Args:
84
+ sql_query: SQL query in various formats
85
+
86
+ Returns:
87
+ Clean SQL query string
88
+ """
89
+
90
+ # Handle BaseModel objects (e.g., Text2SQLOutput)
91
+ if isinstance(sql_query, BaseModel):
92
+ # Try to get 'sql' field from BaseModel
93
+ if hasattr(sql_query, "sql"):
94
+ return sql_query.sql
95
+ # Fall back to model_dump_json if no sql field
96
+ sql_query = sql_query.model_dump_json()
97
+
98
+ # Handle dictionaries with 'sql' key
99
+ if isinstance(sql_query, dict):
100
+ return sql_query.get("sql", str(sql_query))
101
+
102
+ # Handle objects with content attribute (ToolMessage)
103
+ if not isinstance(sql_query, str):
104
+ if hasattr(sql_query, "content"):
105
+ content = sql_query.content
106
+ # Content might be a dict or list
107
+ if isinstance(content, dict):
108
+ return content.get("sql", str(content))
109
+ if isinstance(content, list) and len(content) > 0:
110
+ first_item = content[0]
111
+ if isinstance(first_item, dict):
112
+ return first_item.get("sql", str(first_item))
113
+ sql_query = str(content)
114
+ else:
115
+ sql_query = str(sql_query)
116
+
117
+ # Extract from tool message format (legacy)
118
+ if isinstance(sql_query, str) and 'content="' in sql_query:
119
+ match = re.search(r'content="((?:[^"\\\\]|\\\\.)*)"', sql_query)
120
+ if match:
121
+ sql_query = match.group(1)
122
+ sql_query = sql_query.replace("\\'", "'").replace('\\"', '"')
123
+
124
+ # Try to parse as JSON if it looks like JSON
125
+ if isinstance(sql_query, str) and sql_query.strip().startswith("{"):
126
+ try:
127
+ parsed = json.loads(sql_query)
128
+ if isinstance(parsed, dict) and "sql" in parsed:
129
+ return parsed["sql"]
130
+ except json.JSONDecodeError:
131
+ pass
132
+
133
+ # Handle format: sql='...' explanation='...'
134
+ if isinstance(sql_query, str) and "sql=" in sql_query:
135
+ # Match sql='...' or sql="..." (non-greedy to stop at first closing quote before explanation)
136
+ match = re.search(r"sql=['\"](.+?)['\"](?:\s+explanation=|$)", sql_query)
137
+ if match:
138
+ return match.group(1)
139
+
140
+ return sql_query
141
+
142
+
143
+ def connect_to_databricks(connection_url: str) -> Any:
144
+ """Connect to Databricks SQL Warehouse.
145
+
146
+ Args:
147
+ connection_url: Database connection string
148
+
149
+ Returns:
150
+ Databricks connection object
151
+ """
152
+ try:
153
+ from sqlalchemy import create_engine
154
+
155
+ connection = create_engine(url=connection_url, echo=False)
156
+ logger.info("Connected to Databricks")
157
+ return connection
158
+ except Exception as e:
159
+ logger.error(f"Failed to connect to Databricks: {e}")
160
+ raise
161
+
162
+
163
+ def connect_to_database(
164
+ database_type: str | SupportedDatabase,
165
+ connection_url: str,
166
+ **kwargs,
167
+ ) -> Any:
168
+ """Connect to a database based on type.
169
+
170
+ Currently only Databricks is supported.
171
+
172
+ Args:
173
+ database_type: Type of database (currently only 'databricks' is supported)
174
+ connection_url: Database connection string
175
+ kwargs: Additional database-specific parameters
176
+
177
+ Returns:
178
+ Database connection object
179
+
180
+ Raises:
181
+ ValueError: If database_type is not supported
182
+ """
183
+ # Convert string to enum for validation
184
+ if isinstance(database_type, str):
185
+ try:
186
+ db_type = SupportedDatabase(database_type.lower())
187
+ except ValueError:
188
+ supported = ", ".join([f"'{db.value}'" for db in SupportedDatabase])
189
+ msg = f"Unsupported database type: '{database_type}'. Supported types: {supported}"
190
+ raise ValueError(msg) from None
191
+ else:
192
+ db_type = database_type
193
+
194
+ # Route to appropriate database connector
195
+ if db_type == SupportedDatabase.DATABRICKS:
196
+ return connect_to_databricks(connection_url=connection_url)
197
+
198
+ # This should never be reached if enum is properly defined
199
+ msg = f"Database type '{db_type.value}' has no connector implementation"
200
+ raise NotImplementedError(msg)
201
+
202
+
203
+ def execute_query(connection: Any, query: str) -> QueryResult:
204
+ """Execute a query and return results.
205
+
206
+ Args:
207
+ connection: Database connection object
208
+ query: SQL query to execute
209
+
210
+ Returns:
211
+ QueryResult object containing results and column names
212
+ """
213
+ from sqlalchemy import text
214
+ try:
215
+ with connection.connect() as conn:
216
+ logger.info(f"Executing query: {query}")
217
+ result = conn.execute(text(query))
218
+ rows = result.fetchall()
219
+ columns = list(result.keys()) if result.keys() else []
220
+
221
+ logger.info(f"Query completed, retrieved {len(rows)} rows")
222
+ return QueryResult(results=rows, column_names=columns)
223
+
224
+ except Exception as e:
225
+ logger.error(f"Error executing query: {e}")
226
+ raise
227
+
228
+
229
+ async def async_execute_query(connection: Any, query: str) -> QueryResult:
230
+ """Execute query asynchronously and return QueryResult.
231
+
232
+ Args:
233
+ connection: Database connection object
234
+ query: SQL query to execute
235
+
236
+ Returns:
237
+ QueryResult object containing results and column names
238
+ """
239
+
240
+ # Run synchronous query in executor
241
+ loop = asyncio.get_event_loop()
242
+ query_result = await loop.run_in_executor(None, execute_query, connection, query)
243
+
244
+ return query_result
245
+
246
+
247
+ def setup_vanna_db_connection(
248
+ vn: Any,
249
+ database_type: str | SupportedDatabase,
250
+ connection_url: str,
251
+ **kwargs,
252
+ ) -> None:
253
+ """Set up database connection for Vanna instance.
254
+
255
+ Currently only Databricks is supported.
256
+
257
+ The database Engine is stored in the Vanna instance (vn.db_engine) and will
258
+ persist for the lifetime of the Vanna singleton. The Engine will be disposed
259
+ when the Vanna singleton is reset.
260
+
261
+ Args:
262
+ vn: Vanna instance
263
+ database_type: Type of database (currently only 'databricks' is supported)
264
+ connection_url: Database connection string
265
+ kwargs: Additional connection parameters
266
+
267
+ Raises:
268
+ ValueError: If database_type is not supported
269
+ """
270
+
271
+ # Reuse existing engine if already connected to same URL
272
+ if hasattr(vn, "db_engine") and vn.db_engine is not None:
273
+ logger.info("Reusing existing database engine from Vanna instance")
274
+ engine = vn.db_engine
275
+ else:
276
+ # Connect to database (validation handled by connect_to_database)
277
+ engine = connect_to_database(database_type=database_type, connection_url=connection_url)
278
+ # Store engine in Vanna instance - lifecycle matches singleton
279
+ vn.db_engine = engine
280
+ logger.info(f"Created and stored database engine in Vanna instance for {database_type}")
281
+
282
+ # Define async run_sql function for Vanna
283
+ async def run_sql(sql_query: str) -> Any:
284
+ """Execute SQL asynchronously and return DataFrame."""
285
+ try:
286
+ query_result = await async_execute_query(engine, sql_query)
287
+ return query_result.to_dataframe()
288
+ except Exception:
289
+ logger.exception("Error executing SQL")
290
+ raise
291
+
292
+ # Set up Vanna
293
+ vn.run_sql = run_sql
294
+ vn.run_sql_is_set = True
295
+
296
+ logger.info(f"Database connection configured for {database_type}")
@@ -0,0 +1,237 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import uuid
18
+ from collections.abc import AsyncGenerator
19
+ from typing import Any
20
+
21
+ from pydantic import BaseModel
22
+ from pydantic import Field
23
+
24
+ from nat.builder.builder import Builder
25
+ from nat.builder.framework_enum import LLMFrameworkEnum
26
+ from nat.builder.function_info import FunctionInfo
27
+ from nat.cli.register_workflow import register_function
28
+ from nat.data_models.api_server import ResponseIntermediateStep
29
+ from nat.data_models.function import FunctionBaseConfig
30
+ from nat.plugins.vanna.db_utils import RequiredSecretStr
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class StatusPayload(BaseModel):
36
+ """Payload for status intermediate steps."""
37
+
38
+ message: str
39
+
40
+
41
+ class ExecuteDBQueryInput(BaseModel):
42
+ """Input schema for execute DB query function."""
43
+
44
+ sql_query: str = Field(description="SQL query to execute")
45
+
46
+
47
+ class DataFrameInfo(BaseModel):
48
+ """DataFrame structure information."""
49
+
50
+ shape: list[int] = Field(description="Shape [rows, columns]")
51
+ dtypes: dict[str, str] = Field(description="Column data types")
52
+ columns: list[str] = Field(description="Column names")
53
+
54
+
55
+ class ExecuteDBQueryOutput(BaseModel):
56
+ """Output schema for execute DB query function."""
57
+
58
+ success: bool = Field(description="Whether query executed successfully")
59
+ columns: list[str] = Field(default_factory=list, description="Column names")
60
+ row_count: int = Field(default=0, description="Total rows returned")
61
+ sql_query: str = Field(description="Original SQL query")
62
+ query_executed: str | None = Field(default=None, description="Actual SQL query executed (with prefixes)")
63
+ dataframe_records: list[dict[str, Any]] = Field(default_factory=list, description="Results as list of dicts")
64
+ dataframe_info: DataFrameInfo | None = Field(default=None, description="DataFrame metadata")
65
+ failure_reason: str | None = Field(default=None, description="Reason for failure if query failed")
66
+ limited_to: int | None = Field(default=None, description="Number of rows limited to")
67
+ truncated: bool | None = Field(default=None, description="Whether truncated")
68
+
69
+
70
+ class ExecuteDBQueryConfig(FunctionBaseConfig, name="execute_db_query"):
71
+ """
72
+ Database query execution configuration.
73
+
74
+ Currently only Databricks is supported.
75
+ """
76
+
77
+ # Database configuration
78
+ database_type: str = Field(default="databricks",
79
+ description="Database type (currently only 'databricks' is supported)")
80
+ connection_url: RequiredSecretStr = Field(description="Database connection string")
81
+
82
+ # Query configuration
83
+ max_rows: int = Field(default=100, description="Maximum rows to return")
84
+
85
+
86
+ @register_function(
87
+ config_type=ExecuteDBQueryConfig,
88
+ framework_wrappers=[LLMFrameworkEnum.LANGCHAIN],
89
+ )
90
+ async def execute_db_query(
91
+ config: ExecuteDBQueryConfig,
92
+ _builder: Builder,
93
+ ):
94
+ """Register the Execute DB Query function."""
95
+
96
+ from nat.plugins.vanna.db_utils import async_execute_query
97
+ from nat.plugins.vanna.db_utils import connect_to_database
98
+ from nat.plugins.vanna.db_utils import extract_sql_from_message
99
+
100
+ logger.info("Initializing Execute DB Query function")
101
+
102
+ # Streaming version
103
+ async def _execute_sql_query_stream(
104
+ input_data: ExecuteDBQueryInput, ) -> AsyncGenerator[ResponseIntermediateStep | ExecuteDBQueryOutput, None]:
105
+ """Stream SQL query execution progress and results."""
106
+ sql_query = extract_sql_from_message(input_data.sql_query)
107
+ logger.info(f"Executing SQL: {sql_query}")
108
+
109
+ # Generate parent_id for this function call
110
+ parent_id = str(uuid.uuid4())
111
+
112
+ try:
113
+ # Clean up query
114
+ sql_query = sql_query.strip()
115
+ if sql_query.startswith('"') and sql_query.endswith('"'):
116
+ sql_query = sql_query[1:-1]
117
+ if sql_query.startswith("'") and sql_query.endswith("'"):
118
+ sql_query = sql_query[1:-1]
119
+
120
+ yield ResponseIntermediateStep(
121
+ id=str(uuid.uuid4()),
122
+ parent_id=parent_id,
123
+ type="markdown",
124
+ name="execute_db_query_status",
125
+ payload=StatusPayload(message="Connecting to database and executing query...").model_dump_json(),
126
+ )
127
+
128
+ # Validate database type
129
+ if config.database_type.lower() != "databricks":
130
+ yield ExecuteDBQueryOutput(
131
+ success=False,
132
+ failure_reason=f"Only Databricks is currently supported. Got database_type: {config.database_type}",
133
+ sql_query=sql_query,
134
+ dataframe_info=DataFrameInfo(shape=[0, 0], dtypes={}, columns=[]),
135
+ )
136
+ return
137
+
138
+ connection_url_value = config.connection_url.get_secret_value()
139
+ if not connection_url_value:
140
+ yield ExecuteDBQueryOutput(
141
+ success=False,
142
+ failure_reason="Missing required connection URL",
143
+ sql_query=sql_query,
144
+ dataframe_info=DataFrameInfo(shape=[0, 0], dtypes={}, columns=[]),
145
+ )
146
+ return
147
+
148
+ connection = connect_to_database(
149
+ database_type=config.database_type,
150
+ connection_url=connection_url_value,
151
+ )
152
+
153
+ if connection is None:
154
+ yield ExecuteDBQueryOutput(
155
+ success=False,
156
+ failure_reason="Failed to connect to database",
157
+ sql_query=sql_query,
158
+ dataframe_info=DataFrameInfo(shape=[0, 0], dtypes={}, columns=[]),
159
+ )
160
+ return
161
+
162
+ # Execute query
163
+ query_result = await async_execute_query(connection, sql_query)
164
+ df = query_result.to_dataframe()
165
+
166
+ # Store original row count before limiting
167
+ original_row_count = len(df)
168
+
169
+ # Limit results
170
+ if original_row_count > config.max_rows:
171
+ df = df.head(config.max_rows)
172
+
173
+ # Create response
174
+ dataframe_info = DataFrameInfo(
175
+ shape=[len(df), len(df.columns)] if not df.empty else [0, 0],
176
+ dtypes=({
177
+ str(k): str(v)
178
+ for k, v in df.dtypes.to_dict().items()
179
+ } if not df.empty else {}),
180
+ columns=df.columns.tolist() if not df.empty else [],
181
+ )
182
+
183
+ response = ExecuteDBQueryOutput(
184
+ success=True,
185
+ columns=df.columns.tolist() if not df.empty else [],
186
+ row_count=original_row_count,
187
+ sql_query=sql_query,
188
+ query_executed=sql_query,
189
+ dataframe_records=df.to_dict("records") if not df.empty else [],
190
+ dataframe_info=dataframe_info,
191
+ )
192
+
193
+ if original_row_count > config.max_rows:
194
+ response.limited_to = config.max_rows
195
+ response.truncated = True
196
+
197
+ # Yield final result as ExecuteDBQueryOutput
198
+ yield response
199
+ # Note: Engine is left alive; connections are managed internally by SQLAlchemy pool
200
+
201
+ except Exception as e:
202
+ logger.error("Error executing SQL query", exc_info=e)
203
+ yield ExecuteDBQueryOutput(
204
+ success=False,
205
+ failure_reason="SQL execution failed. Please check server logs for details.",
206
+ sql_query=sql_query,
207
+ dataframe_info=DataFrameInfo(shape=[0, 0], dtypes={}, columns=[]),
208
+ )
209
+
210
+ logger.info("Execute DB Query completed")
211
+
212
+ # Non-streaming version
213
+ async def _execute_sql_query(input_data: ExecuteDBQueryInput) -> ExecuteDBQueryOutput:
214
+ """Execute SQL query and return results."""
215
+ async for update in _execute_sql_query_stream(input_data):
216
+ # Skip ResponseIntermediateStep objects, only return ExecuteDBQueryOutput
217
+ if isinstance(update, ExecuteDBQueryOutput):
218
+ return update
219
+
220
+ # Fallback if no result found
221
+ return ExecuteDBQueryOutput(
222
+ success=False,
223
+ failure_reason="No result returned",
224
+ sql_query="",
225
+ dataframe_info=DataFrameInfo(shape=[0, 0], dtypes={}, columns=[]),
226
+ )
227
+
228
+ description = (f"Execute SQL queries on {config.database_type} and return results. "
229
+ "Connects to the database, executes the provided SQL query, "
230
+ "and returns results in a structured format.")
231
+
232
+ yield FunctionInfo.create(
233
+ single_fn=_execute_sql_query,
234
+ stream_fn=_execute_sql_query_stream,
235
+ description=description,
236
+ input_schema=ExecuteDBQueryInput,
237
+ )
@@ -0,0 +1,22 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # flake8: noqa
17
+ # isort:skip_file
18
+
19
+ # Import any providers which need to be automatically registered here
20
+
21
+ from . import execute_db_query
22
+ from . import text2sql