vanna 0.6.4__tar.gz → 0.6.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {vanna-0.6.4 → vanna-0.6.6}/PKG-INFO +6 -3
- {vanna-0.6.4 → vanna-0.6.6}/pyproject.toml +4 -3
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/base/base.py +5 -10
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/flask/__init__.py +2 -2
- vanna-0.6.6/src/vanna/google/__init__.py +2 -0
- vanna-0.6.6/src/vanna/google/bigquery_vector.py +230 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/mistral/mistral.py +8 -6
- vanna-0.6.6/src/vanna/qianfan/Qianfan_Chat.py +165 -0
- vanna-0.6.6/src/vanna/qianfan/Qianfan_embeddings.py +36 -0
- vanna-0.6.6/src/vanna/qianfan/__init__.py +2 -0
- vanna-0.6.4/src/vanna/google/__init__.py +0 -1
- {vanna-0.6.4 → vanna-0.6.6}/README.md +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/ZhipuAI/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/advanced/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/anthropic/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/anthropic/anthropic_chat.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/base/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/bedrock/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/bedrock/bedrock_converse.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/chromadb/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/chromadb/chromadb_vector.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/exceptions/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/flask/assets.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/flask/auth.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/google/gemini_chat.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/hf/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/hf/hf.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/local.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/marqo/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/marqo/marqo.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/milvus/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/milvus/milvus_vector.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/mistral/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/mock/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/mock/embedding.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/mock/llm.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/mock/vectordb.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/ollama/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/ollama/ollama.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/openai/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/openai/openai_chat.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/openai/openai_embeddings.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/opensearch/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/opensearch/opensearch_vector.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/pinecone/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/pinecone/pinecone_vector.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/qdrant/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/qdrant/qdrant.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/remote.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/types/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/utils.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/vannadb/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/vannadb/vannadb_vector.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/vllm/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/vllm/vllm.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/weaviate/__init__.py +0 -0
- {vanna-0.6.4 → vanna-0.6.6}/src/vanna/weaviate/weaviate_vector.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vanna
|
|
3
|
-
Version: 0.6.
|
|
3
|
+
Version: 0.6.6
|
|
4
4
|
Summary: Generate SQL queries from natural language
|
|
5
5
|
Author-email: Zain Hoda <zain@vanna.ai>
|
|
6
6
|
Requires-Python: >=3.9
|
|
@@ -25,7 +25,8 @@ Requires-Dist: google-cloud-bigquery ; extra == "all"
|
|
|
25
25
|
Requires-Dist: snowflake-connector-python ; extra == "all"
|
|
26
26
|
Requires-Dist: duckdb ; extra == "all"
|
|
27
27
|
Requires-Dist: openai ; extra == "all"
|
|
28
|
-
Requires-Dist:
|
|
28
|
+
Requires-Dist: qianfan ; extra == "all"
|
|
29
|
+
Requires-Dist: mistralai>=1.0.0 ; extra == "all"
|
|
29
30
|
Requires-Dist: chromadb ; extra == "all"
|
|
30
31
|
Requires-Dist: anthropic ; extra == "all"
|
|
31
32
|
Requires-Dist: zhipuai ; extra == "all"
|
|
@@ -55,7 +56,7 @@ Requires-Dist: google-cloud-aiplatform ; extra == "google"
|
|
|
55
56
|
Requires-Dist: transformers ; extra == "hf"
|
|
56
57
|
Requires-Dist: marqo ; extra == "marqo"
|
|
57
58
|
Requires-Dist: pymilvus[model] ; extra == "milvus"
|
|
58
|
-
Requires-Dist: mistralai ; extra == "mistralai"
|
|
59
|
+
Requires-Dist: mistralai>=1.0.0 ; extra == "mistralai"
|
|
59
60
|
Requires-Dist: PyMySQL ; extra == "mysql"
|
|
60
61
|
Requires-Dist: ollama ; extra == "ollama"
|
|
61
62
|
Requires-Dist: httpx ; extra == "ollama"
|
|
@@ -68,6 +69,7 @@ Requires-Dist: psycopg2-binary ; extra == "postgres"
|
|
|
68
69
|
Requires-Dist: db-dtypes ; extra == "postgres"
|
|
69
70
|
Requires-Dist: qdrant-client ; extra == "qdrant"
|
|
70
71
|
Requires-Dist: fastembed ; extra == "qdrant"
|
|
72
|
+
Requires-Dist: qianfan ; extra == "qianfan"
|
|
71
73
|
Requires-Dist: snowflake-connector-python ; extra == "snowflake"
|
|
72
74
|
Requires-Dist: tox ; extra == "test"
|
|
73
75
|
Requires-Dist: vllm ; extra == "vllm"
|
|
@@ -95,6 +97,7 @@ Provides-Extra: opensearch
|
|
|
95
97
|
Provides-Extra: pinecone
|
|
96
98
|
Provides-Extra: postgres
|
|
97
99
|
Provides-Extra: qdrant
|
|
100
|
+
Provides-Extra: qianfan
|
|
98
101
|
Provides-Extra: snowflake
|
|
99
102
|
Provides-Extra: test
|
|
100
103
|
Provides-Extra: vllm
|
|
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "vanna"
|
|
7
|
-
version = "0.6.
|
|
7
|
+
version = "0.6.6"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name="Zain Hoda", email="zain@vanna.ai" },
|
|
10
10
|
]
|
|
@@ -33,11 +33,12 @@ bigquery = ["google-cloud-bigquery"]
|
|
|
33
33
|
snowflake = ["snowflake-connector-python"]
|
|
34
34
|
duckdb = ["duckdb"]
|
|
35
35
|
google = ["google-generativeai", "google-cloud-aiplatform"]
|
|
36
|
-
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client"]
|
|
36
|
+
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client"]
|
|
37
37
|
test = ["tox"]
|
|
38
38
|
chromadb = ["chromadb"]
|
|
39
39
|
openai = ["openai"]
|
|
40
|
-
|
|
40
|
+
qianfan = ["qianfan"]
|
|
41
|
+
mistralai = ["mistralai>=1.0.0"]
|
|
41
42
|
anthropic = ["anthropic"]
|
|
42
43
|
gemini = ["google-generativeai"]
|
|
43
44
|
marqo = ["marqo"]
|
|
@@ -437,7 +437,7 @@ class VannaBase(ABC):
|
|
|
437
437
|
pass
|
|
438
438
|
|
|
439
439
|
@abstractmethod
|
|
440
|
-
def remove_training_data(id: str, **kwargs) -> bool:
|
|
440
|
+
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
441
441
|
"""
|
|
442
442
|
Example:
|
|
443
443
|
```python
|
|
@@ -1276,15 +1276,10 @@ class VannaBase(ABC):
|
|
|
1276
1276
|
|
|
1277
1277
|
def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:
|
|
1278
1278
|
if conn:
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
except GoogleAPIError as error:
|
|
1284
|
-
errors = []
|
|
1285
|
-
for error in error.errors:
|
|
1286
|
-
errors.append(error["message"])
|
|
1287
|
-
raise errors
|
|
1279
|
+
job = conn.query(sql)
|
|
1280
|
+
df = job.result().to_dataframe()
|
|
1281
|
+
return df
|
|
1282
|
+
|
|
1288
1283
|
return None
|
|
1289
1284
|
|
|
1290
1285
|
self.dialect = "BigQuery SQL"
|
|
@@ -12,9 +12,9 @@ from flasgger import Swagger
|
|
|
12
12
|
from flask import Flask, Response, jsonify, request, send_from_directory
|
|
13
13
|
from flask_sock import Sock
|
|
14
14
|
|
|
15
|
+
from ..base import VannaBase
|
|
15
16
|
from .assets import css_content, html_content, js_content
|
|
16
17
|
from .auth import AuthInterface, NoAuth
|
|
17
|
-
from ..base import VannaBase
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class Cache(ABC):
|
|
@@ -1211,7 +1211,7 @@ class VannaFlaskApp(VannaFlaskAPI):
|
|
|
1211
1211
|
self.config["ask_results_correct"] = ask_results_correct
|
|
1212
1212
|
self.config["followup_questions"] = followup_questions
|
|
1213
1213
|
self.config["summarization"] = summarization
|
|
1214
|
-
self.config["function_generation"] = function_generation
|
|
1214
|
+
self.config["function_generation"] = function_generation and hasattr(vn, "get_function")
|
|
1215
1215
|
|
|
1216
1216
|
self.index_html_path = index_html_path
|
|
1217
1217
|
self.assets_folder = assets_folder
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import os
|
|
3
|
+
import uuid
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from google.cloud import bigquery
|
|
8
|
+
|
|
9
|
+
from ..base import VannaBase
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BigQuery_VectorStore(VannaBase):
|
|
13
|
+
def __init__(self, config: dict, **kwargs):
|
|
14
|
+
self.config = config
|
|
15
|
+
|
|
16
|
+
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
|
|
17
|
+
self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
|
|
18
|
+
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
|
|
19
|
+
|
|
20
|
+
if "api_key" in config or os.getenv("GOOGLE_API_KEY"):
|
|
21
|
+
"""
|
|
22
|
+
If Google api_key is provided through config
|
|
23
|
+
or set as an environment variable, assign it.
|
|
24
|
+
"""
|
|
25
|
+
print("Configuring genai")
|
|
26
|
+
import google.generativeai as genai
|
|
27
|
+
|
|
28
|
+
genai.configure(api_key=config["api_key"])
|
|
29
|
+
|
|
30
|
+
self.genai = genai
|
|
31
|
+
else:
|
|
32
|
+
# Authenticate using VertexAI
|
|
33
|
+
from vertexai.language_models import (
|
|
34
|
+
TextEmbeddingInput,
|
|
35
|
+
TextEmbeddingModel,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if self.config.get("project_id"):
|
|
39
|
+
self.project_id = self.config.get("project_id")
|
|
40
|
+
else:
|
|
41
|
+
self.project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
42
|
+
|
|
43
|
+
if self.project_id is None:
|
|
44
|
+
raise ValueError("Project ID is not set")
|
|
45
|
+
|
|
46
|
+
self.conn = bigquery.Client(project=self.project_id)
|
|
47
|
+
|
|
48
|
+
dataset_name = self.config.get('bigquery_dataset_name', 'vanna_managed')
|
|
49
|
+
self.dataset_id = f"{self.project_id}.{dataset_name}"
|
|
50
|
+
dataset = bigquery.Dataset(self.dataset_id)
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
self.conn.get_dataset(self.dataset_id) # Make an API request.
|
|
54
|
+
print(f"Dataset {self.dataset_id} already exists")
|
|
55
|
+
except Exception:
|
|
56
|
+
# Dataset does not exist, create it
|
|
57
|
+
dataset.location = "US"
|
|
58
|
+
self.conn.create_dataset(dataset, timeout=30) # Make an API request.
|
|
59
|
+
print(f"Created dataset {self.dataset_id}")
|
|
60
|
+
|
|
61
|
+
# Create a table called training_data in the dataset that contains the columns:
|
|
62
|
+
# id, training_data_type, question, content, embedding, created_at
|
|
63
|
+
|
|
64
|
+
self.table_id = f"{self.dataset_id}.training_data"
|
|
65
|
+
schema = [
|
|
66
|
+
bigquery.SchemaField("id", "STRING", mode="REQUIRED"),
|
|
67
|
+
bigquery.SchemaField("training_data_type", "STRING", mode="REQUIRED"),
|
|
68
|
+
bigquery.SchemaField("question", "STRING", mode="REQUIRED"),
|
|
69
|
+
bigquery.SchemaField("content", "STRING", mode="REQUIRED"),
|
|
70
|
+
bigquery.SchemaField("embedding", "FLOAT64", mode="REPEATED"),
|
|
71
|
+
bigquery.SchemaField("created_at", "TIMESTAMP", mode="REQUIRED"),
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
table = bigquery.Table(self.table_id, schema=schema)
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
self.conn.get_table(self.table_id) # Make an API request.
|
|
78
|
+
print(f"Table {self.table_id} already exists")
|
|
79
|
+
except Exception:
|
|
80
|
+
# Table does not exist, create it
|
|
81
|
+
self.conn.create_table(table, timeout=30) # Make an API request.
|
|
82
|
+
print(f"Created table {self.table_id}")
|
|
83
|
+
|
|
84
|
+
# Create VECTOR INDEX IF NOT EXISTS
|
|
85
|
+
# TODO: This requires 5000 rows before it can be created
|
|
86
|
+
# vector_index_query = f"""
|
|
87
|
+
# CREATE VECTOR INDEX IF NOT EXISTS my_index
|
|
88
|
+
# ON `{self.table_id}`(embedding)
|
|
89
|
+
# OPTIONS(
|
|
90
|
+
# distance_type='COSINE',
|
|
91
|
+
# index_type='IVF',
|
|
92
|
+
# ivf_options='{{"num_lists": 1000}}'
|
|
93
|
+
# )
|
|
94
|
+
# """
|
|
95
|
+
|
|
96
|
+
# try:
|
|
97
|
+
# self.conn.query(vector_index_query).result() # Make an API request.
|
|
98
|
+
# print(f"Vector index on {self.table_id} created or already exists")
|
|
99
|
+
# except Exception as e:
|
|
100
|
+
# print(f"Failed to create vector index: {e}")
|
|
101
|
+
|
|
102
|
+
def store_training_data(self, training_data_type: str, question: str, content: str, embedding: List[float], **kwargs) -> str:
|
|
103
|
+
id = str(uuid.uuid4())
|
|
104
|
+
created_at = datetime.datetime.now()
|
|
105
|
+
self.conn.insert_rows_json(self.table_id, [{
|
|
106
|
+
"id": id,
|
|
107
|
+
"training_data_type": training_data_type,
|
|
108
|
+
"question": question,
|
|
109
|
+
"content": content,
|
|
110
|
+
"embedding": embedding,
|
|
111
|
+
"created_at": created_at.isoformat()
|
|
112
|
+
}])
|
|
113
|
+
|
|
114
|
+
return id
|
|
115
|
+
|
|
116
|
+
def fetch_similar_training_data(self, training_data_type: str, question: str, n_results, **kwargs) -> pd.DataFrame:
|
|
117
|
+
question_embedding = self.generate_question_embedding(question)
|
|
118
|
+
|
|
119
|
+
query = f"""
|
|
120
|
+
SELECT
|
|
121
|
+
base.id as id,
|
|
122
|
+
base.question as question,
|
|
123
|
+
base.training_data_type as training_data_type,
|
|
124
|
+
base.content as content,
|
|
125
|
+
distance
|
|
126
|
+
FROM
|
|
127
|
+
VECTOR_SEARCH(
|
|
128
|
+
TABLE `{self.table_id}`,
|
|
129
|
+
'embedding',
|
|
130
|
+
(SELECT * FROM UNNEST([STRUCT({question_embedding})])),
|
|
131
|
+
top_k => 5,
|
|
132
|
+
distance_type => 'COSINE',
|
|
133
|
+
options => '{{"use_brute_force":true}}'
|
|
134
|
+
)
|
|
135
|
+
WHERE
|
|
136
|
+
base.training_data_type = '{training_data_type}'
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
results = self.conn.query(query).result().to_dataframe()
|
|
140
|
+
return results
|
|
141
|
+
|
|
142
|
+
def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
|
|
143
|
+
result = self.genai.embed_content(
|
|
144
|
+
model="models/text-embedding-004",
|
|
145
|
+
content=data,
|
|
146
|
+
task_type="retrieval_query")
|
|
147
|
+
|
|
148
|
+
if 'embedding' in result:
|
|
149
|
+
return result['embedding']
|
|
150
|
+
else:
|
|
151
|
+
raise ValueError("No embeddings returned")
|
|
152
|
+
|
|
153
|
+
def generate_storage_embedding(self, data: str, **kwargs) -> List[float]:
|
|
154
|
+
result = self.genai.embed_content(
|
|
155
|
+
model="models/text-embedding-004",
|
|
156
|
+
content=data,
|
|
157
|
+
task_type="retrieval_document")
|
|
158
|
+
|
|
159
|
+
if 'embedding' in result:
|
|
160
|
+
return result['embedding']
|
|
161
|
+
else:
|
|
162
|
+
raise ValueError("No embeddings returned")
|
|
163
|
+
|
|
164
|
+
# task = "RETRIEVAL_DOCUMENT"
|
|
165
|
+
# inputs = [TextEmbeddingInput(data, task)]
|
|
166
|
+
# embeddings = self.vertex_embedding_model.get_embeddings(inputs)
|
|
167
|
+
|
|
168
|
+
# if len(embeddings) == 0:
|
|
169
|
+
# raise ValueError("No embeddings returned")
|
|
170
|
+
|
|
171
|
+
# return embeddings[0].values
|
|
172
|
+
|
|
173
|
+
return result
|
|
174
|
+
|
|
175
|
+
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
176
|
+
return self.generate_storage_embedding(data, **kwargs)
|
|
177
|
+
|
|
178
|
+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
179
|
+
df = self.fetch_similar_training_data(training_data_type="sql", question=question, n_results=self.n_results_sql)
|
|
180
|
+
|
|
181
|
+
# Return a list of dictionaries with only question, sql fields. The content field needs to be renamed to sql
|
|
182
|
+
return df.rename(columns={"content": "sql"})[["question", "sql"]].to_dict(orient="records")
|
|
183
|
+
|
|
184
|
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
185
|
+
df = self.fetch_similar_training_data(training_data_type="ddl", question=question, n_results=self.n_results_ddl)
|
|
186
|
+
|
|
187
|
+
# Return a list of strings of the content
|
|
188
|
+
return df["content"].tolist()
|
|
189
|
+
|
|
190
|
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
191
|
+
df = self.fetch_similar_training_data(training_data_type="documentation", question=question, n_results=self.n_results_documentation)
|
|
192
|
+
|
|
193
|
+
# Return a list of strings of the content
|
|
194
|
+
return df["content"].tolist()
|
|
195
|
+
|
|
196
|
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
197
|
+
doc = {
|
|
198
|
+
"question": question,
|
|
199
|
+
"sql": sql
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
embedding = self.generate_embedding(str(doc))
|
|
203
|
+
|
|
204
|
+
return self.store_training_data(training_data_type="sql", question=question, content=sql, embedding=embedding)
|
|
205
|
+
|
|
206
|
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
207
|
+
embedding = self.generate_embedding(ddl)
|
|
208
|
+
|
|
209
|
+
return self.store_training_data(training_data_type="ddl", question="", content=ddl, embedding=embedding)
|
|
210
|
+
|
|
211
|
+
def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
212
|
+
embedding = self.generate_embedding(documentation)
|
|
213
|
+
|
|
214
|
+
return self.store_training_data(training_data_type="documentation", question="", content=documentation, embedding=embedding)
|
|
215
|
+
|
|
216
|
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
217
|
+
query = f"SELECT id, training_data_type, question, content FROM `{self.table_id}`"
|
|
218
|
+
|
|
219
|
+
return self.conn.query(query).result().to_dataframe()
|
|
220
|
+
|
|
221
|
+
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
222
|
+
query = f"DELETE FROM `{self.table_id}` WHERE id = '{id}'"
|
|
223
|
+
|
|
224
|
+
try:
|
|
225
|
+
self.conn.query(query).result()
|
|
226
|
+
return True
|
|
227
|
+
|
|
228
|
+
except Exception as e:
|
|
229
|
+
print(f"Failed to remove training data: {e}")
|
|
230
|
+
return False
|
|
@@ -1,5 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from mistralai import Mistral as MistralClient
|
|
4
|
+
from mistralai import UserMessage
|
|
3
5
|
|
|
4
6
|
from ..base import VannaBase
|
|
5
7
|
|
|
@@ -23,13 +25,13 @@ class Mistral(VannaBase):
|
|
|
23
25
|
self.model = model
|
|
24
26
|
|
|
25
27
|
def system_message(self, message: str) -> any:
|
|
26
|
-
return
|
|
28
|
+
return {"role": "system", "content": message}
|
|
27
29
|
|
|
28
30
|
def user_message(self, message: str) -> any:
|
|
29
|
-
return
|
|
31
|
+
return {"role": "user", "content": message}
|
|
30
32
|
|
|
31
33
|
def assistant_message(self, message: str) -> any:
|
|
32
|
-
return
|
|
34
|
+
return {"role": "assistant", "content": message}
|
|
33
35
|
|
|
34
36
|
def generate_sql(self, question: str, **kwargs) -> str:
|
|
35
37
|
# Use the super generate_sql
|
|
@@ -41,7 +43,7 @@ class Mistral(VannaBase):
|
|
|
41
43
|
return sql
|
|
42
44
|
|
|
43
45
|
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
44
|
-
chat_response = self.client.chat(
|
|
46
|
+
chat_response = self.client.chat.complete(
|
|
45
47
|
model=self.model,
|
|
46
48
|
messages=prompt,
|
|
47
49
|
)
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import qianfan
|
|
2
|
+
|
|
3
|
+
from ..base import VannaBase
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Qianfan_Chat(VannaBase):
|
|
7
|
+
def __init__(self, client=None, config=None):
|
|
8
|
+
VannaBase.__init__(self, config=config)
|
|
9
|
+
|
|
10
|
+
if "api_key" not in config:
|
|
11
|
+
raise Exception("Missing api_key in config")
|
|
12
|
+
self.api_key = config["api_key"]
|
|
13
|
+
|
|
14
|
+
if "secret_key" not in config:
|
|
15
|
+
raise Exception("Missing secret_key in config")
|
|
16
|
+
self.secret_key = config["secret_key"]
|
|
17
|
+
|
|
18
|
+
# default parameters - can be overrided using config
|
|
19
|
+
self.temperature = 0.9
|
|
20
|
+
self.max_tokens = 1024
|
|
21
|
+
|
|
22
|
+
if "temperature" in config:
|
|
23
|
+
self.temperature = config["temperature"]
|
|
24
|
+
|
|
25
|
+
if "max_tokens" in config:
|
|
26
|
+
self.max_tokens = config["max_tokens"]
|
|
27
|
+
|
|
28
|
+
self.model = config["model"] if "model" in config else "ERNIE-Speed"
|
|
29
|
+
|
|
30
|
+
if client is not None:
|
|
31
|
+
self.client = client
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
self.client = qianfan.ChatCompletion(ak=self.api_key,
|
|
35
|
+
sk=self.secret_key)
|
|
36
|
+
|
|
37
|
+
def system_message(self, message: str) -> any:
|
|
38
|
+
return {"role": "system", "content": message}
|
|
39
|
+
|
|
40
|
+
def user_message(self, message: str) -> any:
|
|
41
|
+
return {"role": "user", "content": message}
|
|
42
|
+
|
|
43
|
+
def assistant_message(self, message: str) -> any:
|
|
44
|
+
return {"role": "assistant", "content": message}
|
|
45
|
+
|
|
46
|
+
def get_sql_prompt(
|
|
47
|
+
self,
|
|
48
|
+
initial_prompt: str,
|
|
49
|
+
question: str,
|
|
50
|
+
question_sql_list: list,
|
|
51
|
+
ddl_list: list,
|
|
52
|
+
doc_list: list,
|
|
53
|
+
**kwargs,
|
|
54
|
+
):
|
|
55
|
+
"""
|
|
56
|
+
Example:
|
|
57
|
+
```python
|
|
58
|
+
vn.get_sql_prompt(
|
|
59
|
+
question="What are the top 10 customers by sales?",
|
|
60
|
+
question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
|
|
61
|
+
ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
|
|
62
|
+
doc_list=["The customers table contains information about customers and their sales."],
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
This method is used to generate a prompt for the LLM to generate SQL.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
question (str): The question to generate SQL for.
|
|
71
|
+
question_sql_list (list): A list of questions and their corresponding SQL statements.
|
|
72
|
+
ddl_list (list): A list of DDL statements.
|
|
73
|
+
doc_list (list): A list of documentation.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
any: The prompt for the LLM to generate SQL.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
if initial_prompt is None:
|
|
80
|
+
initial_prompt = f"You are a {self.dialect} expert. " + \
|
|
81
|
+
"Please help to generate a SQL to answer the question based on some context.Please don't give any explanation for your answer. Just only generate a SQL \n"
|
|
82
|
+
|
|
83
|
+
initial_prompt = self.add_ddl_to_prompt(
|
|
84
|
+
initial_prompt, ddl_list, max_tokens=self.max_tokens
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if self.static_documentation != "":
|
|
88
|
+
doc_list.append(self.static_documentation)
|
|
89
|
+
|
|
90
|
+
initial_prompt = self.add_documentation_to_prompt(
|
|
91
|
+
initial_prompt, doc_list, max_tokens=self.max_tokens
|
|
92
|
+
)
|
|
93
|
+
message_log = []
|
|
94
|
+
|
|
95
|
+
if question_sql_list is None or len(question_sql_list) == 0:
|
|
96
|
+
initial_prompt = initial_prompt + f"question: {question}"
|
|
97
|
+
message_log.append(self.user_message(initial_prompt))
|
|
98
|
+
else:
|
|
99
|
+
for i, example in question_sql_list:
|
|
100
|
+
if example is None:
|
|
101
|
+
print("example is None")
|
|
102
|
+
else:
|
|
103
|
+
if example is not None and "question" in example and "sql" in example:
|
|
104
|
+
if i == 0:
|
|
105
|
+
initial_prompt = initial_prompt + f"question: {example['question']}"
|
|
106
|
+
message_log.append(self.user_message(initial_prompt))
|
|
107
|
+
else:
|
|
108
|
+
message_log.append(self.user_message(example["question"]))
|
|
109
|
+
message_log.append(self.assistant_message(example["sql"]))
|
|
110
|
+
|
|
111
|
+
message_log.append(self.user_message(question))
|
|
112
|
+
return message_log
|
|
113
|
+
|
|
114
|
+
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
115
|
+
if prompt is None:
|
|
116
|
+
raise Exception("Prompt is None")
|
|
117
|
+
|
|
118
|
+
if len(prompt) == 0:
|
|
119
|
+
raise Exception("Prompt is empty")
|
|
120
|
+
|
|
121
|
+
# Count the number of tokens in the message log
|
|
122
|
+
# Use 4 as an approximation for the number of characters per token
|
|
123
|
+
num_tokens = 0
|
|
124
|
+
for message in prompt:
|
|
125
|
+
num_tokens += len(message["content"]) / 4
|
|
126
|
+
|
|
127
|
+
if kwargs.get("model", None) is not None:
|
|
128
|
+
model = kwargs.get("model", None)
|
|
129
|
+
print(
|
|
130
|
+
f"Using model {model} for {num_tokens} tokens (approx)"
|
|
131
|
+
)
|
|
132
|
+
response = self.client.do(
|
|
133
|
+
model=self.model,
|
|
134
|
+
messages=prompt,
|
|
135
|
+
max_output_tokens=self.max_tokens,
|
|
136
|
+
stop=None,
|
|
137
|
+
temperature=self.temperature,
|
|
138
|
+
)
|
|
139
|
+
elif self.config is not None and "model" in self.config:
|
|
140
|
+
print(
|
|
141
|
+
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
|
|
142
|
+
)
|
|
143
|
+
response = self.client.do(
|
|
144
|
+
model=self.config.get("model"),
|
|
145
|
+
messages=prompt,
|
|
146
|
+
max_output_tokens=self.max_tokens,
|
|
147
|
+
stop=None,
|
|
148
|
+
temperature=self.temperature,
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
if num_tokens > 3500:
|
|
152
|
+
model = "ERNIE-Speed-128K"
|
|
153
|
+
else:
|
|
154
|
+
model = "ERNIE-Speed-8K"
|
|
155
|
+
|
|
156
|
+
print(f"Using model {model} for {num_tokens} tokens (approx)")
|
|
157
|
+
response = self.client.do(
|
|
158
|
+
model=model,
|
|
159
|
+
messages=prompt,
|
|
160
|
+
max_output_tokens=self.max_tokens,
|
|
161
|
+
stop=None,
|
|
162
|
+
temperature=self.temperature,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
return response.body.get("result")
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import qianfan
|
|
2
|
+
|
|
3
|
+
from ..base import VannaBase
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Qianfan_Embeddings(VannaBase):
|
|
7
|
+
def __init__(self, client=None, config=None):
|
|
8
|
+
VannaBase.__init__(self, config=config)
|
|
9
|
+
|
|
10
|
+
if client is not None:
|
|
11
|
+
self.client = client
|
|
12
|
+
return
|
|
13
|
+
|
|
14
|
+
if "api_key" not in config:
|
|
15
|
+
raise Exception("Missing api_key in config")
|
|
16
|
+
self.api_key = config["api_key"]
|
|
17
|
+
|
|
18
|
+
if "secret_key" not in config:
|
|
19
|
+
raise Exception("Missing secret_key in config")
|
|
20
|
+
self.secret_key = config["secret_key"]
|
|
21
|
+
|
|
22
|
+
self.client = qianfan.Embedding(ak=self.api_key, sk=self.secret_key)
|
|
23
|
+
|
|
24
|
+
def generate_embedding(self, data: str, **kwargs) -> list[float]:
|
|
25
|
+
if self.config is not None and "model" in self.config:
|
|
26
|
+
embedding = self.client.do(
|
|
27
|
+
model=self.config["model"],
|
|
28
|
+
input=[data],
|
|
29
|
+
)
|
|
30
|
+
else:
|
|
31
|
+
embedding = self.client.do(
|
|
32
|
+
model="bge-large-zh",
|
|
33
|
+
input=[data],
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
return embedding.get("data")[0]["embedding"]
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .gemini_chat import GoogleGeminiChat
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|