vanna 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/anthropic/__init__.py +0 -0
- vanna/anthropic/anthropic_chat.py +78 -0
- vanna/base/base.py +94 -1
- vanna/flask/__init__.py +159 -18
- vanna/flask/assets.py +22 -20
- vanna/mistral/mistral.py +14 -11
- vanna/ollama/__init__.py +0 -75
- vanna/ollama/ollama.py +77 -0
- vanna/openai/openai_chat.py +8 -8
- {vanna-0.2.0.dist-info → vanna-0.3.0.dist-info}/METADATA +7 -1
- {vanna-0.2.0.dist-info → vanna-0.3.0.dist-info}/RECORD +12 -10
- {vanna-0.2.0.dist-info → vanna-0.3.0.dist-info}/WHEEL +0 -0
vanna/mistral/mistral.py
CHANGED
|
@@ -1,33 +1,36 @@
|
|
|
1
1
|
from mistralai.client import MistralClient
|
|
2
2
|
from mistralai.models.chat_completion import ChatMessage
|
|
3
|
+
|
|
3
4
|
from ..base import VannaBase
|
|
4
|
-
|
|
5
|
+
|
|
5
6
|
|
|
6
7
|
class Mistral(VannaBase):
|
|
7
8
|
def __init__(self, config=None):
|
|
8
9
|
if config is None:
|
|
9
|
-
raise ValueError(
|
|
10
|
+
raise ValueError(
|
|
11
|
+
"For Mistral, config must be provided with an api_key and model"
|
|
12
|
+
)
|
|
10
13
|
|
|
11
|
-
if
|
|
14
|
+
if "api_key" not in config:
|
|
12
15
|
raise ValueError("config must contain a Mistral api_key")
|
|
13
|
-
|
|
14
|
-
if
|
|
16
|
+
|
|
17
|
+
if "model" not in config:
|
|
15
18
|
raise ValueError("config must contain a Mistral model")
|
|
16
19
|
|
|
17
|
-
api_key = config[
|
|
18
|
-
model = config[
|
|
20
|
+
api_key = config["api_key"]
|
|
21
|
+
model = config["model"]
|
|
19
22
|
self.client = MistralClient(api_key=api_key)
|
|
20
23
|
self.model = model
|
|
21
24
|
|
|
22
25
|
def system_message(self, message: str) -> any:
|
|
23
26
|
return ChatMessage(role="system", content=message)
|
|
24
|
-
|
|
27
|
+
|
|
25
28
|
def user_message(self, message: str) -> any:
|
|
26
29
|
return ChatMessage(role="user", content=message)
|
|
27
|
-
|
|
30
|
+
|
|
28
31
|
def assistant_message(self, message: str) -> any:
|
|
29
32
|
return ChatMessage(role="assistant", content=message)
|
|
30
|
-
|
|
33
|
+
|
|
31
34
|
def generate_sql(self, question: str, **kwargs) -> str:
|
|
32
35
|
# Use the super generate_sql
|
|
33
36
|
sql = super().generate_sql(question, **kwargs)
|
|
@@ -42,5 +45,5 @@ class Mistral(VannaBase):
|
|
|
42
45
|
model=self.model,
|
|
43
46
|
messages=prompt,
|
|
44
47
|
)
|
|
45
|
-
|
|
48
|
+
|
|
46
49
|
return chat_response.choices[0].message.content
|
vanna/ollama/__init__.py
CHANGED
|
@@ -1,75 +0,0 @@
|
|
|
1
|
-
from ..base import VannaBase
|
|
2
|
-
import requests
|
|
3
|
-
import json
|
|
4
|
-
import re
|
|
5
|
-
|
|
6
|
-
class Ollama(VannaBase):
|
|
7
|
-
def __init__(self, config=None):
|
|
8
|
-
if config is None or 'ollama_host' not in config:
|
|
9
|
-
self.host = "http://localhost:11434"
|
|
10
|
-
else:
|
|
11
|
-
self.host = config['ollama_host']
|
|
12
|
-
|
|
13
|
-
if config is None or 'model' not in config:
|
|
14
|
-
raise ValueError("config must contain a Ollama model")
|
|
15
|
-
else:
|
|
16
|
-
self.model = config['model']
|
|
17
|
-
|
|
18
|
-
def system_message(self, message: str) -> any:
|
|
19
|
-
return {"role": "system", "content": message}
|
|
20
|
-
|
|
21
|
-
def user_message(self, message: str) -> any:
|
|
22
|
-
return {"role": "user", "content": message}
|
|
23
|
-
|
|
24
|
-
def assistant_message(self, message: str) -> any:
|
|
25
|
-
return {"role": "assistant", "content": message}
|
|
26
|
-
|
|
27
|
-
def extract_sql_query(self, text):
|
|
28
|
-
"""
|
|
29
|
-
Extracts the first SQL statement after the word 'select', ignoring case,
|
|
30
|
-
matches until the first semicolon, three backticks, or the end of the string,
|
|
31
|
-
and removes three backticks if they exist in the extracted string.
|
|
32
|
-
|
|
33
|
-
Args:
|
|
34
|
-
- text (str): The string to search within for an SQL statement.
|
|
35
|
-
|
|
36
|
-
Returns:
|
|
37
|
-
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
|
|
38
|
-
"""
|
|
39
|
-
# Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
|
|
40
|
-
pattern = re.compile(r'select.*?(?:;|```|$)', re.IGNORECASE | re.DOTALL)
|
|
41
|
-
|
|
42
|
-
match = pattern.search(text)
|
|
43
|
-
if match:
|
|
44
|
-
# Remove three backticks from the matched string if they exist
|
|
45
|
-
return match.group(0).replace('```', '')
|
|
46
|
-
else:
|
|
47
|
-
return text
|
|
48
|
-
|
|
49
|
-
def generate_sql(self, question: str, **kwargs) -> str:
|
|
50
|
-
# Use the super generate_sql
|
|
51
|
-
sql = super().generate_sql(question, **kwargs)
|
|
52
|
-
|
|
53
|
-
# Replace "\_" with "_"
|
|
54
|
-
sql = sql.replace("\\_", "_")
|
|
55
|
-
|
|
56
|
-
sql = sql.replace("\\", "")
|
|
57
|
-
|
|
58
|
-
return self.extract_sql_query(sql)
|
|
59
|
-
|
|
60
|
-
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
61
|
-
url = f"{self.host}/api/chat"
|
|
62
|
-
data = {
|
|
63
|
-
"model": self.model,
|
|
64
|
-
"stream": False,
|
|
65
|
-
"messages": prompt,
|
|
66
|
-
}
|
|
67
|
-
|
|
68
|
-
response = requests.post(url, json=data)
|
|
69
|
-
|
|
70
|
-
response_dict = response.json()
|
|
71
|
-
|
|
72
|
-
self.log(response.text)
|
|
73
|
-
|
|
74
|
-
return response_dict['message']['content']
|
|
75
|
-
|
vanna/ollama/ollama.py
CHANGED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
import requests
|
|
5
|
+
|
|
6
|
+
from ..base import VannaBase
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Ollama(VannaBase):
|
|
10
|
+
def __init__(self, config=None):
|
|
11
|
+
if config is None or "ollama_host" not in config:
|
|
12
|
+
self.host = "http://localhost:11434"
|
|
13
|
+
else:
|
|
14
|
+
self.host = config["ollama_host"]
|
|
15
|
+
|
|
16
|
+
if config is None or "model" not in config:
|
|
17
|
+
raise ValueError("config must contain a Ollama model")
|
|
18
|
+
else:
|
|
19
|
+
self.model = config["model"]
|
|
20
|
+
|
|
21
|
+
def system_message(self, message: str) -> any:
|
|
22
|
+
return {"role": "system", "content": message}
|
|
23
|
+
|
|
24
|
+
def user_message(self, message: str) -> any:
|
|
25
|
+
return {"role": "user", "content": message}
|
|
26
|
+
|
|
27
|
+
def assistant_message(self, message: str) -> any:
|
|
28
|
+
return {"role": "assistant", "content": message}
|
|
29
|
+
|
|
30
|
+
def extract_sql_query(self, text):
|
|
31
|
+
"""
|
|
32
|
+
Extracts the first SQL statement after the word 'select', ignoring case,
|
|
33
|
+
matches until the first semicolon, three backticks, or the end of the string,
|
|
34
|
+
and removes three backticks if they exist in the extracted string.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
- text (str): The string to search within for an SQL statement.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
|
|
41
|
+
"""
|
|
42
|
+
# Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
|
|
43
|
+
pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)
|
|
44
|
+
|
|
45
|
+
match = pattern.search(text)
|
|
46
|
+
if match:
|
|
47
|
+
# Remove three backticks from the matched string if they exist
|
|
48
|
+
return match.group(0).replace("```", "")
|
|
49
|
+
else:
|
|
50
|
+
return text
|
|
51
|
+
|
|
52
|
+
def generate_sql(self, question: str, **kwargs) -> str:
|
|
53
|
+
# Use the super generate_sql
|
|
54
|
+
sql = super().generate_sql(question, **kwargs)
|
|
55
|
+
|
|
56
|
+
# Replace "\_" with "_"
|
|
57
|
+
sql = sql.replace("\\_", "_")
|
|
58
|
+
|
|
59
|
+
sql = sql.replace("\\", "")
|
|
60
|
+
|
|
61
|
+
return self.extract_sql_query(sql)
|
|
62
|
+
|
|
63
|
+
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
64
|
+
url = f"{self.host}/api/chat"
|
|
65
|
+
data = {
|
|
66
|
+
"model": self.model,
|
|
67
|
+
"stream": False,
|
|
68
|
+
"messages": prompt,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
response = requests.post(url, json=data)
|
|
72
|
+
|
|
73
|
+
response_dict = response.json()
|
|
74
|
+
|
|
75
|
+
self.log(response.text)
|
|
76
|
+
|
|
77
|
+
return response_dict["message"]["content"]
|
vanna/openai/openai_chat.py
CHANGED
|
@@ -9,14 +9,6 @@ class OpenAI_Chat(VannaBase):
|
|
|
9
9
|
def __init__(self, client=None, config=None):
|
|
10
10
|
VannaBase.__init__(self, config=config)
|
|
11
11
|
|
|
12
|
-
if client is not None:
|
|
13
|
-
self.client = client
|
|
14
|
-
return
|
|
15
|
-
|
|
16
|
-
if config is None and client is None:
|
|
17
|
-
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
18
|
-
return
|
|
19
|
-
|
|
20
12
|
# default parameters - can be overrided using config
|
|
21
13
|
self.temperature = 0.7
|
|
22
14
|
self.max_tokens = 500
|
|
@@ -42,6 +34,14 @@ class OpenAI_Chat(VannaBase):
|
|
|
42
34
|
"Passing api_version is now deprecated. Please pass an OpenAI client instead."
|
|
43
35
|
)
|
|
44
36
|
|
|
37
|
+
if client is not None:
|
|
38
|
+
self.client = client
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
if config is None and client is None:
|
|
42
|
+
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
43
|
+
return
|
|
44
|
+
|
|
45
45
|
if "api_key" in config:
|
|
46
46
|
self.client = OpenAI(api_key=config["api_key"])
|
|
47
47
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vanna
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: Generate SQL queries from natural language
|
|
5
5
|
Author-email: Zain Hoda <zain@vanna.ai>
|
|
6
6
|
Requires-Python: >=3.9
|
|
@@ -18,18 +18,22 @@ Requires-Dist: flask
|
|
|
18
18
|
Requires-Dist: sqlalchemy
|
|
19
19
|
Requires-Dist: psycopg2-binary ; extra == "all"
|
|
20
20
|
Requires-Dist: db-dtypes ; extra == "all"
|
|
21
|
+
Requires-Dist: PyMySQL ; extra == "all"
|
|
21
22
|
Requires-Dist: google-cloud-bigquery ; extra == "all"
|
|
22
23
|
Requires-Dist: snowflake-connector-python ; extra == "all"
|
|
23
24
|
Requires-Dist: duckdb ; extra == "all"
|
|
24
25
|
Requires-Dist: openai ; extra == "all"
|
|
25
26
|
Requires-Dist: mistralai ; extra == "all"
|
|
26
27
|
Requires-Dist: chromadb ; extra == "all"
|
|
28
|
+
Requires-Dist: anthropic ; extra == "all"
|
|
29
|
+
Requires-Dist: anthropic ; extra == "anthropic"
|
|
27
30
|
Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
|
|
28
31
|
Requires-Dist: chromadb ; extra == "chromadb"
|
|
29
32
|
Requires-Dist: duckdb ; extra == "duckdb"
|
|
30
33
|
Requires-Dist: google-generativeai ; extra == "gemini"
|
|
31
34
|
Requires-Dist: marqo ; extra == "marqo"
|
|
32
35
|
Requires-Dist: mistralai ; extra == "mistralai"
|
|
36
|
+
Requires-Dist: PyMySQL ; extra == "mysql"
|
|
33
37
|
Requires-Dist: openai ; extra == "openai"
|
|
34
38
|
Requires-Dist: psycopg2-binary ; extra == "postgres"
|
|
35
39
|
Requires-Dist: db-dtypes ; extra == "postgres"
|
|
@@ -38,12 +42,14 @@ Requires-Dist: tox ; extra == "test"
|
|
|
38
42
|
Project-URL: Bug Tracker, https://github.com/vanna-ai/vanna/issues
|
|
39
43
|
Project-URL: Homepage, https://github.com/vanna-ai/vanna
|
|
40
44
|
Provides-Extra: all
|
|
45
|
+
Provides-Extra: anthropic
|
|
41
46
|
Provides-Extra: bigquery
|
|
42
47
|
Provides-Extra: chromadb
|
|
43
48
|
Provides-Extra: duckdb
|
|
44
49
|
Provides-Extra: gemini
|
|
45
50
|
Provides-Extra: marqo
|
|
46
51
|
Provides-Extra: mistralai
|
|
52
|
+
Provides-Extra: mysql
|
|
47
53
|
Provides-Extra: openai
|
|
48
54
|
Provides-Extra: postgres
|
|
49
55
|
Provides-Extra: snowflake
|
|
@@ -5,25 +5,27 @@ vanna/utils.py,sha256=Q0H4eugPYg9SVpEoTWgvmuoJZZxOVRhNzrP97E5lyak,1472
|
|
|
5
5
|
vanna/ZhipuAI/ZhipuAI_Chat.py,sha256=hcx__0ZKHr5wtmIv0Ye2Xake8E-sOi5Qyc5nFIIdFvg,8957
|
|
6
6
|
vanna/ZhipuAI/ZhipuAI_embeddings.py,sha256=lUqzJg9fOx7rVFhjdkFjXcDeVGV4aAB5Ss0oERsa8pE,2849
|
|
7
7
|
vanna/ZhipuAI/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
+
vanna/anthropic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
+
vanna/anthropic/anthropic_chat.py,sha256=Wk0o-NMW1uvR2fhSWxrR_2FqNh-dLprNG4uuVqpqAkY,2615
|
|
8
10
|
vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
|
|
9
|
-
vanna/base/base.py,sha256=
|
|
11
|
+
vanna/base/base.py,sha256=25tHTMsCnoek6X9C0hqpPyNM3H0mihcL4I8Mlq-74c0,54890
|
|
10
12
|
vanna/chromadb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
13
|
vanna/chromadb/chromadb_vector.py,sha256=fa7uj_knzSfzsVLvpSunwwu1ZJNC3GbiNZ4Yy09v4l4,8372
|
|
12
14
|
vanna/exceptions/__init__.py,sha256=N76unE7sjbGGBz6LmCrPQAugFWr9cUFv8ErJxBrCTts,717
|
|
13
|
-
vanna/flask/__init__.py,sha256=
|
|
14
|
-
vanna/flask/assets.py,sha256=
|
|
15
|
+
vanna/flask/__init__.py,sha256=UgM0Ce5pGDdadWV6ZEAXj7RXDE1E420DW1wtR-juBMw,21212
|
|
16
|
+
vanna/flask/assets.py,sha256=uluOgW41cHYmqPyRFHn8HYgaozmPm9nhDLZkHLuSBM4,180728
|
|
15
17
|
vanna/marqo/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
18
|
vanna/marqo/marqo.py,sha256=2OBuC5IZmGcFXN2Ah6GVPKHBYtkDXeSwhXsqUbxyU94,5285
|
|
17
19
|
vanna/mistral/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
|
-
vanna/mistral/mistral.py,sha256=
|
|
19
|
-
vanna/ollama/__init__.py,sha256=
|
|
20
|
-
vanna/ollama/ollama.py,sha256=
|
|
20
|
+
vanna/mistral/mistral.py,sha256=DAEqAT9SzC91rfMM_S3SuzBZ34MrKHw9qAj6EP2MGVk,1508
|
|
21
|
+
vanna/ollama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
+
vanna/ollama/ollama.py,sha256=U02yy6DCAR-Ar7YXF98aBg3Kc-c9CiSWnMIGv4gs8zE,2430
|
|
21
23
|
vanna/openai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
-
vanna/openai/openai_chat.py,sha256=
|
|
24
|
+
vanna/openai/openai_chat.py,sha256=Y3-Fhz9c6D-5vrMR1zGibavAPNPAu-hMTwqGfKhAg3Q,3852
|
|
23
25
|
vanna/openai/openai_embeddings.py,sha256=g4pNh9LVcYP9wOoO8ecaccDFWmCUYMInebfHucAa2Gc,1260
|
|
24
26
|
vanna/types/__init__.py,sha256=Qhn_YscKtJh7mFPCyCDLa2K8a4ORLMGVnPpTbv9uB2U,4957
|
|
25
27
|
vanna/vannadb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
28
|
vanna/vannadb/vannadb_vector.py,sha256=f4kddaJgTpZync7wnQi09QdODUuMtiHsK7WfKBUAmSo,5644
|
|
27
|
-
vanna-0.
|
|
28
|
-
vanna-0.
|
|
29
|
-
vanna-0.
|
|
29
|
+
vanna-0.3.0.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
|
|
30
|
+
vanna-0.3.0.dist-info/METADATA,sha256=lg_gRR1TR9RgdxDh3UmthfyudyEdcWdxHpwXEMrz7AU,9961
|
|
31
|
+
vanna-0.3.0.dist-info/RECORD,,
|
|
File without changes
|