nvidia-nat-vanna 1.4.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nvidia-nat-vanna might be problematic. Click here for more details.

@@ -0,0 +1,250 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import uuid
18
+ from collections.abc import AsyncGenerator
19
+
20
+ from pydantic import BaseModel
21
+ from pydantic import Field
22
+
23
+ from nat.builder.builder import Builder
24
+ from nat.builder.framework_enum import LLMFrameworkEnum
25
+ from nat.builder.function_info import FunctionInfo
26
+ from nat.cli.register_workflow import register_function
27
+ from nat.data_models.api_server import ResponseIntermediateStep
28
+ from nat.data_models.component_ref import EmbedderRef
29
+ from nat.data_models.component_ref import LLMRef
30
+ from nat.data_models.component_ref import RetrieverRef
31
+ from nat.data_models.function import FunctionBaseConfig
32
+ from nat.plugins.vanna.db_utils import RequiredSecretStr
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class StatusPayload(BaseModel):
38
+ """Payload for status intermediate steps."""
39
+ message: str
40
+
41
+
42
+ class Text2SQLOutput(BaseModel):
43
+ """Output schema for text2sql function."""
44
+ sql: str = Field(description="Generated SQL query")
45
+ explanation: str | None = Field(default=None, description="Explanation of the query")
46
+
47
+
48
+ class Text2SQLConfig(FunctionBaseConfig, name="text2sql"):
49
+ """
50
+ Text2SQL configuration with Vanna integration.
51
+
52
+ Currently only Databricks is supported.
53
+ """
54
+
55
+ # LLM and Embedder
56
+ llm_name: LLMRef = Field(description="LLM for SQL generation")
57
+ embedder_name: EmbedderRef = Field(description="Embedder for vector operations")
58
+
59
+ # Milvus retriever (required, must use async client)
60
+ milvus_retriever: RetrieverRef = Field(description="Milvus retriever reference for vector operations. "
61
+ "MUST be configured with use_async_client=true for text2sql function.")
62
+
63
+ # Database configuration
64
+ database_type: str = Field(default="databricks",
65
+ description="Database type (currently only 'databricks' is supported)")
66
+ connection_url: RequiredSecretStr = Field(description="Database connection string")
67
+
68
+ # Vanna Milvus configuration
69
+ milvus_search_limit: int = Field(default=1000,
70
+ description="Maximum limit size for vector search operations in Milvus")
71
+
72
+ # Vanna configuration
73
+ allow_llm_to_see_data: bool = Field(default=False, description="Allow LLM to see data for intermediate queries")
74
+ execute_sql: bool = Field(default=False, description="Execute SQL or just return query string")
75
+ train_on_startup: bool = Field(default=False, description="Train Vanna on startup")
76
+ auto_training: bool = Field(default=False,
77
+ description=("Auto-train Vanna (auto-extract DDL and generate training data "
78
+ "from database) or manually train Vanna (uses training data from "
79
+ "training_db_schema.py)"))
80
+ initial_prompt: str | None = Field(default=None, description="Custom system prompt")
81
+ n_results: int = Field(default=5, description="Number of similar examples")
82
+ sql_collection: str = Field(default="vanna_sql", description="Milvus collection for SQL examples")
83
+ ddl_collection: str = Field(default="vanna_ddl", description="Milvus collection for DDL")
84
+ doc_collection: str = Field(default="vanna_documentation", description="Milvus collection for docs")
85
+
86
+ # Model-specific configuration
87
+ reasoning_models: set[str] = Field(
88
+ default={
89
+ "nvidia/llama-3.1-nemotron-ultra-253b-v1",
90
+ "nvidia/llama-3.3-nemotron-super-49b-v1.5",
91
+ "deepseek-ai/deepseek-v3.1",
92
+ "deepseek-ai/deepseek-r1",
93
+ },
94
+ description="Models that require special handling for think tags removal and JSON extraction")
95
+
96
+ chat_models: set[str] = Field(default={"meta/llama-3.1-70b-instruct"},
97
+ description="Models using standard response handling without think tags")
98
+
99
+
100
+ @register_function(config_type=Text2SQLConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
101
+ async def text2sql(config: Text2SQLConfig, builder: Builder):
102
+ """Register the Text2SQL function with Vanna integration."""
103
+ from nat.plugins.vanna.db_utils import setup_vanna_db_connection
104
+ from nat.plugins.vanna.vanna_utils import VannaSingleton
105
+ from nat.plugins.vanna.vanna_utils import train_vanna
106
+
107
+ logger.info("Initializing Text2SQL function")
108
+
109
+ # Check if singleton exists to avoid unnecessary client creation
110
+ existing_instance = VannaSingleton.instance()
111
+ if existing_instance is not None:
112
+ logger.info("Reusing existing Vanna singleton instance")
113
+ vanna_instance = existing_instance
114
+ else:
115
+ # Create all clients only when initializing new singleton
116
+ logger.info("Creating new Vanna singleton instance")
117
+
118
+ # Get LLM and embedder
119
+ llm_client = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
120
+ embedder_client = await builder.get_embedder(config.embedder_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
121
+
122
+ # Get Milvus clients from retriever (expects async client)
123
+ logger.info("Getting async Milvus client from milvus_retriever")
124
+ retriever = await builder.get_retriever(config.milvus_retriever)
125
+
126
+ # Vanna expects async client from retriever
127
+ if not retriever._is_async: # type: ignore[attr-defined]
128
+ msg = (f"Milvus retriever '{config.milvus_retriever}' must be configured with "
129
+ "use_async_client=true for Vanna text2sql function")
130
+ raise ValueError(msg)
131
+
132
+ # Get async client from retriever
133
+ async_milvus_client = retriever._client # type: ignore[attr-defined]
134
+
135
+ # Initialize Vanna instance (singleton pattern) with async client only
136
+ vanna_instance = await VannaSingleton.get_instance(
137
+ llm_client=llm_client,
138
+ embedder_client=embedder_client,
139
+ async_milvus_client=async_milvus_client,
140
+ dialect=config.database_type,
141
+ initial_prompt=config.initial_prompt,
142
+ n_results=config.n_results,
143
+ sql_collection=config.sql_collection,
144
+ ddl_collection=config.ddl_collection,
145
+ doc_collection=config.doc_collection,
146
+ milvus_search_limit=config.milvus_search_limit,
147
+ reasoning_models=config.reasoning_models,
148
+ chat_models=config.chat_models,
149
+ create_collections=config.train_on_startup,
150
+ )
151
+
152
+ # Validate database type
153
+ if config.database_type.lower() != "databricks":
154
+ msg = f"Only Databricks is currently supported. Got database_type: {config.database_type}"
155
+ raise ValueError(msg)
156
+
157
+ # Setup database connection (Engine stored in vanna_instance.db_engine)
158
+ setup_vanna_db_connection(
159
+ vn=vanna_instance,
160
+ database_type=config.database_type,
161
+ connection_url=config.connection_url.get_secret_value(),
162
+ )
163
+
164
+ # Train on startup if configured
165
+ if config.train_on_startup:
166
+ await train_vanna(vanna_instance, auto_train=config.auto_training)
167
+
168
+ # Streaming version
169
+ async def _generate_sql_stream(question: str, ) -> AsyncGenerator[ResponseIntermediateStep | Text2SQLOutput, None]:
170
+ """Stream SQL generation progress and results."""
171
+ logger.info(f"Text2SQL input: {question}")
172
+
173
+ # Generate parent_id for this function call
174
+ parent_id = str(uuid.uuid4())
175
+
176
+ # Yield starting status as ResponseIntermediateStep
177
+ yield ResponseIntermediateStep(
178
+ id=str(uuid.uuid4()),
179
+ parent_id=parent_id,
180
+ type="markdown",
181
+ name="text2sql_status",
182
+ payload=StatusPayload(message="Starting SQL generation...").model_dump_json(),
183
+ )
184
+
185
+ try:
186
+ # Generate SQL using Vanna (returns dict with sql and explanation)
187
+ sql_result = await vanna_instance.generate_sql(
188
+ question=question,
189
+ allow_llm_to_see_data=config.allow_llm_to_see_data,
190
+ )
191
+
192
+ sql = str(sql_result.get("sql", ""))
193
+ explanation: str | None = sql_result.get("explanation")
194
+
195
+ # If execute_sql is enabled, run the query
196
+ if config.execute_sql:
197
+ yield ResponseIntermediateStep(
198
+ id=str(uuid.uuid4()),
199
+ parent_id=parent_id,
200
+ type="markdown",
201
+ name="text2sql_status",
202
+ payload=StatusPayload(message="Executing SQL query...").model_dump_json(),
203
+ )
204
+ # Execute SQL and propagate errors
205
+ # Note: run_sql is dynamically set as async function in setup_vanna_db_connection
206
+ df = await vanna_instance.run_sql(sql) # type: ignore[misc]
207
+ logger.info(f"SQL executed successfully: {len(df)} rows returned")
208
+
209
+ # Yield final result as Text2SQLOutput
210
+ yield Text2SQLOutput(sql=sql, explanation=explanation)
211
+
212
+ except Exception as e:
213
+ logger.error("SQL generation failed", exc_info=e)
214
+ # Error status as ResponseIntermediateStep
215
+ yield ResponseIntermediateStep(
216
+ id=str(uuid.uuid4()),
217
+ parent_id=parent_id,
218
+ type="markdown",
219
+ name="text2sql_error",
220
+ payload=StatusPayload(
221
+ message="SQL generation failed. Please check server logs for details.").model_dump_json(),
222
+ )
223
+ raise
224
+
225
+ logger.info("Text2SQL completed successfully")
226
+
227
+ # Non-streaming version
228
+ async def _generate_sql(question: str) -> Text2SQLOutput:
229
+ """Generate SQL query from natural language."""
230
+ async for update in _generate_sql_stream(question):
231
+ # Skip ResponseIntermediateStep objects, only return Text2SQLOutput
232
+ if isinstance(update, Text2SQLOutput):
233
+ return update
234
+
235
+ # Fallback if no result found
236
+ return Text2SQLOutput(sql="", explanation=None)
237
+
238
+ description = ("Generate SQL queries from natural language questions using AI. "
239
+ "Leverages similar question-SQL pairs, DDL information, and "
240
+ "documentation to generate accurate SQL queries. "
241
+ "Currently supports Databricks only.")
242
+
243
+ if config.execute_sql:
244
+ description += " Also executes queries and returns results."
245
+
246
+ yield FunctionInfo.create(
247
+ single_fn=_generate_sql,
248
+ stream_fn=_generate_sql_stream,
249
+ description=description,
250
+ )
@@ -0,0 +1,75 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Manual training data and configuration for Vanna text-to-SQL.
16
+
17
+ This module provides default DDL statements, documentation examples,
18
+ question-SQL pairs, and prompt templates used to train and configure
19
+ the Vanna text-to-SQL model with database schema context.
20
+ """
21
+
22
+ # yapf: disable
23
+ # ruff: noqa: E501
24
+
25
+ # DDL statements for training
26
+ # Define your database schema here to help the model understand table structures
27
+ VANNA_TRAINING_DDL: list[str] = [
28
+ "CREATE TABLE customers (id INT PRIMARY KEY, name VARCHAR(100), email VARCHAR(100), created_at TIMESTAMP)",
29
+ "CREATE TABLE orders (id INT PRIMARY KEY, customer_id INT, product VARCHAR(100), amount DECIMAL(10,2), order_date DATE)",
30
+ "CREATE TABLE products (id INT PRIMARY KEY, name VARCHAR(100), category VARCHAR(50), price DECIMAL(10,2))",
31
+ ]
32
+
33
+ # Documentation for training
34
+ # Provide context and business logic about your tables and columns
35
+ VANNA_TRAINING_DOCUMENTATION: list[str] = [
36
+ "The customers table contains all registered users. The created_at field shows registration date.",
37
+ "Orders table tracks all purchases. The amount field is in USD.",
38
+ "Products are organized by category (electronics, clothing, home, etc.).",
39
+ ]
40
+
41
+ # Question-SQL examples for training
42
+ # Provide example question-SQL pairs to teach the model your query patterns
43
+ VANNA_TRAINING_EXAMPLES: list[dict[str, str]] = [
44
+ {
45
+ "question": "How many customers do we have?",
46
+ "sql": "SELECT COUNT(*) as customer_count FROM customers",
47
+ },
48
+ {
49
+ "question": "What is the total revenue?",
50
+ "sql": "SELECT SUM(amount) as total_revenue FROM orders",
51
+ },
52
+ {
53
+ "question": "Who are the top 5 customers by spending?",
54
+ "sql": "SELECT c.name, SUM(o.amount) as total_spent FROM customers c JOIN orders o ON c.id = o.customer_id GROUP BY c.id, c.name ORDER BY total_spent DESC LIMIT 5",
55
+ },
56
+ ]
57
+
58
+ VANNA_ACTIVE_TABLES = ['catalog.schema.table_a', 'catalog.schema.table_b']
59
+
60
+ # Default prompts
61
+ VANNA_RESPONSE_GUIDELINES = """
62
+ Response Guidelines:
63
+ 1. Carefully analyze the question to understand the user's intent, target columns, filters, and any aggregation or grouping requirements.
64
+ 2. Output only JSON:
65
+ {
66
+ "sql": "<valid SQL query>",
67
+ "explanation": "<brief description>",
68
+ }
69
+ """
70
+
71
+ VANNA_TRAINING_PROMPT = """
72
+ Response Guidelines:
73
+ 1. Generate 20 natural language questions and their corresponding valid SQL queries.
74
+ 2. Output JSON like: [{{"question": "...", "sql": "..."}}]
75
+ """