vanna 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vanna/base/base.py CHANGED
@@ -555,7 +555,7 @@ class VannaBase(ABC):
555
555
  """
556
556
 
557
557
  if initial_prompt is None:
558
- initial_prompt = f"You are a {self.dialect} expert. "
558
+ initial_prompt = f"You are a {self.dialect} expert. " + \
559
559
  "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
560
560
 
561
561
  initial_prompt = self.add_ddl_to_prompt(
@@ -1012,6 +1012,85 @@ class VannaBase(ABC):
1012
1012
  self.run_sql_is_set = True
1013
1013
  self.run_sql = run_sql_mysql
1014
1014
 
1015
+ def connect_to_clickhouse(
1016
+ self,
1017
+ host: str = None,
1018
+ dbname: str = None,
1019
+ user: str = None,
1020
+ password: str = None,
1021
+ port: int = None,
1022
+ ):
1023
+
1024
+ try:
1025
+ from clickhouse_driver import connect
1026
+ except ImportError:
1027
+ raise DependencyError(
1028
+ "You need to install required dependencies to execute this method,"
1029
+ " run command: \npip install clickhouse-driver"
1030
+ )
1031
+
1032
+ if not host:
1033
+ host = os.getenv("HOST")
1034
+
1035
+ if not host:
1036
+ raise ImproperlyConfigured("Please set your ClickHouse host")
1037
+
1038
+ if not dbname:
1039
+ dbname = os.getenv("DATABASE")
1040
+
1041
+ if not dbname:
1042
+ raise ImproperlyConfigured("Please set your ClickHouse database")
1043
+
1044
+ if not user:
1045
+ user = os.getenv("USER")
1046
+
1047
+ if not user:
1048
+ raise ImproperlyConfigured("Please set your ClickHouse user")
1049
+
1050
+ if not password:
1051
+ password = os.getenv("PASSWORD")
1052
+
1053
+ if not password:
1054
+ raise ImproperlyConfigured("Please set your ClickHouse password")
1055
+
1056
+ if not port:
1057
+ port = os.getenv("PORT")
1058
+
1059
+ if not port:
1060
+ raise ImproperlyConfigured("Please set your ClickHouse port")
1061
+
1062
+ conn = None
1063
+
1064
+ try:
1065
+ conn = connect(host=host,
1066
+ user=user,
1067
+ password=password,
1068
+ database=dbname,
1069
+ port=port,
1070
+ )
1071
+ print(conn)
1072
+ except Exception as e:
1073
+ raise ValidationError(e)
1074
+
1075
+ def run_sql_clickhouse(sql: str) -> Union[pd.DataFrame, None]:
1076
+ if conn:
1077
+ try:
1078
+ cs = conn.cursor()
1079
+ cs.execute(sql)
1080
+ results = cs.fetchall()
1081
+
1082
+ # Create a pandas dataframe from the results
1083
+ df = pd.DataFrame(
1084
+ results, columns=[desc[0] for desc in cs.description]
1085
+ )
1086
+ return df
1087
+
1088
+ except Exception as e:
1089
+ raise e
1090
+
1091
+ self.run_sql_is_set = True
1092
+ self.run_sql = run_sql_clickhouse
1093
+
1015
1094
  def connect_to_oracle(
1016
1095
  self,
1017
1096
  user: str = None,
@@ -1372,6 +1451,10 @@ class VannaBase(ABC):
1372
1451
  def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
1373
1452
  if conn:
1374
1453
  try:
1454
+ sql = sql.rstrip()
1455
+ # fix for a known problem with presto db where an extra ; will cause an error.
1456
+ if sql.endswith(';'):
1457
+ sql = sql[:-1]
1375
1458
  cs = conn.cursor()
1376
1459
  cs.execute(sql)
1377
1460
  results = cs.fetchall()
@@ -1393,6 +1476,102 @@ class VannaBase(ABC):
1393
1476
  self.run_sql_is_set = True
1394
1477
  self.run_sql = run_sql_presto
1395
1478
 
1479
+ def connect_to_hive(
1480
+ self,
1481
+ host: str = None,
1482
+ dbname: str = 'default',
1483
+ user: str = None,
1484
+ password: str = None,
1485
+ port: int = None,
1486
+ auth: str = 'CUSTOM'
1487
+ ):
1488
+ """
1489
+ Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
1490
+ Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
1491
+
1492
+ Args:
1493
+ host (str): The host of the Hive database.
1494
+ dbname (str): The name of the database to connect to.
1495
+ user (str): The username to use for authentication.
1496
+ password (str): The password to use for authentication.
1497
+ port (int): The port to use for the connection.
1498
+ auth (str): The authentication method to use.
1499
+
1500
+ Returns:
1501
+ None
1502
+ """
1503
+
1504
+ try:
1505
+ from pyhive import hive
1506
+ except ImportError:
1507
+ raise DependencyError(
1508
+ "You need to install required dependencies to execute this method,"
1509
+ " run command: \npip install pyhive"
1510
+ )
1511
+
1512
+ if not host:
1513
+ host = os.getenv("HIVE_HOST")
1514
+
1515
+ if not host:
1516
+ raise ImproperlyConfigured("Please set your hive host")
1517
+
1518
+ if not dbname:
1519
+ dbname = os.getenv("HIVE_DATABASE")
1520
+
1521
+ if not dbname:
1522
+ raise ImproperlyConfigured("Please set your hive database")
1523
+
1524
+ if not user:
1525
+ user = os.getenv("HIVE_USER")
1526
+
1527
+ if not user:
1528
+ raise ImproperlyConfigured("Please set your hive user")
1529
+
1530
+ if not password:
1531
+ password = os.getenv("HIVE_PASSWORD")
1532
+
1533
+ if not port:
1534
+ port = os.getenv("HIVE_PORT")
1535
+
1536
+ if not port:
1537
+ raise ImproperlyConfigured("Please set your hive port")
1538
+
1539
+ conn = None
1540
+
1541
+ try:
1542
+ conn = hive.Connection(host=host,
1543
+ username=user,
1544
+ password=password,
1545
+ database=dbname,
1546
+ port=port,
1547
+ auth=auth)
1548
+ except hive.Error as e:
1549
+ raise ValidationError(e)
1550
+
1551
+ def run_sql_hive(sql: str) -> Union[pd.DataFrame, None]:
1552
+ if conn:
1553
+ try:
1554
+ cs = conn.cursor()
1555
+ cs.execute(sql)
1556
+ results = cs.fetchall()
1557
+
1558
+ # Create a pandas dataframe from the results
1559
+ df = pd.DataFrame(
1560
+ results, columns=[desc[0] for desc in cs.description]
1561
+ )
1562
+ return df
1563
+
1564
+ except hive.Error as e:
1565
+ print(e)
1566
+ raise ValidationError(e)
1567
+
1568
+ except Exception as e:
1569
+ print(e)
1570
+ raise e
1571
+
1572
+ self.run_sql_is_set = True
1573
+ self.run_sql = run_sql_hive
1574
+
1396
1575
  def run_sql(self, sql: str, **kwargs) -> pd.DataFrame:
1397
1576
  """
1398
1577
  Example:
@@ -1522,7 +1701,7 @@ class VannaBase(ABC):
1522
1701
  return None
1523
1702
  else:
1524
1703
  return sql, None, None
1525
- return sql, df, None
1704
+ return sql, df, fig
1526
1705
 
1527
1706
  def train(
1528
1707
  self,
vanna/mock/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from .embedding import MockEmbedding
2
+ from .llm import MockLLM
3
+ from .vectordb import MockVectorDB
@@ -0,0 +1,11 @@
1
+ from typing import List
2
+
3
+ from ..base import VannaBase
4
+
5
+
6
+ class MockEmbedding(VannaBase):
7
+ def __init__(self, config=None):
8
+ pass
9
+
10
+ def generate_embedding(self, data: str, **kwargs) -> List[float]:
11
+ return [1.0, 2.0, 3.0, 4.0, 5.0]
vanna/mock/llm.py ADDED
@@ -0,0 +1,19 @@
1
+
2
+ from ..base import VannaBase
3
+
4
+
5
+ class MockLLM(VannaBase):
6
+ def __init__(self, config=None):
7
+ pass
8
+
9
+ def system_message(self, message: str) -> any:
10
+ return {"role": "system", "content": message}
11
+
12
+ def user_message(self, message: str) -> any:
13
+ return {"role": "user", "content": message}
14
+
15
+ def assistant_message(self, message: str) -> any:
16
+ return {"role": "assistant", "content": message}
17
+
18
+ def submit_prompt(self, prompt, **kwargs) -> str:
19
+ return "Mock LLM response"
vanna/mock/vectordb.py ADDED
@@ -0,0 +1,55 @@
1
+ import pandas as pd
2
+
3
+ from ..base import VannaBase
4
+
5
+
6
+ class MockVectorDB(VannaBase):
7
+ def __init__(self, config=None):
8
+ pass
9
+
10
+ def _get_id(self, value: str, **kwargs) -> str:
11
+ # Hash the value and return the ID
12
+ return str(hash(value))
13
+
14
+ def add_ddl(self, ddl: str, **kwargs) -> str:
15
+ return self._get_id(ddl)
16
+
17
+ def add_documentation(self, doc: str, **kwargs) -> str:
18
+ return self._get_id(doc)
19
+
20
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
21
+ return self._get_id(question)
22
+
23
+ def get_related_ddl(self, question: str, **kwargs) -> list:
24
+ return []
25
+
26
+ def get_related_documentation(self, question: str, **kwargs) -> list:
27
+ return []
28
+
29
+ def get_similar_question_sql(self, question: str, **kwargs) -> list:
30
+ return []
31
+
32
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
33
+ return pd.DataFrame({'id': {0: '19546-ddl',
34
+ 1: '91597-sql',
35
+ 2: '133976-sql',
36
+ 3: '59851-doc',
37
+ 4: '73046-sql'},
38
+ 'training_data_type': {0: 'ddl',
39
+ 1: 'sql',
40
+ 2: 'sql',
41
+ 3: 'documentation',
42
+ 4: 'sql'},
43
+ 'question': {0: None,
44
+ 1: 'What are the top selling genres?',
45
+ 2: 'What are the low 7 artists by sales?',
46
+ 3: None,
47
+ 4: 'What is the total sales for each customer?'},
48
+ 'content': {0: 'CREATE TABLE [Invoice]\n(\n [InvoiceId] INTEGER NOT NULL,\n [CustomerId] INTEGER NOT NULL,\n [InvoiceDate] DATETIME NOT NULL,\n [BillingAddress] NVARCHAR(70),\n [BillingCity] NVARCHAR(40),\n [BillingState] NVARCHAR(40),\n [BillingCountry] NVARCHAR(40),\n [BillingPostalCode] NVARCHAR(10),\n [Total] NUMERIC(10,2) NOT NULL,\n CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]),\n FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n)',
49
+ 1: 'SELECT g.Name AS Genre, SUM(il.Quantity) AS TotalSales\nFROM Genre g\nJOIN Track t ON g.GenreId = t.GenreId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY g.GenreId, g.Name\nORDER BY TotalSales DESC;',
50
+ 2: 'SELECT a.ArtistId, a.Name, SUM(il.Quantity) AS TotalSales\nFROM Artist a\nINNER JOIN Album al ON a.ArtistId = al.ArtistId\nINNER JOIN Track t ON al.AlbumId = t.AlbumId\nINNER JOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY a.ArtistId, a.Name\nORDER BY TotalSales ASC\nLIMIT 7;',
51
+ 3: 'This is a SQLite database. For dates rememeber to use SQLite syntax.',
52
+ 4: 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName;'}})
53
+
54
+ def remove_training_data(id: str, **kwargs) -> bool:
55
+ return True
@@ -155,6 +155,11 @@ class OpenSearch_VectorStore(VannaBase):
155
155
  else:
156
156
  max_retries = 10
157
157
 
158
+ if config is not None and "es_http_compress" in config:
159
+ es_http_compress = config["es_http_compress"]
160
+ else:
161
+ es_http_compress = False
162
+
158
163
  print("OpenSearch_VectorStore initialized with es_urls: ", es_urls,
159
164
  " host: ", host, " port: ", port, " ssl: ", ssl, " verify_certs: ",
160
165
  verify_certs, " timeout: ", timeout, " max_retries: ", max_retries)
@@ -162,7 +167,7 @@ class OpenSearch_VectorStore(VannaBase):
162
167
  # Initialize the OpenSearch client by passing a list of URLs
163
168
  self.client = OpenSearch(
164
169
  hosts=[es_urls],
165
- http_compress=True,
170
+ http_compress=es_http_compress,
166
171
  use_ssl=ssl,
167
172
  verify_certs=verify_certs,
168
173
  timeout=timeout,
@@ -175,7 +180,7 @@ class OpenSearch_VectorStore(VannaBase):
175
180
  # Initialize the OpenSearch client by passing a host and port
176
181
  self.client = OpenSearch(
177
182
  hosts=[{'host': host, 'port': port}],
178
- http_compress=True,
183
+ http_compress=es_http_compress,
179
184
  use_ssl=ssl,
180
185
  verify_certs=verify_certs,
181
186
  timeout=timeout,
@@ -267,6 +272,7 @@ class OpenSearch_VectorStore(VannaBase):
267
272
  }
268
273
  }
269
274
  }
275
+ print(query)
270
276
  response = self.client.search(index=self.ddl_index, body=query,
271
277
  **kwargs)
272
278
  return [hit['_source']['ddl'] for hit in response['hits']['hits']]
@@ -279,6 +285,7 @@ class OpenSearch_VectorStore(VannaBase):
279
285
  }
280
286
  }
281
287
  }
288
+ print(query)
282
289
  response = self.client.search(index=self.document_index,
283
290
  body=query,
284
291
  **kwargs)
@@ -292,6 +299,7 @@ class OpenSearch_VectorStore(VannaBase):
292
299
  }
293
300
  }
294
301
  }
302
+ print(query)
295
303
  response = self.client.search(index=self.question_sql_index,
296
304
  body=query,
297
305
  **kwargs)
@@ -307,6 +315,7 @@ class OpenSearch_VectorStore(VannaBase):
307
315
  body={"query": {"match_all": {}}},
308
316
  size=1000
309
317
  )
318
+ print(query)
310
319
  # records = [hit['_source'] for hit in response['hits']['hits']]
311
320
  for hit in response['hits']['hits']:
312
321
  data.append(
@@ -0,0 +1,3 @@
1
+ from .pinecone_vector import PineconeDB_VectorStore
2
+
3
+ __all__ = ["PineconeDB_VectorStore"]
@@ -0,0 +1,275 @@
1
+ import json
2
+ from typing import List
3
+
4
+ from pinecone import Pinecone, PodSpec, ServerlessSpec
5
+ import pandas as pd
6
+ from ..base import VannaBase
7
+ from ..utils import deterministic_uuid
8
+
9
+ from fastembed import TextEmbedding
10
+
11
+
12
+ class PineconeDB_VectorStore(VannaBase):
13
+ """
14
+ Vectorstore using PineconeDB
15
+
16
+ Args:
17
+ config (dict): Configuration dictionary. Defaults to {}. You must provide either a Pinecone Client or an API key in the config.
18
+ - client (Pinecone, optional): Pinecone client. Defaults to None.
19
+ - api_key (str, optional): Pinecone API key. Defaults to None.
20
+ - n_results (int, optional): Number of results to return. Defaults to 10.
21
+ - dimensions (int, optional): Dimensions of the embeddings. Defaults to 384 which coresponds to the dimensions of BAAI/bge-small-en-v1.5.
22
+ - fastembed_model (str, optional): Fastembed model to use. Defaults to "BAAI/bge-small-en-v1.5".
23
+ - documentation_namespace (str, optional): Namespace for documentation. Defaults to "documentation".
24
+ - distance_metric (str, optional): Distance metric to use. Defaults to "cosine".
25
+ - ddl_namespace (str, optional): Namespace for DDL. Defaults to "ddl".
26
+ - sql_namespace (str, optional): Namespace for SQL. Defaults to "sql".
27
+ - index_name (str, optional): Name of the index. Defaults to "vanna-index".
28
+ - metadata_config (dict, optional): Metadata configuration if using a pinecone pod. Defaults to {}.
29
+ - server_type (str, optional): Type of Pinecone server to use. Defaults to "serverless". Options are "serverless" or "pod".
30
+ - podspec (PodSpec, optional): PodSpec configuration if using a pinecone pod. Defaults to PodSpec(environment="us-west-2", pod_type="p1.x1", metadata_config=self.metadata_config).
31
+ - serverless_spec (ServerlessSpec, optional): ServerlessSpec configuration if using a pinecone serverless index. Defaults to ServerlessSpec(cloud="aws", region="us-west-2").
32
+ Raises:
33
+ ValueError: If config is None, api_key is not provided OR client is not provided, client is not an instance of Pinecone, or server_type is not "serverless" or "pod".
34
+ """
35
+
36
+ def __init__(self, config=None):
37
+ VannaBase.__init__(self, config=config)
38
+ if config is None:
39
+ raise ValueError(
40
+ "config is required, pass either a Pinecone client or an API key in the config."
41
+ )
42
+ client = config.get("client")
43
+ api_key = config.get("api_key")
44
+ if not api_key and not client:
45
+ raise ValueError(
46
+ "api_key is required in config or pass a configured client"
47
+ )
48
+ if not client and api_key:
49
+ self._client = Pinecone(api_key=api_key)
50
+ elif not isinstance(client, Pinecone):
51
+ raise ValueError("client must be an instance of Pinecone")
52
+ else:
53
+ self._client = client
54
+
55
+ self.n_results = config.get("n_results", 10)
56
+ self.dimensions = config.get("dimensions", 384)
57
+ self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
58
+ self.documentation_namespace = config.get(
59
+ "documentation_namespace", "documentation"
60
+ )
61
+ self.distance_metric = config.get("distance_metric", "cosine")
62
+ self.ddl_namespace = config.get("ddl_namespace", "ddl")
63
+ self.sql_namespace = config.get("sql_namespace", "sql")
64
+ self.index_name = config.get("index_name", "vanna-index")
65
+ self.metadata_config = config.get("metadata_config", {})
66
+ self.server_type = config.get("server_type", "serverless")
67
+ if self.server_type not in ["serverless", "pod"]:
68
+ raise ValueError("server_type must be either 'serverless' or 'pod'")
69
+ self.podspec = config.get(
70
+ "podspec",
71
+ PodSpec(
72
+ environment="us-west-2",
73
+ pod_type="p1.x1",
74
+ metadata_config=self.metadata_config,
75
+ ),
76
+ )
77
+ self.serverless_spec = config.get(
78
+ "serverless_spec", ServerlessSpec(cloud="aws", region="us-west-2")
79
+ )
80
+ self._setup_index()
81
+
82
+ def _set_index_host(self, host: str) -> None:
83
+ self.Index = self._client.Index(host=host)
84
+
85
+ def _setup_index(self) -> None:
86
+ existing_indexes = self._get_indexes()
87
+ if self.index_name not in existing_indexes and self.server_type == "serverless":
88
+ self._client.create_index(
89
+ name=self.index_name,
90
+ dimension=self.dimensions,
91
+ metric=self.distance_metric,
92
+ spec=self.serverless_spec,
93
+ )
94
+ pinecone_index_host = self._client.describe_index(self.index_name)["host"]
95
+ self._set_index_host(pinecone_index_host)
96
+ elif self.index_name not in existing_indexes and self.server_type == "pod":
97
+ self._client.create_index(
98
+ name=self.index_name,
99
+ dimension=self.dimensions,
100
+ metric=self.distance_metric,
101
+ spec=self.podspec,
102
+ )
103
+ pinecone_index_host = self._client.describe_index(self.index_name)["host"]
104
+ self._set_index_host(pinecone_index_host)
105
+ else:
106
+ pinecone_index_host = self._client.describe_index(self.index_name)["host"]
107
+ self._set_index_host(pinecone_index_host)
108
+
109
+ def _get_indexes(self) -> list:
110
+ return [index["name"] for index in self._client.list_indexes()]
111
+
112
+ def _check_if_embedding_exists(self, id: str, namespace: str) -> bool:
113
+ fetch_response = self.Index.fetch(ids=[id], namespace=namespace)
114
+ if fetch_response["vectors"] == {}:
115
+ return False
116
+ return True
117
+
118
+ def add_ddl(self, ddl: str, **kwargs) -> str:
119
+ id = deterministic_uuid(ddl) + "-ddl"
120
+ if self._check_if_embedding_exists(id=id, namespace=self.ddl_namespace):
121
+ print(f"DDL with id: {id} already exists in the index. Skipping...")
122
+ return id
123
+ self.Index.upsert(
124
+ vectors=[(id, self.generate_embedding(ddl), {"ddl": ddl})],
125
+ namespace=self.ddl_namespace,
126
+ )
127
+ return id
128
+
129
+ def add_documentation(self, doc: str, **kwargs) -> str:
130
+ id = deterministic_uuid(doc) + "-doc"
131
+
132
+ if self._check_if_embedding_exists(
133
+ id=id, namespace=self.documentation_namespace
134
+ ):
135
+ print(
136
+ f"Documentation with id: {id} already exists in the index. Skipping..."
137
+ )
138
+ return id
139
+ self.Index.upsert(
140
+ vectors=[(id, self.generate_embedding(doc), {"documentation": doc})],
141
+ namespace=self.documentation_namespace,
142
+ )
143
+ return id
144
+
145
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
146
+ question_sql_json = json.dumps(
147
+ {
148
+ "question": question,
149
+ "sql": sql,
150
+ },
151
+ ensure_ascii=False,
152
+ )
153
+ id = deterministic_uuid(question_sql_json) + "-sql"
154
+ if self._check_if_embedding_exists(id=id, namespace=self.sql_namespace):
155
+ print(
156
+ f"Question-SQL with id: {id} already exists in the index. Skipping..."
157
+ )
158
+ return id
159
+ self.Index.upsert(
160
+ vectors=[
161
+ (
162
+ id,
163
+ self.generate_embedding(question_sql_json),
164
+ {"sql": question_sql_json},
165
+ )
166
+ ],
167
+ namespace=self.sql_namespace,
168
+ )
169
+ return id
170
+
171
+ def get_related_ddl(self, question: str, **kwargs) -> list:
172
+ res = self.Index.query(
173
+ namespace=self.ddl_namespace,
174
+ vector=self.generate_embedding(question),
175
+ top_k=self.n_results,
176
+ include_values=True,
177
+ include_metadata=True,
178
+ )
179
+ return [match["metadata"]["ddl"] for match in res["matches"]] if res else []
180
+
181
+ def get_related_documentation(self, question: str, **kwargs) -> list:
182
+ res = self.Index.query(
183
+ namespace=self.documentation_namespace,
184
+ vector=self.generate_embedding(question),
185
+ top_k=self.n_results,
186
+ include_values=True,
187
+ include_metadata=True,
188
+ )
189
+ return (
190
+ [match["metadata"]["documentation"] for match in res["matches"]]
191
+ if res
192
+ else []
193
+ )
194
+
195
+ def get_similar_question_sql(self, question: str, **kwargs) -> list:
196
+ res = self.Index.query(
197
+ namespace=self.sql_namespace,
198
+ vector=self.generate_embedding(question),
199
+ top_k=self.n_results,
200
+ include_values=True,
201
+ include_metadata=True,
202
+ )
203
+ return (
204
+ [
205
+ {
206
+ key: value
207
+ for key, value in json.loads(match["metadata"]["sql"]).items()
208
+ }
209
+ for match in res["matches"]
210
+ ]
211
+ if res
212
+ else []
213
+ )
214
+
215
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
216
+ # Pinecone does not support getting all vectors in a namespace, so we have to query for the top_k vectors with a dummy vector
217
+ df = pd.DataFrame()
218
+ namespaces = {
219
+ "sql": self.sql_namespace,
220
+ "ddl": self.ddl_namespace,
221
+ "documentation": self.documentation_namespace,
222
+ }
223
+
224
+ for data_type, namespace in namespaces.items():
225
+ data = self.Index.query(
226
+ top_k=10000, # max results that pinecone allows
227
+ namespace=namespace,
228
+ include_values=True,
229
+ include_metadata=True,
230
+ vector=[0.0] * self.dimensions,
231
+ )
232
+
233
+ if data is not None:
234
+ id_list = [match["id"] for match in data["matches"]]
235
+ content_list = [
236
+ match["metadata"][data_type] for match in data["matches"]
237
+ ]
238
+ question_list = [
239
+ (
240
+ json.loads(match["metadata"][data_type])["question"]
241
+ if data_type == "sql"
242
+ else None
243
+ )
244
+ for match in data["matches"]
245
+ ]
246
+
247
+ df_data = pd.DataFrame(
248
+ {
249
+ "id": id_list,
250
+ "question": question_list,
251
+ "content": content_list,
252
+ }
253
+ )
254
+ df_data["training_data_type"] = data_type
255
+ df = pd.concat([df, df_data])
256
+
257
+ return df
258
+
259
+ def remove_training_data(self, id: str, **kwargs) -> bool:
260
+ if id.endswith("-sql"):
261
+ self.Index.delete(ids=[id], namespace=self.sql_namespace)
262
+ return True
263
+ elif id.endswith("-ddl"):
264
+ self.Index.delete(ids=[id], namespace=self.ddl_namespace)
265
+ return True
266
+ elif id.endswith("-doc"):
267
+ self.Index.delete(ids=[id], namespace=self.documentation_namespace)
268
+ return True
269
+ else:
270
+ return False
271
+
272
+ def generate_embedding(self, data: str, **kwargs) -> List[float]:
273
+ embedding_model = TextEmbedding(model_name=self.fastembed_model)
274
+ embedding = next(embedding_model.embed(data))
275
+ return embedding.tolist()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.5.3
3
+ Version: 0.5.5
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -38,9 +38,11 @@ Requires-Dist: httpx ; extra == "all"
38
38
  Requires-Dist: opensearch-py ; extra == "all"
39
39
  Requires-Dist: opensearch-dsl ; extra == "all"
40
40
  Requires-Dist: transformers ; extra == "all"
41
+ Requires-Dist: pinecone-client ; extra == "all"
41
42
  Requires-Dist: anthropic ; extra == "anthropic"
42
43
  Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
43
44
  Requires-Dist: chromadb ; extra == "chromadb"
45
+ Requires-Dist: clickhouse_driver ; extra == "clickhouse"
44
46
  Requires-Dist: duckdb ; extra == "duckdb"
45
47
  Requires-Dist: google-generativeai ; extra == "gemini"
46
48
  Requires-Dist: google-generativeai ; extra == "google"
@@ -54,6 +56,8 @@ Requires-Dist: httpx ; extra == "ollama"
54
56
  Requires-Dist: openai ; extra == "openai"
55
57
  Requires-Dist: opensearch-py ; extra == "opensearch"
56
58
  Requires-Dist: opensearch-dsl ; extra == "opensearch"
59
+ Requires-Dist: pinecone-client ; extra == "pinecone"
60
+ Requires-Dist: fastembed ; extra == "pinecone"
57
61
  Requires-Dist: psycopg2-binary ; extra == "postgres"
58
62
  Requires-Dist: db-dtypes ; extra == "postgres"
59
63
  Requires-Dist: qdrant-client ; extra == "qdrant"
@@ -68,6 +72,7 @@ Provides-Extra: all
68
72
  Provides-Extra: anthropic
69
73
  Provides-Extra: bigquery
70
74
  Provides-Extra: chromadb
75
+ Provides-Extra: clickhouse
71
76
  Provides-Extra: duckdb
72
77
  Provides-Extra: gemini
73
78
  Provides-Extra: google
@@ -78,6 +83,7 @@ Provides-Extra: mysql
78
83
  Provides-Extra: ollama
79
84
  Provides-Extra: openai
80
85
  Provides-Extra: opensearch
86
+ Provides-Extra: pinecone
81
87
  Provides-Extra: postgres
82
88
  Provides-Extra: qdrant
83
89
  Provides-Extra: snowflake
@@ -8,7 +8,7 @@ vanna/ZhipuAI/__init__.py,sha256=NlsijtcZp5Tj9jkOe9fNcOQND_QsGgu7otODsCLBPr0,116
8
8
  vanna/anthropic/__init__.py,sha256=85s_2mAyyPxc0T_0JEvYeAkEKWJwkwqoyUwSC5dw9Gk,43
9
9
  vanna/anthropic/anthropic_chat.py,sha256=Wk0o-NMW1uvR2fhSWxrR_2FqNh-dLprNG4uuVqpqAkY,2615
10
10
  vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
11
- vanna/base/base.py,sha256=hUvb94NSVjSIsIa_x38xx8OOdBwL4mmuGmuDj5ednjo,65593
11
+ vanna/base/base.py,sha256=YCL9MhhrGeoVv9da85NdWvEQtnqfkWSyj5AZo_wQ0TU,70853
12
12
  vanna/chromadb/__init__.py,sha256=-iL0nW_g4uM8nWKMuWnNePfN4nb9uk8P3WzGvezOqRg,50
13
13
  vanna/chromadb/chromadb_vector.py,sha256=eKyPck99Y6Jt-BNWojvxLG-zvAERzLSm-3zY-bKXvaA,8792
14
14
  vanna/exceptions/__init__.py,sha256=dJ65xxxZh1lqBeg6nz6Tq_r34jLVmjvBvPO9Q6hFaQ8,685
@@ -23,13 +23,19 @@ vanna/marqo/__init__.py,sha256=GaAWtJ0B-H5rTY607iLCCrLD7T0zMYM5qWIomEB9gLk,37
23
23
  vanna/marqo/marqo.py,sha256=W7WTtzWp4RJjZVy6OaXHqncUBIPdI4Q7qH7BRCxZ1_A,5242
24
24
  vanna/mistral/__init__.py,sha256=70rTY-69Z2ehkkMj84dNMCukPo6AWdflBGvIB_pztS0,29
25
25
  vanna/mistral/mistral.py,sha256=DAEqAT9SzC91rfMM_S3SuzBZ34MrKHw9qAj6EP2MGVk,1508
26
+ vanna/mock/__init__.py,sha256=nYR2WfcV5NdwpK3V64QGOWHBGc3ESN9uV68JLS76aRw,97
27
+ vanna/mock/embedding.py,sha256=ggnP7KuPh6dlqeUFtoN8t0J0P7_yRNtn9rIq6h8g8-w,250
28
+ vanna/mock/llm.py,sha256=WpG9f1pKZftPBHqgIYdARKB2Z9DZhOALYOJWoOjjFEc,518
29
+ vanna/mock/vectordb.py,sha256=h45znfYMUnttE2BBC8v6TKeMaA58pFJL-5B3OGeRNFI,2681
26
30
  vanna/ollama/__init__.py,sha256=4xyu8aHPdnEHg5a-QAMwr5o0ns5wevsp_zkI-ndMO2k,27
27
31
  vanna/ollama/ollama.py,sha256=rXa7cfvdlO1E5SLysXIl3IZpIaA2r0RBvV5jX2-upiE,3794
28
32
  vanna/openai/__init__.py,sha256=tGkeQ7wTIPsando7QhoSHehtoQVdYLwFbKNlSmCmNeQ,86
29
33
  vanna/openai/openai_chat.py,sha256=lm-hUsQxu6Q1t06A2csC037zI4VkMk0wFbQ-_Lj74Wg,4764
30
34
  vanna/openai/openai_embeddings.py,sha256=g4pNh9LVcYP9wOoO8ecaccDFWmCUYMInebfHucAa2Gc,1260
31
35
  vanna/opensearch/__init__.py,sha256=0unDevWOTs7o8S79TOHUKF1mSiuQbBUVm-7k9jV5WW4,54
32
- vanna/opensearch/opensearch_vector.py,sha256=90-nJuRkgOBh49VH5Lknw5qfBAlfSuqvUAqvHFbfa7g,11980
36
+ vanna/opensearch/opensearch_vector.py,sha256=VhIcrSyNzWR9ZrqrJnyGFOyuQZs3swfbhr8QyVGI0eI,12226
37
+ vanna/pinecone/__init__.py,sha256=eO5l8aX8vKL6aIUMgAXGPt1jdqKxB_Hic6cmoVAUrD0,90
38
+ vanna/pinecone/pinecone_vector.py,sha256=mpq1lzo3KRj2QfJEw8pwFclFQK1Oi_Nx-lDkx9Gp0mw,11448
33
39
  vanna/qdrant/__init__.py,sha256=PX_OsDOiPMvwCJ2iGER1drSdQ9AyM8iN5PEBhRb6qqY,73
34
40
  vanna/qdrant/qdrant.py,sha256=6M00nMiuOuftTDf3NsOrOcG9BA4DlIIDck2MNp9iEyg,12613
35
41
  vanna/types/__init__.py,sha256=Qhn_YscKtJh7mFPCyCDLa2K8a4ORLMGVnPpTbv9uB2U,4957
@@ -37,6 +43,6 @@ vanna/vannadb/__init__.py,sha256=C6UkYocmO6dmzfPKZaWojN0mI5YlZZ9VIbdcquBE58A,48
37
43
  vanna/vannadb/vannadb_vector.py,sha256=9YwTO3Lh5owWQE7KPMBqLp2EkiGV0RC1sEYhslzJzgI,6168
38
44
  vanna/vllm/__init__.py,sha256=aNlUkF9tbURdeXAJ8ytuaaF1gYwcG3ny1MfNl_cwQYg,23
39
45
  vanna/vllm/vllm.py,sha256=QerC3xF5eNzE_nGBDl6YrPYF4WYnjf0hHxxlDWdKX-0,2427
40
- vanna-0.5.3.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
41
- vanna-0.5.3.dist-info/METADATA,sha256=JPRFSFUNqsLazznvDMWQIUy_bL9CKRQKCSgPlIwpucU,11248
42
- vanna-0.5.3.dist-info/RECORD,,
46
+ vanna-0.5.5.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
47
+ vanna-0.5.5.dist-info/METADATA,sha256=gmPrDKsawOtYazNHG5CQ_iURmdVLt1Jc8Rfwneue1w0,11505
48
+ vanna-0.5.5.dist-info/RECORD,,
File without changes