vanna 0.3.3__tar.gz → 0.3.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {vanna-0.3.3 → vanna-0.3.4}/PKG-INFO +1 -1
- {vanna-0.3.3 → vanna-0.3.4}/pyproject.toml +1 -1
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +3 -3
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/base/base.py +104 -4
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/chromadb/chromadb_vector.py +16 -14
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/flask/__init__.py +1 -1
- vanna-0.3.4/src/vanna/remote.py +77 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/vannadb/vannadb_vector.py +58 -27
- vanna-0.3.3/src/vanna/remote.py +0 -455
- {vanna-0.3.3 → vanna-0.3.4}/README.md +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/ZhipuAI/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/anthropic/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/anthropic/anthropic_chat.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/base/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/chromadb/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/exceptions/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/flask/assets.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/local.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/marqo/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/marqo/marqo.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/mistral/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/mistral/mistral.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/ollama/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/ollama/ollama.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/openai/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/openai/openai_chat.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/openai/openai_embeddings.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/types/__init__.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/utils.py +0 -0
- {vanna-0.3.3 → vanna-0.3.4}/src/vanna/vannadb/__init__.py +0 -0
|
@@ -40,7 +40,7 @@ class ZhipuAI_Chat(VannaBase):
|
|
|
40
40
|
initial_prompt: str, ddl_list: List[str], max_tokens: int = 14000
|
|
41
41
|
) -> str:
|
|
42
42
|
if len(ddl_list) > 0:
|
|
43
|
-
initial_prompt +=
|
|
43
|
+
initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
44
44
|
|
|
45
45
|
for ddl in ddl_list:
|
|
46
46
|
if (
|
|
@@ -57,7 +57,7 @@ class ZhipuAI_Chat(VannaBase):
|
|
|
57
57
|
initial_prompt: str, documentation_List: List[str], max_tokens: int = 14000
|
|
58
58
|
) -> str:
|
|
59
59
|
if len(documentation_List) > 0:
|
|
60
|
-
initial_prompt +=
|
|
60
|
+
initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
61
61
|
|
|
62
62
|
for documentation in documentation_List:
|
|
63
63
|
if (
|
|
@@ -74,7 +74,7 @@ class ZhipuAI_Chat(VannaBase):
|
|
|
74
74
|
initial_prompt: str, sql_List: List[str], max_tokens: int = 14000
|
|
75
75
|
) -> str:
|
|
76
76
|
if len(sql_List) > 0:
|
|
77
|
-
initial_prompt +=
|
|
77
|
+
initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
78
78
|
|
|
79
79
|
for question in sql_List:
|
|
80
80
|
if (
|
|
@@ -124,6 +124,18 @@ class VannaBase(ABC):
|
|
|
124
124
|
return self.extract_sql(llm_response)
|
|
125
125
|
|
|
126
126
|
def extract_sql(self, llm_response: str) -> str:
|
|
127
|
+
# If the llm_response is not markdown formatted, extract sql by finding select and ; in the response
|
|
128
|
+
sql = re.search(r"SELECT.*?;", llm_response, re.DOTALL)
|
|
129
|
+
if sql:
|
|
130
|
+
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}"
|
|
131
|
+
)
|
|
132
|
+
return sql.group(0)
|
|
133
|
+
|
|
134
|
+
# If the llm_response contains a CTE (with clause), extract the sql bewteen WITH and ;
|
|
135
|
+
sql = re.search(r"WITH.*?;", llm_response, re.DOTALL)
|
|
136
|
+
if sql:
|
|
137
|
+
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}")
|
|
138
|
+
return sql.group(0)
|
|
127
139
|
# If the llm_response contains a markdown code block, with or without the sql tag, extract the sql from it
|
|
128
140
|
sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL)
|
|
129
141
|
if sql:
|
|
@@ -363,7 +375,7 @@ class VannaBase(ABC):
|
|
|
363
375
|
self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000
|
|
364
376
|
) -> str:
|
|
365
377
|
if len(ddl_list) > 0:
|
|
366
|
-
initial_prompt +=
|
|
378
|
+
initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
367
379
|
|
|
368
380
|
for ddl in ddl_list:
|
|
369
381
|
if (
|
|
@@ -382,7 +394,7 @@ class VannaBase(ABC):
|
|
|
382
394
|
max_tokens: int = 14000,
|
|
383
395
|
) -> str:
|
|
384
396
|
if len(documentation_list) > 0:
|
|
385
|
-
initial_prompt +=
|
|
397
|
+
initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
386
398
|
|
|
387
399
|
for documentation in documentation_list:
|
|
388
400
|
if (
|
|
@@ -398,7 +410,7 @@ class VannaBase(ABC):
|
|
|
398
410
|
self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000
|
|
399
411
|
) -> str:
|
|
400
412
|
if len(sql_list) > 0:
|
|
401
|
-
initial_prompt +=
|
|
413
|
+
initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
402
414
|
|
|
403
415
|
for question in sql_list:
|
|
404
416
|
if (
|
|
@@ -890,6 +902,94 @@ class VannaBase(ABC):
|
|
|
890
902
|
self.run_sql_is_set = True
|
|
891
903
|
self.run_sql = run_sql_mysql
|
|
892
904
|
|
|
905
|
+
def connect_to_oracle(
|
|
906
|
+
self,
|
|
907
|
+
user: str = None,
|
|
908
|
+
password: str = None,
|
|
909
|
+
dsn: str = None,
|
|
910
|
+
):
|
|
911
|
+
|
|
912
|
+
"""
|
|
913
|
+
Connect to an Oracle db using oracledb package. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
914
|
+
**Example:**
|
|
915
|
+
```python
|
|
916
|
+
vn.connect_to_oracle(
|
|
917
|
+
user="username",
|
|
918
|
+
password="password",
|
|
919
|
+
dns="host:port/sid",
|
|
920
|
+
)
|
|
921
|
+
```
|
|
922
|
+
Args:
|
|
923
|
+
USER (str): Oracle db user name.
|
|
924
|
+
PASSWORD (str): Oracle db user password.
|
|
925
|
+
DSN (str): Oracle db host ip - host:port/sid.
|
|
926
|
+
"""
|
|
927
|
+
|
|
928
|
+
try:
|
|
929
|
+
import oracledb
|
|
930
|
+
except ImportError:
|
|
931
|
+
|
|
932
|
+
raise DependencyError(
|
|
933
|
+
"You need to install required dependencies to execute this method,"
|
|
934
|
+
" run command: \npip install oracledb"
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
if not dsn:
|
|
938
|
+
dsn = os.getenv("DSN")
|
|
939
|
+
|
|
940
|
+
if not dsn:
|
|
941
|
+
raise ImproperlyConfigured("Please set your Oracle dsn which should include host:port/sid")
|
|
942
|
+
|
|
943
|
+
if not user:
|
|
944
|
+
user = os.getenv("USER")
|
|
945
|
+
|
|
946
|
+
if not user:
|
|
947
|
+
raise ImproperlyConfigured("Please set your Oracle db user")
|
|
948
|
+
|
|
949
|
+
if not password:
|
|
950
|
+
password = os.getenv("PASSWORD")
|
|
951
|
+
|
|
952
|
+
if not password:
|
|
953
|
+
raise ImproperlyConfigured("Please set your Oracle db password")
|
|
954
|
+
|
|
955
|
+
conn = None
|
|
956
|
+
|
|
957
|
+
try:
|
|
958
|
+
conn = oracledb.connect(
|
|
959
|
+
user=user,
|
|
960
|
+
password=password,
|
|
961
|
+
dsn=dsn,
|
|
962
|
+
)
|
|
963
|
+
except oracledb.Error as e:
|
|
964
|
+
raise ValidationError(e)
|
|
965
|
+
|
|
966
|
+
def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]:
|
|
967
|
+
if conn:
|
|
968
|
+
try:
|
|
969
|
+
sql = sql.rstrip()
|
|
970
|
+
if sql.endswith(';'): #fix for a known problem with Oracle db where an extra ; will cause an error.
|
|
971
|
+
sql = sql[:-1]
|
|
972
|
+
|
|
973
|
+
cs = conn.cursor()
|
|
974
|
+
cs.execute(sql)
|
|
975
|
+
results = cs.fetchall()
|
|
976
|
+
|
|
977
|
+
# Create a pandas dataframe from the results
|
|
978
|
+
df = pd.DataFrame(
|
|
979
|
+
results, columns=[desc[0] for desc in cs.description]
|
|
980
|
+
)
|
|
981
|
+
return df
|
|
982
|
+
|
|
983
|
+
except oracledb.Error as e:
|
|
984
|
+
conn.rollback()
|
|
985
|
+
raise ValidationError(e)
|
|
986
|
+
|
|
987
|
+
except Exception as e:
|
|
988
|
+
conn.rollback()
|
|
989
|
+
raise e
|
|
990
|
+
|
|
991
|
+
self.run_sql_is_set = True
|
|
992
|
+
self.run_sql = run_sql_oracle
|
|
893
993
|
|
|
894
994
|
def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None):
|
|
895
995
|
"""
|
|
@@ -1238,7 +1338,7 @@ class VannaBase(ABC):
|
|
|
1238
1338
|
"""
|
|
1239
1339
|
|
|
1240
1340
|
if question and not sql:
|
|
1241
|
-
raise ValidationError(
|
|
1341
|
+
raise ValidationError("Please also provide a SQL query")
|
|
1242
1342
|
|
|
1243
1343
|
if documentation:
|
|
1244
1344
|
print("Adding documentation....")
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import json
|
|
2
|
-
import uuid
|
|
3
2
|
from typing import List
|
|
4
3
|
|
|
5
4
|
import chromadb
|
|
@@ -16,17 +15,14 @@ default_ef = embedding_functions.DefaultEmbeddingFunction()
|
|
|
16
15
|
class ChromaDB_VectorStore(VannaBase):
|
|
17
16
|
def __init__(self, config=None):
|
|
18
17
|
VannaBase.__init__(self, config=config)
|
|
18
|
+
if config is None:
|
|
19
|
+
config = {}
|
|
19
20
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
else:
|
|
26
|
-
path = "."
|
|
27
|
-
self.embedding_function = default_ef
|
|
28
|
-
curr_client = "persistent" # defaults to persistent storage
|
|
29
|
-
self.n_results = 10 # defaults to 10 documents
|
|
21
|
+
path = config.get("path", ".")
|
|
22
|
+
self.embedding_function = config.get("embedding_function", default_ef)
|
|
23
|
+
curr_client = config.get("client", "persistent")
|
|
24
|
+
collection_metadata = config.get("collection_metadata", None)
|
|
25
|
+
self.n_results = config.get("n_results", 10)
|
|
30
26
|
|
|
31
27
|
if curr_client == "persistent":
|
|
32
28
|
self.chroma_client = chromadb.PersistentClient(
|
|
@@ -43,13 +39,19 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
43
39
|
raise ValueError(f"Unsupported client was set in config: {curr_client}")
|
|
44
40
|
|
|
45
41
|
self.documentation_collection = self.chroma_client.get_or_create_collection(
|
|
46
|
-
name="documentation",
|
|
42
|
+
name="documentation",
|
|
43
|
+
embedding_function=self.embedding_function,
|
|
44
|
+
metadata=collection_metadata,
|
|
47
45
|
)
|
|
48
46
|
self.ddl_collection = self.chroma_client.get_or_create_collection(
|
|
49
|
-
name="ddl",
|
|
47
|
+
name="ddl",
|
|
48
|
+
embedding_function=self.embedding_function,
|
|
49
|
+
metadata=collection_metadata,
|
|
50
50
|
)
|
|
51
51
|
self.sql_collection = self.chroma_client.get_or_create_collection(
|
|
52
|
-
name="sql",
|
|
52
|
+
name="sql",
|
|
53
|
+
embedding_function=self.embedding_function,
|
|
54
|
+
metadata=collection_metadata,
|
|
53
55
|
)
|
|
54
56
|
|
|
55
57
|
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
@@ -525,7 +525,7 @@ class VannaFlaskApp:
|
|
|
525
525
|
# Proxy the /vanna.svg file to the remote server
|
|
526
526
|
@self.flask_app.route("/vanna.svg")
|
|
527
527
|
def proxy_vanna_svg():
|
|
528
|
-
remote_url =
|
|
528
|
+
remote_url = "https://vanna.ai/img/vanna.svg"
|
|
529
529
|
response = requests.get(remote_url, stream=True)
|
|
530
530
|
|
|
531
531
|
# Check if the request to the remote URL was successful
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import json
|
|
3
|
+
from io import StringIO
|
|
4
|
+
from typing import Callable, List, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import requests
|
|
8
|
+
|
|
9
|
+
from .base import VannaBase
|
|
10
|
+
from .types import (
|
|
11
|
+
AccuracyStats,
|
|
12
|
+
ApiKey,
|
|
13
|
+
DataFrameJSON,
|
|
14
|
+
DataResult,
|
|
15
|
+
Explanation,
|
|
16
|
+
FullQuestionDocument,
|
|
17
|
+
NewOrganization,
|
|
18
|
+
NewOrganizationMember,
|
|
19
|
+
Organization,
|
|
20
|
+
OrganizationList,
|
|
21
|
+
PlotlyResult,
|
|
22
|
+
Question,
|
|
23
|
+
QuestionCategory,
|
|
24
|
+
QuestionId,
|
|
25
|
+
QuestionList,
|
|
26
|
+
QuestionSQLPair,
|
|
27
|
+
QuestionStringList,
|
|
28
|
+
SQLAnswer,
|
|
29
|
+
Status,
|
|
30
|
+
StatusWithId,
|
|
31
|
+
StringData,
|
|
32
|
+
TrainingData,
|
|
33
|
+
UserEmail,
|
|
34
|
+
UserOTP,
|
|
35
|
+
Visibility,
|
|
36
|
+
)
|
|
37
|
+
from .vannadb import VannaDB_VectorStore
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class VannaDefault(VannaDB_VectorStore):
|
|
41
|
+
def __init__(self, model: str, api_key: str, config=None):
|
|
42
|
+
VannaBase.__init__(self, config=config)
|
|
43
|
+
VannaDB_VectorStore.__init__(self, vanna_model=model, vanna_api_key=api_key, config=config)
|
|
44
|
+
|
|
45
|
+
self._model = model
|
|
46
|
+
self._api_key = api_key
|
|
47
|
+
|
|
48
|
+
self._endpoint = (
|
|
49
|
+
"https://ask.vanna.ai/rpc"
|
|
50
|
+
if config is None or "endpoint" not in config
|
|
51
|
+
else config["endpoint"]
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def system_message(self, message: str) -> any:
|
|
55
|
+
return {"role": "system", "content": message}
|
|
56
|
+
|
|
57
|
+
def user_message(self, message: str) -> any:
|
|
58
|
+
return {"role": "user", "content": message}
|
|
59
|
+
|
|
60
|
+
def assistant_message(self, message: str) -> any:
|
|
61
|
+
return {"role": "assistant", "content": message}
|
|
62
|
+
|
|
63
|
+
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
64
|
+
# JSON-ify the prompt
|
|
65
|
+
json_prompt = json.dumps(prompt)
|
|
66
|
+
|
|
67
|
+
params = [StringData(data=json_prompt)]
|
|
68
|
+
|
|
69
|
+
d = self._rpc_call(method="submit_prompt", params=params)
|
|
70
|
+
|
|
71
|
+
if "result" not in d:
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
# Load the result into a dataclass
|
|
75
|
+
results = StringData(**d["result"])
|
|
76
|
+
|
|
77
|
+
return results.data
|
|
@@ -7,14 +7,17 @@ import requests
|
|
|
7
7
|
|
|
8
8
|
from ..base import VannaBase
|
|
9
9
|
from ..types import (
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
10
|
+
DataFrameJSON,
|
|
11
|
+
NewOrganization,
|
|
12
|
+
OrganizationList,
|
|
13
|
+
Question,
|
|
14
|
+
QuestionSQLPair,
|
|
15
|
+
Status,
|
|
16
|
+
StatusWithId,
|
|
17
|
+
StringData,
|
|
18
|
+
TrainingData,
|
|
17
19
|
)
|
|
20
|
+
from ..utils import sanitize_model_name
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
class VannaDB_VectorStore(VannaBase):
|
|
@@ -29,27 +32,8 @@ class VannaDB_VectorStore(VannaBase):
|
|
|
29
32
|
if config is None or "endpoint" not in config
|
|
30
33
|
else config["endpoint"]
|
|
31
34
|
)
|
|
32
|
-
self._unauthenticated_endpoint = (
|
|
33
|
-
"https://ask.vanna.ai/unauthenticated_rpc"
|
|
34
|
-
if config is None or "unauthenticated_endpoint" not in config
|
|
35
|
-
else config["unauthenticated_endpoint"]
|
|
36
|
-
)
|
|
37
35
|
self.related_training_data = {}
|
|
38
36
|
|
|
39
|
-
def _unauthenticated_rpc_call(self, method, params):
|
|
40
|
-
headers = {
|
|
41
|
-
"Content-Type": "application/json",
|
|
42
|
-
}
|
|
43
|
-
data = {
|
|
44
|
-
"method": method,
|
|
45
|
-
"params": [self._dataclass_to_dict(obj) for obj in params],
|
|
46
|
-
}
|
|
47
|
-
|
|
48
|
-
response = requests.post(
|
|
49
|
-
self._unauthenticated_endpoint, headers=headers, data=json.dumps(data)
|
|
50
|
-
)
|
|
51
|
-
return response.json()
|
|
52
|
-
|
|
53
37
|
def _rpc_call(self, method, params):
|
|
54
38
|
if method != "list_orgs":
|
|
55
39
|
headers = {
|
|
@@ -75,6 +59,53 @@ class VannaDB_VectorStore(VannaBase):
|
|
|
75
59
|
def _dataclass_to_dict(self, obj):
|
|
76
60
|
return dataclasses.asdict(obj)
|
|
77
61
|
|
|
62
|
+
def create_model(self, model: str, **kwargs) -> bool:
|
|
63
|
+
"""
|
|
64
|
+
**Example:**
|
|
65
|
+
```python
|
|
66
|
+
success = vn.create_model("my_model")
|
|
67
|
+
```
|
|
68
|
+
Create a new model.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
model (str): The name of the model to create.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
bool: True if the model was created, False otherwise.
|
|
75
|
+
"""
|
|
76
|
+
model = sanitize_model_name(model)
|
|
77
|
+
params = [NewOrganization(org_name=model, db_type="")]
|
|
78
|
+
|
|
79
|
+
d = self._rpc_call(method="create_org", params=params)
|
|
80
|
+
|
|
81
|
+
if "result" not in d:
|
|
82
|
+
return False
|
|
83
|
+
|
|
84
|
+
status = Status(**d["result"])
|
|
85
|
+
|
|
86
|
+
return status.success
|
|
87
|
+
|
|
88
|
+
def get_models(self) -> list:
|
|
89
|
+
"""
|
|
90
|
+
**Example:**
|
|
91
|
+
```python
|
|
92
|
+
models = vn.get_models()
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
List the models that belong to the user.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
List[str]: A list of model names.
|
|
99
|
+
"""
|
|
100
|
+
d = self._rpc_call(method="list_my_models", params=[])
|
|
101
|
+
|
|
102
|
+
if "result" not in d:
|
|
103
|
+
return []
|
|
104
|
+
|
|
105
|
+
orgs = OrganizationList(**d["result"])
|
|
106
|
+
|
|
107
|
+
return orgs.organizations
|
|
108
|
+
|
|
78
109
|
def generate_embedding(self, data: str, **kwargs) -> list[float]:
|
|
79
110
|
# This is done server-side
|
|
80
111
|
pass
|
|
@@ -141,7 +172,7 @@ class VannaDB_VectorStore(VannaBase):
|
|
|
141
172
|
d = self._rpc_call(method="remove_training_data", params=params)
|
|
142
173
|
|
|
143
174
|
if "result" not in d:
|
|
144
|
-
raise Exception(
|
|
175
|
+
raise Exception("Error removing training data")
|
|
145
176
|
|
|
146
177
|
status = Status(**d["result"])
|
|
147
178
|
|
vanna-0.3.3/src/vanna/remote.py
DELETED
|
@@ -1,455 +0,0 @@
|
|
|
1
|
-
import dataclasses
|
|
2
|
-
import json
|
|
3
|
-
from io import StringIO
|
|
4
|
-
from typing import Callable, List, Tuple, Union
|
|
5
|
-
|
|
6
|
-
import pandas as pd
|
|
7
|
-
import requests
|
|
8
|
-
|
|
9
|
-
from .base import VannaBase
|
|
10
|
-
from .types import (
|
|
11
|
-
AccuracyStats,
|
|
12
|
-
ApiKey,
|
|
13
|
-
DataFrameJSON,
|
|
14
|
-
DataResult,
|
|
15
|
-
Explanation,
|
|
16
|
-
FullQuestionDocument,
|
|
17
|
-
NewOrganization,
|
|
18
|
-
NewOrganizationMember,
|
|
19
|
-
Organization,
|
|
20
|
-
OrganizationList,
|
|
21
|
-
PlotlyResult,
|
|
22
|
-
Question,
|
|
23
|
-
QuestionCategory,
|
|
24
|
-
QuestionId,
|
|
25
|
-
QuestionList,
|
|
26
|
-
QuestionSQLPair,
|
|
27
|
-
QuestionStringList,
|
|
28
|
-
SQLAnswer,
|
|
29
|
-
Status,
|
|
30
|
-
StatusWithId,
|
|
31
|
-
StringData,
|
|
32
|
-
TrainingData,
|
|
33
|
-
UserEmail,
|
|
34
|
-
UserOTP,
|
|
35
|
-
Visibility,
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class VannaDefault(VannaBase):
|
|
40
|
-
def __init__(self, model: str, api_key: str, config=None):
|
|
41
|
-
VannaBase.__init__(self, config=config)
|
|
42
|
-
|
|
43
|
-
self._model = model
|
|
44
|
-
self._api_key = api_key
|
|
45
|
-
|
|
46
|
-
self._endpoint = (
|
|
47
|
-
"https://ask.vanna.ai/rpc"
|
|
48
|
-
if config is None or "endpoint" not in config
|
|
49
|
-
else config["endpoint"]
|
|
50
|
-
)
|
|
51
|
-
self._unauthenticated_endpoint = (
|
|
52
|
-
"https://ask.vanna.ai/unauthenticated_rpc"
|
|
53
|
-
if config is None or "unauthenticated_endpoint" not in config
|
|
54
|
-
else config["unauthenticated_endpoint"]
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
def _unauthenticated_rpc_call(self, method, params):
|
|
58
|
-
headers = {
|
|
59
|
-
"Content-Type": "application/json",
|
|
60
|
-
}
|
|
61
|
-
data = {
|
|
62
|
-
"method": method,
|
|
63
|
-
"params": [self._dataclass_to_dict(obj) for obj in params],
|
|
64
|
-
}
|
|
65
|
-
|
|
66
|
-
response = requests.post(
|
|
67
|
-
self._unauthenticated_endpoint, headers=headers, data=json.dumps(data)
|
|
68
|
-
)
|
|
69
|
-
return response.json()
|
|
70
|
-
|
|
71
|
-
def _rpc_call(self, method, params):
|
|
72
|
-
if method != "list_orgs":
|
|
73
|
-
headers = {
|
|
74
|
-
"Content-Type": "application/json",
|
|
75
|
-
"Vanna-Key": self._api_key,
|
|
76
|
-
"Vanna-Org": self._model,
|
|
77
|
-
}
|
|
78
|
-
else:
|
|
79
|
-
headers = {
|
|
80
|
-
"Content-Type": "application/json",
|
|
81
|
-
"Vanna-Key": self._api_key,
|
|
82
|
-
"Vanna-Org": "demo-tpc-h",
|
|
83
|
-
}
|
|
84
|
-
|
|
85
|
-
data = {
|
|
86
|
-
"method": method,
|
|
87
|
-
"params": [self._dataclass_to_dict(obj) for obj in params],
|
|
88
|
-
}
|
|
89
|
-
|
|
90
|
-
response = requests.post(self._endpoint, headers=headers, data=json.dumps(data))
|
|
91
|
-
return response.json()
|
|
92
|
-
|
|
93
|
-
def _dataclass_to_dict(self, obj):
|
|
94
|
-
return dataclasses.asdict(obj)
|
|
95
|
-
|
|
96
|
-
def system_message(self, message: str) -> any:
|
|
97
|
-
return {"role": "system", "content": message}
|
|
98
|
-
|
|
99
|
-
def user_message(self, message: str) -> any:
|
|
100
|
-
return {"role": "user", "content": message}
|
|
101
|
-
|
|
102
|
-
def assistant_message(self, message: str) -> any:
|
|
103
|
-
return {"role": "assistant", "content": message}
|
|
104
|
-
|
|
105
|
-
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
106
|
-
"""
|
|
107
|
-
Get the training data for the current model
|
|
108
|
-
|
|
109
|
-
**Example:**
|
|
110
|
-
```python
|
|
111
|
-
training_data = vn.get_training_data()
|
|
112
|
-
```
|
|
113
|
-
|
|
114
|
-
Returns:
|
|
115
|
-
pd.DataFrame or None: The training data, or None if an error occurred.
|
|
116
|
-
|
|
117
|
-
"""
|
|
118
|
-
params = []
|
|
119
|
-
|
|
120
|
-
d = self._rpc_call(method="get_training_data", params=params)
|
|
121
|
-
|
|
122
|
-
if "result" not in d:
|
|
123
|
-
return None
|
|
124
|
-
|
|
125
|
-
# Load the result into a dataclass
|
|
126
|
-
training_data = DataFrameJSON(**d["result"])
|
|
127
|
-
|
|
128
|
-
df = pd.read_json(StringIO(training_data.data))
|
|
129
|
-
|
|
130
|
-
return df
|
|
131
|
-
|
|
132
|
-
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
133
|
-
"""
|
|
134
|
-
Remove training data from the model
|
|
135
|
-
|
|
136
|
-
**Example:**
|
|
137
|
-
```python
|
|
138
|
-
vn.remove_training_data(id="1-ddl")
|
|
139
|
-
```
|
|
140
|
-
|
|
141
|
-
Args:
|
|
142
|
-
id (str): The ID of the training data to remove.
|
|
143
|
-
"""
|
|
144
|
-
params = [StringData(data=id)]
|
|
145
|
-
|
|
146
|
-
d = self._rpc_call(method="remove_training_data", params=params)
|
|
147
|
-
|
|
148
|
-
if "result" not in d:
|
|
149
|
-
raise Exception(f"Error removing training data")
|
|
150
|
-
|
|
151
|
-
status = Status(**d["result"])
|
|
152
|
-
|
|
153
|
-
if not status.success:
|
|
154
|
-
raise Exception(f"Error removing training data: {status.message}")
|
|
155
|
-
|
|
156
|
-
return status.success
|
|
157
|
-
|
|
158
|
-
def generate_questions(self) -> list[str]:
|
|
159
|
-
"""
|
|
160
|
-
**Example:**
|
|
161
|
-
```python
|
|
162
|
-
vn.generate_questions()
|
|
163
|
-
# ['What is the average salary of employees?', 'What is the total salary of employees?', ...]
|
|
164
|
-
```
|
|
165
|
-
|
|
166
|
-
Generate questions using the Vanna.AI API.
|
|
167
|
-
|
|
168
|
-
Returns:
|
|
169
|
-
List[str] or None: The questions, or None if an error occurred.
|
|
170
|
-
"""
|
|
171
|
-
d = self._rpc_call(method="generate_questions", params=[])
|
|
172
|
-
|
|
173
|
-
if "result" not in d:
|
|
174
|
-
return None
|
|
175
|
-
|
|
176
|
-
# Load the result into a dataclass
|
|
177
|
-
question_string_list = QuestionStringList(**d["result"])
|
|
178
|
-
|
|
179
|
-
return question_string_list.questions
|
|
180
|
-
|
|
181
|
-
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
182
|
-
"""
|
|
183
|
-
Adds a DDL statement to the model's training data
|
|
184
|
-
|
|
185
|
-
**Example:**
|
|
186
|
-
```python
|
|
187
|
-
vn.add_ddl(
|
|
188
|
-
ddl="CREATE TABLE employees (id INT, name VARCHAR(255), salary INT)"
|
|
189
|
-
)
|
|
190
|
-
```
|
|
191
|
-
|
|
192
|
-
Args:
|
|
193
|
-
ddl (str): The DDL statement to store.
|
|
194
|
-
|
|
195
|
-
Returns:
|
|
196
|
-
str: The ID of the DDL statement.
|
|
197
|
-
"""
|
|
198
|
-
params = [StringData(data=ddl)]
|
|
199
|
-
|
|
200
|
-
d = self._rpc_call(method="add_ddl", params=params)
|
|
201
|
-
|
|
202
|
-
if "result" not in d:
|
|
203
|
-
raise Exception("Error adding DDL", d)
|
|
204
|
-
|
|
205
|
-
status = StatusWithId(**d["result"])
|
|
206
|
-
|
|
207
|
-
return status.id
|
|
208
|
-
|
|
209
|
-
def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
210
|
-
"""
|
|
211
|
-
Adds documentation to the model's training data
|
|
212
|
-
|
|
213
|
-
**Example:**
|
|
214
|
-
```python
|
|
215
|
-
vn.add_documentation(
|
|
216
|
-
documentation="Our organization's definition of sales is the discount price of an item multiplied by the quantity sold."
|
|
217
|
-
)
|
|
218
|
-
```
|
|
219
|
-
|
|
220
|
-
Args:
|
|
221
|
-
documentation (str): The documentation string to store.
|
|
222
|
-
|
|
223
|
-
Returns:
|
|
224
|
-
str: The ID of the documentation string.
|
|
225
|
-
"""
|
|
226
|
-
params = [StringData(data=documentation)]
|
|
227
|
-
|
|
228
|
-
d = self._rpc_call(method="add_documentation", params=params)
|
|
229
|
-
|
|
230
|
-
if "result" not in d:
|
|
231
|
-
raise Exception("Error adding documentation", d)
|
|
232
|
-
|
|
233
|
-
status = StatusWithId(**d["result"])
|
|
234
|
-
|
|
235
|
-
return status.id
|
|
236
|
-
|
|
237
|
-
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
238
|
-
"""
|
|
239
|
-
Adds a question and its corresponding SQL query to the model's training data. The preferred way to call this is to use [`vn.train(sql=...)`][vanna.train].
|
|
240
|
-
|
|
241
|
-
**Example:**
|
|
242
|
-
```python
|
|
243
|
-
vn.add_sql(
|
|
244
|
-
question="What is the average salary of employees?",
|
|
245
|
-
sql="SELECT AVG(salary) FROM employees"
|
|
246
|
-
)
|
|
247
|
-
```
|
|
248
|
-
|
|
249
|
-
Args:
|
|
250
|
-
question (str): The question to store.
|
|
251
|
-
sql (str): The SQL query to store.
|
|
252
|
-
tag (Union[str, None]): A tag to associate with the question and SQL query.
|
|
253
|
-
|
|
254
|
-
Returns:
|
|
255
|
-
str: The ID of the question and SQL query.
|
|
256
|
-
"""
|
|
257
|
-
if "tag" in kwargs:
|
|
258
|
-
tag = kwargs["tag"]
|
|
259
|
-
else:
|
|
260
|
-
tag = "Manually Trained"
|
|
261
|
-
|
|
262
|
-
params = [QuestionSQLPair(question=question, sql=sql, tag=tag)]
|
|
263
|
-
|
|
264
|
-
d = self._rpc_call(method="add_sql", params=params)
|
|
265
|
-
|
|
266
|
-
if "result" not in d:
|
|
267
|
-
raise Exception("Error adding question and SQL pair", d)
|
|
268
|
-
|
|
269
|
-
status = StatusWithId(**d["result"])
|
|
270
|
-
|
|
271
|
-
return status.id
|
|
272
|
-
|
|
273
|
-
def generate_embedding(self, data: str, **kwargs) -> list[float]:
|
|
274
|
-
"""
|
|
275
|
-
Not necessary for remote models as embeddings are generated on the server side.
|
|
276
|
-
"""
|
|
277
|
-
pass
|
|
278
|
-
|
|
279
|
-
def generate_plotly_code(
|
|
280
|
-
self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs
|
|
281
|
-
) -> str:
|
|
282
|
-
"""
|
|
283
|
-
**Example:**
|
|
284
|
-
```python
|
|
285
|
-
vn.generate_plotly_code(
|
|
286
|
-
question="What is the average salary of employees?",
|
|
287
|
-
sql="SELECT AVG(salary) FROM employees",
|
|
288
|
-
df_metadata=df.dtypes
|
|
289
|
-
)
|
|
290
|
-
# fig = px.bar(df, x="name", y="salary")
|
|
291
|
-
```
|
|
292
|
-
Generate Plotly code using the Vanna.AI API.
|
|
293
|
-
|
|
294
|
-
Args:
|
|
295
|
-
question (str): The question to generate Plotly code for.
|
|
296
|
-
sql (str): The SQL query to generate Plotly code for.
|
|
297
|
-
df (pd.DataFrame): The dataframe to generate Plotly code for.
|
|
298
|
-
chart_instructions (str): Optional instructions for how to plot the chart.
|
|
299
|
-
|
|
300
|
-
Returns:
|
|
301
|
-
str or None: The Plotly code, or None if an error occurred.
|
|
302
|
-
"""
|
|
303
|
-
if kwargs is not None and "chart_instructions" in kwargs:
|
|
304
|
-
if question is not None:
|
|
305
|
-
question = (
|
|
306
|
-
question
|
|
307
|
-
+ " -- When plotting, follow these instructions: "
|
|
308
|
-
+ kwargs["chart_instructions"]
|
|
309
|
-
)
|
|
310
|
-
else:
|
|
311
|
-
question = (
|
|
312
|
-
"When plotting, follow these instructions: "
|
|
313
|
-
+ kwargs["chart_instructions"]
|
|
314
|
-
)
|
|
315
|
-
|
|
316
|
-
params = [
|
|
317
|
-
DataResult(
|
|
318
|
-
question=question,
|
|
319
|
-
sql=sql,
|
|
320
|
-
table_markdown=df_metadata,
|
|
321
|
-
error=None,
|
|
322
|
-
correction_attempts=0,
|
|
323
|
-
)
|
|
324
|
-
]
|
|
325
|
-
|
|
326
|
-
d = self._rpc_call(method="generate_plotly_code", params=params)
|
|
327
|
-
|
|
328
|
-
if "result" not in d:
|
|
329
|
-
return None
|
|
330
|
-
|
|
331
|
-
# Load the result into a dataclass
|
|
332
|
-
plotly_code = PlotlyResult(**d["result"])
|
|
333
|
-
|
|
334
|
-
return plotly_code.plotly_code
|
|
335
|
-
|
|
336
|
-
def generate_question(self, sql: str, **kwargs) -> str:
|
|
337
|
-
"""
|
|
338
|
-
|
|
339
|
-
**Example:**
|
|
340
|
-
```python
|
|
341
|
-
vn.generate_question(sql="SELECT * FROM students WHERE name = 'John Doe'")
|
|
342
|
-
# 'What is the name of the student?'
|
|
343
|
-
```
|
|
344
|
-
|
|
345
|
-
Generate a question from an SQL query using the Vanna.AI API.
|
|
346
|
-
|
|
347
|
-
Args:
|
|
348
|
-
sql (str): The SQL query to generate a question for.
|
|
349
|
-
|
|
350
|
-
Returns:
|
|
351
|
-
str or None: The question, or None if an error occurred.
|
|
352
|
-
|
|
353
|
-
"""
|
|
354
|
-
params = [
|
|
355
|
-
SQLAnswer(
|
|
356
|
-
raw_answer="",
|
|
357
|
-
prefix="",
|
|
358
|
-
postfix="",
|
|
359
|
-
sql=sql,
|
|
360
|
-
)
|
|
361
|
-
]
|
|
362
|
-
|
|
363
|
-
d = self._rpc_call(method="generate_question", params=params)
|
|
364
|
-
|
|
365
|
-
if "result" not in d:
|
|
366
|
-
return None
|
|
367
|
-
|
|
368
|
-
# Load the result into a dataclass
|
|
369
|
-
question = Question(**d["result"])
|
|
370
|
-
|
|
371
|
-
return question.question
|
|
372
|
-
|
|
373
|
-
def get_sql_prompt(
|
|
374
|
-
self,
|
|
375
|
-
question: str,
|
|
376
|
-
question_sql_list: list,
|
|
377
|
-
ddl_list: list,
|
|
378
|
-
doc_list: list,
|
|
379
|
-
**kwargs,
|
|
380
|
-
):
|
|
381
|
-
"""
|
|
382
|
-
Not necessary for remote models as prompts are generated on the server side.
|
|
383
|
-
"""
|
|
384
|
-
|
|
385
|
-
def get_followup_questions_prompt(
|
|
386
|
-
self,
|
|
387
|
-
question: str,
|
|
388
|
-
df: pd.DataFrame,
|
|
389
|
-
question_sql_list: list,
|
|
390
|
-
ddl_list: list,
|
|
391
|
-
doc_list: list,
|
|
392
|
-
**kwargs,
|
|
393
|
-
):
|
|
394
|
-
"""
|
|
395
|
-
Not necessary for remote models as prompts are generated on the server side.
|
|
396
|
-
"""
|
|
397
|
-
|
|
398
|
-
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
399
|
-
# JSON-ify the prompt
|
|
400
|
-
json_prompt = json.dumps(prompt)
|
|
401
|
-
|
|
402
|
-
params = [StringData(data=json_prompt)]
|
|
403
|
-
|
|
404
|
-
d = self._rpc_call(method="submit_prompt", params=params)
|
|
405
|
-
|
|
406
|
-
if "result" not in d:
|
|
407
|
-
return None
|
|
408
|
-
|
|
409
|
-
# Load the result into a dataclass
|
|
410
|
-
results = StringData(**d["result"])
|
|
411
|
-
|
|
412
|
-
return results.data
|
|
413
|
-
|
|
414
|
-
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
415
|
-
"""
|
|
416
|
-
Not necessary for remote models as similar questions are generated on the server side.
|
|
417
|
-
"""
|
|
418
|
-
|
|
419
|
-
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
420
|
-
"""
|
|
421
|
-
Not necessary for remote models as related DDL statements are generated on the server side.
|
|
422
|
-
"""
|
|
423
|
-
|
|
424
|
-
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
425
|
-
"""
|
|
426
|
-
Not necessary for remote models as related documentation is generated on the server side.
|
|
427
|
-
"""
|
|
428
|
-
|
|
429
|
-
def generate_sql(self, question: str, **kwargs) -> str:
|
|
430
|
-
"""
|
|
431
|
-
**Example:**
|
|
432
|
-
```python
|
|
433
|
-
vn.generate_sql_from_question(question="What is the average salary of employees?")
|
|
434
|
-
# SELECT AVG(salary) FROM employees
|
|
435
|
-
```
|
|
436
|
-
|
|
437
|
-
Generate an SQL query using the Vanna.AI API.
|
|
438
|
-
|
|
439
|
-
Args:
|
|
440
|
-
question (str): The question to generate an SQL query for.
|
|
441
|
-
|
|
442
|
-
Returns:
|
|
443
|
-
str or None: The SQL query, or None if an error occurred.
|
|
444
|
-
"""
|
|
445
|
-
params = [Question(question=question)]
|
|
446
|
-
|
|
447
|
-
d = self._rpc_call(method="generate_sql_from_question", params=params)
|
|
448
|
-
|
|
449
|
-
if "result" not in d:
|
|
450
|
-
return None
|
|
451
|
-
|
|
452
|
-
# Load the result into a dataclass
|
|
453
|
-
sql_answer = SQLAnswer(**d["result"])
|
|
454
|
-
|
|
455
|
-
return sql_answer.sql
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|