nvidia-nat-vanna 1.5.0a20260115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nvidia-nat-vanna might be problematic. Click here for more details.

@@ -0,0 +1,843 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import json
18
+ import logging
19
+ import uuid
20
+
21
+ from nat.plugins.vanna.training_db_schema import VANNA_RESPONSE_GUIDELINES
22
+ from nat.plugins.vanna.training_db_schema import VANNA_TRAINING_DDL
23
+ from nat.plugins.vanna.training_db_schema import VANNA_TRAINING_DOCUMENTATION
24
+ from nat.plugins.vanna.training_db_schema import VANNA_TRAINING_EXAMPLES
25
+ from nat.plugins.vanna.training_db_schema import VANNA_TRAINING_PROMPT
26
+ from vanna.legacy.base import VannaBase
27
+ from vanna.legacy.milvus import Milvus_VectorStore
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def extract_json_from_string(content: str) -> dict:
33
+ """Extract JSON from a string that may contain additional content.
34
+
35
+ Args:
36
+ content: String containing JSON data
37
+
38
+ Returns:
39
+ Parsed JSON as dictionary
40
+
41
+ Raises:
42
+ ValueError: If no valid JSON found
43
+ """
44
+ try:
45
+ return json.loads(content)
46
+ except json.JSONDecodeError:
47
+ try:
48
+ # Extract JSON from string that may contain additional content
49
+ json_str = content
50
+ # Try to find JSON between ``` markers
51
+ if "```" in content:
52
+ json_start = content.find("```")
53
+ if json_start != -1:
54
+ json_start += len("```")
55
+ json_end = content.find("```", json_start)
56
+ if json_end != -1:
57
+ json_str = content[json_start:json_end]
58
+ else:
59
+ msg = "No JSON found in response"
60
+ raise ValueError(msg)
61
+ else:
62
+ json_start = content.find("{")
63
+ json_end = content.rfind("}") + 1
64
+ json_str = content[json_start:json_end]
65
+
66
+ return json.loads(json_str.strip())
67
+ except (json.JSONDecodeError, ValueError) as e:
68
+ logger.error(f"Failed to extract JSON from content: {e}")
69
+ raise ValueError("Could not extract valid JSON from response") from e
70
+
71
+
72
+ def remove_think_tags(text: str, model_name: str, reasoning_models: set[str]) -> str:
73
+ """Remove think tags from reasoning model output based on model type.
74
+
75
+ Args:
76
+ text: Text potentially containing think tags
77
+ model_name: Name of the model
78
+ reasoning_models: Set of model names that require think tag removal
79
+
80
+ Returns:
81
+ Text with think tags removed if applicable
82
+ """
83
+ if "openai/gpt-oss" in model_name:
84
+ return text
85
+ elif model_name in reasoning_models:
86
+ from nat.utils.io.model_processing import remove_r1_think_tags
87
+
88
+ return remove_r1_think_tags(text)
89
+ else:
90
+ return text
91
+
92
+
93
+ def to_langchain_msgs(msgs):
94
+ """Convert message dicts to LangChain message objects."""
95
+ from langchain_core.messages import AIMessage
96
+ from langchain_core.messages import HumanMessage
97
+ from langchain_core.messages import SystemMessage
98
+
99
+ role2cls = {"system": SystemMessage, "user": HumanMessage, "assistant": AIMessage}
100
+ return [role2cls[m["role"]](content=m["content"]) for m in msgs]
101
+
102
+
103
+ class VannaLangChainLLM(VannaBase):
104
+ """LangChain LLM integration for Vanna framework."""
105
+
106
+ def __init__(self, client=None, config=None):
107
+ if client is None:
108
+ msg = "LangChain client must be provided"
109
+ raise ValueError(msg)
110
+
111
+ self.client = client
112
+ self.config = config or {}
113
+ self.dialect = self.config.get("dialect", "SQL")
114
+ self.model = getattr(self.client, "model", "unknown")
115
+
116
+ # Store configurable values
117
+ self.milvus_search_limit = self.config.get("milvus_search_limit", 1000)
118
+ self.reasoning_models = self.config["reasoning_models"]
119
+ self.chat_models = self.config["chat_models"]
120
+
121
+ def system_message(self, message: str) -> dict:
122
+ """Create system message."""
123
+ return {"role": "system", "content": message}
124
+
125
+ def user_message(self, message: str) -> dict:
126
+ """Create user message."""
127
+ return {"role": "user", "content": message}
128
+
129
+ def assistant_message(self, message: str) -> dict:
130
+ """Create assistant message."""
131
+ return {"role": "assistant", "content": message}
132
+
133
+ def get_training_sql_prompt(
134
+ self,
135
+ ddl_list: list,
136
+ doc_list: list,
137
+ ) -> list:
138
+ """Generate prompt for synthetic question-SQL pairs."""
139
+ initial_prompt = (f"You are a {self.dialect} expert. "
140
+ "Please generate diverse question-SQL pairs where each SQL "
141
+ "statement starts with either `SELECT` or `WITH`. "
142
+ "Your response should follow the response guidelines and format instructions.")
143
+
144
+ # Add DDL information
145
+ initial_prompt = self.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=self.max_tokens)
146
+
147
+ # Add documentation
148
+ if self.static_documentation != "":
149
+ doc_list.append(self.static_documentation)
150
+
151
+ initial_prompt = self.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=self.max_tokens)
152
+
153
+ # Add response guidelines
154
+ initial_prompt += VANNA_TRAINING_PROMPT
155
+
156
+ # Build message log
157
+ message_log = [self.system_message(initial_prompt)]
158
+ message_log.append(self.user_message('Begin:'))
159
+ return message_log
160
+
161
+ def get_sql_prompt(
162
+ self,
163
+ initial_prompt: str | None,
164
+ question: str,
165
+ question_sql_list: list,
166
+ ddl_list: list,
167
+ doc_list: list,
168
+ error_message: dict | None = None,
169
+ **kwargs,
170
+ ) -> list:
171
+ """Generate prompt for SQL generation."""
172
+ if initial_prompt is None:
173
+ initial_prompt = (f"You are a {self.dialect} expert. "
174
+ "Please help to generate a SQL query to answer the question. "
175
+ "Your response should ONLY be based on the given context "
176
+ "and follow the response guidelines and format instructions.")
177
+
178
+ # Add DDL information
179
+ initial_prompt = self.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=self.max_tokens)
180
+
181
+ # Add documentation
182
+ if self.static_documentation != "":
183
+ doc_list.append(self.static_documentation)
184
+
185
+ initial_prompt = self.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=self.max_tokens)
186
+
187
+ # Add response guidelines
188
+ initial_prompt += VANNA_RESPONSE_GUIDELINES
189
+ initial_prompt += (f"3. Ensure that the output SQL is {self.dialect}-compliant "
190
+ "and executable, and free of syntax errors.\n")
191
+
192
+ # Add error message if provided
193
+ if error_message is not None:
194
+ initial_prompt += (f"4. For question: {question}. "
195
+ "\tPrevious SQL attempt failed with error: "
196
+ f"{error_message['sql_error']}\n"
197
+ f"\tPrevious SQL was: {error_message['previous_sql']}\n"
198
+ "\tPlease fix the SQL syntax/logic error and regenerate.")
199
+
200
+ # Build message log with examples
201
+ message_log = [self.system_message(initial_prompt)]
202
+
203
+ for example in question_sql_list:
204
+ if example and "question" in example and "sql" in example:
205
+ message_log.append(self.user_message(example["question"]))
206
+ message_log.append(self.assistant_message(example["sql"]))
207
+
208
+ message_log.append(self.user_message(question))
209
+ return message_log
210
+
211
+ async def submit_prompt(self, prompt, **kwargs) -> str:
212
+ """Submit prompt to LLM."""
213
+ try:
214
+ # Determine model name
215
+ llm_name = getattr(self.client, 'model_name', None) or getattr(self.client, 'model', 'unknown')
216
+
217
+ # Get LLM response (with streaming for reasoning models)
218
+ if llm_name in self.reasoning_models:
219
+ llm_output = ""
220
+ async for chunk in self.client.astream(prompt):
221
+ llm_output += chunk.content
222
+ llm_response = remove_think_tags(llm_output, llm_name, self.reasoning_models)
223
+ else:
224
+ llm_response = (await self.client.ainvoke(prompt)).content
225
+
226
+ logger.debug(f"LLM Response: {llm_response}")
227
+ return llm_response
228
+
229
+ except Exception as e:
230
+ logger.error(f"Error calling LLM during SQL query generation: {e}")
231
+ raise
232
+
233
+
234
+ class MilvusVectorStore(Milvus_VectorStore):
235
+ """Extended Milvus vector store for Vanna."""
236
+
237
+ def __init__(self, config=None):
238
+ try:
239
+ VannaBase.__init__(self, config=config)
240
+
241
+ # Only use async client
242
+ self.async_milvus_client = config["async_milvus_client"]
243
+ self.n_results = config.get("n_results", 5)
244
+ self.milvus_search_limit = config.get("milvus_search_limit", 1000)
245
+
246
+ # Use configured embedder
247
+ if config.get("embedder_client") is not None:
248
+ logger.info("Using configured embedder client")
249
+ self.embedder = config["embedder_client"]
250
+ else:
251
+ msg = "Embedder client must be provided in config"
252
+ raise ValueError(msg)
253
+
254
+ try:
255
+ self._embedding_dim = len(self.embedder.embed_documents(["test"])[0])
256
+ logger.info(f"Embedding dimension: {self._embedding_dim}")
257
+ except Exception as e:
258
+ logger.error(f"Error calling embedder during Milvus initialization: {e}")
259
+ raise
260
+
261
+ # Collection names
262
+ self.sql_collection = config.get("sql_collection", "vanna_sql")
263
+ self.ddl_collection = config.get("ddl_collection", "vanna_ddl")
264
+ self.doc_collection = config.get("doc_collection", "vanna_documentation")
265
+
266
+ # Collection creation tracking
267
+ self._collections_created = False
268
+ except Exception as e:
269
+ logger.error(f"Error initializing MilvusVectorStore: {e}")
270
+ raise
271
+
272
+ async def _ensure_collections_created(self):
273
+ """Ensure all necessary Milvus collections are created (async)."""
274
+ if self._collections_created:
275
+ return
276
+
277
+ logger.info("Creating Milvus collections if they don't exist...")
278
+ await self._create_sql_collection(self.sql_collection)
279
+ await self._create_ddl_collection(self.ddl_collection)
280
+ await self._create_doc_collection(self.doc_collection)
281
+ self._collections_created = True
282
+
283
+ async def _create_sql_collection(self, name: str):
284
+ """Create SQL collection using async client."""
285
+ from pymilvus import DataType
286
+ from pymilvus import MilvusClient
287
+ from pymilvus import MilvusException
288
+
289
+ # Check if collection already exists by attempting to load it
290
+ try:
291
+ await self.async_milvus_client.load_collection(collection_name=name)
292
+ logger.debug(f"Collection {name} already exists, skipping creation")
293
+ return
294
+ except MilvusException as e:
295
+ if "collection not found" not in str(e).lower():
296
+ raise # Unexpected error, re-raise
297
+ # Collection doesn't exist, proceed to create it
298
+
299
+ # Create the collection
300
+ schema = MilvusClient.create_schema(
301
+ auto_id=False,
302
+ enable_dynamic_field=False,
303
+ )
304
+ schema.add_field(
305
+ field_name="id",
306
+ datatype=DataType.VARCHAR,
307
+ is_primary=True,
308
+ max_length=65535,
309
+ )
310
+ schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
311
+ schema.add_field(field_name="sql", datatype=DataType.VARCHAR, max_length=65535)
312
+ schema.add_field(
313
+ field_name="vector",
314
+ datatype=DataType.FLOAT_VECTOR,
315
+ dim=self._embedding_dim,
316
+ )
317
+
318
+ index_params = MilvusClient.prepare_index_params()
319
+ index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="L2")
320
+ await self.async_milvus_client.create_collection(
321
+ collection_name=name,
322
+ schema=schema,
323
+ index_params=index_params,
324
+ consistency_level="Strong",
325
+ )
326
+ logger.info(f"Created collection: {name}")
327
+
328
+ async def _create_ddl_collection(self, name: str):
329
+ """Create DDL collection using async client."""
330
+ from pymilvus import DataType
331
+ from pymilvus import MilvusClient
332
+ from pymilvus import MilvusException
333
+
334
+ # Check if collection already exists by attempting to load it
335
+ try:
336
+ await self.async_milvus_client.load_collection(collection_name=name)
337
+ logger.debug(f"Collection {name} already exists, skipping creation")
338
+ return
339
+ except MilvusException as e:
340
+ if "collection not found" not in str(e).lower():
341
+ raise # Unexpected error, re-raise
342
+ # Collection doesn't exist, proceed to create it
343
+
344
+ # Create the collection
345
+ schema = MilvusClient.create_schema(
346
+ auto_id=False,
347
+ enable_dynamic_field=False,
348
+ )
349
+ schema.add_field(
350
+ field_name="id",
351
+ datatype=DataType.VARCHAR,
352
+ is_primary=True,
353
+ max_length=65535,
354
+ )
355
+ schema.add_field(field_name="ddl", datatype=DataType.VARCHAR, max_length=65535)
356
+ schema.add_field(
357
+ field_name="vector",
358
+ datatype=DataType.FLOAT_VECTOR,
359
+ dim=self._embedding_dim,
360
+ )
361
+
362
+ index_params = MilvusClient.prepare_index_params()
363
+ index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="L2")
364
+ await self.async_milvus_client.create_collection(
365
+ collection_name=name,
366
+ schema=schema,
367
+ index_params=index_params,
368
+ consistency_level="Strong",
369
+ )
370
+ logger.info(f"Created collection: {name}")
371
+
372
+ async def _create_doc_collection(self, name: str):
373
+ """Create documentation collection using async client."""
374
+ from pymilvus import DataType
375
+ from pymilvus import MilvusClient
376
+ from pymilvus import MilvusException
377
+
378
+ # Check if collection already exists by attempting to load it
379
+ try:
380
+ await self.async_milvus_client.load_collection(collection_name=name)
381
+ logger.debug(f"Collection {name} already exists, skipping creation")
382
+ return
383
+ except MilvusException as e:
384
+ if "collection not found" not in str(e).lower():
385
+ raise # Unexpected error, re-raise
386
+ # Collection doesn't exist, proceed to create it
387
+
388
+ # Create the collection
389
+ schema = MilvusClient.create_schema(
390
+ auto_id=False,
391
+ enable_dynamic_field=False,
392
+ )
393
+ schema.add_field(
394
+ field_name="id",
395
+ datatype=DataType.VARCHAR,
396
+ is_primary=True,
397
+ max_length=65535,
398
+ )
399
+ schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
400
+ schema.add_field(
401
+ field_name="vector",
402
+ datatype=DataType.FLOAT_VECTOR,
403
+ dim=self._embedding_dim,
404
+ )
405
+
406
+ index_params = MilvusClient.prepare_index_params()
407
+ index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="L2")
408
+ await self.async_milvus_client.create_collection(
409
+ collection_name=name,
410
+ schema=schema,
411
+ index_params=index_params,
412
+ consistency_level="Strong",
413
+ )
414
+ logger.info(f"Created collection: {name}")
415
+
416
+ async def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
417
+ """Add question-SQL pair to collection using async client."""
418
+ if len(question) == 0 or len(sql) == 0:
419
+ msg = "Question and SQL cannot be empty"
420
+ raise ValueError(msg)
421
+ _id = str(uuid.uuid4()) + "-sql"
422
+ embedding = (await self.embedder.aembed_documents([question]))[0]
423
+ data = {"id": _id, "text": question, "sql": sql, "vector": embedding}
424
+ await self.async_milvus_client.insert(collection_name=self.sql_collection, data=data)
425
+ return _id
426
+
427
+ async def add_ddl(self, ddl: str, **kwargs) -> str:
428
+ """Add DDL to collection using async client."""
429
+ if len(ddl) == 0:
430
+ msg = "DDL cannot be empty"
431
+ raise ValueError(msg)
432
+ _id = str(uuid.uuid4()) + "-ddl"
433
+ embedding = self.embedder.embed_documents([ddl])[0]
434
+ await self.async_milvus_client.insert(
435
+ collection_name=self.ddl_collection,
436
+ data={
437
+ "id": _id, "ddl": ddl, "vector": embedding
438
+ },
439
+ )
440
+ return _id
441
+
442
+ async def add_documentation(self, documentation: str, **kwargs) -> str:
443
+ """Add documentation to collection using async client."""
444
+ if len(documentation) == 0:
445
+ msg = "Documentation cannot be empty"
446
+ raise ValueError(msg)
447
+ _id = str(uuid.uuid4()) + "-doc"
448
+ embedding = self.embedder.embed_documents([documentation])[0]
449
+ await self.async_milvus_client.insert(
450
+ collection_name=self.doc_collection,
451
+ data={
452
+ "id": _id, "doc": documentation, "vector": embedding
453
+ },
454
+ )
455
+ return _id
456
+
457
+ async def get_related_record(self, collection_name: str) -> list:
458
+ """Retrieve all related records using async client."""
459
+
460
+ if 'ddl' in collection_name:
461
+ output_field = "ddl"
462
+ elif 'doc' in collection_name:
463
+ output_field = "doc"
464
+ else:
465
+ output_field = collection_name
466
+
467
+ record_list = []
468
+ try:
469
+ records = await self.async_milvus_client.query(
470
+ collection_name=collection_name,
471
+ output_fields=[output_field],
472
+ limit=self.milvus_search_limit,
473
+ )
474
+ for record in records:
475
+ record_list.append(record[output_field])
476
+ except Exception as e:
477
+ logger.exception(f"Error retrieving {collection_name}: {e}")
478
+ return record_list
479
+
480
+ async def get_similar_question_sql(self, question: str, **kwargs) -> list:
481
+ """Get similar question-SQL pairs using async client."""
482
+ search_params = {"metric_type": "L2", "params": {"nprobe": 128}}
483
+ list_sql = []
484
+ try:
485
+ # Use async embedder and async Milvus client
486
+ embeddings = [await self.embedder.aembed_query(question)]
487
+ res = await self.async_milvus_client.search(
488
+ collection_name=self.sql_collection,
489
+ anns_field="vector",
490
+ data=embeddings,
491
+ limit=self.n_results,
492
+ output_fields=["text", "sql"],
493
+ search_params=search_params,
494
+ )
495
+ res = res[0]
496
+
497
+ for doc in res:
498
+ entry = {
499
+ "question": doc["entity"]["text"],
500
+ "sql": doc["entity"]["sql"],
501
+ }
502
+ list_sql.append(entry)
503
+
504
+ logger.info(f"Retrieved {len(list_sql)} similar SQL examples")
505
+ except Exception as e:
506
+ logger.exception(f"Error retrieving similar questions: {e}")
507
+ return list_sql
508
+
509
+ async def get_training_data(self, **kwargs):
510
+ """Get all training data using async client."""
511
+ import pandas as pd
512
+
513
+ df = pd.DataFrame()
514
+
515
+ # Get SQL data
516
+ sql_data = await self.async_milvus_client.query(collection_name=self.sql_collection,
517
+ output_fields=["*"],
518
+ limit=1000)
519
+ if sql_data:
520
+ df_sql = pd.DataFrame({
521
+ "id": [doc["id"] for doc in sql_data],
522
+ "question": [doc["text"] for doc in sql_data],
523
+ "content": [doc["sql"] for doc in sql_data],
524
+ })
525
+ df_sql["training_data_type"] = "sql"
526
+ df = pd.concat([df, df_sql])
527
+
528
+ # Get DDL data
529
+ ddl_data = await self.async_milvus_client.query(collection_name=self.ddl_collection,
530
+ output_fields=["*"],
531
+ limit=1000)
532
+ if ddl_data:
533
+ df_ddl = pd.DataFrame({
534
+ "id": [doc["id"] for doc in ddl_data],
535
+ "question": [None for doc in ddl_data],
536
+ "content": [doc["ddl"] for doc in ddl_data],
537
+ })
538
+ df_ddl["training_data_type"] = "ddl"
539
+ df = pd.concat([df, df_ddl])
540
+
541
+ # Get documentation data
542
+ doc_data = await self.async_milvus_client.query(collection_name=self.doc_collection,
543
+ output_fields=["*"],
544
+ limit=1000)
545
+ if doc_data:
546
+ df_doc = pd.DataFrame({
547
+ "id": [doc["id"] for doc in doc_data],
548
+ "question": [None for doc in doc_data],
549
+ "content": [doc["doc"] for doc in doc_data],
550
+ })
551
+ df_doc["training_data_type"] = "documentation"
552
+ df = pd.concat([df, df_doc])
553
+
554
+ return df
555
+
556
+ async def close(self):
557
+ """Close async Milvus client connection."""
558
+ if hasattr(self, 'async_milvus_client') and self.async_milvus_client is not None:
559
+ try:
560
+ await self.async_milvus_client.close()
561
+ logger.info("Closed async Milvus client")
562
+ except Exception as e:
563
+ logger.warning(f"Error closing async Milvus client: {e}")
564
+
565
+
566
+ class VannaLangChain(MilvusVectorStore, VannaLangChainLLM):
567
+ """Combined Vanna implementation with Milvus and LangChain LLM."""
568
+
569
+ def __init__(self, client, config=None):
570
+ """Initialize VannaLangChain.
571
+
572
+ Args:
573
+ client: LangChain LLM client
574
+ config: Configuration dict for Milvus vector store and LLM settings
575
+ """
576
+ MilvusVectorStore.__init__(self, config=config)
577
+ VannaLangChainLLM.__init__(self, client=client, config=config)
578
+ # Store database engine (if any) - lifecycle matches Vanna singleton
579
+ self.db_engine = None
580
+
581
+ async def generate_sql(
582
+ self,
583
+ question: str,
584
+ allow_llm_to_see_data: bool = False,
585
+ error_message: dict | None = None,
586
+ **kwargs,
587
+ ) -> dict[str, str | None]:
588
+ """Generate SQL using the LLM.
589
+
590
+ Args:
591
+ question: Natural language question to convert to SQL
592
+ allow_llm_to_see_data: Whether to allow LLM to see actual data
593
+ error_message: Optional error message from previous SQL execution
594
+ kwargs: Additional keyword arguments
595
+
596
+ Returns:
597
+ Dictionary with 'sql' and optional 'explanation' keys
598
+ """
599
+ logger.info("Starting SQL Generation with Vanna")
600
+
601
+ # Get initial prompt from config
602
+ initial_prompt = self.config.get("initial_prompt", None)
603
+
604
+ # Retrieve relevant context in parallel
605
+ retrieval_tasks = [
606
+ self.get_similar_question_sql(question, **kwargs),
607
+ self.get_related_record(self.ddl_collection),
608
+ self.get_related_record(self.doc_collection),
609
+ ]
610
+
611
+ question_sql_list, ddl_list, doc_list = await asyncio.gather(*retrieval_tasks)
612
+
613
+ # Build prompt
614
+ prompt = self.get_sql_prompt(
615
+ initial_prompt=initial_prompt,
616
+ question=question,
617
+ question_sql_list=question_sql_list,
618
+ ddl_list=ddl_list,
619
+ doc_list=doc_list,
620
+ error_message=error_message,
621
+ **kwargs,
622
+ )
623
+
624
+ llm_response = await self.submit_prompt(prompt)
625
+
626
+ # Try to extract structured JSON response (sql + explanation)
627
+ try:
628
+ llm_response_json = extract_json_from_string(llm_response)
629
+ sql_text = llm_response_json.get("sql", "")
630
+ explanation_text = llm_response_json.get("explanation")
631
+ except Exception:
632
+ # Fallback: treat entire response as SQL without explanation
633
+ sql_text = llm_response
634
+ explanation_text = None
635
+
636
+ sql = self.extract_sql(sql_text)
637
+ return {"sql": sql.replace("\\_", "_"), "explanation": explanation_text}
638
+
639
+
640
+ class VannaSingleton:
641
+ """Singleton manager for Vanna instances."""
642
+
643
+ _instance: VannaLangChain | None = None
644
+ _lock: asyncio.Lock | None = None
645
+
646
+ @classmethod
647
+ def _get_lock(cls) -> asyncio.Lock:
648
+ """Get or create the lock in the current event loop."""
649
+ if cls._lock is None:
650
+ cls._lock = asyncio.Lock()
651
+ return cls._lock
652
+
653
+ @classmethod
654
+ def instance(cls) -> VannaLangChain | None:
655
+ """Get current instance without creating one.
656
+
657
+ Returns:
658
+ Current Vanna instance or None if not initialized
659
+ """
660
+ return cls._instance
661
+
662
+ @classmethod
663
+ async def get_instance(
664
+ cls,
665
+ llm_client,
666
+ embedder_client,
667
+ async_milvus_client,
668
+ dialect: str = "SQLite",
669
+ initial_prompt: str | None = None,
670
+ n_results: int = 5,
671
+ sql_collection: str = "vanna_sql",
672
+ ddl_collection: str = "vanna_ddl",
673
+ doc_collection: str = "vanna_documentation",
674
+ milvus_search_limit: int = 1000,
675
+ reasoning_models: set[str] | None = None,
676
+ chat_models: set[str] | None = None,
677
+ create_collections: bool = True,
678
+ ) -> VannaLangChain:
679
+ """Get or create a singleton Vanna instance.
680
+
681
+ Args:
682
+ llm_client: LangChain LLM client for SQL generation
683
+ embedder_client: LangChain embedder for vector operations
684
+ async_milvus_client: Async Milvus client
685
+ dialect: SQL dialect (e.g., 'databricks', 'postgres', 'mysql')
686
+ initial_prompt: Optional custom system prompt
687
+ n_results: Number of similar examples to retrieve
688
+ sql_collection: Collection name for SQL examples
689
+ ddl_collection: Collection name for DDL
690
+ doc_collection: Collection name for documentation
691
+ milvus_search_limit: Maximum limit size for vector search operations
692
+ reasoning_models: Models requiring special handling for think tags
693
+ chat_models: Models using standard response handling
694
+ create_collections: Whether to create Milvus collections if they don't exist (default True)
695
+
696
+ Returns:
697
+ Initialized Vanna instance
698
+ """
699
+ logger.info("Setting up Vanna instance...")
700
+
701
+ # Fast path - return existing instance
702
+ if cls._instance is not None:
703
+ logger.info("Vanna instance already exists")
704
+ return cls._instance
705
+
706
+ # Slow path - create new instance
707
+ async with cls._get_lock():
708
+ # Double check after acquiring lock
709
+ if cls._instance is not None:
710
+ logger.info("Vanna instance already exists")
711
+ return cls._instance
712
+
713
+ config = {
714
+ "async_milvus_client": async_milvus_client,
715
+ "embedder_client": embedder_client,
716
+ "dialect": dialect,
717
+ "initial_prompt": initial_prompt,
718
+ "n_results": n_results,
719
+ "sql_collection": sql_collection,
720
+ "ddl_collection": ddl_collection,
721
+ "doc_collection": doc_collection,
722
+ "milvus_search_limit": milvus_search_limit,
723
+ "reasoning_models": reasoning_models,
724
+ "chat_models": chat_models,
725
+ "create_collections": create_collections,
726
+ }
727
+
728
+ logger.info(f"Creating new Vanna instance with LangChain (dialect: {dialect})")
729
+ cls._instance = VannaLangChain(client=llm_client, config=config)
730
+
731
+ # Create collections if requested
732
+ if create_collections:
733
+ await cls._instance._ensure_collections_created() # type: ignore[attr-defined]
734
+
735
+ return cls._instance
736
+
737
+ @classmethod
738
+ async def reset(cls):
739
+ """Reset the singleton Vanna instance.
740
+
741
+ Useful for testing or when configuration changes.
742
+ Properly disposes of database engine if present.
743
+ """
744
+ if cls._instance is not None:
745
+ try:
746
+ # Dispose database engine if present
747
+ if hasattr(cls._instance, "db_engine") and cls._instance.db_engine is not None:
748
+ try:
749
+ cls._instance.db_engine.dispose()
750
+ logger.info("Disposed database engine pool")
751
+ except Exception as e:
752
+ logger.warning(f"Error disposing database engine: {e}")
753
+
754
+ await cls._instance.close()
755
+ except Exception as e:
756
+ logger.warning(f"Error closing Vanna instance: {e}")
757
+ cls._instance = None
758
+
759
+
760
+ async def train_vanna(vn: VannaLangChain, auto_train: bool = False):
761
+ """Train Vanna with DDL, documentation, and question-SQL examples.
762
+
763
+ Args:
764
+ vn: Vanna instance
765
+ auto_train: Whether to automatically train Vanna (auto-extract DDL and generate training data from database)
766
+ """
767
+ logger.info("Training Vanna...")
768
+
769
+ # Train with DDL
770
+ if auto_train:
771
+ from nat.plugins.vanna.training_db_schema import VANNA_ACTIVE_TABLES
772
+
773
+ dialect = vn.dialect.lower()
774
+ ddls = []
775
+
776
+ if dialect == 'databricks':
777
+ for table in VANNA_ACTIVE_TABLES:
778
+ ddl_sql = f"SHOW CREATE TABLE {table}"
779
+ ddl = await vn.run_sql(ddl_sql)
780
+ ddl = ddl.to_string() # Convert DataFrame to string
781
+ ddls.append(ddl)
782
+ else:
783
+ error_msg = (f"Auto-extraction of DDL is currently only supported for Databricks. "
784
+ f"Current dialect: {vn.dialect}. "
785
+ "Please either set auto_train=False or use 'databricks' as the dialect.")
786
+ logger.error(error_msg)
787
+ raise NotImplementedError(error_msg)
788
+ else:
789
+ ddls = VANNA_TRAINING_DDL
790
+
791
+ for ddl in ddls:
792
+ await vn.add_ddl(ddl=ddl)
793
+
794
+ # Train with documentation
795
+ for doc in VANNA_TRAINING_DOCUMENTATION:
796
+ await vn.add_documentation(documentation=doc)
797
+
798
+ # Train with examples
799
+ # Add manual examples
800
+ examples = []
801
+ examples.extend(VANNA_TRAINING_EXAMPLES)
802
+
803
+ if auto_train:
804
+ logger.info("Generating training examples with LLM...")
805
+ # Retrieve relevant context in parallel
806
+ retrieval_tasks = [vn.get_related_record(vn.ddl_collection), vn.get_related_record(vn.doc_collection)]
807
+
808
+ ddl_list, doc_list = await asyncio.gather(*retrieval_tasks)
809
+
810
+ prompt = vn.get_training_sql_prompt(
811
+ ddl_list=ddl_list,
812
+ doc_list=doc_list,
813
+ )
814
+
815
+ llm_response = await vn.submit_prompt(prompt)
816
+
817
+ # Validate LLM-generated examples
818
+ try:
819
+ question_sql_list = extract_json_from_string(llm_response)
820
+ for question_sql in question_sql_list:
821
+ sql = question_sql.get("sql", "")
822
+ if not sql:
823
+ continue
824
+ try:
825
+ await vn.run_sql(sql)
826
+ examples.append({
827
+ "question": question_sql.get("question", ""),
828
+ "sql": sql,
829
+ })
830
+ log_msg = f"Adding valid LLM-generated Question-SQL:\n{question_sql.get('question', '')}\n{sql}"
831
+ logger.info(log_msg)
832
+ except Exception as e:
833
+ logger.debug(f"Dropping invalid LLM-generated SQL: {e}")
834
+ except Exception as e:
835
+ logger.warning(f"Failed to parse LLM response for training examples: {e}")
836
+
837
+ # Train with validated examples
838
+ logger.info(f"Training Vanna with {len(examples)} validated examples")
839
+ for example in examples:
840
+ await vn.add_question_sql(question=example["question"], sql=example["sql"])
841
+ df = await vn.get_training_data()
842
+ df.to_csv("vanna_training_data.csv", index=False)
843
+ logger.info("Vanna training complete")