vanna 0.6.2__tar.gz → 0.6.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {vanna-0.6.2 → vanna-0.6.3}/PKG-INFO +7 -1
- {vanna-0.6.2 → vanna-0.6.3}/pyproject.toml +4 -2
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/base/base.py +1 -1
- vanna-0.6.3/src/vanna/bedrock/__init__.py +1 -0
- vanna-0.6.3/src/vanna/bedrock/bedrock_converse.py +85 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/google/gemini_chat.py +2 -2
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/hf/hf.py +8 -6
- vanna-0.6.3/src/vanna/weaviate/__init__.py +1 -0
- vanna-0.6.3/src/vanna/weaviate/weaviate_vector.py +174 -0
- {vanna-0.6.2 → vanna-0.6.3}/README.md +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/ZhipuAI/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/advanced/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/anthropic/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/anthropic/anthropic_chat.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/base/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/chromadb/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/chromadb/chromadb_vector.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/exceptions/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/flask/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/flask/assets.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/flask/auth.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/google/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/hf/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/local.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/marqo/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/marqo/marqo.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/milvus/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/milvus/milvus_vector.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/mistral/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/mistral/mistral.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/mock/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/mock/embedding.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/mock/llm.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/mock/vectordb.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/ollama/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/ollama/ollama.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/openai/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/openai/openai_chat.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/openai/openai_embeddings.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/opensearch/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/opensearch/opensearch_vector.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/pinecone/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/pinecone/pinecone_vector.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/qdrant/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/qdrant/qdrant.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/remote.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/types/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/utils.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/vannadb/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/vannadb/vannadb_vector.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/vllm/__init__.py +0 -0
- {vanna-0.6.2 → vanna-0.6.3}/src/vanna/vllm/vllm.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vanna
|
|
3
|
-
Version: 0.6.
|
|
3
|
+
Version: 0.6.3
|
|
4
4
|
Summary: Generate SQL queries from natural language
|
|
5
5
|
Author-email: Zain Hoda <zain@vanna.ai>
|
|
6
6
|
Requires-Python: >=3.9
|
|
@@ -40,7 +40,10 @@ Requires-Dist: opensearch-dsl ; extra == "all"
|
|
|
40
40
|
Requires-Dist: transformers ; extra == "all"
|
|
41
41
|
Requires-Dist: pinecone-client ; extra == "all"
|
|
42
42
|
Requires-Dist: pymilvus[model] ; extra == "all"
|
|
43
|
+
Requires-Dist: weaviate-client ; extra == "all"
|
|
43
44
|
Requires-Dist: anthropic ; extra == "anthropic"
|
|
45
|
+
Requires-Dist: boto3 ; extra == "bedrock"
|
|
46
|
+
Requires-Dist: botocore ; extra == "bedrock"
|
|
44
47
|
Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
|
|
45
48
|
Requires-Dist: chromadb ; extra == "chromadb"
|
|
46
49
|
Requires-Dist: clickhouse_connect ; extra == "clickhouse"
|
|
@@ -67,11 +70,13 @@ Requires-Dist: fastembed ; extra == "qdrant"
|
|
|
67
70
|
Requires-Dist: snowflake-connector-python ; extra == "snowflake"
|
|
68
71
|
Requires-Dist: tox ; extra == "test"
|
|
69
72
|
Requires-Dist: vllm ; extra == "vllm"
|
|
73
|
+
Requires-Dist: weaviate-client ; extra == "weaviate"
|
|
70
74
|
Requires-Dist: zhipuai ; extra == "zhipuai"
|
|
71
75
|
Project-URL: Bug Tracker, https://github.com/vanna-ai/vanna/issues
|
|
72
76
|
Project-URL: Homepage, https://github.com/vanna-ai/vanna
|
|
73
77
|
Provides-Extra: all
|
|
74
78
|
Provides-Extra: anthropic
|
|
79
|
+
Provides-Extra: bedrock
|
|
75
80
|
Provides-Extra: bigquery
|
|
76
81
|
Provides-Extra: chromadb
|
|
77
82
|
Provides-Extra: clickhouse
|
|
@@ -92,6 +97,7 @@ Provides-Extra: qdrant
|
|
|
92
97
|
Provides-Extra: snowflake
|
|
93
98
|
Provides-Extra: test
|
|
94
99
|
Provides-Extra: vllm
|
|
100
|
+
Provides-Extra: weaviate
|
|
95
101
|
Provides-Extra: zhipuai
|
|
96
102
|
|
|
97
103
|
|
|
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "vanna"
|
|
7
|
-
version = "0.6.
|
|
7
|
+
version = "0.6.3"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name="Zain Hoda", email="zain@vanna.ai" },
|
|
10
10
|
]
|
|
@@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
|
|
|
33
33
|
snowflake = ["snowflake-connector-python"]
|
|
34
34
|
duckdb = ["duckdb"]
|
|
35
35
|
google = ["google-generativeai", "google-cloud-aiplatform"]
|
|
36
|
-
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]"]
|
|
36
|
+
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client"]
|
|
37
37
|
test = ["tox"]
|
|
38
38
|
chromadb = ["chromadb"]
|
|
39
39
|
openai = ["openai"]
|
|
@@ -49,3 +49,5 @@ pinecone = ["pinecone-client", "fastembed"]
|
|
|
49
49
|
opensearch = ["opensearch-py", "opensearch-dsl"]
|
|
50
50
|
hf = ["transformers"]
|
|
51
51
|
milvus = ["pymilvus[model]"]
|
|
52
|
+
bedrock = ["boto3", "botocore"]
|
|
53
|
+
weaviate = ["weaviate-client"]
|
|
@@ -182,7 +182,7 @@ class VannaBase(ABC):
|
|
|
182
182
|
"""
|
|
183
183
|
|
|
184
184
|
# If the llm_response contains a CTE (with clause), extract the last sql between WITH and ;
|
|
185
|
-
sqls = re.findall(r"
|
|
185
|
+
sqls = re.findall(r"\bWITH\b .*?;", llm_response, re.DOTALL)
|
|
186
186
|
if sqls:
|
|
187
187
|
sql = sqls[-1]
|
|
188
188
|
self.log(title="Extracted SQL", message=f"{sql}")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .bedrock_converse import Bedrock_Converse
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from ..base import VannaBase
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import boto3
|
|
5
|
+
from botocore.exceptions import ClientError
|
|
6
|
+
except ImportError:
|
|
7
|
+
raise ImportError("Please install boto3 and botocore to use Amazon Bedrock models")
|
|
8
|
+
|
|
9
|
+
class Bedrock_Converse(VannaBase):
|
|
10
|
+
def __init__(self, client=None, config=None):
|
|
11
|
+
VannaBase.__init__(self, config=config)
|
|
12
|
+
|
|
13
|
+
# default parameters
|
|
14
|
+
self.temperature = 0.0
|
|
15
|
+
self.max_tokens = 500
|
|
16
|
+
|
|
17
|
+
if client is None:
|
|
18
|
+
raise ValueError(
|
|
19
|
+
"A valid Bedrock runtime client must be provided to invoke Bedrock models"
|
|
20
|
+
)
|
|
21
|
+
else:
|
|
22
|
+
self.client = client
|
|
23
|
+
|
|
24
|
+
if config is None:
|
|
25
|
+
raise ValueError(
|
|
26
|
+
"Config is required with model_id and inference parameters"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
if "modelId" not in config:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
"config must contain a modelId to invoke"
|
|
32
|
+
)
|
|
33
|
+
else:
|
|
34
|
+
self.model = config["modelId"]
|
|
35
|
+
|
|
36
|
+
if "temperature" in config:
|
|
37
|
+
self.temperature = config["temperature"]
|
|
38
|
+
|
|
39
|
+
if "max_tokens" in config:
|
|
40
|
+
self.max_tokens = config["max_tokens"]
|
|
41
|
+
|
|
42
|
+
def system_message(self, message: str) -> dict:
|
|
43
|
+
return {"role": "system", "content": message}
|
|
44
|
+
|
|
45
|
+
def user_message(self, message: str) -> dict:
|
|
46
|
+
return {"role": "user", "content": message}
|
|
47
|
+
|
|
48
|
+
def assistant_message(self, message: str) -> dict:
|
|
49
|
+
return {"role": "assistant", "content": message}
|
|
50
|
+
|
|
51
|
+
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
52
|
+
inference_config = {
|
|
53
|
+
"temperature": self.temperature,
|
|
54
|
+
"maxTokens": self.max_tokens
|
|
55
|
+
}
|
|
56
|
+
additional_model_fields = {
|
|
57
|
+
"top_p": 1, # setting top_p value for nucleus sampling
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
system_message = None
|
|
61
|
+
no_system_prompt = []
|
|
62
|
+
for prompt_message in prompt:
|
|
63
|
+
role = prompt_message["role"]
|
|
64
|
+
if role == "system":
|
|
65
|
+
system_message = prompt_message["content"]
|
|
66
|
+
else:
|
|
67
|
+
no_system_prompt.append({"role": role, "content":[{"text": prompt_message["content"]}]})
|
|
68
|
+
|
|
69
|
+
converse_api_params = {
|
|
70
|
+
"modelId": self.model,
|
|
71
|
+
"messages": no_system_prompt,
|
|
72
|
+
"inferenceConfig": inference_config,
|
|
73
|
+
"additionalModelRequestFields": additional_model_fields
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
if system_message:
|
|
77
|
+
converse_api_params["system"] = [{"text": system_message}]
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
response = self.client.converse(**converse_api_params)
|
|
81
|
+
text_content = response["output"]["message"]["content"][0]["text"]
|
|
82
|
+
return text_content
|
|
83
|
+
except ClientError as err:
|
|
84
|
+
message = err.response["Error"]["Message"]
|
|
85
|
+
raise Exception(f"A Bedrock client error occurred: {message}")
|
|
@@ -7,7 +7,7 @@ class GoogleGeminiChat(VannaBase):
|
|
|
7
7
|
VannaBase.__init__(self, config=config)
|
|
8
8
|
|
|
9
9
|
# default temperature - can be overrided using config
|
|
10
|
-
self.temperature = 0.7
|
|
10
|
+
self.temperature = 0.7
|
|
11
11
|
|
|
12
12
|
if "temperature" in config:
|
|
13
13
|
self.temperature = config["temperature"]
|
|
@@ -31,7 +31,7 @@ class GoogleGeminiChat(VannaBase):
|
|
|
31
31
|
else:
|
|
32
32
|
# Authenticate using VertexAI
|
|
33
33
|
from vertexai.preview.generative_models import GenerativeModel
|
|
34
|
-
self.chat_model = GenerativeModel(
|
|
34
|
+
self.chat_model = GenerativeModel(model_name)
|
|
35
35
|
|
|
36
36
|
def system_message(self, message: str) -> any:
|
|
37
37
|
return message
|
|
@@ -6,13 +6,15 @@ from ..base import VannaBase
|
|
|
6
6
|
|
|
7
7
|
class Hf(VannaBase):
|
|
8
8
|
def __init__(self, config=None):
|
|
9
|
-
|
|
10
|
-
"
|
|
11
|
-
) # e.g. meta-llama/Meta-Llama-3-8B-Instruct
|
|
12
|
-
|
|
9
|
+
model_name_or_path = self.config.get(
|
|
10
|
+
"model_name_or_path", None
|
|
11
|
+
) # e.g. meta-llama/Meta-Llama-3-8B-Instruct or local path to the model checkpoint files
|
|
12
|
+
# list of quantization methods supported by transformers package: https://huggingface.co/docs/transformers/main/en/quantization/overview
|
|
13
|
+
quantization_config = self.config.get("quantization_config", None)
|
|
14
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
13
15
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
14
|
-
|
|
15
|
-
|
|
16
|
+
model_name_or_path,
|
|
17
|
+
quantization_config=quantization_config,
|
|
16
18
|
device_map="auto",
|
|
17
19
|
)
|
|
18
20
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .weaviate_vector import WeaviateDatabase
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
import weaviate
|
|
2
|
+
import weaviate.classes as wvc
|
|
3
|
+
from fastembed import TextEmbedding
|
|
4
|
+
|
|
5
|
+
from vanna.base import VannaBase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class WeaviateDatabase(VannaBase):
|
|
9
|
+
|
|
10
|
+
def __init__(self, config=None):
|
|
11
|
+
"""
|
|
12
|
+
Initialize the VannaEnhanced class with the provided configuration.
|
|
13
|
+
|
|
14
|
+
:param config: Dictionary containing configuration parameters.
|
|
15
|
+
|
|
16
|
+
params:
|
|
17
|
+
weaviate_url (str): Weaviate cluster URL while using weaviate cloud,
|
|
18
|
+
weaviate_api_key (str): Weaviate API key while using weaviate cloud,
|
|
19
|
+
weaviate_port (num): Weaviate port while using local weaviate,
|
|
20
|
+
weaviate_grpc (num): Weaviate gRPC port while using local weaviate,
|
|
21
|
+
fastembed_model (str): Fastembed model name for text embeddings. BAAI/bge-small-en-v1.5 by default.
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
super().__init__(config=config)
|
|
25
|
+
|
|
26
|
+
if config is None:
|
|
27
|
+
raise ValueError("config is required")
|
|
28
|
+
|
|
29
|
+
self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
|
|
30
|
+
self.weaviate_api_key = config.get("weaviate_api_key")
|
|
31
|
+
self.weaviate_url = config.get("weaviate_url")
|
|
32
|
+
self.weaviate_port = config.get("weaviate_port")
|
|
33
|
+
self.weaviate_grpc_port = config.get("weaviate_grpc", 50051)
|
|
34
|
+
|
|
35
|
+
if not self.weaviate_api_key and not self.weaviate_port:
|
|
36
|
+
raise ValueError("Add proper credentials to connect to weaviate")
|
|
37
|
+
|
|
38
|
+
self.weaviate_client = self._initialize_weaviate_client()
|
|
39
|
+
self.embeddings = TextEmbedding(model_name=self.fastembed_model)
|
|
40
|
+
|
|
41
|
+
self.training_data_cluster = {
|
|
42
|
+
"sql": "SQLTrainingDataEntry",
|
|
43
|
+
"ddl": "DDLEntry",
|
|
44
|
+
"doc": "DocumentationEntry"
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
self._create_collections_if_not_exist()
|
|
48
|
+
|
|
49
|
+
def _create_collections_if_not_exist(self):
|
|
50
|
+
properties_dict = {
|
|
51
|
+
self.training_data_cluster['ddl']: [
|
|
52
|
+
wvc.config.Property(name="description", data_type=wvc.config.DataType.TEXT),
|
|
53
|
+
],
|
|
54
|
+
self.training_data_cluster['doc']: [
|
|
55
|
+
wvc.config.Property(name="description", data_type=wvc.config.DataType.TEXT),
|
|
56
|
+
],
|
|
57
|
+
self.training_data_cluster['sql']: [
|
|
58
|
+
wvc.config.Property(name="sql", data_type=wvc.config.DataType.TEXT),
|
|
59
|
+
wvc.config.Property(name="natural_language_question", data_type=wvc.config.DataType.TEXT),
|
|
60
|
+
]
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
for cluster, properties in properties_dict.items():
|
|
64
|
+
if not self.weaviate_client.collections.exists(cluster):
|
|
65
|
+
self.weaviate_client.collections.create(
|
|
66
|
+
name=cluster,
|
|
67
|
+
properties=properties
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def _initialize_weaviate_client(self):
|
|
71
|
+
if self.weaviate_api_key:
|
|
72
|
+
return weaviate.connect_to_wcs(
|
|
73
|
+
cluster_url=self.weaviate_url,
|
|
74
|
+
auth_credentials=weaviate.auth.AuthApiKey(self.weaviate_api_key),
|
|
75
|
+
additional_config=weaviate.config.AdditionalConfig(timeout=(10, 300)),
|
|
76
|
+
skip_init_checks=True
|
|
77
|
+
)
|
|
78
|
+
else:
|
|
79
|
+
return weaviate.connect_to_local(
|
|
80
|
+
port=self.weaviate_port,
|
|
81
|
+
grpc_port=self.weaviate_grpc_port,
|
|
82
|
+
additional_config=weaviate.config.AdditionalConfig(timeout=(10, 300)),
|
|
83
|
+
skip_init_checks=True
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def generate_embedding(self, data: str, **kwargs):
|
|
87
|
+
embedding_model = TextEmbedding(model_name=self.fastembed_model)
|
|
88
|
+
embedding = next(embedding_model.embed(data))
|
|
89
|
+
return embedding.tolist()
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _insert_data(self, cluster_key: str, data_object: dict, vector: list) -> str:
|
|
93
|
+
self.weaviate_client.connect()
|
|
94
|
+
response = self.weaviate_client.collections.get(self.training_data_cluster[cluster_key]).data.insert(
|
|
95
|
+
properties=data_object,
|
|
96
|
+
vector=vector
|
|
97
|
+
)
|
|
98
|
+
self.weaviate_client.close()
|
|
99
|
+
return response
|
|
100
|
+
|
|
101
|
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
102
|
+
data_object = {
|
|
103
|
+
"description": ddl,
|
|
104
|
+
}
|
|
105
|
+
response = self._insert_data('ddl', data_object, self.generate_embedding(ddl))
|
|
106
|
+
return f'{response}-ddl'
|
|
107
|
+
|
|
108
|
+
def add_documentation(self, doc: str, **kwargs) -> str:
|
|
109
|
+
data_object = {
|
|
110
|
+
"description": doc,
|
|
111
|
+
}
|
|
112
|
+
response = self._insert_data('doc', data_object, self.generate_embedding(doc))
|
|
113
|
+
return f'{response}-doc'
|
|
114
|
+
|
|
115
|
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
116
|
+
data_object = {
|
|
117
|
+
"sql": sql,
|
|
118
|
+
"natural_language_question": question,
|
|
119
|
+
}
|
|
120
|
+
response = self._insert_data('sql', data_object, self.generate_embedding(question))
|
|
121
|
+
return f'{response}-sql'
|
|
122
|
+
|
|
123
|
+
def _query_collection(self, cluster_key: str, vector_input: list, return_properties: list, limit: int = 3) -> list:
|
|
124
|
+
self.weaviate_client.connect()
|
|
125
|
+
collection = self.weaviate_client.collections.get(self.training_data_cluster[cluster_key])
|
|
126
|
+
response = collection.query.near_vector(
|
|
127
|
+
near_vector=vector_input,
|
|
128
|
+
limit=limit,
|
|
129
|
+
return_properties=return_properties
|
|
130
|
+
)
|
|
131
|
+
response_list = [item.properties for item in response.objects]
|
|
132
|
+
self.weaviate_client.close()
|
|
133
|
+
return response_list
|
|
134
|
+
|
|
135
|
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
136
|
+
vector_input = self.generate_embedding(question)
|
|
137
|
+
response_list = self._query_collection('ddl', vector_input, ["description"])
|
|
138
|
+
return [item["description"] for item in response_list]
|
|
139
|
+
|
|
140
|
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
141
|
+
vector_input = self.generate_embedding(question)
|
|
142
|
+
response_list = self._query_collection('doc', vector_input, ["description"])
|
|
143
|
+
return [item["description"] for item in response_list]
|
|
144
|
+
|
|
145
|
+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
146
|
+
vector_input = self.generate_embedding(question)
|
|
147
|
+
response_list = self._query_collection('sql', vector_input, ["sql", "natural_language_question"])
|
|
148
|
+
return [{"question": item["natural_language_question"], "sql": item["sql"]} for item in response_list]
|
|
149
|
+
|
|
150
|
+
def get_training_data(self, **kwargs) -> list:
|
|
151
|
+
self.weaviate_client.connect()
|
|
152
|
+
combined_response_list = []
|
|
153
|
+
for collection_name in self.training_data_cluster.values():
|
|
154
|
+
if self.weaviate_client.collections.exists(collection_name):
|
|
155
|
+
collection = self.weaviate_client.collections.get(collection_name)
|
|
156
|
+
response_list = [item.properties for item in collection.iterator()]
|
|
157
|
+
combined_response_list.extend(response_list)
|
|
158
|
+
self.weaviate_client.close()
|
|
159
|
+
return combined_response_list
|
|
160
|
+
|
|
161
|
+
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
162
|
+
self.weaviate_client.connect()
|
|
163
|
+
success = False
|
|
164
|
+
if id.endswith("-sql"):
|
|
165
|
+
id = id.replace('-sql', '')
|
|
166
|
+
success = self.weaviate_client.collections.get(self.training_data_cluster['sql']).data.delete_by_id(id)
|
|
167
|
+
elif id.endswith("-ddl"):
|
|
168
|
+
id = id.replace('-ddl', '')
|
|
169
|
+
success = self.weaviate_client.collections.get(self.training_data_cluster['ddl']).data.delete_by_id(id)
|
|
170
|
+
elif id.endswith("-doc"):
|
|
171
|
+
id = id.replace('-doc', '')
|
|
172
|
+
success = self.weaviate_client.collections.get(self.training_data_cluster['doc']).data.delete_by_id(id)
|
|
173
|
+
self.weaviate_client.close()
|
|
174
|
+
return success
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|