vanna 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/base/base.py +181 -2
- vanna/mock/__init__.py +3 -0
- vanna/mock/embedding.py +11 -0
- vanna/mock/llm.py +19 -0
- vanna/mock/vectordb.py +55 -0
- vanna/opensearch/opensearch_vector.py +11 -2
- vanna/pinecone/__init__.py +3 -0
- vanna/pinecone/pinecone_vector.py +275 -0
- {vanna-0.5.3.dist-info → vanna-0.5.5.dist-info}/METADATA +7 -1
- {vanna-0.5.3.dist-info → vanna-0.5.5.dist-info}/RECORD +11 -5
- {vanna-0.5.3.dist-info → vanna-0.5.5.dist-info}/WHEEL +0 -0
vanna/base/base.py
CHANGED
|
@@ -555,7 +555,7 @@ class VannaBase(ABC):
|
|
|
555
555
|
"""
|
|
556
556
|
|
|
557
557
|
if initial_prompt is None:
|
|
558
|
-
initial_prompt = f"You are a {self.dialect} expert. "
|
|
558
|
+
initial_prompt = f"You are a {self.dialect} expert. " + \
|
|
559
559
|
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
|
560
560
|
|
|
561
561
|
initial_prompt = self.add_ddl_to_prompt(
|
|
@@ -1012,6 +1012,85 @@ class VannaBase(ABC):
|
|
|
1012
1012
|
self.run_sql_is_set = True
|
|
1013
1013
|
self.run_sql = run_sql_mysql
|
|
1014
1014
|
|
|
1015
|
+
def connect_to_clickhouse(
|
|
1016
|
+
self,
|
|
1017
|
+
host: str = None,
|
|
1018
|
+
dbname: str = None,
|
|
1019
|
+
user: str = None,
|
|
1020
|
+
password: str = None,
|
|
1021
|
+
port: int = None,
|
|
1022
|
+
):
|
|
1023
|
+
|
|
1024
|
+
try:
|
|
1025
|
+
from clickhouse_driver import connect
|
|
1026
|
+
except ImportError:
|
|
1027
|
+
raise DependencyError(
|
|
1028
|
+
"You need to install required dependencies to execute this method,"
|
|
1029
|
+
" run command: \npip install clickhouse-driver"
|
|
1030
|
+
)
|
|
1031
|
+
|
|
1032
|
+
if not host:
|
|
1033
|
+
host = os.getenv("HOST")
|
|
1034
|
+
|
|
1035
|
+
if not host:
|
|
1036
|
+
raise ImproperlyConfigured("Please set your ClickHouse host")
|
|
1037
|
+
|
|
1038
|
+
if not dbname:
|
|
1039
|
+
dbname = os.getenv("DATABASE")
|
|
1040
|
+
|
|
1041
|
+
if not dbname:
|
|
1042
|
+
raise ImproperlyConfigured("Please set your ClickHouse database")
|
|
1043
|
+
|
|
1044
|
+
if not user:
|
|
1045
|
+
user = os.getenv("USER")
|
|
1046
|
+
|
|
1047
|
+
if not user:
|
|
1048
|
+
raise ImproperlyConfigured("Please set your ClickHouse user")
|
|
1049
|
+
|
|
1050
|
+
if not password:
|
|
1051
|
+
password = os.getenv("PASSWORD")
|
|
1052
|
+
|
|
1053
|
+
if not password:
|
|
1054
|
+
raise ImproperlyConfigured("Please set your ClickHouse password")
|
|
1055
|
+
|
|
1056
|
+
if not port:
|
|
1057
|
+
port = os.getenv("PORT")
|
|
1058
|
+
|
|
1059
|
+
if not port:
|
|
1060
|
+
raise ImproperlyConfigured("Please set your ClickHouse port")
|
|
1061
|
+
|
|
1062
|
+
conn = None
|
|
1063
|
+
|
|
1064
|
+
try:
|
|
1065
|
+
conn = connect(host=host,
|
|
1066
|
+
user=user,
|
|
1067
|
+
password=password,
|
|
1068
|
+
database=dbname,
|
|
1069
|
+
port=port,
|
|
1070
|
+
)
|
|
1071
|
+
print(conn)
|
|
1072
|
+
except Exception as e:
|
|
1073
|
+
raise ValidationError(e)
|
|
1074
|
+
|
|
1075
|
+
def run_sql_clickhouse(sql: str) -> Union[pd.DataFrame, None]:
|
|
1076
|
+
if conn:
|
|
1077
|
+
try:
|
|
1078
|
+
cs = conn.cursor()
|
|
1079
|
+
cs.execute(sql)
|
|
1080
|
+
results = cs.fetchall()
|
|
1081
|
+
|
|
1082
|
+
# Create a pandas dataframe from the results
|
|
1083
|
+
df = pd.DataFrame(
|
|
1084
|
+
results, columns=[desc[0] for desc in cs.description]
|
|
1085
|
+
)
|
|
1086
|
+
return df
|
|
1087
|
+
|
|
1088
|
+
except Exception as e:
|
|
1089
|
+
raise e
|
|
1090
|
+
|
|
1091
|
+
self.run_sql_is_set = True
|
|
1092
|
+
self.run_sql = run_sql_clickhouse
|
|
1093
|
+
|
|
1015
1094
|
def connect_to_oracle(
|
|
1016
1095
|
self,
|
|
1017
1096
|
user: str = None,
|
|
@@ -1372,6 +1451,10 @@ class VannaBase(ABC):
|
|
|
1372
1451
|
def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
|
|
1373
1452
|
if conn:
|
|
1374
1453
|
try:
|
|
1454
|
+
sql = sql.rstrip()
|
|
1455
|
+
# fix for a known problem with presto db where an extra ; will cause an error.
|
|
1456
|
+
if sql.endswith(';'):
|
|
1457
|
+
sql = sql[:-1]
|
|
1375
1458
|
cs = conn.cursor()
|
|
1376
1459
|
cs.execute(sql)
|
|
1377
1460
|
results = cs.fetchall()
|
|
@@ -1393,6 +1476,102 @@ class VannaBase(ABC):
|
|
|
1393
1476
|
self.run_sql_is_set = True
|
|
1394
1477
|
self.run_sql = run_sql_presto
|
|
1395
1478
|
|
|
1479
|
+
def connect_to_hive(
|
|
1480
|
+
self,
|
|
1481
|
+
host: str = None,
|
|
1482
|
+
dbname: str = 'default',
|
|
1483
|
+
user: str = None,
|
|
1484
|
+
password: str = None,
|
|
1485
|
+
port: int = None,
|
|
1486
|
+
auth: str = 'CUSTOM'
|
|
1487
|
+
):
|
|
1488
|
+
"""
|
|
1489
|
+
Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
1490
|
+
Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
1491
|
+
|
|
1492
|
+
Args:
|
|
1493
|
+
host (str): The host of the Hive database.
|
|
1494
|
+
dbname (str): The name of the database to connect to.
|
|
1495
|
+
user (str): The username to use for authentication.
|
|
1496
|
+
password (str): The password to use for authentication.
|
|
1497
|
+
port (int): The port to use for the connection.
|
|
1498
|
+
auth (str): The authentication method to use.
|
|
1499
|
+
|
|
1500
|
+
Returns:
|
|
1501
|
+
None
|
|
1502
|
+
"""
|
|
1503
|
+
|
|
1504
|
+
try:
|
|
1505
|
+
from pyhive import hive
|
|
1506
|
+
except ImportError:
|
|
1507
|
+
raise DependencyError(
|
|
1508
|
+
"You need to install required dependencies to execute this method,"
|
|
1509
|
+
" run command: \npip install pyhive"
|
|
1510
|
+
)
|
|
1511
|
+
|
|
1512
|
+
if not host:
|
|
1513
|
+
host = os.getenv("HIVE_HOST")
|
|
1514
|
+
|
|
1515
|
+
if not host:
|
|
1516
|
+
raise ImproperlyConfigured("Please set your hive host")
|
|
1517
|
+
|
|
1518
|
+
if not dbname:
|
|
1519
|
+
dbname = os.getenv("HIVE_DATABASE")
|
|
1520
|
+
|
|
1521
|
+
if not dbname:
|
|
1522
|
+
raise ImproperlyConfigured("Please set your hive database")
|
|
1523
|
+
|
|
1524
|
+
if not user:
|
|
1525
|
+
user = os.getenv("HIVE_USER")
|
|
1526
|
+
|
|
1527
|
+
if not user:
|
|
1528
|
+
raise ImproperlyConfigured("Please set your hive user")
|
|
1529
|
+
|
|
1530
|
+
if not password:
|
|
1531
|
+
password = os.getenv("HIVE_PASSWORD")
|
|
1532
|
+
|
|
1533
|
+
if not port:
|
|
1534
|
+
port = os.getenv("HIVE_PORT")
|
|
1535
|
+
|
|
1536
|
+
if not port:
|
|
1537
|
+
raise ImproperlyConfigured("Please set your hive port")
|
|
1538
|
+
|
|
1539
|
+
conn = None
|
|
1540
|
+
|
|
1541
|
+
try:
|
|
1542
|
+
conn = hive.Connection(host=host,
|
|
1543
|
+
username=user,
|
|
1544
|
+
password=password,
|
|
1545
|
+
database=dbname,
|
|
1546
|
+
port=port,
|
|
1547
|
+
auth=auth)
|
|
1548
|
+
except hive.Error as e:
|
|
1549
|
+
raise ValidationError(e)
|
|
1550
|
+
|
|
1551
|
+
def run_sql_hive(sql: str) -> Union[pd.DataFrame, None]:
|
|
1552
|
+
if conn:
|
|
1553
|
+
try:
|
|
1554
|
+
cs = conn.cursor()
|
|
1555
|
+
cs.execute(sql)
|
|
1556
|
+
results = cs.fetchall()
|
|
1557
|
+
|
|
1558
|
+
# Create a pandas dataframe from the results
|
|
1559
|
+
df = pd.DataFrame(
|
|
1560
|
+
results, columns=[desc[0] for desc in cs.description]
|
|
1561
|
+
)
|
|
1562
|
+
return df
|
|
1563
|
+
|
|
1564
|
+
except hive.Error as e:
|
|
1565
|
+
print(e)
|
|
1566
|
+
raise ValidationError(e)
|
|
1567
|
+
|
|
1568
|
+
except Exception as e:
|
|
1569
|
+
print(e)
|
|
1570
|
+
raise e
|
|
1571
|
+
|
|
1572
|
+
self.run_sql_is_set = True
|
|
1573
|
+
self.run_sql = run_sql_hive
|
|
1574
|
+
|
|
1396
1575
|
def run_sql(self, sql: str, **kwargs) -> pd.DataFrame:
|
|
1397
1576
|
"""
|
|
1398
1577
|
Example:
|
|
@@ -1522,7 +1701,7 @@ class VannaBase(ABC):
|
|
|
1522
1701
|
return None
|
|
1523
1702
|
else:
|
|
1524
1703
|
return sql, None, None
|
|
1525
|
-
return sql, df,
|
|
1704
|
+
return sql, df, fig
|
|
1526
1705
|
|
|
1527
1706
|
def train(
|
|
1528
1707
|
self,
|
vanna/mock/__init__.py
ADDED
vanna/mock/embedding.py
ADDED
vanna/mock/llm.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
|
|
2
|
+
from ..base import VannaBase
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MockLLM(VannaBase):
|
|
6
|
+
def __init__(self, config=None):
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
def system_message(self, message: str) -> any:
|
|
10
|
+
return {"role": "system", "content": message}
|
|
11
|
+
|
|
12
|
+
def user_message(self, message: str) -> any:
|
|
13
|
+
return {"role": "user", "content": message}
|
|
14
|
+
|
|
15
|
+
def assistant_message(self, message: str) -> any:
|
|
16
|
+
return {"role": "assistant", "content": message}
|
|
17
|
+
|
|
18
|
+
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
19
|
+
return "Mock LLM response"
|
vanna/mock/vectordb.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
|
|
3
|
+
from ..base import VannaBase
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MockVectorDB(VannaBase):
|
|
7
|
+
def __init__(self, config=None):
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
def _get_id(self, value: str, **kwargs) -> str:
|
|
11
|
+
# Hash the value and return the ID
|
|
12
|
+
return str(hash(value))
|
|
13
|
+
|
|
14
|
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
15
|
+
return self._get_id(ddl)
|
|
16
|
+
|
|
17
|
+
def add_documentation(self, doc: str, **kwargs) -> str:
|
|
18
|
+
return self._get_id(doc)
|
|
19
|
+
|
|
20
|
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
21
|
+
return self._get_id(question)
|
|
22
|
+
|
|
23
|
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
24
|
+
return []
|
|
25
|
+
|
|
26
|
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
27
|
+
return []
|
|
28
|
+
|
|
29
|
+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
30
|
+
return []
|
|
31
|
+
|
|
32
|
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
33
|
+
return pd.DataFrame({'id': {0: '19546-ddl',
|
|
34
|
+
1: '91597-sql',
|
|
35
|
+
2: '133976-sql',
|
|
36
|
+
3: '59851-doc',
|
|
37
|
+
4: '73046-sql'},
|
|
38
|
+
'training_data_type': {0: 'ddl',
|
|
39
|
+
1: 'sql',
|
|
40
|
+
2: 'sql',
|
|
41
|
+
3: 'documentation',
|
|
42
|
+
4: 'sql'},
|
|
43
|
+
'question': {0: None,
|
|
44
|
+
1: 'What are the top selling genres?',
|
|
45
|
+
2: 'What are the low 7 artists by sales?',
|
|
46
|
+
3: None,
|
|
47
|
+
4: 'What is the total sales for each customer?'},
|
|
48
|
+
'content': {0: 'CREATE TABLE [Invoice]\n(\n [InvoiceId] INTEGER NOT NULL,\n [CustomerId] INTEGER NOT NULL,\n [InvoiceDate] DATETIME NOT NULL,\n [BillingAddress] NVARCHAR(70),\n [BillingCity] NVARCHAR(40),\n [BillingState] NVARCHAR(40),\n [BillingCountry] NVARCHAR(40),\n [BillingPostalCode] NVARCHAR(10),\n [Total] NUMERIC(10,2) NOT NULL,\n CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]),\n FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n)',
|
|
49
|
+
1: 'SELECT g.Name AS Genre, SUM(il.Quantity) AS TotalSales\nFROM Genre g\nJOIN Track t ON g.GenreId = t.GenreId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY g.GenreId, g.Name\nORDER BY TotalSales DESC;',
|
|
50
|
+
2: 'SELECT a.ArtistId, a.Name, SUM(il.Quantity) AS TotalSales\nFROM Artist a\nINNER JOIN Album al ON a.ArtistId = al.ArtistId\nINNER JOIN Track t ON al.AlbumId = t.AlbumId\nINNER JOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY a.ArtistId, a.Name\nORDER BY TotalSales ASC\nLIMIT 7;',
|
|
51
|
+
3: 'This is a SQLite database. For dates rememeber to use SQLite syntax.',
|
|
52
|
+
4: 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName;'}})
|
|
53
|
+
|
|
54
|
+
def remove_training_data(id: str, **kwargs) -> bool:
|
|
55
|
+
return True
|
|
@@ -155,6 +155,11 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
155
155
|
else:
|
|
156
156
|
max_retries = 10
|
|
157
157
|
|
|
158
|
+
if config is not None and "es_http_compress" in config:
|
|
159
|
+
es_http_compress = config["es_http_compress"]
|
|
160
|
+
else:
|
|
161
|
+
es_http_compress = False
|
|
162
|
+
|
|
158
163
|
print("OpenSearch_VectorStore initialized with es_urls: ", es_urls,
|
|
159
164
|
" host: ", host, " port: ", port, " ssl: ", ssl, " verify_certs: ",
|
|
160
165
|
verify_certs, " timeout: ", timeout, " max_retries: ", max_retries)
|
|
@@ -162,7 +167,7 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
162
167
|
# Initialize the OpenSearch client by passing a list of URLs
|
|
163
168
|
self.client = OpenSearch(
|
|
164
169
|
hosts=[es_urls],
|
|
165
|
-
http_compress=
|
|
170
|
+
http_compress=es_http_compress,
|
|
166
171
|
use_ssl=ssl,
|
|
167
172
|
verify_certs=verify_certs,
|
|
168
173
|
timeout=timeout,
|
|
@@ -175,7 +180,7 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
175
180
|
# Initialize the OpenSearch client by passing a host and port
|
|
176
181
|
self.client = OpenSearch(
|
|
177
182
|
hosts=[{'host': host, 'port': port}],
|
|
178
|
-
http_compress=
|
|
183
|
+
http_compress=es_http_compress,
|
|
179
184
|
use_ssl=ssl,
|
|
180
185
|
verify_certs=verify_certs,
|
|
181
186
|
timeout=timeout,
|
|
@@ -267,6 +272,7 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
267
272
|
}
|
|
268
273
|
}
|
|
269
274
|
}
|
|
275
|
+
print(query)
|
|
270
276
|
response = self.client.search(index=self.ddl_index, body=query,
|
|
271
277
|
**kwargs)
|
|
272
278
|
return [hit['_source']['ddl'] for hit in response['hits']['hits']]
|
|
@@ -279,6 +285,7 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
279
285
|
}
|
|
280
286
|
}
|
|
281
287
|
}
|
|
288
|
+
print(query)
|
|
282
289
|
response = self.client.search(index=self.document_index,
|
|
283
290
|
body=query,
|
|
284
291
|
**kwargs)
|
|
@@ -292,6 +299,7 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
292
299
|
}
|
|
293
300
|
}
|
|
294
301
|
}
|
|
302
|
+
print(query)
|
|
295
303
|
response = self.client.search(index=self.question_sql_index,
|
|
296
304
|
body=query,
|
|
297
305
|
**kwargs)
|
|
@@ -307,6 +315,7 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
307
315
|
body={"query": {"match_all": {}}},
|
|
308
316
|
size=1000
|
|
309
317
|
)
|
|
318
|
+
print(query)
|
|
310
319
|
# records = [hit['_source'] for hit in response['hits']['hits']]
|
|
311
320
|
for hit in response['hits']['hits']:
|
|
312
321
|
data.append(
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from pinecone import Pinecone, PodSpec, ServerlessSpec
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from ..base import VannaBase
|
|
7
|
+
from ..utils import deterministic_uuid
|
|
8
|
+
|
|
9
|
+
from fastembed import TextEmbedding
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PineconeDB_VectorStore(VannaBase):
|
|
13
|
+
"""
|
|
14
|
+
Vectorstore using PineconeDB
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
config (dict): Configuration dictionary. Defaults to {}. You must provide either a Pinecone Client or an API key in the config.
|
|
18
|
+
- client (Pinecone, optional): Pinecone client. Defaults to None.
|
|
19
|
+
- api_key (str, optional): Pinecone API key. Defaults to None.
|
|
20
|
+
- n_results (int, optional): Number of results to return. Defaults to 10.
|
|
21
|
+
- dimensions (int, optional): Dimensions of the embeddings. Defaults to 384 which coresponds to the dimensions of BAAI/bge-small-en-v1.5.
|
|
22
|
+
- fastembed_model (str, optional): Fastembed model to use. Defaults to "BAAI/bge-small-en-v1.5".
|
|
23
|
+
- documentation_namespace (str, optional): Namespace for documentation. Defaults to "documentation".
|
|
24
|
+
- distance_metric (str, optional): Distance metric to use. Defaults to "cosine".
|
|
25
|
+
- ddl_namespace (str, optional): Namespace for DDL. Defaults to "ddl".
|
|
26
|
+
- sql_namespace (str, optional): Namespace for SQL. Defaults to "sql".
|
|
27
|
+
- index_name (str, optional): Name of the index. Defaults to "vanna-index".
|
|
28
|
+
- metadata_config (dict, optional): Metadata configuration if using a pinecone pod. Defaults to {}.
|
|
29
|
+
- server_type (str, optional): Type of Pinecone server to use. Defaults to "serverless". Options are "serverless" or "pod".
|
|
30
|
+
- podspec (PodSpec, optional): PodSpec configuration if using a pinecone pod. Defaults to PodSpec(environment="us-west-2", pod_type="p1.x1", metadata_config=self.metadata_config).
|
|
31
|
+
- serverless_spec (ServerlessSpec, optional): ServerlessSpec configuration if using a pinecone serverless index. Defaults to ServerlessSpec(cloud="aws", region="us-west-2").
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If config is None, api_key is not provided OR client is not provided, client is not an instance of Pinecone, or server_type is not "serverless" or "pod".
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, config=None):
|
|
37
|
+
VannaBase.__init__(self, config=config)
|
|
38
|
+
if config is None:
|
|
39
|
+
raise ValueError(
|
|
40
|
+
"config is required, pass either a Pinecone client or an API key in the config."
|
|
41
|
+
)
|
|
42
|
+
client = config.get("client")
|
|
43
|
+
api_key = config.get("api_key")
|
|
44
|
+
if not api_key and not client:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
"api_key is required in config or pass a configured client"
|
|
47
|
+
)
|
|
48
|
+
if not client and api_key:
|
|
49
|
+
self._client = Pinecone(api_key=api_key)
|
|
50
|
+
elif not isinstance(client, Pinecone):
|
|
51
|
+
raise ValueError("client must be an instance of Pinecone")
|
|
52
|
+
else:
|
|
53
|
+
self._client = client
|
|
54
|
+
|
|
55
|
+
self.n_results = config.get("n_results", 10)
|
|
56
|
+
self.dimensions = config.get("dimensions", 384)
|
|
57
|
+
self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
|
|
58
|
+
self.documentation_namespace = config.get(
|
|
59
|
+
"documentation_namespace", "documentation"
|
|
60
|
+
)
|
|
61
|
+
self.distance_metric = config.get("distance_metric", "cosine")
|
|
62
|
+
self.ddl_namespace = config.get("ddl_namespace", "ddl")
|
|
63
|
+
self.sql_namespace = config.get("sql_namespace", "sql")
|
|
64
|
+
self.index_name = config.get("index_name", "vanna-index")
|
|
65
|
+
self.metadata_config = config.get("metadata_config", {})
|
|
66
|
+
self.server_type = config.get("server_type", "serverless")
|
|
67
|
+
if self.server_type not in ["serverless", "pod"]:
|
|
68
|
+
raise ValueError("server_type must be either 'serverless' or 'pod'")
|
|
69
|
+
self.podspec = config.get(
|
|
70
|
+
"podspec",
|
|
71
|
+
PodSpec(
|
|
72
|
+
environment="us-west-2",
|
|
73
|
+
pod_type="p1.x1",
|
|
74
|
+
metadata_config=self.metadata_config,
|
|
75
|
+
),
|
|
76
|
+
)
|
|
77
|
+
self.serverless_spec = config.get(
|
|
78
|
+
"serverless_spec", ServerlessSpec(cloud="aws", region="us-west-2")
|
|
79
|
+
)
|
|
80
|
+
self._setup_index()
|
|
81
|
+
|
|
82
|
+
def _set_index_host(self, host: str) -> None:
|
|
83
|
+
self.Index = self._client.Index(host=host)
|
|
84
|
+
|
|
85
|
+
def _setup_index(self) -> None:
|
|
86
|
+
existing_indexes = self._get_indexes()
|
|
87
|
+
if self.index_name not in existing_indexes and self.server_type == "serverless":
|
|
88
|
+
self._client.create_index(
|
|
89
|
+
name=self.index_name,
|
|
90
|
+
dimension=self.dimensions,
|
|
91
|
+
metric=self.distance_metric,
|
|
92
|
+
spec=self.serverless_spec,
|
|
93
|
+
)
|
|
94
|
+
pinecone_index_host = self._client.describe_index(self.index_name)["host"]
|
|
95
|
+
self._set_index_host(pinecone_index_host)
|
|
96
|
+
elif self.index_name not in existing_indexes and self.server_type == "pod":
|
|
97
|
+
self._client.create_index(
|
|
98
|
+
name=self.index_name,
|
|
99
|
+
dimension=self.dimensions,
|
|
100
|
+
metric=self.distance_metric,
|
|
101
|
+
spec=self.podspec,
|
|
102
|
+
)
|
|
103
|
+
pinecone_index_host = self._client.describe_index(self.index_name)["host"]
|
|
104
|
+
self._set_index_host(pinecone_index_host)
|
|
105
|
+
else:
|
|
106
|
+
pinecone_index_host = self._client.describe_index(self.index_name)["host"]
|
|
107
|
+
self._set_index_host(pinecone_index_host)
|
|
108
|
+
|
|
109
|
+
def _get_indexes(self) -> list:
|
|
110
|
+
return [index["name"] for index in self._client.list_indexes()]
|
|
111
|
+
|
|
112
|
+
def _check_if_embedding_exists(self, id: str, namespace: str) -> bool:
|
|
113
|
+
fetch_response = self.Index.fetch(ids=[id], namespace=namespace)
|
|
114
|
+
if fetch_response["vectors"] == {}:
|
|
115
|
+
return False
|
|
116
|
+
return True
|
|
117
|
+
|
|
118
|
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
119
|
+
id = deterministic_uuid(ddl) + "-ddl"
|
|
120
|
+
if self._check_if_embedding_exists(id=id, namespace=self.ddl_namespace):
|
|
121
|
+
print(f"DDL with id: {id} already exists in the index. Skipping...")
|
|
122
|
+
return id
|
|
123
|
+
self.Index.upsert(
|
|
124
|
+
vectors=[(id, self.generate_embedding(ddl), {"ddl": ddl})],
|
|
125
|
+
namespace=self.ddl_namespace,
|
|
126
|
+
)
|
|
127
|
+
return id
|
|
128
|
+
|
|
129
|
+
def add_documentation(self, doc: str, **kwargs) -> str:
|
|
130
|
+
id = deterministic_uuid(doc) + "-doc"
|
|
131
|
+
|
|
132
|
+
if self._check_if_embedding_exists(
|
|
133
|
+
id=id, namespace=self.documentation_namespace
|
|
134
|
+
):
|
|
135
|
+
print(
|
|
136
|
+
f"Documentation with id: {id} already exists in the index. Skipping..."
|
|
137
|
+
)
|
|
138
|
+
return id
|
|
139
|
+
self.Index.upsert(
|
|
140
|
+
vectors=[(id, self.generate_embedding(doc), {"documentation": doc})],
|
|
141
|
+
namespace=self.documentation_namespace,
|
|
142
|
+
)
|
|
143
|
+
return id
|
|
144
|
+
|
|
145
|
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
146
|
+
question_sql_json = json.dumps(
|
|
147
|
+
{
|
|
148
|
+
"question": question,
|
|
149
|
+
"sql": sql,
|
|
150
|
+
},
|
|
151
|
+
ensure_ascii=False,
|
|
152
|
+
)
|
|
153
|
+
id = deterministic_uuid(question_sql_json) + "-sql"
|
|
154
|
+
if self._check_if_embedding_exists(id=id, namespace=self.sql_namespace):
|
|
155
|
+
print(
|
|
156
|
+
f"Question-SQL with id: {id} already exists in the index. Skipping..."
|
|
157
|
+
)
|
|
158
|
+
return id
|
|
159
|
+
self.Index.upsert(
|
|
160
|
+
vectors=[
|
|
161
|
+
(
|
|
162
|
+
id,
|
|
163
|
+
self.generate_embedding(question_sql_json),
|
|
164
|
+
{"sql": question_sql_json},
|
|
165
|
+
)
|
|
166
|
+
],
|
|
167
|
+
namespace=self.sql_namespace,
|
|
168
|
+
)
|
|
169
|
+
return id
|
|
170
|
+
|
|
171
|
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
172
|
+
res = self.Index.query(
|
|
173
|
+
namespace=self.ddl_namespace,
|
|
174
|
+
vector=self.generate_embedding(question),
|
|
175
|
+
top_k=self.n_results,
|
|
176
|
+
include_values=True,
|
|
177
|
+
include_metadata=True,
|
|
178
|
+
)
|
|
179
|
+
return [match["metadata"]["ddl"] for match in res["matches"]] if res else []
|
|
180
|
+
|
|
181
|
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
182
|
+
res = self.Index.query(
|
|
183
|
+
namespace=self.documentation_namespace,
|
|
184
|
+
vector=self.generate_embedding(question),
|
|
185
|
+
top_k=self.n_results,
|
|
186
|
+
include_values=True,
|
|
187
|
+
include_metadata=True,
|
|
188
|
+
)
|
|
189
|
+
return (
|
|
190
|
+
[match["metadata"]["documentation"] for match in res["matches"]]
|
|
191
|
+
if res
|
|
192
|
+
else []
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
196
|
+
res = self.Index.query(
|
|
197
|
+
namespace=self.sql_namespace,
|
|
198
|
+
vector=self.generate_embedding(question),
|
|
199
|
+
top_k=self.n_results,
|
|
200
|
+
include_values=True,
|
|
201
|
+
include_metadata=True,
|
|
202
|
+
)
|
|
203
|
+
return (
|
|
204
|
+
[
|
|
205
|
+
{
|
|
206
|
+
key: value
|
|
207
|
+
for key, value in json.loads(match["metadata"]["sql"]).items()
|
|
208
|
+
}
|
|
209
|
+
for match in res["matches"]
|
|
210
|
+
]
|
|
211
|
+
if res
|
|
212
|
+
else []
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
216
|
+
# Pinecone does not support getting all vectors in a namespace, so we have to query for the top_k vectors with a dummy vector
|
|
217
|
+
df = pd.DataFrame()
|
|
218
|
+
namespaces = {
|
|
219
|
+
"sql": self.sql_namespace,
|
|
220
|
+
"ddl": self.ddl_namespace,
|
|
221
|
+
"documentation": self.documentation_namespace,
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
for data_type, namespace in namespaces.items():
|
|
225
|
+
data = self.Index.query(
|
|
226
|
+
top_k=10000, # max results that pinecone allows
|
|
227
|
+
namespace=namespace,
|
|
228
|
+
include_values=True,
|
|
229
|
+
include_metadata=True,
|
|
230
|
+
vector=[0.0] * self.dimensions,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
if data is not None:
|
|
234
|
+
id_list = [match["id"] for match in data["matches"]]
|
|
235
|
+
content_list = [
|
|
236
|
+
match["metadata"][data_type] for match in data["matches"]
|
|
237
|
+
]
|
|
238
|
+
question_list = [
|
|
239
|
+
(
|
|
240
|
+
json.loads(match["metadata"][data_type])["question"]
|
|
241
|
+
if data_type == "sql"
|
|
242
|
+
else None
|
|
243
|
+
)
|
|
244
|
+
for match in data["matches"]
|
|
245
|
+
]
|
|
246
|
+
|
|
247
|
+
df_data = pd.DataFrame(
|
|
248
|
+
{
|
|
249
|
+
"id": id_list,
|
|
250
|
+
"question": question_list,
|
|
251
|
+
"content": content_list,
|
|
252
|
+
}
|
|
253
|
+
)
|
|
254
|
+
df_data["training_data_type"] = data_type
|
|
255
|
+
df = pd.concat([df, df_data])
|
|
256
|
+
|
|
257
|
+
return df
|
|
258
|
+
|
|
259
|
+
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
260
|
+
if id.endswith("-sql"):
|
|
261
|
+
self.Index.delete(ids=[id], namespace=self.sql_namespace)
|
|
262
|
+
return True
|
|
263
|
+
elif id.endswith("-ddl"):
|
|
264
|
+
self.Index.delete(ids=[id], namespace=self.ddl_namespace)
|
|
265
|
+
return True
|
|
266
|
+
elif id.endswith("-doc"):
|
|
267
|
+
self.Index.delete(ids=[id], namespace=self.documentation_namespace)
|
|
268
|
+
return True
|
|
269
|
+
else:
|
|
270
|
+
return False
|
|
271
|
+
|
|
272
|
+
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
273
|
+
embedding_model = TextEmbedding(model_name=self.fastembed_model)
|
|
274
|
+
embedding = next(embedding_model.embed(data))
|
|
275
|
+
return embedding.tolist()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vanna
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.5
|
|
4
4
|
Summary: Generate SQL queries from natural language
|
|
5
5
|
Author-email: Zain Hoda <zain@vanna.ai>
|
|
6
6
|
Requires-Python: >=3.9
|
|
@@ -38,9 +38,11 @@ Requires-Dist: httpx ; extra == "all"
|
|
|
38
38
|
Requires-Dist: opensearch-py ; extra == "all"
|
|
39
39
|
Requires-Dist: opensearch-dsl ; extra == "all"
|
|
40
40
|
Requires-Dist: transformers ; extra == "all"
|
|
41
|
+
Requires-Dist: pinecone-client ; extra == "all"
|
|
41
42
|
Requires-Dist: anthropic ; extra == "anthropic"
|
|
42
43
|
Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
|
|
43
44
|
Requires-Dist: chromadb ; extra == "chromadb"
|
|
45
|
+
Requires-Dist: clickhouse_driver ; extra == "clickhouse"
|
|
44
46
|
Requires-Dist: duckdb ; extra == "duckdb"
|
|
45
47
|
Requires-Dist: google-generativeai ; extra == "gemini"
|
|
46
48
|
Requires-Dist: google-generativeai ; extra == "google"
|
|
@@ -54,6 +56,8 @@ Requires-Dist: httpx ; extra == "ollama"
|
|
|
54
56
|
Requires-Dist: openai ; extra == "openai"
|
|
55
57
|
Requires-Dist: opensearch-py ; extra == "opensearch"
|
|
56
58
|
Requires-Dist: opensearch-dsl ; extra == "opensearch"
|
|
59
|
+
Requires-Dist: pinecone-client ; extra == "pinecone"
|
|
60
|
+
Requires-Dist: fastembed ; extra == "pinecone"
|
|
57
61
|
Requires-Dist: psycopg2-binary ; extra == "postgres"
|
|
58
62
|
Requires-Dist: db-dtypes ; extra == "postgres"
|
|
59
63
|
Requires-Dist: qdrant-client ; extra == "qdrant"
|
|
@@ -68,6 +72,7 @@ Provides-Extra: all
|
|
|
68
72
|
Provides-Extra: anthropic
|
|
69
73
|
Provides-Extra: bigquery
|
|
70
74
|
Provides-Extra: chromadb
|
|
75
|
+
Provides-Extra: clickhouse
|
|
71
76
|
Provides-Extra: duckdb
|
|
72
77
|
Provides-Extra: gemini
|
|
73
78
|
Provides-Extra: google
|
|
@@ -78,6 +83,7 @@ Provides-Extra: mysql
|
|
|
78
83
|
Provides-Extra: ollama
|
|
79
84
|
Provides-Extra: openai
|
|
80
85
|
Provides-Extra: opensearch
|
|
86
|
+
Provides-Extra: pinecone
|
|
81
87
|
Provides-Extra: postgres
|
|
82
88
|
Provides-Extra: qdrant
|
|
83
89
|
Provides-Extra: snowflake
|
|
@@ -8,7 +8,7 @@ vanna/ZhipuAI/__init__.py,sha256=NlsijtcZp5Tj9jkOe9fNcOQND_QsGgu7otODsCLBPr0,116
|
|
|
8
8
|
vanna/anthropic/__init__.py,sha256=85s_2mAyyPxc0T_0JEvYeAkEKWJwkwqoyUwSC5dw9Gk,43
|
|
9
9
|
vanna/anthropic/anthropic_chat.py,sha256=Wk0o-NMW1uvR2fhSWxrR_2FqNh-dLprNG4uuVqpqAkY,2615
|
|
10
10
|
vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
|
|
11
|
-
vanna/base/base.py,sha256=
|
|
11
|
+
vanna/base/base.py,sha256=YCL9MhhrGeoVv9da85NdWvEQtnqfkWSyj5AZo_wQ0TU,70853
|
|
12
12
|
vanna/chromadb/__init__.py,sha256=-iL0nW_g4uM8nWKMuWnNePfN4nb9uk8P3WzGvezOqRg,50
|
|
13
13
|
vanna/chromadb/chromadb_vector.py,sha256=eKyPck99Y6Jt-BNWojvxLG-zvAERzLSm-3zY-bKXvaA,8792
|
|
14
14
|
vanna/exceptions/__init__.py,sha256=dJ65xxxZh1lqBeg6nz6Tq_r34jLVmjvBvPO9Q6hFaQ8,685
|
|
@@ -23,13 +23,19 @@ vanna/marqo/__init__.py,sha256=GaAWtJ0B-H5rTY607iLCCrLD7T0zMYM5qWIomEB9gLk,37
|
|
|
23
23
|
vanna/marqo/marqo.py,sha256=W7WTtzWp4RJjZVy6OaXHqncUBIPdI4Q7qH7BRCxZ1_A,5242
|
|
24
24
|
vanna/mistral/__init__.py,sha256=70rTY-69Z2ehkkMj84dNMCukPo6AWdflBGvIB_pztS0,29
|
|
25
25
|
vanna/mistral/mistral.py,sha256=DAEqAT9SzC91rfMM_S3SuzBZ34MrKHw9qAj6EP2MGVk,1508
|
|
26
|
+
vanna/mock/__init__.py,sha256=nYR2WfcV5NdwpK3V64QGOWHBGc3ESN9uV68JLS76aRw,97
|
|
27
|
+
vanna/mock/embedding.py,sha256=ggnP7KuPh6dlqeUFtoN8t0J0P7_yRNtn9rIq6h8g8-w,250
|
|
28
|
+
vanna/mock/llm.py,sha256=WpG9f1pKZftPBHqgIYdARKB2Z9DZhOALYOJWoOjjFEc,518
|
|
29
|
+
vanna/mock/vectordb.py,sha256=h45znfYMUnttE2BBC8v6TKeMaA58pFJL-5B3OGeRNFI,2681
|
|
26
30
|
vanna/ollama/__init__.py,sha256=4xyu8aHPdnEHg5a-QAMwr5o0ns5wevsp_zkI-ndMO2k,27
|
|
27
31
|
vanna/ollama/ollama.py,sha256=rXa7cfvdlO1E5SLysXIl3IZpIaA2r0RBvV5jX2-upiE,3794
|
|
28
32
|
vanna/openai/__init__.py,sha256=tGkeQ7wTIPsando7QhoSHehtoQVdYLwFbKNlSmCmNeQ,86
|
|
29
33
|
vanna/openai/openai_chat.py,sha256=lm-hUsQxu6Q1t06A2csC037zI4VkMk0wFbQ-_Lj74Wg,4764
|
|
30
34
|
vanna/openai/openai_embeddings.py,sha256=g4pNh9LVcYP9wOoO8ecaccDFWmCUYMInebfHucAa2Gc,1260
|
|
31
35
|
vanna/opensearch/__init__.py,sha256=0unDevWOTs7o8S79TOHUKF1mSiuQbBUVm-7k9jV5WW4,54
|
|
32
|
-
vanna/opensearch/opensearch_vector.py,sha256=
|
|
36
|
+
vanna/opensearch/opensearch_vector.py,sha256=VhIcrSyNzWR9ZrqrJnyGFOyuQZs3swfbhr8QyVGI0eI,12226
|
|
37
|
+
vanna/pinecone/__init__.py,sha256=eO5l8aX8vKL6aIUMgAXGPt1jdqKxB_Hic6cmoVAUrD0,90
|
|
38
|
+
vanna/pinecone/pinecone_vector.py,sha256=mpq1lzo3KRj2QfJEw8pwFclFQK1Oi_Nx-lDkx9Gp0mw,11448
|
|
33
39
|
vanna/qdrant/__init__.py,sha256=PX_OsDOiPMvwCJ2iGER1drSdQ9AyM8iN5PEBhRb6qqY,73
|
|
34
40
|
vanna/qdrant/qdrant.py,sha256=6M00nMiuOuftTDf3NsOrOcG9BA4DlIIDck2MNp9iEyg,12613
|
|
35
41
|
vanna/types/__init__.py,sha256=Qhn_YscKtJh7mFPCyCDLa2K8a4ORLMGVnPpTbv9uB2U,4957
|
|
@@ -37,6 +43,6 @@ vanna/vannadb/__init__.py,sha256=C6UkYocmO6dmzfPKZaWojN0mI5YlZZ9VIbdcquBE58A,48
|
|
|
37
43
|
vanna/vannadb/vannadb_vector.py,sha256=9YwTO3Lh5owWQE7KPMBqLp2EkiGV0RC1sEYhslzJzgI,6168
|
|
38
44
|
vanna/vllm/__init__.py,sha256=aNlUkF9tbURdeXAJ8ytuaaF1gYwcG3ny1MfNl_cwQYg,23
|
|
39
45
|
vanna/vllm/vllm.py,sha256=QerC3xF5eNzE_nGBDl6YrPYF4WYnjf0hHxxlDWdKX-0,2427
|
|
40
|
-
vanna-0.5.
|
|
41
|
-
vanna-0.5.
|
|
42
|
-
vanna-0.5.
|
|
46
|
+
vanna-0.5.5.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
|
|
47
|
+
vanna-0.5.5.dist-info/METADATA,sha256=gmPrDKsawOtYazNHG5CQ_iURmdVLt1Jc8Rfwneue1w0,11505
|
|
48
|
+
vanna-0.5.5.dist-info/RECORD,,
|
|
File without changes
|