vanna 0.3.3__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {vanna-0.3.3 → vanna-0.4.0}/PKG-INFO +6 -1
  2. {vanna-0.3.3 → vanna-0.4.0}/pyproject.toml +3 -2
  3. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +3 -3
  4. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/base/base.py +105 -4
  5. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/chromadb/chromadb_vector.py +21 -15
  6. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/flask/__init__.py +117 -43
  7. vanna-0.4.0/src/vanna/flask/assets.py +38 -0
  8. vanna-0.4.0/src/vanna/flask/auth.py +55 -0
  9. vanna-0.4.0/src/vanna/google/__init__.py +1 -0
  10. vanna-0.4.0/src/vanna/google/gemini_chat.py +52 -0
  11. vanna-0.4.0/src/vanna/remote.py +77 -0
  12. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/vannadb/vannadb_vector.py +58 -27
  13. vanna-0.3.3/src/vanna/flask/assets.py +0 -38
  14. vanna-0.3.3/src/vanna/remote.py +0 -455
  15. {vanna-0.3.3 → vanna-0.4.0}/README.md +0 -0
  16. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
  17. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/ZhipuAI/__init__.py +0 -0
  18. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/__init__.py +0 -0
  19. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/anthropic/__init__.py +0 -0
  20. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/anthropic/anthropic_chat.py +0 -0
  21. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/base/__init__.py +0 -0
  22. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/chromadb/__init__.py +0 -0
  23. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/exceptions/__init__.py +0 -0
  24. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/local.py +0 -0
  25. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/marqo/__init__.py +0 -0
  26. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/marqo/marqo.py +0 -0
  27. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/mistral/__init__.py +0 -0
  28. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/mistral/mistral.py +0 -0
  29. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/ollama/__init__.py +0 -0
  30. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/ollama/ollama.py +0 -0
  31. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/openai/__init__.py +0 -0
  32. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/openai/openai_chat.py +0 -0
  33. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/openai/openai_embeddings.py +0 -0
  34. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/types/__init__.py +0 -0
  35. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/utils.py +0 -0
  36. {vanna-0.3.3 → vanna-0.4.0}/src/vanna/vannadb/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.3.3
3
+ Version: 0.4.0
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -28,11 +28,15 @@ Requires-Dist: chromadb ; extra == "all"
28
28
  Requires-Dist: anthropic ; extra == "all"
29
29
  Requires-Dist: zhipuai ; extra == "all"
30
30
  Requires-Dist: marqo ; extra == "all"
31
+ Requires-Dist: google-generativeai ; extra == "all"
32
+ Requires-Dist: google-cloud-aiplatform ; extra == "all"
31
33
  Requires-Dist: anthropic ; extra == "anthropic"
32
34
  Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
33
35
  Requires-Dist: chromadb ; extra == "chromadb"
34
36
  Requires-Dist: duckdb ; extra == "duckdb"
35
37
  Requires-Dist: google-generativeai ; extra == "gemini"
38
+ Requires-Dist: google-generativeai ; extra == "google"
39
+ Requires-Dist: google-cloud-aiplatform ; extra == "google"
36
40
  Requires-Dist: marqo ; extra == "marqo"
37
41
  Requires-Dist: mistralai ; extra == "mistralai"
38
42
  Requires-Dist: PyMySQL ; extra == "mysql"
@@ -50,6 +54,7 @@ Provides-Extra: bigquery
50
54
  Provides-Extra: chromadb
51
55
  Provides-Extra: duckdb
52
56
  Provides-Extra: gemini
57
+ Provides-Extra: google
53
58
  Provides-Extra: marqo
54
59
  Provides-Extra: mistralai
55
60
  Provides-Extra: mysql
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
4
4
 
5
5
  [project]
6
6
  name = "vanna"
7
- version = "0.3.3"
7
+ version = "0.4.0"
8
8
  authors = [
9
9
  { name="Zain Hoda", email="zain@vanna.ai" },
10
10
  ]
@@ -31,7 +31,8 @@ mysql = ["PyMySQL"]
31
31
  bigquery = ["google-cloud-bigquery"]
32
32
  snowflake = ["snowflake-connector-python"]
33
33
  duckdb = ["duckdb"]
34
- all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo"]
34
+ google = ["google-generativeai", "google-cloud-aiplatform"]
35
+ all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform"]
35
36
  test = ["tox"]
36
37
  chromadb = ["chromadb"]
37
38
  openai = ["openai"]
@@ -40,7 +40,7 @@ class ZhipuAI_Chat(VannaBase):
40
40
  initial_prompt: str, ddl_list: List[str], max_tokens: int = 14000
41
41
  ) -> str:
42
42
  if len(ddl_list) > 0:
43
- initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
43
+ initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
44
44
 
45
45
  for ddl in ddl_list:
46
46
  if (
@@ -57,7 +57,7 @@ class ZhipuAI_Chat(VannaBase):
57
57
  initial_prompt: str, documentation_List: List[str], max_tokens: int = 14000
58
58
  ) -> str:
59
59
  if len(documentation_List) > 0:
60
- initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
60
+ initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
61
61
 
62
62
  for documentation in documentation_List:
63
63
  if (
@@ -74,7 +74,7 @@ class ZhipuAI_Chat(VannaBase):
74
74
  initial_prompt: str, sql_List: List[str], max_tokens: int = 14000
75
75
  ) -> str:
76
76
  if len(sql_List) > 0:
77
- initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
77
+ initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
78
78
 
79
79
  for question in sql_List:
80
80
  if (
@@ -124,6 +124,18 @@ class VannaBase(ABC):
124
124
  return self.extract_sql(llm_response)
125
125
 
126
126
  def extract_sql(self, llm_response: str) -> str:
127
+ # If the llm_response is not markdown formatted, extract sql by finding select and ; in the response
128
+ sql = re.search(r"SELECT.*?;", llm_response, re.DOTALL)
129
+ if sql:
130
+ self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}"
131
+ )
132
+ return sql.group(0)
133
+
134
+ # If the llm_response contains a CTE (with clause), extract the sql bewteen WITH and ;
135
+ sql = re.search(r"WITH.*?;", llm_response, re.DOTALL)
136
+ if sql:
137
+ self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}")
138
+ return sql.group(0)
127
139
  # If the llm_response contains a markdown code block, with or without the sql tag, extract the sql from it
128
140
  sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL)
129
141
  if sql:
@@ -363,7 +375,7 @@ class VannaBase(ABC):
363
375
  self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000
364
376
  ) -> str:
365
377
  if len(ddl_list) > 0:
366
- initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
378
+ initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
367
379
 
368
380
  for ddl in ddl_list:
369
381
  if (
@@ -382,7 +394,7 @@ class VannaBase(ABC):
382
394
  max_tokens: int = 14000,
383
395
  ) -> str:
384
396
  if len(documentation_list) > 0:
385
- initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
397
+ initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
386
398
 
387
399
  for documentation in documentation_list:
388
400
  if (
@@ -398,7 +410,7 @@ class VannaBase(ABC):
398
410
  self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000
399
411
  ) -> str:
400
412
  if len(sql_list) > 0:
401
- initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
413
+ initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
402
414
 
403
415
  for question in sql_list:
404
416
  if (
@@ -642,6 +654,7 @@ class VannaBase(ABC):
642
654
  password=password,
643
655
  account=account,
644
656
  database=database,
657
+ client_session_keep_alive=True
645
658
  )
646
659
 
647
660
  def run_sql_snowflake(sql: str) -> pd.DataFrame:
@@ -890,6 +903,94 @@ class VannaBase(ABC):
890
903
  self.run_sql_is_set = True
891
904
  self.run_sql = run_sql_mysql
892
905
 
906
+ def connect_to_oracle(
907
+ self,
908
+ user: str = None,
909
+ password: str = None,
910
+ dsn: str = None,
911
+ ):
912
+
913
+ """
914
+ Connect to an Oracle db using oracledb package. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
915
+ **Example:**
916
+ ```python
917
+ vn.connect_to_oracle(
918
+ user="username",
919
+ password="password",
920
+ dns="host:port/sid",
921
+ )
922
+ ```
923
+ Args:
924
+ USER (str): Oracle db user name.
925
+ PASSWORD (str): Oracle db user password.
926
+ DSN (str): Oracle db host ip - host:port/sid.
927
+ """
928
+
929
+ try:
930
+ import oracledb
931
+ except ImportError:
932
+
933
+ raise DependencyError(
934
+ "You need to install required dependencies to execute this method,"
935
+ " run command: \npip install oracledb"
936
+ )
937
+
938
+ if not dsn:
939
+ dsn = os.getenv("DSN")
940
+
941
+ if not dsn:
942
+ raise ImproperlyConfigured("Please set your Oracle dsn which should include host:port/sid")
943
+
944
+ if not user:
945
+ user = os.getenv("USER")
946
+
947
+ if not user:
948
+ raise ImproperlyConfigured("Please set your Oracle db user")
949
+
950
+ if not password:
951
+ password = os.getenv("PASSWORD")
952
+
953
+ if not password:
954
+ raise ImproperlyConfigured("Please set your Oracle db password")
955
+
956
+ conn = None
957
+
958
+ try:
959
+ conn = oracledb.connect(
960
+ user=user,
961
+ password=password,
962
+ dsn=dsn,
963
+ )
964
+ except oracledb.Error as e:
965
+ raise ValidationError(e)
966
+
967
+ def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]:
968
+ if conn:
969
+ try:
970
+ sql = sql.rstrip()
971
+ if sql.endswith(';'): #fix for a known problem with Oracle db where an extra ; will cause an error.
972
+ sql = sql[:-1]
973
+
974
+ cs = conn.cursor()
975
+ cs.execute(sql)
976
+ results = cs.fetchall()
977
+
978
+ # Create a pandas dataframe from the results
979
+ df = pd.DataFrame(
980
+ results, columns=[desc[0] for desc in cs.description]
981
+ )
982
+ return df
983
+
984
+ except oracledb.Error as e:
985
+ conn.rollback()
986
+ raise ValidationError(e)
987
+
988
+ except Exception as e:
989
+ conn.rollback()
990
+ raise e
991
+
992
+ self.run_sql_is_set = True
993
+ self.run_sql = run_sql_oracle
893
994
 
894
995
  def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None):
895
996
  """
@@ -1238,7 +1339,7 @@ class VannaBase(ABC):
1238
1339
  """
1239
1340
 
1240
1341
  if question and not sql:
1241
- raise ValidationError(f"Please also provide a SQL query")
1342
+ raise ValidationError("Please also provide a SQL query")
1242
1343
 
1243
1344
  if documentation:
1244
1345
  print("Adding documentation....")
@@ -1,5 +1,4 @@
1
1
  import json
2
- import uuid
3
2
  from typing import List
4
3
 
5
4
  import chromadb
@@ -16,17 +15,16 @@ default_ef = embedding_functions.DefaultEmbeddingFunction()
16
15
  class ChromaDB_VectorStore(VannaBase):
17
16
  def __init__(self, config=None):
18
17
  VannaBase.__init__(self, config=config)
18
+ if config is None:
19
+ config = {}
19
20
 
20
- if config is not None:
21
- path = config.get("path", ".")
22
- self.embedding_function = config.get("embedding_function", default_ef)
23
- curr_client = config.get("client", "persistent")
24
- self.n_results = config.get("n_results", 10)
25
- else:
26
- path = "."
27
- self.embedding_function = default_ef
28
- curr_client = "persistent" # defaults to persistent storage
29
- self.n_results = 10 # defaults to 10 documents
21
+ path = config.get("path", ".")
22
+ self.embedding_function = config.get("embedding_function", default_ef)
23
+ curr_client = config.get("client", "persistent")
24
+ collection_metadata = config.get("collection_metadata", None)
25
+ self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
26
+ self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
27
+ self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
30
28
 
31
29
  if curr_client == "persistent":
32
30
  self.chroma_client = chromadb.PersistentClient(
@@ -43,13 +41,19 @@ class ChromaDB_VectorStore(VannaBase):
43
41
  raise ValueError(f"Unsupported client was set in config: {curr_client}")
44
42
 
45
43
  self.documentation_collection = self.chroma_client.get_or_create_collection(
46
- name="documentation", embedding_function=self.embedding_function
44
+ name="documentation",
45
+ embedding_function=self.embedding_function,
46
+ metadata=collection_metadata,
47
47
  )
48
48
  self.ddl_collection = self.chroma_client.get_or_create_collection(
49
- name="ddl", embedding_function=self.embedding_function
49
+ name="ddl",
50
+ embedding_function=self.embedding_function,
51
+ metadata=collection_metadata,
50
52
  )
51
53
  self.sql_collection = self.chroma_client.get_or_create_collection(
52
- name="sql", embedding_function=self.embedding_function
54
+ name="sql",
55
+ embedding_function=self.embedding_function,
56
+ metadata=collection_metadata,
53
57
  )
54
58
 
55
59
  def generate_embedding(self, data: str, **kwargs) -> List[float]:
@@ -232,7 +236,7 @@ class ChromaDB_VectorStore(VannaBase):
232
236
  return ChromaDB_VectorStore._extract_documents(
233
237
  self.sql_collection.query(
234
238
  query_texts=[question],
235
- n_results=self.n_results,
239
+ n_results=self.n_results_sql,
236
240
  )
237
241
  )
238
242
 
@@ -240,6 +244,7 @@ class ChromaDB_VectorStore(VannaBase):
240
244
  return ChromaDB_VectorStore._extract_documents(
241
245
  self.ddl_collection.query(
242
246
  query_texts=[question],
247
+ n_results=self.n_results_ddl,
243
248
  )
244
249
  )
245
250
 
@@ -247,5 +252,6 @@ class ChromaDB_VectorStore(VannaBase):
247
252
  return ChromaDB_VectorStore._extract_documents(
248
253
  self.documentation_collection.query(
249
254
  query_texts=[question],
255
+ n_results=self.n_results_documentation,
250
256
  )
251
257
  )