nvidia-nat-vanna 1.5.0a20260115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nvidia-nat-vanna might be problematic. Click here for more details.
- nat/meta/pypi.md +129 -0
- nat/plugins/vanna/__init__.py +14 -0
- nat/plugins/vanna/db_utils.py +296 -0
- nat/plugins/vanna/execute_db_query.py +237 -0
- nat/plugins/vanna/register.py +22 -0
- nat/plugins/vanna/text2sql.py +250 -0
- nat/plugins/vanna/training_db_schema.py +75 -0
- nat/plugins/vanna/vanna_utils.py +843 -0
- nvidia_nat_vanna-1.5.0a20260115.dist-info/METADATA +149 -0
- nvidia_nat_vanna-1.5.0a20260115.dist-info/RECORD +13 -0
- nvidia_nat_vanna-1.5.0a20260115.dist-info/WHEEL +5 -0
- nvidia_nat_vanna-1.5.0a20260115.dist-info/entry_points.txt +2 -0
- nvidia_nat_vanna-1.5.0a20260115.dist-info/top_level.txt +1 -0
nat/meta/pypi.md
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
<!--
|
|
2
|
+
SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
you may not use this file except in compliance with the License.
|
|
7
|
+
You may obtain a copy of the License at
|
|
8
|
+
|
|
9
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
|
|
11
|
+
Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
See the License for the specific language governing permissions and
|
|
15
|
+
limitations under the License.
|
|
16
|
+
-->
|
|
17
|
+
|
|
18
|
+
# NVIDIA NeMo Agent Toolkit Vanna
|
|
19
|
+
|
|
20
|
+
Vanna-based Text-to-SQL integration for NeMo Agent toolkit.
|
|
21
|
+
|
|
22
|
+
## Overview
|
|
23
|
+
|
|
24
|
+
This package provides production-ready text-to-SQL capabilities using the Vanna framework with Databricks support.
|
|
25
|
+
|
|
26
|
+
## Features
|
|
27
|
+
|
|
28
|
+
- **AI-Powered SQL Generation**: Convert natural language to SQL using LLMs
|
|
29
|
+
- **Databricks Support**: Optimized for Databricks SQL warehouses
|
|
30
|
+
- **Vector-Based Similarity Search**: Milvus integration for few-shot learning
|
|
31
|
+
- **Streaming Support**: Real-time progress updates
|
|
32
|
+
- **Query Execution**: Optional database execution with formatted results
|
|
33
|
+
- **Highly Configurable**: Customizable prompts, examples, and connections
|
|
34
|
+
|
|
35
|
+
## Quick Start
|
|
36
|
+
|
|
37
|
+
Install the package:
|
|
38
|
+
|
|
39
|
+
```bash
|
|
40
|
+
pip install nvidia-nat-vanna
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
Create a workflow configuration:
|
|
44
|
+
|
|
45
|
+
```yaml
|
|
46
|
+
functions:
|
|
47
|
+
text2sql:
|
|
48
|
+
_type: text2sql
|
|
49
|
+
llm_name: my_llm
|
|
50
|
+
embedder_name: my_embedder
|
|
51
|
+
milvus_retriever: my_retriever
|
|
52
|
+
database_type: databricks
|
|
53
|
+
connection_url: "${CONNECTION_URL}"
|
|
54
|
+
execute_sql: false
|
|
55
|
+
|
|
56
|
+
execute_db_query:
|
|
57
|
+
_type: execute_db_query
|
|
58
|
+
database_type: databricks
|
|
59
|
+
connection_url: "${CONNECTION_URL}"
|
|
60
|
+
max_rows: 100
|
|
61
|
+
|
|
62
|
+
llms:
|
|
63
|
+
my_llm:
|
|
64
|
+
_type: nim
|
|
65
|
+
model_name: meta/llama-3.1-70b-instruct
|
|
66
|
+
api_key: "${NVIDIA_API_KEY}"
|
|
67
|
+
|
|
68
|
+
embedders:
|
|
69
|
+
my_embedder:
|
|
70
|
+
_type: nim
|
|
71
|
+
model_name: nvidia/llama-3.2-nv-embedqa-1b-v2
|
|
72
|
+
api_key: "${NVIDIA_API_KEY}"
|
|
73
|
+
|
|
74
|
+
retrievers:
|
|
75
|
+
my_retriever:
|
|
76
|
+
_type: milvus_retriever
|
|
77
|
+
uri: "${MILVUS_URI}"
|
|
78
|
+
connection_args:
|
|
79
|
+
user: "developer"
|
|
80
|
+
password: "${MILVUS_PASSWORD}"
|
|
81
|
+
db_name: "default"
|
|
82
|
+
embedding_model: my_embedder
|
|
83
|
+
content_field: text
|
|
84
|
+
use_async_client: true
|
|
85
|
+
|
|
86
|
+
workflow:
|
|
87
|
+
_type: rewoo_agent
|
|
88
|
+
tool_names: [text2sql, execute_db_query]
|
|
89
|
+
llm_name: my_llm
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
Run the workflow:
|
|
93
|
+
|
|
94
|
+
```bash
|
|
95
|
+
nat run --config config.yml --input "How many customers do we have?"
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
## Components
|
|
99
|
+
|
|
100
|
+
### `text2sql` Function
|
|
101
|
+
|
|
102
|
+
Generates SQL queries from natural language using:
|
|
103
|
+
- Few-shot learning with similar examples
|
|
104
|
+
- DDL (schema) information
|
|
105
|
+
- Custom documentation
|
|
106
|
+
- LLM-powered query generation
|
|
107
|
+
|
|
108
|
+
### `execute_db_query` Function
|
|
109
|
+
|
|
110
|
+
Executes SQL queries and returns formatted results:
|
|
111
|
+
- Databricks SQL execution
|
|
112
|
+
- Result limiting and pagination
|
|
113
|
+
- Structured output format
|
|
114
|
+
- SQLAlchemy Object Relational Mapper (ORM)-based connection
|
|
115
|
+
|
|
116
|
+
## Use Cases
|
|
117
|
+
|
|
118
|
+
- **Business Intelligence**: Enable non-technical users to query data
|
|
119
|
+
- **Data Exploration**: Rapid prototyping and analysis
|
|
120
|
+
- **Conversational Analytics**: Multi-turn Q&A about your data
|
|
121
|
+
- **SQL Assistance**: Help analysts write complex queries
|
|
122
|
+
|
|
123
|
+
## Documentation
|
|
124
|
+
|
|
125
|
+
Full documentation: <https://docs.nvidia.com/nemo/agent-toolkit/latest/>
|
|
126
|
+
|
|
127
|
+
## License
|
|
128
|
+
|
|
129
|
+
Part of NVIDIA NeMo Agent toolkit. See repository for license details.
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import re
|
|
20
|
+
import typing
|
|
21
|
+
from enum import Enum
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from pydantic import BaseModel
|
|
25
|
+
from pydantic import Field
|
|
26
|
+
from pydantic import PlainSerializer
|
|
27
|
+
from pydantic import SecretStr
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _serialize_secret(v: SecretStr) -> str:
|
|
33
|
+
"""Serialize SecretStr to plain string for required secret fields."""
|
|
34
|
+
return v.get_secret_value()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# Required SecretStr that follows OptionalSecretStr pattern
|
|
38
|
+
RequiredSecretStr = typing.Annotated[SecretStr, PlainSerializer(_serialize_secret)]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SupportedDatabase(str, Enum):
|
|
42
|
+
"""Supported database types for Vanna text-to-SQL."""
|
|
43
|
+
|
|
44
|
+
DATABRICKS = "databricks"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class QueryResult(BaseModel):
|
|
48
|
+
"""Result from executing a database query."""
|
|
49
|
+
|
|
50
|
+
results: list[tuple[Any, ...]] = Field(description="List of tuples representing rows returned from the query")
|
|
51
|
+
column_names: list[str] = Field(description="List of column names for the result set")
|
|
52
|
+
|
|
53
|
+
def to_dataframe(self) -> Any:
|
|
54
|
+
"""Convert query results to a pandas DataFrame."""
|
|
55
|
+
import pandas as pd
|
|
56
|
+
|
|
57
|
+
return pd.DataFrame(self.results, columns=self.column_names)
|
|
58
|
+
|
|
59
|
+
def to_records(self) -> list[dict[str, Any]]:
|
|
60
|
+
"""Convert query results to a list of dictionaries."""
|
|
61
|
+
return [dict(zip(self.column_names, row, strict=False)) for row in self.results]
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def row_count(self) -> int:
|
|
65
|
+
"""Get the number of rows in the result set.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Number of rows
|
|
69
|
+
"""
|
|
70
|
+
return len(self.results)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def extract_sql_from_message(sql_query: str | Any) -> str:
|
|
74
|
+
"""Extract clean SQL query from various input formats.
|
|
75
|
+
|
|
76
|
+
Handles:
|
|
77
|
+
1. Direct SQL strings (passes through)
|
|
78
|
+
2. BaseModel objects with 'sql' field (Text2SQLOutput)
|
|
79
|
+
3. Dictionaries with 'sql' key
|
|
80
|
+
4. Tool message format with content attribute
|
|
81
|
+
5. String representations of tool messages
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
sql_query: SQL query in various formats
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Clean SQL query string
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
# Handle BaseModel objects (e.g., Text2SQLOutput)
|
|
91
|
+
if isinstance(sql_query, BaseModel):
|
|
92
|
+
# Try to get 'sql' field from BaseModel
|
|
93
|
+
if hasattr(sql_query, "sql"):
|
|
94
|
+
return sql_query.sql
|
|
95
|
+
# Fall back to model_dump_json if no sql field
|
|
96
|
+
sql_query = sql_query.model_dump_json()
|
|
97
|
+
|
|
98
|
+
# Handle dictionaries with 'sql' key
|
|
99
|
+
if isinstance(sql_query, dict):
|
|
100
|
+
return sql_query.get("sql", str(sql_query))
|
|
101
|
+
|
|
102
|
+
# Handle objects with content attribute (ToolMessage)
|
|
103
|
+
if not isinstance(sql_query, str):
|
|
104
|
+
if hasattr(sql_query, "content"):
|
|
105
|
+
content = sql_query.content
|
|
106
|
+
# Content might be a dict or list
|
|
107
|
+
if isinstance(content, dict):
|
|
108
|
+
return content.get("sql", str(content))
|
|
109
|
+
if isinstance(content, list) and len(content) > 0:
|
|
110
|
+
first_item = content[0]
|
|
111
|
+
if isinstance(first_item, dict):
|
|
112
|
+
return first_item.get("sql", str(first_item))
|
|
113
|
+
sql_query = str(content)
|
|
114
|
+
else:
|
|
115
|
+
sql_query = str(sql_query)
|
|
116
|
+
|
|
117
|
+
# Extract from tool message format (legacy)
|
|
118
|
+
if isinstance(sql_query, str) and 'content="' in sql_query:
|
|
119
|
+
match = re.search(r'content="((?:[^"\\\\]|\\\\.)*)"', sql_query)
|
|
120
|
+
if match:
|
|
121
|
+
sql_query = match.group(1)
|
|
122
|
+
sql_query = sql_query.replace("\\'", "'").replace('\\"', '"')
|
|
123
|
+
|
|
124
|
+
# Try to parse as JSON if it looks like JSON
|
|
125
|
+
if isinstance(sql_query, str) and sql_query.strip().startswith("{"):
|
|
126
|
+
try:
|
|
127
|
+
parsed = json.loads(sql_query)
|
|
128
|
+
if isinstance(parsed, dict) and "sql" in parsed:
|
|
129
|
+
return parsed["sql"]
|
|
130
|
+
except json.JSONDecodeError:
|
|
131
|
+
pass
|
|
132
|
+
|
|
133
|
+
# Handle format: sql='...' explanation='...'
|
|
134
|
+
if isinstance(sql_query, str) and "sql=" in sql_query:
|
|
135
|
+
# Match sql='...' or sql="..." (non-greedy to stop at first closing quote before explanation)
|
|
136
|
+
match = re.search(r"sql=['\"](.+?)['\"](?:\s+explanation=|$)", sql_query)
|
|
137
|
+
if match:
|
|
138
|
+
return match.group(1)
|
|
139
|
+
|
|
140
|
+
return sql_query
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def connect_to_databricks(connection_url: str) -> Any:
|
|
144
|
+
"""Connect to Databricks SQL Warehouse.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
connection_url: Database connection string
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Databricks connection object
|
|
151
|
+
"""
|
|
152
|
+
try:
|
|
153
|
+
from sqlalchemy import create_engine
|
|
154
|
+
|
|
155
|
+
connection = create_engine(url=connection_url, echo=False)
|
|
156
|
+
logger.info("Connected to Databricks")
|
|
157
|
+
return connection
|
|
158
|
+
except Exception as e:
|
|
159
|
+
logger.error(f"Failed to connect to Databricks: {e}")
|
|
160
|
+
raise
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def connect_to_database(
|
|
164
|
+
database_type: str | SupportedDatabase,
|
|
165
|
+
connection_url: str,
|
|
166
|
+
**kwargs,
|
|
167
|
+
) -> Any:
|
|
168
|
+
"""Connect to a database based on type.
|
|
169
|
+
|
|
170
|
+
Currently only Databricks is supported.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
database_type: Type of database (currently only 'databricks' is supported)
|
|
174
|
+
connection_url: Database connection string
|
|
175
|
+
kwargs: Additional database-specific parameters
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
Database connection object
|
|
179
|
+
|
|
180
|
+
Raises:
|
|
181
|
+
ValueError: If database_type is not supported
|
|
182
|
+
"""
|
|
183
|
+
# Convert string to enum for validation
|
|
184
|
+
if isinstance(database_type, str):
|
|
185
|
+
try:
|
|
186
|
+
db_type = SupportedDatabase(database_type.lower())
|
|
187
|
+
except ValueError:
|
|
188
|
+
supported = ", ".join([f"'{db.value}'" for db in SupportedDatabase])
|
|
189
|
+
msg = f"Unsupported database type: '{database_type}'. Supported types: {supported}"
|
|
190
|
+
raise ValueError(msg) from None
|
|
191
|
+
else:
|
|
192
|
+
db_type = database_type
|
|
193
|
+
|
|
194
|
+
# Route to appropriate database connector
|
|
195
|
+
if db_type == SupportedDatabase.DATABRICKS:
|
|
196
|
+
return connect_to_databricks(connection_url=connection_url)
|
|
197
|
+
|
|
198
|
+
# This should never be reached if enum is properly defined
|
|
199
|
+
msg = f"Database type '{db_type.value}' has no connector implementation"
|
|
200
|
+
raise NotImplementedError(msg)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def execute_query(connection: Any, query: str) -> QueryResult:
|
|
204
|
+
"""Execute a query and return results.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
connection: Database connection object
|
|
208
|
+
query: SQL query to execute
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
QueryResult object containing results and column names
|
|
212
|
+
"""
|
|
213
|
+
from sqlalchemy import text
|
|
214
|
+
try:
|
|
215
|
+
with connection.connect() as conn:
|
|
216
|
+
logger.info(f"Executing query: {query}")
|
|
217
|
+
result = conn.execute(text(query))
|
|
218
|
+
rows = result.fetchall()
|
|
219
|
+
columns = list(result.keys()) if result.keys() else []
|
|
220
|
+
|
|
221
|
+
logger.info(f"Query completed, retrieved {len(rows)} rows")
|
|
222
|
+
return QueryResult(results=rows, column_names=columns)
|
|
223
|
+
|
|
224
|
+
except Exception as e:
|
|
225
|
+
logger.error(f"Error executing query: {e}")
|
|
226
|
+
raise
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
async def async_execute_query(connection: Any, query: str) -> QueryResult:
|
|
230
|
+
"""Execute query asynchronously and return QueryResult.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
connection: Database connection object
|
|
234
|
+
query: SQL query to execute
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
QueryResult object containing results and column names
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
# Run synchronous query in executor
|
|
241
|
+
loop = asyncio.get_event_loop()
|
|
242
|
+
query_result = await loop.run_in_executor(None, execute_query, connection, query)
|
|
243
|
+
|
|
244
|
+
return query_result
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def setup_vanna_db_connection(
|
|
248
|
+
vn: Any,
|
|
249
|
+
database_type: str | SupportedDatabase,
|
|
250
|
+
connection_url: str,
|
|
251
|
+
**kwargs,
|
|
252
|
+
) -> None:
|
|
253
|
+
"""Set up database connection for Vanna instance.
|
|
254
|
+
|
|
255
|
+
Currently only Databricks is supported.
|
|
256
|
+
|
|
257
|
+
The database Engine is stored in the Vanna instance (vn.db_engine) and will
|
|
258
|
+
persist for the lifetime of the Vanna singleton. The Engine will be disposed
|
|
259
|
+
when the Vanna singleton is reset.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
vn: Vanna instance
|
|
263
|
+
database_type: Type of database (currently only 'databricks' is supported)
|
|
264
|
+
connection_url: Database connection string
|
|
265
|
+
kwargs: Additional connection parameters
|
|
266
|
+
|
|
267
|
+
Raises:
|
|
268
|
+
ValueError: If database_type is not supported
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
# Reuse existing engine if already connected to same URL
|
|
272
|
+
if hasattr(vn, "db_engine") and vn.db_engine is not None:
|
|
273
|
+
logger.info("Reusing existing database engine from Vanna instance")
|
|
274
|
+
engine = vn.db_engine
|
|
275
|
+
else:
|
|
276
|
+
# Connect to database (validation handled by connect_to_database)
|
|
277
|
+
engine = connect_to_database(database_type=database_type, connection_url=connection_url)
|
|
278
|
+
# Store engine in Vanna instance - lifecycle matches singleton
|
|
279
|
+
vn.db_engine = engine
|
|
280
|
+
logger.info(f"Created and stored database engine in Vanna instance for {database_type}")
|
|
281
|
+
|
|
282
|
+
# Define async run_sql function for Vanna
|
|
283
|
+
async def run_sql(sql_query: str) -> Any:
|
|
284
|
+
"""Execute SQL asynchronously and return DataFrame."""
|
|
285
|
+
try:
|
|
286
|
+
query_result = await async_execute_query(engine, sql_query)
|
|
287
|
+
return query_result.to_dataframe()
|
|
288
|
+
except Exception:
|
|
289
|
+
logger.exception("Error executing SQL")
|
|
290
|
+
raise
|
|
291
|
+
|
|
292
|
+
# Set up Vanna
|
|
293
|
+
vn.run_sql = run_sql
|
|
294
|
+
vn.run_sql_is_set = True
|
|
295
|
+
|
|
296
|
+
logger.info(f"Database connection configured for {database_type}")
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import uuid
|
|
18
|
+
from collections.abc import AsyncGenerator
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
from pydantic import Field
|
|
23
|
+
|
|
24
|
+
from nat.builder.builder import Builder
|
|
25
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
26
|
+
from nat.builder.function_info import FunctionInfo
|
|
27
|
+
from nat.cli.register_workflow import register_function
|
|
28
|
+
from nat.data_models.api_server import ResponseIntermediateStep
|
|
29
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
30
|
+
from nat.plugins.vanna.db_utils import RequiredSecretStr
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class StatusPayload(BaseModel):
|
|
36
|
+
"""Payload for status intermediate steps."""
|
|
37
|
+
|
|
38
|
+
message: str
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ExecuteDBQueryInput(BaseModel):
|
|
42
|
+
"""Input schema for execute DB query function."""
|
|
43
|
+
|
|
44
|
+
sql_query: str = Field(description="SQL query to execute")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class DataFrameInfo(BaseModel):
|
|
48
|
+
"""DataFrame structure information."""
|
|
49
|
+
|
|
50
|
+
shape: list[int] = Field(description="Shape [rows, columns]")
|
|
51
|
+
dtypes: dict[str, str] = Field(description="Column data types")
|
|
52
|
+
columns: list[str] = Field(description="Column names")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ExecuteDBQueryOutput(BaseModel):
|
|
56
|
+
"""Output schema for execute DB query function."""
|
|
57
|
+
|
|
58
|
+
success: bool = Field(description="Whether query executed successfully")
|
|
59
|
+
columns: list[str] = Field(default_factory=list, description="Column names")
|
|
60
|
+
row_count: int = Field(default=0, description="Total rows returned")
|
|
61
|
+
sql_query: str = Field(description="Original SQL query")
|
|
62
|
+
query_executed: str | None = Field(default=None, description="Actual SQL query executed (with prefixes)")
|
|
63
|
+
dataframe_records: list[dict[str, Any]] = Field(default_factory=list, description="Results as list of dicts")
|
|
64
|
+
dataframe_info: DataFrameInfo | None = Field(default=None, description="DataFrame metadata")
|
|
65
|
+
failure_reason: str | None = Field(default=None, description="Reason for failure if query failed")
|
|
66
|
+
limited_to: int | None = Field(default=None, description="Number of rows limited to")
|
|
67
|
+
truncated: bool | None = Field(default=None, description="Whether truncated")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ExecuteDBQueryConfig(FunctionBaseConfig, name="execute_db_query"):
|
|
71
|
+
"""
|
|
72
|
+
Database query execution configuration.
|
|
73
|
+
|
|
74
|
+
Currently only Databricks is supported.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
# Database configuration
|
|
78
|
+
database_type: str = Field(default="databricks",
|
|
79
|
+
description="Database type (currently only 'databricks' is supported)")
|
|
80
|
+
connection_url: RequiredSecretStr = Field(description="Database connection string")
|
|
81
|
+
|
|
82
|
+
# Query configuration
|
|
83
|
+
max_rows: int = Field(default=100, description="Maximum rows to return")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@register_function(
|
|
87
|
+
config_type=ExecuteDBQueryConfig,
|
|
88
|
+
framework_wrappers=[LLMFrameworkEnum.LANGCHAIN],
|
|
89
|
+
)
|
|
90
|
+
async def execute_db_query(
|
|
91
|
+
config: ExecuteDBQueryConfig,
|
|
92
|
+
_builder: Builder,
|
|
93
|
+
):
|
|
94
|
+
"""Register the Execute DB Query function."""
|
|
95
|
+
|
|
96
|
+
from nat.plugins.vanna.db_utils import async_execute_query
|
|
97
|
+
from nat.plugins.vanna.db_utils import connect_to_database
|
|
98
|
+
from nat.plugins.vanna.db_utils import extract_sql_from_message
|
|
99
|
+
|
|
100
|
+
logger.info("Initializing Execute DB Query function")
|
|
101
|
+
|
|
102
|
+
# Streaming version
|
|
103
|
+
async def _execute_sql_query_stream(
|
|
104
|
+
input_data: ExecuteDBQueryInput, ) -> AsyncGenerator[ResponseIntermediateStep | ExecuteDBQueryOutput, None]:
|
|
105
|
+
"""Stream SQL query execution progress and results."""
|
|
106
|
+
sql_query = extract_sql_from_message(input_data.sql_query)
|
|
107
|
+
logger.info(f"Executing SQL: {sql_query}")
|
|
108
|
+
|
|
109
|
+
# Generate parent_id for this function call
|
|
110
|
+
parent_id = str(uuid.uuid4())
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
# Clean up query
|
|
114
|
+
sql_query = sql_query.strip()
|
|
115
|
+
if sql_query.startswith('"') and sql_query.endswith('"'):
|
|
116
|
+
sql_query = sql_query[1:-1]
|
|
117
|
+
if sql_query.startswith("'") and sql_query.endswith("'"):
|
|
118
|
+
sql_query = sql_query[1:-1]
|
|
119
|
+
|
|
120
|
+
yield ResponseIntermediateStep(
|
|
121
|
+
id=str(uuid.uuid4()),
|
|
122
|
+
parent_id=parent_id,
|
|
123
|
+
type="markdown",
|
|
124
|
+
name="execute_db_query_status",
|
|
125
|
+
payload=StatusPayload(message="Connecting to database and executing query...").model_dump_json(),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Validate database type
|
|
129
|
+
if config.database_type.lower() != "databricks":
|
|
130
|
+
yield ExecuteDBQueryOutput(
|
|
131
|
+
success=False,
|
|
132
|
+
failure_reason=f"Only Databricks is currently supported. Got database_type: {config.database_type}",
|
|
133
|
+
sql_query=sql_query,
|
|
134
|
+
dataframe_info=DataFrameInfo(shape=[0, 0], dtypes={}, columns=[]),
|
|
135
|
+
)
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
connection_url_value = config.connection_url.get_secret_value()
|
|
139
|
+
if not connection_url_value:
|
|
140
|
+
yield ExecuteDBQueryOutput(
|
|
141
|
+
success=False,
|
|
142
|
+
failure_reason="Missing required connection URL",
|
|
143
|
+
sql_query=sql_query,
|
|
144
|
+
dataframe_info=DataFrameInfo(shape=[0, 0], dtypes={}, columns=[]),
|
|
145
|
+
)
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
connection = connect_to_database(
|
|
149
|
+
database_type=config.database_type,
|
|
150
|
+
connection_url=connection_url_value,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if connection is None:
|
|
154
|
+
yield ExecuteDBQueryOutput(
|
|
155
|
+
success=False,
|
|
156
|
+
failure_reason="Failed to connect to database",
|
|
157
|
+
sql_query=sql_query,
|
|
158
|
+
dataframe_info=DataFrameInfo(shape=[0, 0], dtypes={}, columns=[]),
|
|
159
|
+
)
|
|
160
|
+
return
|
|
161
|
+
|
|
162
|
+
# Execute query
|
|
163
|
+
query_result = await async_execute_query(connection, sql_query)
|
|
164
|
+
df = query_result.to_dataframe()
|
|
165
|
+
|
|
166
|
+
# Store original row count before limiting
|
|
167
|
+
original_row_count = len(df)
|
|
168
|
+
|
|
169
|
+
# Limit results
|
|
170
|
+
if original_row_count > config.max_rows:
|
|
171
|
+
df = df.head(config.max_rows)
|
|
172
|
+
|
|
173
|
+
# Create response
|
|
174
|
+
dataframe_info = DataFrameInfo(
|
|
175
|
+
shape=[len(df), len(df.columns)] if not df.empty else [0, 0],
|
|
176
|
+
dtypes=({
|
|
177
|
+
str(k): str(v)
|
|
178
|
+
for k, v in df.dtypes.to_dict().items()
|
|
179
|
+
} if not df.empty else {}),
|
|
180
|
+
columns=df.columns.tolist() if not df.empty else [],
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
response = ExecuteDBQueryOutput(
|
|
184
|
+
success=True,
|
|
185
|
+
columns=df.columns.tolist() if not df.empty else [],
|
|
186
|
+
row_count=original_row_count,
|
|
187
|
+
sql_query=sql_query,
|
|
188
|
+
query_executed=sql_query,
|
|
189
|
+
dataframe_records=df.to_dict("records") if not df.empty else [],
|
|
190
|
+
dataframe_info=dataframe_info,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if original_row_count > config.max_rows:
|
|
194
|
+
response.limited_to = config.max_rows
|
|
195
|
+
response.truncated = True
|
|
196
|
+
|
|
197
|
+
# Yield final result as ExecuteDBQueryOutput
|
|
198
|
+
yield response
|
|
199
|
+
# Note: Engine is left alive; connections are managed internally by SQLAlchemy pool
|
|
200
|
+
|
|
201
|
+
except Exception as e:
|
|
202
|
+
logger.error("Error executing SQL query", exc_info=e)
|
|
203
|
+
yield ExecuteDBQueryOutput(
|
|
204
|
+
success=False,
|
|
205
|
+
failure_reason="SQL execution failed. Please check server logs for details.",
|
|
206
|
+
sql_query=sql_query,
|
|
207
|
+
dataframe_info=DataFrameInfo(shape=[0, 0], dtypes={}, columns=[]),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
logger.info("Execute DB Query completed")
|
|
211
|
+
|
|
212
|
+
# Non-streaming version
|
|
213
|
+
async def _execute_sql_query(input_data: ExecuteDBQueryInput) -> ExecuteDBQueryOutput:
|
|
214
|
+
"""Execute SQL query and return results."""
|
|
215
|
+
async for update in _execute_sql_query_stream(input_data):
|
|
216
|
+
# Skip ResponseIntermediateStep objects, only return ExecuteDBQueryOutput
|
|
217
|
+
if isinstance(update, ExecuteDBQueryOutput):
|
|
218
|
+
return update
|
|
219
|
+
|
|
220
|
+
# Fallback if no result found
|
|
221
|
+
return ExecuteDBQueryOutput(
|
|
222
|
+
success=False,
|
|
223
|
+
failure_reason="No result returned",
|
|
224
|
+
sql_query="",
|
|
225
|
+
dataframe_info=DataFrameInfo(shape=[0, 0], dtypes={}, columns=[]),
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
description = (f"Execute SQL queries on {config.database_type} and return results. "
|
|
229
|
+
"Connects to the database, executes the provided SQL query, "
|
|
230
|
+
"and returns results in a structured format.")
|
|
231
|
+
|
|
232
|
+
yield FunctionInfo.create(
|
|
233
|
+
single_fn=_execute_sql_query,
|
|
234
|
+
stream_fn=_execute_sql_query_stream,
|
|
235
|
+
description=description,
|
|
236
|
+
input_schema=ExecuteDBQueryInput,
|
|
237
|
+
)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# flake8: noqa
|
|
17
|
+
# isort:skip_file
|
|
18
|
+
|
|
19
|
+
# Import any providers which need to be automatically registered here
|
|
20
|
+
|
|
21
|
+
from . import execute_db_query
|
|
22
|
+
from . import text2sql
|