nvidia-nat-vanna 1.4.0b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nvidia-nat-vanna might be problematic. Click here for more details.
- nat/meta/pypi.md +129 -0
- nat/plugins/vanna/__init__.py +14 -0
- nat/plugins/vanna/db_utils.py +296 -0
- nat/plugins/vanna/execute_db_query.py +237 -0
- nat/plugins/vanna/register.py +22 -0
- nat/plugins/vanna/text2sql.py +250 -0
- nat/plugins/vanna/training_db_schema.py +75 -0
- nat/plugins/vanna/vanna_utils.py +843 -0
- nvidia_nat_vanna-1.4.0b2.dist-info/METADATA +149 -0
- nvidia_nat_vanna-1.4.0b2.dist-info/RECORD +13 -0
- nvidia_nat_vanna-1.4.0b2.dist-info/WHEEL +5 -0
- nvidia_nat_vanna-1.4.0b2.dist-info/entry_points.txt +2 -0
- nvidia_nat_vanna-1.4.0b2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import uuid
|
|
18
|
+
from collections.abc import AsyncGenerator
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel
|
|
21
|
+
from pydantic import Field
|
|
22
|
+
|
|
23
|
+
from nat.builder.builder import Builder
|
|
24
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
25
|
+
from nat.builder.function_info import FunctionInfo
|
|
26
|
+
from nat.cli.register_workflow import register_function
|
|
27
|
+
from nat.data_models.api_server import ResponseIntermediateStep
|
|
28
|
+
from nat.data_models.component_ref import EmbedderRef
|
|
29
|
+
from nat.data_models.component_ref import LLMRef
|
|
30
|
+
from nat.data_models.component_ref import RetrieverRef
|
|
31
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
32
|
+
from nat.plugins.vanna.db_utils import RequiredSecretStr
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class StatusPayload(BaseModel):
|
|
38
|
+
"""Payload for status intermediate steps."""
|
|
39
|
+
message: str
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Text2SQLOutput(BaseModel):
|
|
43
|
+
"""Output schema for text2sql function."""
|
|
44
|
+
sql: str = Field(description="Generated SQL query")
|
|
45
|
+
explanation: str | None = Field(default=None, description="Explanation of the query")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class Text2SQLConfig(FunctionBaseConfig, name="text2sql"):
|
|
49
|
+
"""
|
|
50
|
+
Text2SQL configuration with Vanna integration.
|
|
51
|
+
|
|
52
|
+
Currently only Databricks is supported.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
# LLM and Embedder
|
|
56
|
+
llm_name: LLMRef = Field(description="LLM for SQL generation")
|
|
57
|
+
embedder_name: EmbedderRef = Field(description="Embedder for vector operations")
|
|
58
|
+
|
|
59
|
+
# Milvus retriever (required, must use async client)
|
|
60
|
+
milvus_retriever: RetrieverRef = Field(description="Milvus retriever reference for vector operations. "
|
|
61
|
+
"MUST be configured with use_async_client=true for text2sql function.")
|
|
62
|
+
|
|
63
|
+
# Database configuration
|
|
64
|
+
database_type: str = Field(default="databricks",
|
|
65
|
+
description="Database type (currently only 'databricks' is supported)")
|
|
66
|
+
connection_url: RequiredSecretStr = Field(description="Database connection string")
|
|
67
|
+
|
|
68
|
+
# Vanna Milvus configuration
|
|
69
|
+
milvus_search_limit: int = Field(default=1000,
|
|
70
|
+
description="Maximum limit size for vector search operations in Milvus")
|
|
71
|
+
|
|
72
|
+
# Vanna configuration
|
|
73
|
+
allow_llm_to_see_data: bool = Field(default=False, description="Allow LLM to see data for intermediate queries")
|
|
74
|
+
execute_sql: bool = Field(default=False, description="Execute SQL or just return query string")
|
|
75
|
+
train_on_startup: bool = Field(default=False, description="Train Vanna on startup")
|
|
76
|
+
auto_training: bool = Field(default=False,
|
|
77
|
+
description=("Auto-train Vanna (auto-extract DDL and generate training data "
|
|
78
|
+
"from database) or manually train Vanna (uses training data from "
|
|
79
|
+
"training_db_schema.py)"))
|
|
80
|
+
initial_prompt: str | None = Field(default=None, description="Custom system prompt")
|
|
81
|
+
n_results: int = Field(default=5, description="Number of similar examples")
|
|
82
|
+
sql_collection: str = Field(default="vanna_sql", description="Milvus collection for SQL examples")
|
|
83
|
+
ddl_collection: str = Field(default="vanna_ddl", description="Milvus collection for DDL")
|
|
84
|
+
doc_collection: str = Field(default="vanna_documentation", description="Milvus collection for docs")
|
|
85
|
+
|
|
86
|
+
# Model-specific configuration
|
|
87
|
+
reasoning_models: set[str] = Field(
|
|
88
|
+
default={
|
|
89
|
+
"nvidia/llama-3.1-nemotron-ultra-253b-v1",
|
|
90
|
+
"nvidia/llama-3.3-nemotron-super-49b-v1.5",
|
|
91
|
+
"deepseek-ai/deepseek-v3.1",
|
|
92
|
+
"deepseek-ai/deepseek-r1",
|
|
93
|
+
},
|
|
94
|
+
description="Models that require special handling for think tags removal and JSON extraction")
|
|
95
|
+
|
|
96
|
+
chat_models: set[str] = Field(default={"meta/llama-3.1-70b-instruct"},
|
|
97
|
+
description="Models using standard response handling without think tags")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@register_function(config_type=Text2SQLConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
101
|
+
async def text2sql(config: Text2SQLConfig, builder: Builder):
|
|
102
|
+
"""Register the Text2SQL function with Vanna integration."""
|
|
103
|
+
from nat.plugins.vanna.db_utils import setup_vanna_db_connection
|
|
104
|
+
from nat.plugins.vanna.vanna_utils import VannaSingleton
|
|
105
|
+
from nat.plugins.vanna.vanna_utils import train_vanna
|
|
106
|
+
|
|
107
|
+
logger.info("Initializing Text2SQL function")
|
|
108
|
+
|
|
109
|
+
# Check if singleton exists to avoid unnecessary client creation
|
|
110
|
+
existing_instance = VannaSingleton.instance()
|
|
111
|
+
if existing_instance is not None:
|
|
112
|
+
logger.info("Reusing existing Vanna singleton instance")
|
|
113
|
+
vanna_instance = existing_instance
|
|
114
|
+
else:
|
|
115
|
+
# Create all clients only when initializing new singleton
|
|
116
|
+
logger.info("Creating new Vanna singleton instance")
|
|
117
|
+
|
|
118
|
+
# Get LLM and embedder
|
|
119
|
+
llm_client = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
120
|
+
embedder_client = await builder.get_embedder(config.embedder_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
121
|
+
|
|
122
|
+
# Get Milvus clients from retriever (expects async client)
|
|
123
|
+
logger.info("Getting async Milvus client from milvus_retriever")
|
|
124
|
+
retriever = await builder.get_retriever(config.milvus_retriever)
|
|
125
|
+
|
|
126
|
+
# Vanna expects async client from retriever
|
|
127
|
+
if not retriever._is_async: # type: ignore[attr-defined]
|
|
128
|
+
msg = (f"Milvus retriever '{config.milvus_retriever}' must be configured with "
|
|
129
|
+
"use_async_client=true for Vanna text2sql function")
|
|
130
|
+
raise ValueError(msg)
|
|
131
|
+
|
|
132
|
+
# Get async client from retriever
|
|
133
|
+
async_milvus_client = retriever._client # type: ignore[attr-defined]
|
|
134
|
+
|
|
135
|
+
# Initialize Vanna instance (singleton pattern) with async client only
|
|
136
|
+
vanna_instance = await VannaSingleton.get_instance(
|
|
137
|
+
llm_client=llm_client,
|
|
138
|
+
embedder_client=embedder_client,
|
|
139
|
+
async_milvus_client=async_milvus_client,
|
|
140
|
+
dialect=config.database_type,
|
|
141
|
+
initial_prompt=config.initial_prompt,
|
|
142
|
+
n_results=config.n_results,
|
|
143
|
+
sql_collection=config.sql_collection,
|
|
144
|
+
ddl_collection=config.ddl_collection,
|
|
145
|
+
doc_collection=config.doc_collection,
|
|
146
|
+
milvus_search_limit=config.milvus_search_limit,
|
|
147
|
+
reasoning_models=config.reasoning_models,
|
|
148
|
+
chat_models=config.chat_models,
|
|
149
|
+
create_collections=config.train_on_startup,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Validate database type
|
|
153
|
+
if config.database_type.lower() != "databricks":
|
|
154
|
+
msg = f"Only Databricks is currently supported. Got database_type: {config.database_type}"
|
|
155
|
+
raise ValueError(msg)
|
|
156
|
+
|
|
157
|
+
# Setup database connection (Engine stored in vanna_instance.db_engine)
|
|
158
|
+
setup_vanna_db_connection(
|
|
159
|
+
vn=vanna_instance,
|
|
160
|
+
database_type=config.database_type,
|
|
161
|
+
connection_url=config.connection_url.get_secret_value(),
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Train on startup if configured
|
|
165
|
+
if config.train_on_startup:
|
|
166
|
+
await train_vanna(vanna_instance, auto_train=config.auto_training)
|
|
167
|
+
|
|
168
|
+
# Streaming version
|
|
169
|
+
async def _generate_sql_stream(question: str, ) -> AsyncGenerator[ResponseIntermediateStep | Text2SQLOutput, None]:
|
|
170
|
+
"""Stream SQL generation progress and results."""
|
|
171
|
+
logger.info(f"Text2SQL input: {question}")
|
|
172
|
+
|
|
173
|
+
# Generate parent_id for this function call
|
|
174
|
+
parent_id = str(uuid.uuid4())
|
|
175
|
+
|
|
176
|
+
# Yield starting status as ResponseIntermediateStep
|
|
177
|
+
yield ResponseIntermediateStep(
|
|
178
|
+
id=str(uuid.uuid4()),
|
|
179
|
+
parent_id=parent_id,
|
|
180
|
+
type="markdown",
|
|
181
|
+
name="text2sql_status",
|
|
182
|
+
payload=StatusPayload(message="Starting SQL generation...").model_dump_json(),
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
# Generate SQL using Vanna (returns dict with sql and explanation)
|
|
187
|
+
sql_result = await vanna_instance.generate_sql(
|
|
188
|
+
question=question,
|
|
189
|
+
allow_llm_to_see_data=config.allow_llm_to_see_data,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
sql = str(sql_result.get("sql", ""))
|
|
193
|
+
explanation: str | None = sql_result.get("explanation")
|
|
194
|
+
|
|
195
|
+
# If execute_sql is enabled, run the query
|
|
196
|
+
if config.execute_sql:
|
|
197
|
+
yield ResponseIntermediateStep(
|
|
198
|
+
id=str(uuid.uuid4()),
|
|
199
|
+
parent_id=parent_id,
|
|
200
|
+
type="markdown",
|
|
201
|
+
name="text2sql_status",
|
|
202
|
+
payload=StatusPayload(message="Executing SQL query...").model_dump_json(),
|
|
203
|
+
)
|
|
204
|
+
# Execute SQL and propagate errors
|
|
205
|
+
# Note: run_sql is dynamically set as async function in setup_vanna_db_connection
|
|
206
|
+
df = await vanna_instance.run_sql(sql) # type: ignore[misc]
|
|
207
|
+
logger.info(f"SQL executed successfully: {len(df)} rows returned")
|
|
208
|
+
|
|
209
|
+
# Yield final result as Text2SQLOutput
|
|
210
|
+
yield Text2SQLOutput(sql=sql, explanation=explanation)
|
|
211
|
+
|
|
212
|
+
except Exception as e:
|
|
213
|
+
logger.error("SQL generation failed", exc_info=e)
|
|
214
|
+
# Error status as ResponseIntermediateStep
|
|
215
|
+
yield ResponseIntermediateStep(
|
|
216
|
+
id=str(uuid.uuid4()),
|
|
217
|
+
parent_id=parent_id,
|
|
218
|
+
type="markdown",
|
|
219
|
+
name="text2sql_error",
|
|
220
|
+
payload=StatusPayload(
|
|
221
|
+
message="SQL generation failed. Please check server logs for details.").model_dump_json(),
|
|
222
|
+
)
|
|
223
|
+
raise
|
|
224
|
+
|
|
225
|
+
logger.info("Text2SQL completed successfully")
|
|
226
|
+
|
|
227
|
+
# Non-streaming version
|
|
228
|
+
async def _generate_sql(question: str) -> Text2SQLOutput:
|
|
229
|
+
"""Generate SQL query from natural language."""
|
|
230
|
+
async for update in _generate_sql_stream(question):
|
|
231
|
+
# Skip ResponseIntermediateStep objects, only return Text2SQLOutput
|
|
232
|
+
if isinstance(update, Text2SQLOutput):
|
|
233
|
+
return update
|
|
234
|
+
|
|
235
|
+
# Fallback if no result found
|
|
236
|
+
return Text2SQLOutput(sql="", explanation=None)
|
|
237
|
+
|
|
238
|
+
description = ("Generate SQL queries from natural language questions using AI. "
|
|
239
|
+
"Leverages similar question-SQL pairs, DDL information, and "
|
|
240
|
+
"documentation to generate accurate SQL queries. "
|
|
241
|
+
"Currently supports Databricks only.")
|
|
242
|
+
|
|
243
|
+
if config.execute_sql:
|
|
244
|
+
description += " Also executes queries and returns results."
|
|
245
|
+
|
|
246
|
+
yield FunctionInfo.create(
|
|
247
|
+
single_fn=_generate_sql,
|
|
248
|
+
stream_fn=_generate_sql_stream,
|
|
249
|
+
description=description,
|
|
250
|
+
)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Manual training data and configuration for Vanna text-to-SQL.
|
|
16
|
+
|
|
17
|
+
This module provides default DDL statements, documentation examples,
|
|
18
|
+
question-SQL pairs, and prompt templates used to train and configure
|
|
19
|
+
the Vanna text-to-SQL model with database schema context.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
# yapf: disable
|
|
23
|
+
# ruff: noqa: E501
|
|
24
|
+
|
|
25
|
+
# DDL statements for training
|
|
26
|
+
# Define your database schema here to help the model understand table structures
|
|
27
|
+
VANNA_TRAINING_DDL: list[str] = [
|
|
28
|
+
"CREATE TABLE customers (id INT PRIMARY KEY, name VARCHAR(100), email VARCHAR(100), created_at TIMESTAMP)",
|
|
29
|
+
"CREATE TABLE orders (id INT PRIMARY KEY, customer_id INT, product VARCHAR(100), amount DECIMAL(10,2), order_date DATE)",
|
|
30
|
+
"CREATE TABLE products (id INT PRIMARY KEY, name VARCHAR(100), category VARCHAR(50), price DECIMAL(10,2))",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
# Documentation for training
|
|
34
|
+
# Provide context and business logic about your tables and columns
|
|
35
|
+
VANNA_TRAINING_DOCUMENTATION: list[str] = [
|
|
36
|
+
"The customers table contains all registered users. The created_at field shows registration date.",
|
|
37
|
+
"Orders table tracks all purchases. The amount field is in USD.",
|
|
38
|
+
"Products are organized by category (electronics, clothing, home, etc.).",
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
# Question-SQL examples for training
|
|
42
|
+
# Provide example question-SQL pairs to teach the model your query patterns
|
|
43
|
+
VANNA_TRAINING_EXAMPLES: list[dict[str, str]] = [
|
|
44
|
+
{
|
|
45
|
+
"question": "How many customers do we have?",
|
|
46
|
+
"sql": "SELECT COUNT(*) as customer_count FROM customers",
|
|
47
|
+
},
|
|
48
|
+
{
|
|
49
|
+
"question": "What is the total revenue?",
|
|
50
|
+
"sql": "SELECT SUM(amount) as total_revenue FROM orders",
|
|
51
|
+
},
|
|
52
|
+
{
|
|
53
|
+
"question": "Who are the top 5 customers by spending?",
|
|
54
|
+
"sql": "SELECT c.name, SUM(o.amount) as total_spent FROM customers c JOIN orders o ON c.id = o.customer_id GROUP BY c.id, c.name ORDER BY total_spent DESC LIMIT 5",
|
|
55
|
+
},
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
VANNA_ACTIVE_TABLES = ['catalog.schema.table_a', 'catalog.schema.table_b']
|
|
59
|
+
|
|
60
|
+
# Default prompts
|
|
61
|
+
VANNA_RESPONSE_GUIDELINES = """
|
|
62
|
+
Response Guidelines:
|
|
63
|
+
1. Carefully analyze the question to understand the user's intent, target columns, filters, and any aggregation or grouping requirements.
|
|
64
|
+
2. Output only JSON:
|
|
65
|
+
{
|
|
66
|
+
"sql": "<valid SQL query>",
|
|
67
|
+
"explanation": "<brief description>",
|
|
68
|
+
}
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
VANNA_TRAINING_PROMPT = """
|
|
72
|
+
Response Guidelines:
|
|
73
|
+
1. Generate 20 natural language questions and their corresponding valid SQL queries.
|
|
74
|
+
2. Output JSON like: [{{"question": "...", "sql": "..."}}]
|
|
75
|
+
"""
|