vanna 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vanna/mistral/mistral.py CHANGED
@@ -1,33 +1,36 @@
1
1
  from mistralai.client import MistralClient
2
2
  from mistralai.models.chat_completion import ChatMessage
3
+
3
4
  from ..base import VannaBase
4
- import re
5
+
5
6
 
6
7
  class Mistral(VannaBase):
7
8
  def __init__(self, config=None):
8
9
  if config is None:
9
- raise ValueError("For Mistral, config must be provided with an api_key and model")
10
+ raise ValueError(
11
+ "For Mistral, config must be provided with an api_key and model"
12
+ )
10
13
 
11
- if 'api_key' not in config:
14
+ if "api_key" not in config:
12
15
  raise ValueError("config must contain a Mistral api_key")
13
-
14
- if 'model' not in config:
16
+
17
+ if "model" not in config:
15
18
  raise ValueError("config must contain a Mistral model")
16
19
 
17
- api_key = config['api_key']
18
- model = config['model']
20
+ api_key = config["api_key"]
21
+ model = config["model"]
19
22
  self.client = MistralClient(api_key=api_key)
20
23
  self.model = model
21
24
 
22
25
  def system_message(self, message: str) -> any:
23
26
  return ChatMessage(role="system", content=message)
24
-
27
+
25
28
  def user_message(self, message: str) -> any:
26
29
  return ChatMessage(role="user", content=message)
27
-
30
+
28
31
  def assistant_message(self, message: str) -> any:
29
32
  return ChatMessage(role="assistant", content=message)
30
-
33
+
31
34
  def generate_sql(self, question: str, **kwargs) -> str:
32
35
  # Use the super generate_sql
33
36
  sql = super().generate_sql(question, **kwargs)
@@ -42,5 +45,5 @@ class Mistral(VannaBase):
42
45
  model=self.model,
43
46
  messages=prompt,
44
47
  )
45
-
48
+
46
49
  return chat_response.choices[0].message.content
vanna/ollama/__init__.py CHANGED
@@ -1,75 +0,0 @@
1
- from ..base import VannaBase
2
- import requests
3
- import json
4
- import re
5
-
6
- class Ollama(VannaBase):
7
- def __init__(self, config=None):
8
- if config is None or 'ollama_host' not in config:
9
- self.host = "http://localhost:11434"
10
- else:
11
- self.host = config['ollama_host']
12
-
13
- if config is None or 'model' not in config:
14
- raise ValueError("config must contain a Ollama model")
15
- else:
16
- self.model = config['model']
17
-
18
- def system_message(self, message: str) -> any:
19
- return {"role": "system", "content": message}
20
-
21
- def user_message(self, message: str) -> any:
22
- return {"role": "user", "content": message}
23
-
24
- def assistant_message(self, message: str) -> any:
25
- return {"role": "assistant", "content": message}
26
-
27
- def extract_sql_query(self, text):
28
- """
29
- Extracts the first SQL statement after the word 'select', ignoring case,
30
- matches until the first semicolon, three backticks, or the end of the string,
31
- and removes three backticks if they exist in the extracted string.
32
-
33
- Args:
34
- - text (str): The string to search within for an SQL statement.
35
-
36
- Returns:
37
- - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
38
- """
39
- # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
40
- pattern = re.compile(r'select.*?(?:;|```|$)', re.IGNORECASE | re.DOTALL)
41
-
42
- match = pattern.search(text)
43
- if match:
44
- # Remove three backticks from the matched string if they exist
45
- return match.group(0).replace('```', '')
46
- else:
47
- return text
48
-
49
- def generate_sql(self, question: str, **kwargs) -> str:
50
- # Use the super generate_sql
51
- sql = super().generate_sql(question, **kwargs)
52
-
53
- # Replace "\_" with "_"
54
- sql = sql.replace("\\_", "_")
55
-
56
- sql = sql.replace("\\", "")
57
-
58
- return self.extract_sql_query(sql)
59
-
60
- def submit_prompt(self, prompt, **kwargs) -> str:
61
- url = f"{self.host}/api/chat"
62
- data = {
63
- "model": self.model,
64
- "stream": False,
65
- "messages": prompt,
66
- }
67
-
68
- response = requests.post(url, json=data)
69
-
70
- response_dict = response.json()
71
-
72
- self.log(response.text)
73
-
74
- return response_dict['message']['content']
75
-
vanna/ollama/ollama.py CHANGED
@@ -0,0 +1,77 @@
1
+ import json
2
+ import re
3
+
4
+ import requests
5
+
6
+ from ..base import VannaBase
7
+
8
+
9
+ class Ollama(VannaBase):
10
+ def __init__(self, config=None):
11
+ if config is None or "ollama_host" not in config:
12
+ self.host = "http://localhost:11434"
13
+ else:
14
+ self.host = config["ollama_host"]
15
+
16
+ if config is None or "model" not in config:
17
+ raise ValueError("config must contain a Ollama model")
18
+ else:
19
+ self.model = config["model"]
20
+
21
+ def system_message(self, message: str) -> any:
22
+ return {"role": "system", "content": message}
23
+
24
+ def user_message(self, message: str) -> any:
25
+ return {"role": "user", "content": message}
26
+
27
+ def assistant_message(self, message: str) -> any:
28
+ return {"role": "assistant", "content": message}
29
+
30
+ def extract_sql_query(self, text):
31
+ """
32
+ Extracts the first SQL statement after the word 'select', ignoring case,
33
+ matches until the first semicolon, three backticks, or the end of the string,
34
+ and removes three backticks if they exist in the extracted string.
35
+
36
+ Args:
37
+ - text (str): The string to search within for an SQL statement.
38
+
39
+ Returns:
40
+ - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
41
+ """
42
+ # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
43
+ pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)
44
+
45
+ match = pattern.search(text)
46
+ if match:
47
+ # Remove three backticks from the matched string if they exist
48
+ return match.group(0).replace("```", "")
49
+ else:
50
+ return text
51
+
52
+ def generate_sql(self, question: str, **kwargs) -> str:
53
+ # Use the super generate_sql
54
+ sql = super().generate_sql(question, **kwargs)
55
+
56
+ # Replace "\_" with "_"
57
+ sql = sql.replace("\\_", "_")
58
+
59
+ sql = sql.replace("\\", "")
60
+
61
+ return self.extract_sql_query(sql)
62
+
63
+ def submit_prompt(self, prompt, **kwargs) -> str:
64
+ url = f"{self.host}/api/chat"
65
+ data = {
66
+ "model": self.model,
67
+ "stream": False,
68
+ "messages": prompt,
69
+ }
70
+
71
+ response = requests.post(url, json=data)
72
+
73
+ response_dict = response.json()
74
+
75
+ self.log(response.text)
76
+
77
+ return response_dict["message"]["content"]
@@ -9,14 +9,6 @@ class OpenAI_Chat(VannaBase):
9
9
  def __init__(self, client=None, config=None):
10
10
  VannaBase.__init__(self, config=config)
11
11
 
12
- if client is not None:
13
- self.client = client
14
- return
15
-
16
- if config is None and client is None:
17
- self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
18
- return
19
-
20
12
  # default parameters - can be overrided using config
21
13
  self.temperature = 0.7
22
14
  self.max_tokens = 500
@@ -42,6 +34,14 @@ class OpenAI_Chat(VannaBase):
42
34
  "Passing api_version is now deprecated. Please pass an OpenAI client instead."
43
35
  )
44
36
 
37
+ if client is not None:
38
+ self.client = client
39
+ return
40
+
41
+ if config is None and client is None:
42
+ self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
43
+ return
44
+
45
45
  if "api_key" in config:
46
46
  self.client = OpenAI(api_key=config["api_key"])
47
47
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -18,18 +18,22 @@ Requires-Dist: flask
18
18
  Requires-Dist: sqlalchemy
19
19
  Requires-Dist: psycopg2-binary ; extra == "all"
20
20
  Requires-Dist: db-dtypes ; extra == "all"
21
+ Requires-Dist: PyMySQL ; extra == "all"
21
22
  Requires-Dist: google-cloud-bigquery ; extra == "all"
22
23
  Requires-Dist: snowflake-connector-python ; extra == "all"
23
24
  Requires-Dist: duckdb ; extra == "all"
24
25
  Requires-Dist: openai ; extra == "all"
25
26
  Requires-Dist: mistralai ; extra == "all"
26
27
  Requires-Dist: chromadb ; extra == "all"
28
+ Requires-Dist: anthropic ; extra == "all"
29
+ Requires-Dist: anthropic ; extra == "anthropic"
27
30
  Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
28
31
  Requires-Dist: chromadb ; extra == "chromadb"
29
32
  Requires-Dist: duckdb ; extra == "duckdb"
30
33
  Requires-Dist: google-generativeai ; extra == "gemini"
31
34
  Requires-Dist: marqo ; extra == "marqo"
32
35
  Requires-Dist: mistralai ; extra == "mistralai"
36
+ Requires-Dist: PyMySQL ; extra == "mysql"
33
37
  Requires-Dist: openai ; extra == "openai"
34
38
  Requires-Dist: psycopg2-binary ; extra == "postgres"
35
39
  Requires-Dist: db-dtypes ; extra == "postgres"
@@ -38,12 +42,14 @@ Requires-Dist: tox ; extra == "test"
38
42
  Project-URL: Bug Tracker, https://github.com/vanna-ai/vanna/issues
39
43
  Project-URL: Homepage, https://github.com/vanna-ai/vanna
40
44
  Provides-Extra: all
45
+ Provides-Extra: anthropic
41
46
  Provides-Extra: bigquery
42
47
  Provides-Extra: chromadb
43
48
  Provides-Extra: duckdb
44
49
  Provides-Extra: gemini
45
50
  Provides-Extra: marqo
46
51
  Provides-Extra: mistralai
52
+ Provides-Extra: mysql
47
53
  Provides-Extra: openai
48
54
  Provides-Extra: postgres
49
55
  Provides-Extra: snowflake
@@ -5,25 +5,27 @@ vanna/utils.py,sha256=Q0H4eugPYg9SVpEoTWgvmuoJZZxOVRhNzrP97E5lyak,1472
5
5
  vanna/ZhipuAI/ZhipuAI_Chat.py,sha256=hcx__0ZKHr5wtmIv0Ye2Xake8E-sOi5Qyc5nFIIdFvg,8957
6
6
  vanna/ZhipuAI/ZhipuAI_embeddings.py,sha256=lUqzJg9fOx7rVFhjdkFjXcDeVGV4aAB5Ss0oERsa8pE,2849
7
7
  vanna/ZhipuAI/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ vanna/anthropic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ vanna/anthropic/anthropic_chat.py,sha256=Wk0o-NMW1uvR2fhSWxrR_2FqNh-dLprNG4uuVqpqAkY,2615
8
10
  vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
9
- vanna/base/base.py,sha256=6tSgklJdaZwhK9vMhI7ZilZsB_bsngUWk7bW1vVSoOE,52175
11
+ vanna/base/base.py,sha256=25tHTMsCnoek6X9C0hqpPyNM3H0mihcL4I8Mlq-74c0,54890
10
12
  vanna/chromadb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
13
  vanna/chromadb/chromadb_vector.py,sha256=fa7uj_knzSfzsVLvpSunwwu1ZJNC3GbiNZ4Yy09v4l4,8372
12
14
  vanna/exceptions/__init__.py,sha256=N76unE7sjbGGBz6LmCrPQAugFWr9cUFv8ErJxBrCTts,717
13
- vanna/flask/__init__.py,sha256=i7eh55uB64Kio3nnrAK4eg45gAoY_b2uziLLlCmj1Dk,15186
14
- vanna/flask/assets.py,sha256=1DcZNu3Uo93hz7M8ldEYKDeEslROIL3IK44sSS8PZgM,163301
15
+ vanna/flask/__init__.py,sha256=UgM0Ce5pGDdadWV6ZEAXj7RXDE1E420DW1wtR-juBMw,21212
16
+ vanna/flask/assets.py,sha256=uluOgW41cHYmqPyRFHn8HYgaozmPm9nhDLZkHLuSBM4,180728
15
17
  vanna/marqo/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
18
  vanna/marqo/marqo.py,sha256=2OBuC5IZmGcFXN2Ah6GVPKHBYtkDXeSwhXsqUbxyU94,5285
17
19
  vanna/mistral/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
- vanna/mistral/mistral.py,sha256=6NarIVLfh9qxJIyUv_9pdRzbbMfC54gYsqg99Q-K7aA,1514
19
- vanna/ollama/__init__.py,sha256=GW1ek7zw_fpL2yFNgrnN5RNjV2PdkK8CmTmcJF9YpDU,2464
20
- vanna/ollama/ollama.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
+ vanna/mistral/mistral.py,sha256=DAEqAT9SzC91rfMM_S3SuzBZ34MrKHw9qAj6EP2MGVk,1508
21
+ vanna/ollama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ vanna/ollama/ollama.py,sha256=U02yy6DCAR-Ar7YXF98aBg3Kc-c9CiSWnMIGv4gs8zE,2430
21
23
  vanna/openai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
- vanna/openai/openai_chat.py,sha256=2A6YEibwnMgdc6r1mvInqi_32xDDkdGi-bx73UEUaJM,3846
24
+ vanna/openai/openai_chat.py,sha256=Y3-Fhz9c6D-5vrMR1zGibavAPNPAu-hMTwqGfKhAg3Q,3852
23
25
  vanna/openai/openai_embeddings.py,sha256=g4pNh9LVcYP9wOoO8ecaccDFWmCUYMInebfHucAa2Gc,1260
24
26
  vanna/types/__init__.py,sha256=Qhn_YscKtJh7mFPCyCDLa2K8a4ORLMGVnPpTbv9uB2U,4957
25
27
  vanna/vannadb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
28
  vanna/vannadb/vannadb_vector.py,sha256=f4kddaJgTpZync7wnQi09QdODUuMtiHsK7WfKBUAmSo,5644
27
- vanna-0.2.0.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
28
- vanna-0.2.0.dist-info/METADATA,sha256=UzXxOVxudUIRanqrhsqp_ssL-lpvLqulSrn_PMyjnSY,9741
29
- vanna-0.2.0.dist-info/RECORD,,
29
+ vanna-0.3.0.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
30
+ vanna-0.3.0.dist-info/METADATA,sha256=lg_gRR1TR9RgdxDh3UmthfyudyEdcWdxHpwXEMrz7AU,9961
31
+ vanna-0.3.0.dist-info/RECORD,,
File without changes