nvidia-nat-vanna 1.5.0a20260115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nvidia-nat-vanna might be problematic. Click here for more details.
- nat/meta/pypi.md +129 -0
- nat/plugins/vanna/__init__.py +14 -0
- nat/plugins/vanna/db_utils.py +296 -0
- nat/plugins/vanna/execute_db_query.py +237 -0
- nat/plugins/vanna/register.py +22 -0
- nat/plugins/vanna/text2sql.py +250 -0
- nat/plugins/vanna/training_db_schema.py +75 -0
- nat/plugins/vanna/vanna_utils.py +843 -0
- nvidia_nat_vanna-1.5.0a20260115.dist-info/METADATA +149 -0
- nvidia_nat_vanna-1.5.0a20260115.dist-info/RECORD +13 -0
- nvidia_nat_vanna-1.5.0a20260115.dist-info/WHEEL +5 -0
- nvidia_nat_vanna-1.5.0a20260115.dist-info/entry_points.txt +2 -0
- nvidia_nat_vanna-1.5.0a20260115.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,843 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import uuid
|
|
20
|
+
|
|
21
|
+
from nat.plugins.vanna.training_db_schema import VANNA_RESPONSE_GUIDELINES
|
|
22
|
+
from nat.plugins.vanna.training_db_schema import VANNA_TRAINING_DDL
|
|
23
|
+
from nat.plugins.vanna.training_db_schema import VANNA_TRAINING_DOCUMENTATION
|
|
24
|
+
from nat.plugins.vanna.training_db_schema import VANNA_TRAINING_EXAMPLES
|
|
25
|
+
from nat.plugins.vanna.training_db_schema import VANNA_TRAINING_PROMPT
|
|
26
|
+
from vanna.legacy.base import VannaBase
|
|
27
|
+
from vanna.legacy.milvus import Milvus_VectorStore
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def extract_json_from_string(content: str) -> dict:
|
|
33
|
+
"""Extract JSON from a string that may contain additional content.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
content: String containing JSON data
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Parsed JSON as dictionary
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
ValueError: If no valid JSON found
|
|
43
|
+
"""
|
|
44
|
+
try:
|
|
45
|
+
return json.loads(content)
|
|
46
|
+
except json.JSONDecodeError:
|
|
47
|
+
try:
|
|
48
|
+
# Extract JSON from string that may contain additional content
|
|
49
|
+
json_str = content
|
|
50
|
+
# Try to find JSON between ``` markers
|
|
51
|
+
if "```" in content:
|
|
52
|
+
json_start = content.find("```")
|
|
53
|
+
if json_start != -1:
|
|
54
|
+
json_start += len("```")
|
|
55
|
+
json_end = content.find("```", json_start)
|
|
56
|
+
if json_end != -1:
|
|
57
|
+
json_str = content[json_start:json_end]
|
|
58
|
+
else:
|
|
59
|
+
msg = "No JSON found in response"
|
|
60
|
+
raise ValueError(msg)
|
|
61
|
+
else:
|
|
62
|
+
json_start = content.find("{")
|
|
63
|
+
json_end = content.rfind("}") + 1
|
|
64
|
+
json_str = content[json_start:json_end]
|
|
65
|
+
|
|
66
|
+
return json.loads(json_str.strip())
|
|
67
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
68
|
+
logger.error(f"Failed to extract JSON from content: {e}")
|
|
69
|
+
raise ValueError("Could not extract valid JSON from response") from e
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def remove_think_tags(text: str, model_name: str, reasoning_models: set[str]) -> str:
|
|
73
|
+
"""Remove think tags from reasoning model output based on model type.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
text: Text potentially containing think tags
|
|
77
|
+
model_name: Name of the model
|
|
78
|
+
reasoning_models: Set of model names that require think tag removal
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Text with think tags removed if applicable
|
|
82
|
+
"""
|
|
83
|
+
if "openai/gpt-oss" in model_name:
|
|
84
|
+
return text
|
|
85
|
+
elif model_name in reasoning_models:
|
|
86
|
+
from nat.utils.io.model_processing import remove_r1_think_tags
|
|
87
|
+
|
|
88
|
+
return remove_r1_think_tags(text)
|
|
89
|
+
else:
|
|
90
|
+
return text
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def to_langchain_msgs(msgs):
|
|
94
|
+
"""Convert message dicts to LangChain message objects."""
|
|
95
|
+
from langchain_core.messages import AIMessage
|
|
96
|
+
from langchain_core.messages import HumanMessage
|
|
97
|
+
from langchain_core.messages import SystemMessage
|
|
98
|
+
|
|
99
|
+
role2cls = {"system": SystemMessage, "user": HumanMessage, "assistant": AIMessage}
|
|
100
|
+
return [role2cls[m["role"]](content=m["content"]) for m in msgs]
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class VannaLangChainLLM(VannaBase):
|
|
104
|
+
"""LangChain LLM integration for Vanna framework."""
|
|
105
|
+
|
|
106
|
+
def __init__(self, client=None, config=None):
|
|
107
|
+
if client is None:
|
|
108
|
+
msg = "LangChain client must be provided"
|
|
109
|
+
raise ValueError(msg)
|
|
110
|
+
|
|
111
|
+
self.client = client
|
|
112
|
+
self.config = config or {}
|
|
113
|
+
self.dialect = self.config.get("dialect", "SQL")
|
|
114
|
+
self.model = getattr(self.client, "model", "unknown")
|
|
115
|
+
|
|
116
|
+
# Store configurable values
|
|
117
|
+
self.milvus_search_limit = self.config.get("milvus_search_limit", 1000)
|
|
118
|
+
self.reasoning_models = self.config["reasoning_models"]
|
|
119
|
+
self.chat_models = self.config["chat_models"]
|
|
120
|
+
|
|
121
|
+
def system_message(self, message: str) -> dict:
|
|
122
|
+
"""Create system message."""
|
|
123
|
+
return {"role": "system", "content": message}
|
|
124
|
+
|
|
125
|
+
def user_message(self, message: str) -> dict:
|
|
126
|
+
"""Create user message."""
|
|
127
|
+
return {"role": "user", "content": message}
|
|
128
|
+
|
|
129
|
+
def assistant_message(self, message: str) -> dict:
|
|
130
|
+
"""Create assistant message."""
|
|
131
|
+
return {"role": "assistant", "content": message}
|
|
132
|
+
|
|
133
|
+
def get_training_sql_prompt(
|
|
134
|
+
self,
|
|
135
|
+
ddl_list: list,
|
|
136
|
+
doc_list: list,
|
|
137
|
+
) -> list:
|
|
138
|
+
"""Generate prompt for synthetic question-SQL pairs."""
|
|
139
|
+
initial_prompt = (f"You are a {self.dialect} expert. "
|
|
140
|
+
"Please generate diverse question-SQL pairs where each SQL "
|
|
141
|
+
"statement starts with either `SELECT` or `WITH`. "
|
|
142
|
+
"Your response should follow the response guidelines and format instructions.")
|
|
143
|
+
|
|
144
|
+
# Add DDL information
|
|
145
|
+
initial_prompt = self.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=self.max_tokens)
|
|
146
|
+
|
|
147
|
+
# Add documentation
|
|
148
|
+
if self.static_documentation != "":
|
|
149
|
+
doc_list.append(self.static_documentation)
|
|
150
|
+
|
|
151
|
+
initial_prompt = self.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=self.max_tokens)
|
|
152
|
+
|
|
153
|
+
# Add response guidelines
|
|
154
|
+
initial_prompt += VANNA_TRAINING_PROMPT
|
|
155
|
+
|
|
156
|
+
# Build message log
|
|
157
|
+
message_log = [self.system_message(initial_prompt)]
|
|
158
|
+
message_log.append(self.user_message('Begin:'))
|
|
159
|
+
return message_log
|
|
160
|
+
|
|
161
|
+
def get_sql_prompt(
|
|
162
|
+
self,
|
|
163
|
+
initial_prompt: str | None,
|
|
164
|
+
question: str,
|
|
165
|
+
question_sql_list: list,
|
|
166
|
+
ddl_list: list,
|
|
167
|
+
doc_list: list,
|
|
168
|
+
error_message: dict | None = None,
|
|
169
|
+
**kwargs,
|
|
170
|
+
) -> list:
|
|
171
|
+
"""Generate prompt for SQL generation."""
|
|
172
|
+
if initial_prompt is None:
|
|
173
|
+
initial_prompt = (f"You are a {self.dialect} expert. "
|
|
174
|
+
"Please help to generate a SQL query to answer the question. "
|
|
175
|
+
"Your response should ONLY be based on the given context "
|
|
176
|
+
"and follow the response guidelines and format instructions.")
|
|
177
|
+
|
|
178
|
+
# Add DDL information
|
|
179
|
+
initial_prompt = self.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=self.max_tokens)
|
|
180
|
+
|
|
181
|
+
# Add documentation
|
|
182
|
+
if self.static_documentation != "":
|
|
183
|
+
doc_list.append(self.static_documentation)
|
|
184
|
+
|
|
185
|
+
initial_prompt = self.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=self.max_tokens)
|
|
186
|
+
|
|
187
|
+
# Add response guidelines
|
|
188
|
+
initial_prompt += VANNA_RESPONSE_GUIDELINES
|
|
189
|
+
initial_prompt += (f"3. Ensure that the output SQL is {self.dialect}-compliant "
|
|
190
|
+
"and executable, and free of syntax errors.\n")
|
|
191
|
+
|
|
192
|
+
# Add error message if provided
|
|
193
|
+
if error_message is not None:
|
|
194
|
+
initial_prompt += (f"4. For question: {question}. "
|
|
195
|
+
"\tPrevious SQL attempt failed with error: "
|
|
196
|
+
f"{error_message['sql_error']}\n"
|
|
197
|
+
f"\tPrevious SQL was: {error_message['previous_sql']}\n"
|
|
198
|
+
"\tPlease fix the SQL syntax/logic error and regenerate.")
|
|
199
|
+
|
|
200
|
+
# Build message log with examples
|
|
201
|
+
message_log = [self.system_message(initial_prompt)]
|
|
202
|
+
|
|
203
|
+
for example in question_sql_list:
|
|
204
|
+
if example and "question" in example and "sql" in example:
|
|
205
|
+
message_log.append(self.user_message(example["question"]))
|
|
206
|
+
message_log.append(self.assistant_message(example["sql"]))
|
|
207
|
+
|
|
208
|
+
message_log.append(self.user_message(question))
|
|
209
|
+
return message_log
|
|
210
|
+
|
|
211
|
+
async def submit_prompt(self, prompt, **kwargs) -> str:
|
|
212
|
+
"""Submit prompt to LLM."""
|
|
213
|
+
try:
|
|
214
|
+
# Determine model name
|
|
215
|
+
llm_name = getattr(self.client, 'model_name', None) or getattr(self.client, 'model', 'unknown')
|
|
216
|
+
|
|
217
|
+
# Get LLM response (with streaming for reasoning models)
|
|
218
|
+
if llm_name in self.reasoning_models:
|
|
219
|
+
llm_output = ""
|
|
220
|
+
async for chunk in self.client.astream(prompt):
|
|
221
|
+
llm_output += chunk.content
|
|
222
|
+
llm_response = remove_think_tags(llm_output, llm_name, self.reasoning_models)
|
|
223
|
+
else:
|
|
224
|
+
llm_response = (await self.client.ainvoke(prompt)).content
|
|
225
|
+
|
|
226
|
+
logger.debug(f"LLM Response: {llm_response}")
|
|
227
|
+
return llm_response
|
|
228
|
+
|
|
229
|
+
except Exception as e:
|
|
230
|
+
logger.error(f"Error calling LLM during SQL query generation: {e}")
|
|
231
|
+
raise
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class MilvusVectorStore(Milvus_VectorStore):
|
|
235
|
+
"""Extended Milvus vector store for Vanna."""
|
|
236
|
+
|
|
237
|
+
def __init__(self, config=None):
|
|
238
|
+
try:
|
|
239
|
+
VannaBase.__init__(self, config=config)
|
|
240
|
+
|
|
241
|
+
# Only use async client
|
|
242
|
+
self.async_milvus_client = config["async_milvus_client"]
|
|
243
|
+
self.n_results = config.get("n_results", 5)
|
|
244
|
+
self.milvus_search_limit = config.get("milvus_search_limit", 1000)
|
|
245
|
+
|
|
246
|
+
# Use configured embedder
|
|
247
|
+
if config.get("embedder_client") is not None:
|
|
248
|
+
logger.info("Using configured embedder client")
|
|
249
|
+
self.embedder = config["embedder_client"]
|
|
250
|
+
else:
|
|
251
|
+
msg = "Embedder client must be provided in config"
|
|
252
|
+
raise ValueError(msg)
|
|
253
|
+
|
|
254
|
+
try:
|
|
255
|
+
self._embedding_dim = len(self.embedder.embed_documents(["test"])[0])
|
|
256
|
+
logger.info(f"Embedding dimension: {self._embedding_dim}")
|
|
257
|
+
except Exception as e:
|
|
258
|
+
logger.error(f"Error calling embedder during Milvus initialization: {e}")
|
|
259
|
+
raise
|
|
260
|
+
|
|
261
|
+
# Collection names
|
|
262
|
+
self.sql_collection = config.get("sql_collection", "vanna_sql")
|
|
263
|
+
self.ddl_collection = config.get("ddl_collection", "vanna_ddl")
|
|
264
|
+
self.doc_collection = config.get("doc_collection", "vanna_documentation")
|
|
265
|
+
|
|
266
|
+
# Collection creation tracking
|
|
267
|
+
self._collections_created = False
|
|
268
|
+
except Exception as e:
|
|
269
|
+
logger.error(f"Error initializing MilvusVectorStore: {e}")
|
|
270
|
+
raise
|
|
271
|
+
|
|
272
|
+
async def _ensure_collections_created(self):
|
|
273
|
+
"""Ensure all necessary Milvus collections are created (async)."""
|
|
274
|
+
if self._collections_created:
|
|
275
|
+
return
|
|
276
|
+
|
|
277
|
+
logger.info("Creating Milvus collections if they don't exist...")
|
|
278
|
+
await self._create_sql_collection(self.sql_collection)
|
|
279
|
+
await self._create_ddl_collection(self.ddl_collection)
|
|
280
|
+
await self._create_doc_collection(self.doc_collection)
|
|
281
|
+
self._collections_created = True
|
|
282
|
+
|
|
283
|
+
async def _create_sql_collection(self, name: str):
|
|
284
|
+
"""Create SQL collection using async client."""
|
|
285
|
+
from pymilvus import DataType
|
|
286
|
+
from pymilvus import MilvusClient
|
|
287
|
+
from pymilvus import MilvusException
|
|
288
|
+
|
|
289
|
+
# Check if collection already exists by attempting to load it
|
|
290
|
+
try:
|
|
291
|
+
await self.async_milvus_client.load_collection(collection_name=name)
|
|
292
|
+
logger.debug(f"Collection {name} already exists, skipping creation")
|
|
293
|
+
return
|
|
294
|
+
except MilvusException as e:
|
|
295
|
+
if "collection not found" not in str(e).lower():
|
|
296
|
+
raise # Unexpected error, re-raise
|
|
297
|
+
# Collection doesn't exist, proceed to create it
|
|
298
|
+
|
|
299
|
+
# Create the collection
|
|
300
|
+
schema = MilvusClient.create_schema(
|
|
301
|
+
auto_id=False,
|
|
302
|
+
enable_dynamic_field=False,
|
|
303
|
+
)
|
|
304
|
+
schema.add_field(
|
|
305
|
+
field_name="id",
|
|
306
|
+
datatype=DataType.VARCHAR,
|
|
307
|
+
is_primary=True,
|
|
308
|
+
max_length=65535,
|
|
309
|
+
)
|
|
310
|
+
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
|
|
311
|
+
schema.add_field(field_name="sql", datatype=DataType.VARCHAR, max_length=65535)
|
|
312
|
+
schema.add_field(
|
|
313
|
+
field_name="vector",
|
|
314
|
+
datatype=DataType.FLOAT_VECTOR,
|
|
315
|
+
dim=self._embedding_dim,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
index_params = MilvusClient.prepare_index_params()
|
|
319
|
+
index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="L2")
|
|
320
|
+
await self.async_milvus_client.create_collection(
|
|
321
|
+
collection_name=name,
|
|
322
|
+
schema=schema,
|
|
323
|
+
index_params=index_params,
|
|
324
|
+
consistency_level="Strong",
|
|
325
|
+
)
|
|
326
|
+
logger.info(f"Created collection: {name}")
|
|
327
|
+
|
|
328
|
+
async def _create_ddl_collection(self, name: str):
|
|
329
|
+
"""Create DDL collection using async client."""
|
|
330
|
+
from pymilvus import DataType
|
|
331
|
+
from pymilvus import MilvusClient
|
|
332
|
+
from pymilvus import MilvusException
|
|
333
|
+
|
|
334
|
+
# Check if collection already exists by attempting to load it
|
|
335
|
+
try:
|
|
336
|
+
await self.async_milvus_client.load_collection(collection_name=name)
|
|
337
|
+
logger.debug(f"Collection {name} already exists, skipping creation")
|
|
338
|
+
return
|
|
339
|
+
except MilvusException as e:
|
|
340
|
+
if "collection not found" not in str(e).lower():
|
|
341
|
+
raise # Unexpected error, re-raise
|
|
342
|
+
# Collection doesn't exist, proceed to create it
|
|
343
|
+
|
|
344
|
+
# Create the collection
|
|
345
|
+
schema = MilvusClient.create_schema(
|
|
346
|
+
auto_id=False,
|
|
347
|
+
enable_dynamic_field=False,
|
|
348
|
+
)
|
|
349
|
+
schema.add_field(
|
|
350
|
+
field_name="id",
|
|
351
|
+
datatype=DataType.VARCHAR,
|
|
352
|
+
is_primary=True,
|
|
353
|
+
max_length=65535,
|
|
354
|
+
)
|
|
355
|
+
schema.add_field(field_name="ddl", datatype=DataType.VARCHAR, max_length=65535)
|
|
356
|
+
schema.add_field(
|
|
357
|
+
field_name="vector",
|
|
358
|
+
datatype=DataType.FLOAT_VECTOR,
|
|
359
|
+
dim=self._embedding_dim,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
index_params = MilvusClient.prepare_index_params()
|
|
363
|
+
index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="L2")
|
|
364
|
+
await self.async_milvus_client.create_collection(
|
|
365
|
+
collection_name=name,
|
|
366
|
+
schema=schema,
|
|
367
|
+
index_params=index_params,
|
|
368
|
+
consistency_level="Strong",
|
|
369
|
+
)
|
|
370
|
+
logger.info(f"Created collection: {name}")
|
|
371
|
+
|
|
372
|
+
async def _create_doc_collection(self, name: str):
|
|
373
|
+
"""Create documentation collection using async client."""
|
|
374
|
+
from pymilvus import DataType
|
|
375
|
+
from pymilvus import MilvusClient
|
|
376
|
+
from pymilvus import MilvusException
|
|
377
|
+
|
|
378
|
+
# Check if collection already exists by attempting to load it
|
|
379
|
+
try:
|
|
380
|
+
await self.async_milvus_client.load_collection(collection_name=name)
|
|
381
|
+
logger.debug(f"Collection {name} already exists, skipping creation")
|
|
382
|
+
return
|
|
383
|
+
except MilvusException as e:
|
|
384
|
+
if "collection not found" not in str(e).lower():
|
|
385
|
+
raise # Unexpected error, re-raise
|
|
386
|
+
# Collection doesn't exist, proceed to create it
|
|
387
|
+
|
|
388
|
+
# Create the collection
|
|
389
|
+
schema = MilvusClient.create_schema(
|
|
390
|
+
auto_id=False,
|
|
391
|
+
enable_dynamic_field=False,
|
|
392
|
+
)
|
|
393
|
+
schema.add_field(
|
|
394
|
+
field_name="id",
|
|
395
|
+
datatype=DataType.VARCHAR,
|
|
396
|
+
is_primary=True,
|
|
397
|
+
max_length=65535,
|
|
398
|
+
)
|
|
399
|
+
schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
|
|
400
|
+
schema.add_field(
|
|
401
|
+
field_name="vector",
|
|
402
|
+
datatype=DataType.FLOAT_VECTOR,
|
|
403
|
+
dim=self._embedding_dim,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
index_params = MilvusClient.prepare_index_params()
|
|
407
|
+
index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="L2")
|
|
408
|
+
await self.async_milvus_client.create_collection(
|
|
409
|
+
collection_name=name,
|
|
410
|
+
schema=schema,
|
|
411
|
+
index_params=index_params,
|
|
412
|
+
consistency_level="Strong",
|
|
413
|
+
)
|
|
414
|
+
logger.info(f"Created collection: {name}")
|
|
415
|
+
|
|
416
|
+
async def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
417
|
+
"""Add question-SQL pair to collection using async client."""
|
|
418
|
+
if len(question) == 0 or len(sql) == 0:
|
|
419
|
+
msg = "Question and SQL cannot be empty"
|
|
420
|
+
raise ValueError(msg)
|
|
421
|
+
_id = str(uuid.uuid4()) + "-sql"
|
|
422
|
+
embedding = (await self.embedder.aembed_documents([question]))[0]
|
|
423
|
+
data = {"id": _id, "text": question, "sql": sql, "vector": embedding}
|
|
424
|
+
await self.async_milvus_client.insert(collection_name=self.sql_collection, data=data)
|
|
425
|
+
return _id
|
|
426
|
+
|
|
427
|
+
async def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
428
|
+
"""Add DDL to collection using async client."""
|
|
429
|
+
if len(ddl) == 0:
|
|
430
|
+
msg = "DDL cannot be empty"
|
|
431
|
+
raise ValueError(msg)
|
|
432
|
+
_id = str(uuid.uuid4()) + "-ddl"
|
|
433
|
+
embedding = self.embedder.embed_documents([ddl])[0]
|
|
434
|
+
await self.async_milvus_client.insert(
|
|
435
|
+
collection_name=self.ddl_collection,
|
|
436
|
+
data={
|
|
437
|
+
"id": _id, "ddl": ddl, "vector": embedding
|
|
438
|
+
},
|
|
439
|
+
)
|
|
440
|
+
return _id
|
|
441
|
+
|
|
442
|
+
async def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
443
|
+
"""Add documentation to collection using async client."""
|
|
444
|
+
if len(documentation) == 0:
|
|
445
|
+
msg = "Documentation cannot be empty"
|
|
446
|
+
raise ValueError(msg)
|
|
447
|
+
_id = str(uuid.uuid4()) + "-doc"
|
|
448
|
+
embedding = self.embedder.embed_documents([documentation])[0]
|
|
449
|
+
await self.async_milvus_client.insert(
|
|
450
|
+
collection_name=self.doc_collection,
|
|
451
|
+
data={
|
|
452
|
+
"id": _id, "doc": documentation, "vector": embedding
|
|
453
|
+
},
|
|
454
|
+
)
|
|
455
|
+
return _id
|
|
456
|
+
|
|
457
|
+
async def get_related_record(self, collection_name: str) -> list:
|
|
458
|
+
"""Retrieve all related records using async client."""
|
|
459
|
+
|
|
460
|
+
if 'ddl' in collection_name:
|
|
461
|
+
output_field = "ddl"
|
|
462
|
+
elif 'doc' in collection_name:
|
|
463
|
+
output_field = "doc"
|
|
464
|
+
else:
|
|
465
|
+
output_field = collection_name
|
|
466
|
+
|
|
467
|
+
record_list = []
|
|
468
|
+
try:
|
|
469
|
+
records = await self.async_milvus_client.query(
|
|
470
|
+
collection_name=collection_name,
|
|
471
|
+
output_fields=[output_field],
|
|
472
|
+
limit=self.milvus_search_limit,
|
|
473
|
+
)
|
|
474
|
+
for record in records:
|
|
475
|
+
record_list.append(record[output_field])
|
|
476
|
+
except Exception as e:
|
|
477
|
+
logger.exception(f"Error retrieving {collection_name}: {e}")
|
|
478
|
+
return record_list
|
|
479
|
+
|
|
480
|
+
async def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
481
|
+
"""Get similar question-SQL pairs using async client."""
|
|
482
|
+
search_params = {"metric_type": "L2", "params": {"nprobe": 128}}
|
|
483
|
+
list_sql = []
|
|
484
|
+
try:
|
|
485
|
+
# Use async embedder and async Milvus client
|
|
486
|
+
embeddings = [await self.embedder.aembed_query(question)]
|
|
487
|
+
res = await self.async_milvus_client.search(
|
|
488
|
+
collection_name=self.sql_collection,
|
|
489
|
+
anns_field="vector",
|
|
490
|
+
data=embeddings,
|
|
491
|
+
limit=self.n_results,
|
|
492
|
+
output_fields=["text", "sql"],
|
|
493
|
+
search_params=search_params,
|
|
494
|
+
)
|
|
495
|
+
res = res[0]
|
|
496
|
+
|
|
497
|
+
for doc in res:
|
|
498
|
+
entry = {
|
|
499
|
+
"question": doc["entity"]["text"],
|
|
500
|
+
"sql": doc["entity"]["sql"],
|
|
501
|
+
}
|
|
502
|
+
list_sql.append(entry)
|
|
503
|
+
|
|
504
|
+
logger.info(f"Retrieved {len(list_sql)} similar SQL examples")
|
|
505
|
+
except Exception as e:
|
|
506
|
+
logger.exception(f"Error retrieving similar questions: {e}")
|
|
507
|
+
return list_sql
|
|
508
|
+
|
|
509
|
+
async def get_training_data(self, **kwargs):
|
|
510
|
+
"""Get all training data using async client."""
|
|
511
|
+
import pandas as pd
|
|
512
|
+
|
|
513
|
+
df = pd.DataFrame()
|
|
514
|
+
|
|
515
|
+
# Get SQL data
|
|
516
|
+
sql_data = await self.async_milvus_client.query(collection_name=self.sql_collection,
|
|
517
|
+
output_fields=["*"],
|
|
518
|
+
limit=1000)
|
|
519
|
+
if sql_data:
|
|
520
|
+
df_sql = pd.DataFrame({
|
|
521
|
+
"id": [doc["id"] for doc in sql_data],
|
|
522
|
+
"question": [doc["text"] for doc in sql_data],
|
|
523
|
+
"content": [doc["sql"] for doc in sql_data],
|
|
524
|
+
})
|
|
525
|
+
df_sql["training_data_type"] = "sql"
|
|
526
|
+
df = pd.concat([df, df_sql])
|
|
527
|
+
|
|
528
|
+
# Get DDL data
|
|
529
|
+
ddl_data = await self.async_milvus_client.query(collection_name=self.ddl_collection,
|
|
530
|
+
output_fields=["*"],
|
|
531
|
+
limit=1000)
|
|
532
|
+
if ddl_data:
|
|
533
|
+
df_ddl = pd.DataFrame({
|
|
534
|
+
"id": [doc["id"] for doc in ddl_data],
|
|
535
|
+
"question": [None for doc in ddl_data],
|
|
536
|
+
"content": [doc["ddl"] for doc in ddl_data],
|
|
537
|
+
})
|
|
538
|
+
df_ddl["training_data_type"] = "ddl"
|
|
539
|
+
df = pd.concat([df, df_ddl])
|
|
540
|
+
|
|
541
|
+
# Get documentation data
|
|
542
|
+
doc_data = await self.async_milvus_client.query(collection_name=self.doc_collection,
|
|
543
|
+
output_fields=["*"],
|
|
544
|
+
limit=1000)
|
|
545
|
+
if doc_data:
|
|
546
|
+
df_doc = pd.DataFrame({
|
|
547
|
+
"id": [doc["id"] for doc in doc_data],
|
|
548
|
+
"question": [None for doc in doc_data],
|
|
549
|
+
"content": [doc["doc"] for doc in doc_data],
|
|
550
|
+
})
|
|
551
|
+
df_doc["training_data_type"] = "documentation"
|
|
552
|
+
df = pd.concat([df, df_doc])
|
|
553
|
+
|
|
554
|
+
return df
|
|
555
|
+
|
|
556
|
+
async def close(self):
|
|
557
|
+
"""Close async Milvus client connection."""
|
|
558
|
+
if hasattr(self, 'async_milvus_client') and self.async_milvus_client is not None:
|
|
559
|
+
try:
|
|
560
|
+
await self.async_milvus_client.close()
|
|
561
|
+
logger.info("Closed async Milvus client")
|
|
562
|
+
except Exception as e:
|
|
563
|
+
logger.warning(f"Error closing async Milvus client: {e}")
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
class VannaLangChain(MilvusVectorStore, VannaLangChainLLM):
|
|
567
|
+
"""Combined Vanna implementation with Milvus and LangChain LLM."""
|
|
568
|
+
|
|
569
|
+
def __init__(self, client, config=None):
|
|
570
|
+
"""Initialize VannaLangChain.
|
|
571
|
+
|
|
572
|
+
Args:
|
|
573
|
+
client: LangChain LLM client
|
|
574
|
+
config: Configuration dict for Milvus vector store and LLM settings
|
|
575
|
+
"""
|
|
576
|
+
MilvusVectorStore.__init__(self, config=config)
|
|
577
|
+
VannaLangChainLLM.__init__(self, client=client, config=config)
|
|
578
|
+
# Store database engine (if any) - lifecycle matches Vanna singleton
|
|
579
|
+
self.db_engine = None
|
|
580
|
+
|
|
581
|
+
async def generate_sql(
|
|
582
|
+
self,
|
|
583
|
+
question: str,
|
|
584
|
+
allow_llm_to_see_data: bool = False,
|
|
585
|
+
error_message: dict | None = None,
|
|
586
|
+
**kwargs,
|
|
587
|
+
) -> dict[str, str | None]:
|
|
588
|
+
"""Generate SQL using the LLM.
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
question: Natural language question to convert to SQL
|
|
592
|
+
allow_llm_to_see_data: Whether to allow LLM to see actual data
|
|
593
|
+
error_message: Optional error message from previous SQL execution
|
|
594
|
+
kwargs: Additional keyword arguments
|
|
595
|
+
|
|
596
|
+
Returns:
|
|
597
|
+
Dictionary with 'sql' and optional 'explanation' keys
|
|
598
|
+
"""
|
|
599
|
+
logger.info("Starting SQL Generation with Vanna")
|
|
600
|
+
|
|
601
|
+
# Get initial prompt from config
|
|
602
|
+
initial_prompt = self.config.get("initial_prompt", None)
|
|
603
|
+
|
|
604
|
+
# Retrieve relevant context in parallel
|
|
605
|
+
retrieval_tasks = [
|
|
606
|
+
self.get_similar_question_sql(question, **kwargs),
|
|
607
|
+
self.get_related_record(self.ddl_collection),
|
|
608
|
+
self.get_related_record(self.doc_collection),
|
|
609
|
+
]
|
|
610
|
+
|
|
611
|
+
question_sql_list, ddl_list, doc_list = await asyncio.gather(*retrieval_tasks)
|
|
612
|
+
|
|
613
|
+
# Build prompt
|
|
614
|
+
prompt = self.get_sql_prompt(
|
|
615
|
+
initial_prompt=initial_prompt,
|
|
616
|
+
question=question,
|
|
617
|
+
question_sql_list=question_sql_list,
|
|
618
|
+
ddl_list=ddl_list,
|
|
619
|
+
doc_list=doc_list,
|
|
620
|
+
error_message=error_message,
|
|
621
|
+
**kwargs,
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
llm_response = await self.submit_prompt(prompt)
|
|
625
|
+
|
|
626
|
+
# Try to extract structured JSON response (sql + explanation)
|
|
627
|
+
try:
|
|
628
|
+
llm_response_json = extract_json_from_string(llm_response)
|
|
629
|
+
sql_text = llm_response_json.get("sql", "")
|
|
630
|
+
explanation_text = llm_response_json.get("explanation")
|
|
631
|
+
except Exception:
|
|
632
|
+
# Fallback: treat entire response as SQL without explanation
|
|
633
|
+
sql_text = llm_response
|
|
634
|
+
explanation_text = None
|
|
635
|
+
|
|
636
|
+
sql = self.extract_sql(sql_text)
|
|
637
|
+
return {"sql": sql.replace("\\_", "_"), "explanation": explanation_text}
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
class VannaSingleton:
|
|
641
|
+
"""Singleton manager for Vanna instances."""
|
|
642
|
+
|
|
643
|
+
_instance: VannaLangChain | None = None
|
|
644
|
+
_lock: asyncio.Lock | None = None
|
|
645
|
+
|
|
646
|
+
@classmethod
|
|
647
|
+
def _get_lock(cls) -> asyncio.Lock:
|
|
648
|
+
"""Get or create the lock in the current event loop."""
|
|
649
|
+
if cls._lock is None:
|
|
650
|
+
cls._lock = asyncio.Lock()
|
|
651
|
+
return cls._lock
|
|
652
|
+
|
|
653
|
+
@classmethod
|
|
654
|
+
def instance(cls) -> VannaLangChain | None:
|
|
655
|
+
"""Get current instance without creating one.
|
|
656
|
+
|
|
657
|
+
Returns:
|
|
658
|
+
Current Vanna instance or None if not initialized
|
|
659
|
+
"""
|
|
660
|
+
return cls._instance
|
|
661
|
+
|
|
662
|
+
@classmethod
|
|
663
|
+
async def get_instance(
|
|
664
|
+
cls,
|
|
665
|
+
llm_client,
|
|
666
|
+
embedder_client,
|
|
667
|
+
async_milvus_client,
|
|
668
|
+
dialect: str = "SQLite",
|
|
669
|
+
initial_prompt: str | None = None,
|
|
670
|
+
n_results: int = 5,
|
|
671
|
+
sql_collection: str = "vanna_sql",
|
|
672
|
+
ddl_collection: str = "vanna_ddl",
|
|
673
|
+
doc_collection: str = "vanna_documentation",
|
|
674
|
+
milvus_search_limit: int = 1000,
|
|
675
|
+
reasoning_models: set[str] | None = None,
|
|
676
|
+
chat_models: set[str] | None = None,
|
|
677
|
+
create_collections: bool = True,
|
|
678
|
+
) -> VannaLangChain:
|
|
679
|
+
"""Get or create a singleton Vanna instance.
|
|
680
|
+
|
|
681
|
+
Args:
|
|
682
|
+
llm_client: LangChain LLM client for SQL generation
|
|
683
|
+
embedder_client: LangChain embedder for vector operations
|
|
684
|
+
async_milvus_client: Async Milvus client
|
|
685
|
+
dialect: SQL dialect (e.g., 'databricks', 'postgres', 'mysql')
|
|
686
|
+
initial_prompt: Optional custom system prompt
|
|
687
|
+
n_results: Number of similar examples to retrieve
|
|
688
|
+
sql_collection: Collection name for SQL examples
|
|
689
|
+
ddl_collection: Collection name for DDL
|
|
690
|
+
doc_collection: Collection name for documentation
|
|
691
|
+
milvus_search_limit: Maximum limit size for vector search operations
|
|
692
|
+
reasoning_models: Models requiring special handling for think tags
|
|
693
|
+
chat_models: Models using standard response handling
|
|
694
|
+
create_collections: Whether to create Milvus collections if they don't exist (default True)
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
Initialized Vanna instance
|
|
698
|
+
"""
|
|
699
|
+
logger.info("Setting up Vanna instance...")
|
|
700
|
+
|
|
701
|
+
# Fast path - return existing instance
|
|
702
|
+
if cls._instance is not None:
|
|
703
|
+
logger.info("Vanna instance already exists")
|
|
704
|
+
return cls._instance
|
|
705
|
+
|
|
706
|
+
# Slow path - create new instance
|
|
707
|
+
async with cls._get_lock():
|
|
708
|
+
# Double check after acquiring lock
|
|
709
|
+
if cls._instance is not None:
|
|
710
|
+
logger.info("Vanna instance already exists")
|
|
711
|
+
return cls._instance
|
|
712
|
+
|
|
713
|
+
config = {
|
|
714
|
+
"async_milvus_client": async_milvus_client,
|
|
715
|
+
"embedder_client": embedder_client,
|
|
716
|
+
"dialect": dialect,
|
|
717
|
+
"initial_prompt": initial_prompt,
|
|
718
|
+
"n_results": n_results,
|
|
719
|
+
"sql_collection": sql_collection,
|
|
720
|
+
"ddl_collection": ddl_collection,
|
|
721
|
+
"doc_collection": doc_collection,
|
|
722
|
+
"milvus_search_limit": milvus_search_limit,
|
|
723
|
+
"reasoning_models": reasoning_models,
|
|
724
|
+
"chat_models": chat_models,
|
|
725
|
+
"create_collections": create_collections,
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
logger.info(f"Creating new Vanna instance with LangChain (dialect: {dialect})")
|
|
729
|
+
cls._instance = VannaLangChain(client=llm_client, config=config)
|
|
730
|
+
|
|
731
|
+
# Create collections if requested
|
|
732
|
+
if create_collections:
|
|
733
|
+
await cls._instance._ensure_collections_created() # type: ignore[attr-defined]
|
|
734
|
+
|
|
735
|
+
return cls._instance
|
|
736
|
+
|
|
737
|
+
@classmethod
|
|
738
|
+
async def reset(cls):
|
|
739
|
+
"""Reset the singleton Vanna instance.
|
|
740
|
+
|
|
741
|
+
Useful for testing or when configuration changes.
|
|
742
|
+
Properly disposes of database engine if present.
|
|
743
|
+
"""
|
|
744
|
+
if cls._instance is not None:
|
|
745
|
+
try:
|
|
746
|
+
# Dispose database engine if present
|
|
747
|
+
if hasattr(cls._instance, "db_engine") and cls._instance.db_engine is not None:
|
|
748
|
+
try:
|
|
749
|
+
cls._instance.db_engine.dispose()
|
|
750
|
+
logger.info("Disposed database engine pool")
|
|
751
|
+
except Exception as e:
|
|
752
|
+
logger.warning(f"Error disposing database engine: {e}")
|
|
753
|
+
|
|
754
|
+
await cls._instance.close()
|
|
755
|
+
except Exception as e:
|
|
756
|
+
logger.warning(f"Error closing Vanna instance: {e}")
|
|
757
|
+
cls._instance = None
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
async def train_vanna(vn: VannaLangChain, auto_train: bool = False):
|
|
761
|
+
"""Train Vanna with DDL, documentation, and question-SQL examples.
|
|
762
|
+
|
|
763
|
+
Args:
|
|
764
|
+
vn: Vanna instance
|
|
765
|
+
auto_train: Whether to automatically train Vanna (auto-extract DDL and generate training data from database)
|
|
766
|
+
"""
|
|
767
|
+
logger.info("Training Vanna...")
|
|
768
|
+
|
|
769
|
+
# Train with DDL
|
|
770
|
+
if auto_train:
|
|
771
|
+
from nat.plugins.vanna.training_db_schema import VANNA_ACTIVE_TABLES
|
|
772
|
+
|
|
773
|
+
dialect = vn.dialect.lower()
|
|
774
|
+
ddls = []
|
|
775
|
+
|
|
776
|
+
if dialect == 'databricks':
|
|
777
|
+
for table in VANNA_ACTIVE_TABLES:
|
|
778
|
+
ddl_sql = f"SHOW CREATE TABLE {table}"
|
|
779
|
+
ddl = await vn.run_sql(ddl_sql)
|
|
780
|
+
ddl = ddl.to_string() # Convert DataFrame to string
|
|
781
|
+
ddls.append(ddl)
|
|
782
|
+
else:
|
|
783
|
+
error_msg = (f"Auto-extraction of DDL is currently only supported for Databricks. "
|
|
784
|
+
f"Current dialect: {vn.dialect}. "
|
|
785
|
+
"Please either set auto_train=False or use 'databricks' as the dialect.")
|
|
786
|
+
logger.error(error_msg)
|
|
787
|
+
raise NotImplementedError(error_msg)
|
|
788
|
+
else:
|
|
789
|
+
ddls = VANNA_TRAINING_DDL
|
|
790
|
+
|
|
791
|
+
for ddl in ddls:
|
|
792
|
+
await vn.add_ddl(ddl=ddl)
|
|
793
|
+
|
|
794
|
+
# Train with documentation
|
|
795
|
+
for doc in VANNA_TRAINING_DOCUMENTATION:
|
|
796
|
+
await vn.add_documentation(documentation=doc)
|
|
797
|
+
|
|
798
|
+
# Train with examples
|
|
799
|
+
# Add manual examples
|
|
800
|
+
examples = []
|
|
801
|
+
examples.extend(VANNA_TRAINING_EXAMPLES)
|
|
802
|
+
|
|
803
|
+
if auto_train:
|
|
804
|
+
logger.info("Generating training examples with LLM...")
|
|
805
|
+
# Retrieve relevant context in parallel
|
|
806
|
+
retrieval_tasks = [vn.get_related_record(vn.ddl_collection), vn.get_related_record(vn.doc_collection)]
|
|
807
|
+
|
|
808
|
+
ddl_list, doc_list = await asyncio.gather(*retrieval_tasks)
|
|
809
|
+
|
|
810
|
+
prompt = vn.get_training_sql_prompt(
|
|
811
|
+
ddl_list=ddl_list,
|
|
812
|
+
doc_list=doc_list,
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
llm_response = await vn.submit_prompt(prompt)
|
|
816
|
+
|
|
817
|
+
# Validate LLM-generated examples
|
|
818
|
+
try:
|
|
819
|
+
question_sql_list = extract_json_from_string(llm_response)
|
|
820
|
+
for question_sql in question_sql_list:
|
|
821
|
+
sql = question_sql.get("sql", "")
|
|
822
|
+
if not sql:
|
|
823
|
+
continue
|
|
824
|
+
try:
|
|
825
|
+
await vn.run_sql(sql)
|
|
826
|
+
examples.append({
|
|
827
|
+
"question": question_sql.get("question", ""),
|
|
828
|
+
"sql": sql,
|
|
829
|
+
})
|
|
830
|
+
log_msg = f"Adding valid LLM-generated Question-SQL:\n{question_sql.get('question', '')}\n{sql}"
|
|
831
|
+
logger.info(log_msg)
|
|
832
|
+
except Exception as e:
|
|
833
|
+
logger.debug(f"Dropping invalid LLM-generated SQL: {e}")
|
|
834
|
+
except Exception as e:
|
|
835
|
+
logger.warning(f"Failed to parse LLM response for training examples: {e}")
|
|
836
|
+
|
|
837
|
+
# Train with validated examples
|
|
838
|
+
logger.info(f"Training Vanna with {len(examples)} validated examples")
|
|
839
|
+
for example in examples:
|
|
840
|
+
await vn.add_question_sql(question=example["question"], sql=example["sql"])
|
|
841
|
+
df = await vn.get_training_data()
|
|
842
|
+
df.to_csv("vanna_training_data.csv", index=False)
|
|
843
|
+
logger.info("Vanna training complete")
|