vanna 0.3.3__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {vanna-0.3.3 → vanna-0.4.0}/PKG-INFO +6 -1
- {vanna-0.3.3 → vanna-0.4.0}/pyproject.toml +3 -2
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +3 -3
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/base/base.py +105 -4
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/chromadb/chromadb_vector.py +21 -15
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/flask/__init__.py +117 -43
- vanna-0.4.0/src/vanna/flask/assets.py +38 -0
- vanna-0.4.0/src/vanna/flask/auth.py +55 -0
- vanna-0.4.0/src/vanna/google/__init__.py +1 -0
- vanna-0.4.0/src/vanna/google/gemini_chat.py +52 -0
- vanna-0.4.0/src/vanna/remote.py +77 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/vannadb/vannadb_vector.py +58 -27
- vanna-0.3.3/src/vanna/flask/assets.py +0 -38
- vanna-0.3.3/src/vanna/remote.py +0 -455
- {vanna-0.3.3 → vanna-0.4.0}/README.md +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/ZhipuAI/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/anthropic/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/anthropic/anthropic_chat.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/base/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/chromadb/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/exceptions/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/local.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/marqo/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/marqo/marqo.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/mistral/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/mistral/mistral.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/ollama/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/ollama/ollama.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/openai/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/openai/openai_chat.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/openai/openai_embeddings.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/types/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/utils.py +0 -0
- {vanna-0.3.3 → vanna-0.4.0}/src/vanna/vannadb/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vanna
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: Generate SQL queries from natural language
|
|
5
5
|
Author-email: Zain Hoda <zain@vanna.ai>
|
|
6
6
|
Requires-Python: >=3.9
|
|
@@ -28,11 +28,15 @@ Requires-Dist: chromadb ; extra == "all"
|
|
|
28
28
|
Requires-Dist: anthropic ; extra == "all"
|
|
29
29
|
Requires-Dist: zhipuai ; extra == "all"
|
|
30
30
|
Requires-Dist: marqo ; extra == "all"
|
|
31
|
+
Requires-Dist: google-generativeai ; extra == "all"
|
|
32
|
+
Requires-Dist: google-cloud-aiplatform ; extra == "all"
|
|
31
33
|
Requires-Dist: anthropic ; extra == "anthropic"
|
|
32
34
|
Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
|
|
33
35
|
Requires-Dist: chromadb ; extra == "chromadb"
|
|
34
36
|
Requires-Dist: duckdb ; extra == "duckdb"
|
|
35
37
|
Requires-Dist: google-generativeai ; extra == "gemini"
|
|
38
|
+
Requires-Dist: google-generativeai ; extra == "google"
|
|
39
|
+
Requires-Dist: google-cloud-aiplatform ; extra == "google"
|
|
36
40
|
Requires-Dist: marqo ; extra == "marqo"
|
|
37
41
|
Requires-Dist: mistralai ; extra == "mistralai"
|
|
38
42
|
Requires-Dist: PyMySQL ; extra == "mysql"
|
|
@@ -50,6 +54,7 @@ Provides-Extra: bigquery
|
|
|
50
54
|
Provides-Extra: chromadb
|
|
51
55
|
Provides-Extra: duckdb
|
|
52
56
|
Provides-Extra: gemini
|
|
57
|
+
Provides-Extra: google
|
|
53
58
|
Provides-Extra: marqo
|
|
54
59
|
Provides-Extra: mistralai
|
|
55
60
|
Provides-Extra: mysql
|
|
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "vanna"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.4.0"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name="Zain Hoda", email="zain@vanna.ai" },
|
|
10
10
|
]
|
|
@@ -31,7 +31,8 @@ mysql = ["PyMySQL"]
|
|
|
31
31
|
bigquery = ["google-cloud-bigquery"]
|
|
32
32
|
snowflake = ["snowflake-connector-python"]
|
|
33
33
|
duckdb = ["duckdb"]
|
|
34
|
-
|
|
34
|
+
google = ["google-generativeai", "google-cloud-aiplatform"]
|
|
35
|
+
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform"]
|
|
35
36
|
test = ["tox"]
|
|
36
37
|
chromadb = ["chromadb"]
|
|
37
38
|
openai = ["openai"]
|
|
@@ -40,7 +40,7 @@ class ZhipuAI_Chat(VannaBase):
|
|
|
40
40
|
initial_prompt: str, ddl_list: List[str], max_tokens: int = 14000
|
|
41
41
|
) -> str:
|
|
42
42
|
if len(ddl_list) > 0:
|
|
43
|
-
initial_prompt +=
|
|
43
|
+
initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
44
44
|
|
|
45
45
|
for ddl in ddl_list:
|
|
46
46
|
if (
|
|
@@ -57,7 +57,7 @@ class ZhipuAI_Chat(VannaBase):
|
|
|
57
57
|
initial_prompt: str, documentation_List: List[str], max_tokens: int = 14000
|
|
58
58
|
) -> str:
|
|
59
59
|
if len(documentation_List) > 0:
|
|
60
|
-
initial_prompt +=
|
|
60
|
+
initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
61
61
|
|
|
62
62
|
for documentation in documentation_List:
|
|
63
63
|
if (
|
|
@@ -74,7 +74,7 @@ class ZhipuAI_Chat(VannaBase):
|
|
|
74
74
|
initial_prompt: str, sql_List: List[str], max_tokens: int = 14000
|
|
75
75
|
) -> str:
|
|
76
76
|
if len(sql_List) > 0:
|
|
77
|
-
initial_prompt +=
|
|
77
|
+
initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
78
78
|
|
|
79
79
|
for question in sql_List:
|
|
80
80
|
if (
|
|
@@ -124,6 +124,18 @@ class VannaBase(ABC):
|
|
|
124
124
|
return self.extract_sql(llm_response)
|
|
125
125
|
|
|
126
126
|
def extract_sql(self, llm_response: str) -> str:
|
|
127
|
+
# If the llm_response is not markdown formatted, extract sql by finding select and ; in the response
|
|
128
|
+
sql = re.search(r"SELECT.*?;", llm_response, re.DOTALL)
|
|
129
|
+
if sql:
|
|
130
|
+
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}"
|
|
131
|
+
)
|
|
132
|
+
return sql.group(0)
|
|
133
|
+
|
|
134
|
+
# If the llm_response contains a CTE (with clause), extract the sql bewteen WITH and ;
|
|
135
|
+
sql = re.search(r"WITH.*?;", llm_response, re.DOTALL)
|
|
136
|
+
if sql:
|
|
137
|
+
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}")
|
|
138
|
+
return sql.group(0)
|
|
127
139
|
# If the llm_response contains a markdown code block, with or without the sql tag, extract the sql from it
|
|
128
140
|
sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL)
|
|
129
141
|
if sql:
|
|
@@ -363,7 +375,7 @@ class VannaBase(ABC):
|
|
|
363
375
|
self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000
|
|
364
376
|
) -> str:
|
|
365
377
|
if len(ddl_list) > 0:
|
|
366
|
-
initial_prompt +=
|
|
378
|
+
initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
367
379
|
|
|
368
380
|
for ddl in ddl_list:
|
|
369
381
|
if (
|
|
@@ -382,7 +394,7 @@ class VannaBase(ABC):
|
|
|
382
394
|
max_tokens: int = 14000,
|
|
383
395
|
) -> str:
|
|
384
396
|
if len(documentation_list) > 0:
|
|
385
|
-
initial_prompt +=
|
|
397
|
+
initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
386
398
|
|
|
387
399
|
for documentation in documentation_list:
|
|
388
400
|
if (
|
|
@@ -398,7 +410,7 @@ class VannaBase(ABC):
|
|
|
398
410
|
self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000
|
|
399
411
|
) -> str:
|
|
400
412
|
if len(sql_list) > 0:
|
|
401
|
-
initial_prompt +=
|
|
413
|
+
initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
402
414
|
|
|
403
415
|
for question in sql_list:
|
|
404
416
|
if (
|
|
@@ -642,6 +654,7 @@ class VannaBase(ABC):
|
|
|
642
654
|
password=password,
|
|
643
655
|
account=account,
|
|
644
656
|
database=database,
|
|
657
|
+
client_session_keep_alive=True
|
|
645
658
|
)
|
|
646
659
|
|
|
647
660
|
def run_sql_snowflake(sql: str) -> pd.DataFrame:
|
|
@@ -890,6 +903,94 @@ class VannaBase(ABC):
|
|
|
890
903
|
self.run_sql_is_set = True
|
|
891
904
|
self.run_sql = run_sql_mysql
|
|
892
905
|
|
|
906
|
+
def connect_to_oracle(
|
|
907
|
+
self,
|
|
908
|
+
user: str = None,
|
|
909
|
+
password: str = None,
|
|
910
|
+
dsn: str = None,
|
|
911
|
+
):
|
|
912
|
+
|
|
913
|
+
"""
|
|
914
|
+
Connect to an Oracle db using oracledb package. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
915
|
+
**Example:**
|
|
916
|
+
```python
|
|
917
|
+
vn.connect_to_oracle(
|
|
918
|
+
user="username",
|
|
919
|
+
password="password",
|
|
920
|
+
dns="host:port/sid",
|
|
921
|
+
)
|
|
922
|
+
```
|
|
923
|
+
Args:
|
|
924
|
+
USER (str): Oracle db user name.
|
|
925
|
+
PASSWORD (str): Oracle db user password.
|
|
926
|
+
DSN (str): Oracle db host ip - host:port/sid.
|
|
927
|
+
"""
|
|
928
|
+
|
|
929
|
+
try:
|
|
930
|
+
import oracledb
|
|
931
|
+
except ImportError:
|
|
932
|
+
|
|
933
|
+
raise DependencyError(
|
|
934
|
+
"You need to install required dependencies to execute this method,"
|
|
935
|
+
" run command: \npip install oracledb"
|
|
936
|
+
)
|
|
937
|
+
|
|
938
|
+
if not dsn:
|
|
939
|
+
dsn = os.getenv("DSN")
|
|
940
|
+
|
|
941
|
+
if not dsn:
|
|
942
|
+
raise ImproperlyConfigured("Please set your Oracle dsn which should include host:port/sid")
|
|
943
|
+
|
|
944
|
+
if not user:
|
|
945
|
+
user = os.getenv("USER")
|
|
946
|
+
|
|
947
|
+
if not user:
|
|
948
|
+
raise ImproperlyConfigured("Please set your Oracle db user")
|
|
949
|
+
|
|
950
|
+
if not password:
|
|
951
|
+
password = os.getenv("PASSWORD")
|
|
952
|
+
|
|
953
|
+
if not password:
|
|
954
|
+
raise ImproperlyConfigured("Please set your Oracle db password")
|
|
955
|
+
|
|
956
|
+
conn = None
|
|
957
|
+
|
|
958
|
+
try:
|
|
959
|
+
conn = oracledb.connect(
|
|
960
|
+
user=user,
|
|
961
|
+
password=password,
|
|
962
|
+
dsn=dsn,
|
|
963
|
+
)
|
|
964
|
+
except oracledb.Error as e:
|
|
965
|
+
raise ValidationError(e)
|
|
966
|
+
|
|
967
|
+
def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]:
|
|
968
|
+
if conn:
|
|
969
|
+
try:
|
|
970
|
+
sql = sql.rstrip()
|
|
971
|
+
if sql.endswith(';'): #fix for a known problem with Oracle db where an extra ; will cause an error.
|
|
972
|
+
sql = sql[:-1]
|
|
973
|
+
|
|
974
|
+
cs = conn.cursor()
|
|
975
|
+
cs.execute(sql)
|
|
976
|
+
results = cs.fetchall()
|
|
977
|
+
|
|
978
|
+
# Create a pandas dataframe from the results
|
|
979
|
+
df = pd.DataFrame(
|
|
980
|
+
results, columns=[desc[0] for desc in cs.description]
|
|
981
|
+
)
|
|
982
|
+
return df
|
|
983
|
+
|
|
984
|
+
except oracledb.Error as e:
|
|
985
|
+
conn.rollback()
|
|
986
|
+
raise ValidationError(e)
|
|
987
|
+
|
|
988
|
+
except Exception as e:
|
|
989
|
+
conn.rollback()
|
|
990
|
+
raise e
|
|
991
|
+
|
|
992
|
+
self.run_sql_is_set = True
|
|
993
|
+
self.run_sql = run_sql_oracle
|
|
893
994
|
|
|
894
995
|
def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None):
|
|
895
996
|
"""
|
|
@@ -1238,7 +1339,7 @@ class VannaBase(ABC):
|
|
|
1238
1339
|
"""
|
|
1239
1340
|
|
|
1240
1341
|
if question and not sql:
|
|
1241
|
-
raise ValidationError(
|
|
1342
|
+
raise ValidationError("Please also provide a SQL query")
|
|
1242
1343
|
|
|
1243
1344
|
if documentation:
|
|
1244
1345
|
print("Adding documentation....")
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import json
|
|
2
|
-
import uuid
|
|
3
2
|
from typing import List
|
|
4
3
|
|
|
5
4
|
import chromadb
|
|
@@ -16,17 +15,16 @@ default_ef = embedding_functions.DefaultEmbeddingFunction()
|
|
|
16
15
|
class ChromaDB_VectorStore(VannaBase):
|
|
17
16
|
def __init__(self, config=None):
|
|
18
17
|
VannaBase.__init__(self, config=config)
|
|
18
|
+
if config is None:
|
|
19
|
+
config = {}
|
|
19
20
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
self.embedding_function = default_ef
|
|
28
|
-
curr_client = "persistent" # defaults to persistent storage
|
|
29
|
-
self.n_results = 10 # defaults to 10 documents
|
|
21
|
+
path = config.get("path", ".")
|
|
22
|
+
self.embedding_function = config.get("embedding_function", default_ef)
|
|
23
|
+
curr_client = config.get("client", "persistent")
|
|
24
|
+
collection_metadata = config.get("collection_metadata", None)
|
|
25
|
+
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
|
|
26
|
+
self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
|
|
27
|
+
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
|
|
30
28
|
|
|
31
29
|
if curr_client == "persistent":
|
|
32
30
|
self.chroma_client = chromadb.PersistentClient(
|
|
@@ -43,13 +41,19 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
43
41
|
raise ValueError(f"Unsupported client was set in config: {curr_client}")
|
|
44
42
|
|
|
45
43
|
self.documentation_collection = self.chroma_client.get_or_create_collection(
|
|
46
|
-
name="documentation",
|
|
44
|
+
name="documentation",
|
|
45
|
+
embedding_function=self.embedding_function,
|
|
46
|
+
metadata=collection_metadata,
|
|
47
47
|
)
|
|
48
48
|
self.ddl_collection = self.chroma_client.get_or_create_collection(
|
|
49
|
-
name="ddl",
|
|
49
|
+
name="ddl",
|
|
50
|
+
embedding_function=self.embedding_function,
|
|
51
|
+
metadata=collection_metadata,
|
|
50
52
|
)
|
|
51
53
|
self.sql_collection = self.chroma_client.get_or_create_collection(
|
|
52
|
-
name="sql",
|
|
54
|
+
name="sql",
|
|
55
|
+
embedding_function=self.embedding_function,
|
|
56
|
+
metadata=collection_metadata,
|
|
53
57
|
)
|
|
54
58
|
|
|
55
59
|
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
@@ -232,7 +236,7 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
232
236
|
return ChromaDB_VectorStore._extract_documents(
|
|
233
237
|
self.sql_collection.query(
|
|
234
238
|
query_texts=[question],
|
|
235
|
-
n_results=self.
|
|
239
|
+
n_results=self.n_results_sql,
|
|
236
240
|
)
|
|
237
241
|
)
|
|
238
242
|
|
|
@@ -240,6 +244,7 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
240
244
|
return ChromaDB_VectorStore._extract_documents(
|
|
241
245
|
self.ddl_collection.query(
|
|
242
246
|
query_texts=[question],
|
|
247
|
+
n_results=self.n_results_ddl,
|
|
243
248
|
)
|
|
244
249
|
)
|
|
245
250
|
|
|
@@ -247,5 +252,6 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
247
252
|
return ChromaDB_VectorStore._extract_documents(
|
|
248
253
|
self.documentation_collection.query(
|
|
249
254
|
query_texts=[question],
|
|
255
|
+
n_results=self.n_results_documentation,
|
|
250
256
|
)
|
|
251
257
|
)
|