vanna 0.6.0__tar.gz → 0.6.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {vanna-0.6.0 → vanna-0.6.2}/PKG-INFO +4 -1
- {vanna-0.6.0 → vanna-0.6.2}/pyproject.toml +3 -2
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/base/base.py +1 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/flask/__init__.py +2 -5
- vanna-0.6.2/src/vanna/milvus/__init__.py +1 -0
- vanna-0.6.2/src/vanna/milvus/milvus_vector.py +305 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/qdrant/qdrant.py +10 -14
- {vanna-0.6.0 → vanna-0.6.2}/README.md +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/ZhipuAI/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/advanced/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/anthropic/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/anthropic/anthropic_chat.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/base/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/chromadb/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/chromadb/chromadb_vector.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/exceptions/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/flask/assets.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/flask/auth.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/google/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/google/gemini_chat.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/hf/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/hf/hf.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/local.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/marqo/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/marqo/marqo.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/mistral/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/mistral/mistral.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/mock/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/mock/embedding.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/mock/llm.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/mock/vectordb.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/ollama/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/ollama/ollama.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/openai/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/openai/openai_chat.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/openai/openai_embeddings.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/opensearch/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/opensearch/opensearch_vector.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/pinecone/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/pinecone/pinecone_vector.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/qdrant/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/remote.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/types/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/utils.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/vannadb/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/vannadb/vannadb_vector.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/vllm/__init__.py +0 -0
- {vanna-0.6.0 → vanna-0.6.2}/src/vanna/vllm/vllm.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vanna
|
|
3
|
-
Version: 0.6.
|
|
3
|
+
Version: 0.6.2
|
|
4
4
|
Summary: Generate SQL queries from natural language
|
|
5
5
|
Author-email: Zain Hoda <zain@vanna.ai>
|
|
6
6
|
Requires-Python: >=3.9
|
|
@@ -39,6 +39,7 @@ Requires-Dist: opensearch-py ; extra == "all"
|
|
|
39
39
|
Requires-Dist: opensearch-dsl ; extra == "all"
|
|
40
40
|
Requires-Dist: transformers ; extra == "all"
|
|
41
41
|
Requires-Dist: pinecone-client ; extra == "all"
|
|
42
|
+
Requires-Dist: pymilvus[model] ; extra == "all"
|
|
42
43
|
Requires-Dist: anthropic ; extra == "anthropic"
|
|
43
44
|
Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
|
|
44
45
|
Requires-Dist: chromadb ; extra == "chromadb"
|
|
@@ -49,6 +50,7 @@ Requires-Dist: google-generativeai ; extra == "google"
|
|
|
49
50
|
Requires-Dist: google-cloud-aiplatform ; extra == "google"
|
|
50
51
|
Requires-Dist: transformers ; extra == "hf"
|
|
51
52
|
Requires-Dist: marqo ; extra == "marqo"
|
|
53
|
+
Requires-Dist: pymilvus[model] ; extra == "milvus"
|
|
52
54
|
Requires-Dist: mistralai ; extra == "mistralai"
|
|
53
55
|
Requires-Dist: PyMySQL ; extra == "mysql"
|
|
54
56
|
Requires-Dist: ollama ; extra == "ollama"
|
|
@@ -78,6 +80,7 @@ Provides-Extra: gemini
|
|
|
78
80
|
Provides-Extra: google
|
|
79
81
|
Provides-Extra: hf
|
|
80
82
|
Provides-Extra: marqo
|
|
83
|
+
Provides-Extra: milvus
|
|
81
84
|
Provides-Extra: mistralai
|
|
82
85
|
Provides-Extra: mysql
|
|
83
86
|
Provides-Extra: ollama
|
|
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "vanna"
|
|
7
|
-
version = "0.6.
|
|
7
|
+
version = "0.6.2"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name="Zain Hoda", email="zain@vanna.ai" },
|
|
10
10
|
]
|
|
@@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
|
|
|
33
33
|
snowflake = ["snowflake-connector-python"]
|
|
34
34
|
duckdb = ["duckdb"]
|
|
35
35
|
google = ["google-generativeai", "google-cloud-aiplatform"]
|
|
36
|
-
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client"]
|
|
36
|
+
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]"]
|
|
37
37
|
test = ["tox"]
|
|
38
38
|
chromadb = ["chromadb"]
|
|
39
39
|
openai = ["openai"]
|
|
@@ -48,3 +48,4 @@ vllm = ["vllm"]
|
|
|
48
48
|
pinecone = ["pinecone-client", "fastembed"]
|
|
49
49
|
opensearch = ["opensearch-py", "opensearch-dsl"]
|
|
50
50
|
hf = ["transformers"]
|
|
51
|
+
milvus = ["pymilvus[model]"]
|
|
@@ -494,15 +494,12 @@ class VannaFlaskApp:
|
|
|
494
494
|
def generate_plotly_figure(user: any, id: str, df, question, sql):
|
|
495
495
|
chart_instructions = flask.request.args.get('chart_instructions')
|
|
496
496
|
|
|
497
|
-
if chart_instructions is not None:
|
|
498
|
-
question = f"{question}. When generating the chart, use these special instructions: {chart_instructions}"
|
|
499
|
-
|
|
500
497
|
try:
|
|
501
498
|
# If chart_instructions is not set then attempt to retrieve the code from the cache
|
|
502
499
|
if chart_instructions is None or len(chart_instructions) == 0:
|
|
503
500
|
code = self.cache.get(id=id, field="plotly_code")
|
|
504
|
-
|
|
505
|
-
|
|
501
|
+
else:
|
|
502
|
+
question = f"{question}. When generating the chart, use these special instructions: {chart_instructions}"
|
|
506
503
|
code = vn.generate_plotly_code(
|
|
507
504
|
question=question,
|
|
508
505
|
sql=sql,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .milvus_vector import Milvus_VectorStore
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from pymilvus import DataType, MilvusClient, model
|
|
6
|
+
|
|
7
|
+
from ..base import VannaBase
|
|
8
|
+
|
|
9
|
+
# Setting the URI as a local file, e.g.`./milvus.db`,
|
|
10
|
+
# is the most convenient method, as it automatically utilizes Milvus Lite
|
|
11
|
+
# to store all data in this file.
|
|
12
|
+
#
|
|
13
|
+
# If you have large scale of data such as more than a million docs, we
|
|
14
|
+
# recommend setting up a more performant Milvus server on docker or kubernetes.
|
|
15
|
+
# When using this setup, please use the server URI,
|
|
16
|
+
# e.g.`http://localhost:19530`, as your URI.
|
|
17
|
+
|
|
18
|
+
DEFAULT_MILVUS_URI = "./milvus.db"
|
|
19
|
+
# DEFAULT_MILVUS_URI = "http://localhost:19530"
|
|
20
|
+
|
|
21
|
+
MAX_LIMIT_SIZE = 10_000
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Milvus_VectorStore(VannaBase):
|
|
25
|
+
"""
|
|
26
|
+
Vectorstore implementation using Milvus - https://milvus.io/docs/quickstart.md
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
- config (dict, optional): Dictionary of `Milvus_VectorStore config` options. Defaults to `None`.
|
|
30
|
+
- milvus_client: A `pymilvus.MilvusClient` instance.
|
|
31
|
+
- embedding_function:
|
|
32
|
+
A `milvus_model.base.BaseEmbeddingFunction` instance. Defaults to `DefaultEmbeddingFunction()`.
|
|
33
|
+
For more models, please refer to:
|
|
34
|
+
https://milvus.io/docs/embeddings.md
|
|
35
|
+
"""
|
|
36
|
+
def __init__(self, config=None):
|
|
37
|
+
VannaBase.__init__(self, config=config)
|
|
38
|
+
|
|
39
|
+
if "milvus_client" in config:
|
|
40
|
+
self.milvus_client = config["milvus_client"]
|
|
41
|
+
else:
|
|
42
|
+
self.milvus_client = MilvusClient(uri=DEFAULT_MILVUS_URI)
|
|
43
|
+
|
|
44
|
+
if "embedding_function" in config:
|
|
45
|
+
self.embedding_function = config.get("embedding_function")
|
|
46
|
+
else:
|
|
47
|
+
self.embedding_function = model.DefaultEmbeddingFunction()
|
|
48
|
+
self._embedding_dim = self.embedding_function.encode_documents(["foo"])[0].shape[0]
|
|
49
|
+
self._create_collections()
|
|
50
|
+
self.n_results = config.get("n_results", 10)
|
|
51
|
+
|
|
52
|
+
def _create_collections(self):
|
|
53
|
+
self._create_sql_collection("vannasql")
|
|
54
|
+
self._create_ddl_collection("vannaddl")
|
|
55
|
+
self._create_doc_collection("vannadoc")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
59
|
+
return self.embedding_function.encode_documents(data).tolist()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _create_sql_collection(self, name: str):
|
|
63
|
+
if not self.milvus_client.has_collection(collection_name=name):
|
|
64
|
+
vannasql_schema = MilvusClient.create_schema(
|
|
65
|
+
auto_id=False,
|
|
66
|
+
enable_dynamic_field=False,
|
|
67
|
+
)
|
|
68
|
+
vannasql_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
|
|
69
|
+
vannasql_schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
|
|
70
|
+
vannasql_schema.add_field(field_name="sql", datatype=DataType.VARCHAR, max_length=65535)
|
|
71
|
+
vannasql_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
|
|
72
|
+
|
|
73
|
+
vannasql_index_params = self.milvus_client.prepare_index_params()
|
|
74
|
+
vannasql_index_params.add_index(
|
|
75
|
+
field_name="vector",
|
|
76
|
+
index_name="vector",
|
|
77
|
+
index_type="AUTOINDEX",
|
|
78
|
+
metric_type="L2",
|
|
79
|
+
)
|
|
80
|
+
self.milvus_client.create_collection(
|
|
81
|
+
collection_name=name,
|
|
82
|
+
schema=vannasql_schema,
|
|
83
|
+
index_params=vannasql_index_params,
|
|
84
|
+
consistency_level="Strong"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def _create_ddl_collection(self, name: str):
|
|
88
|
+
if not self.milvus_client.has_collection(collection_name=name):
|
|
89
|
+
vannaddl_schema = MilvusClient.create_schema(
|
|
90
|
+
auto_id=False,
|
|
91
|
+
enable_dynamic_field=False,
|
|
92
|
+
)
|
|
93
|
+
vannaddl_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
|
|
94
|
+
vannaddl_schema.add_field(field_name="ddl", datatype=DataType.VARCHAR, max_length=65535)
|
|
95
|
+
vannaddl_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
|
|
96
|
+
|
|
97
|
+
vannaddl_index_params = self.milvus_client.prepare_index_params()
|
|
98
|
+
vannaddl_index_params.add_index(
|
|
99
|
+
field_name="vector",
|
|
100
|
+
index_name="vector",
|
|
101
|
+
index_type="AUTOINDEX",
|
|
102
|
+
metric_type="L2",
|
|
103
|
+
)
|
|
104
|
+
self.milvus_client.create_collection(
|
|
105
|
+
collection_name=name,
|
|
106
|
+
schema=vannaddl_schema,
|
|
107
|
+
index_params=vannaddl_index_params,
|
|
108
|
+
consistency_level="Strong"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def _create_doc_collection(self, name: str):
|
|
112
|
+
if not self.milvus_client.has_collection(collection_name=name):
|
|
113
|
+
vannadoc_schema = MilvusClient.create_schema(
|
|
114
|
+
auto_id=False,
|
|
115
|
+
enable_dynamic_field=False,
|
|
116
|
+
)
|
|
117
|
+
vannadoc_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
|
|
118
|
+
vannadoc_schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
|
|
119
|
+
vannadoc_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
|
|
120
|
+
|
|
121
|
+
vannadoc_index_params = self.milvus_client.prepare_index_params()
|
|
122
|
+
vannadoc_index_params.add_index(
|
|
123
|
+
field_name="vector",
|
|
124
|
+
index_name="vector",
|
|
125
|
+
index_type="AUTOINDEX",
|
|
126
|
+
metric_type="L2",
|
|
127
|
+
)
|
|
128
|
+
self.milvus_client.create_collection(
|
|
129
|
+
collection_name=name,
|
|
130
|
+
schema=vannadoc_schema,
|
|
131
|
+
index_params=vannadoc_index_params,
|
|
132
|
+
consistency_level="Strong"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
136
|
+
if len(question) == 0 or len(sql) == 0:
|
|
137
|
+
raise Exception("pair of question and sql can not be null")
|
|
138
|
+
_id = str(uuid.uuid4()) + "-sql"
|
|
139
|
+
embedding = self.embedding_function.encode_documents([question])[0]
|
|
140
|
+
self.milvus_client.insert(
|
|
141
|
+
collection_name="vannasql",
|
|
142
|
+
data={
|
|
143
|
+
"id": _id,
|
|
144
|
+
"text": question,
|
|
145
|
+
"sql": sql,
|
|
146
|
+
"vector": embedding
|
|
147
|
+
}
|
|
148
|
+
)
|
|
149
|
+
return _id
|
|
150
|
+
|
|
151
|
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
152
|
+
if len(ddl) == 0:
|
|
153
|
+
raise Exception("ddl can not be null")
|
|
154
|
+
_id = str(uuid.uuid4()) + "-ddl"
|
|
155
|
+
embedding = self.embedding_function.encode_documents([ddl])[0]
|
|
156
|
+
self.milvus_client.insert(
|
|
157
|
+
collection_name="vannaddl",
|
|
158
|
+
data={
|
|
159
|
+
"id": _id,
|
|
160
|
+
"ddl": ddl,
|
|
161
|
+
"vector": embedding
|
|
162
|
+
}
|
|
163
|
+
)
|
|
164
|
+
return _id
|
|
165
|
+
|
|
166
|
+
def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
167
|
+
if len(documentation) == 0:
|
|
168
|
+
raise Exception("documentation can not be null")
|
|
169
|
+
_id = str(uuid.uuid4()) + "-doc"
|
|
170
|
+
embedding = self.embedding_function.encode_documents([documentation])[0]
|
|
171
|
+
self.milvus_client.insert(
|
|
172
|
+
collection_name="vannadoc",
|
|
173
|
+
data={
|
|
174
|
+
"id": _id,
|
|
175
|
+
"doc": documentation,
|
|
176
|
+
"vector": embedding
|
|
177
|
+
}
|
|
178
|
+
)
|
|
179
|
+
return _id
|
|
180
|
+
|
|
181
|
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
182
|
+
sql_data = self.milvus_client.query(
|
|
183
|
+
collection_name="vannasql",
|
|
184
|
+
output_fields=["*"],
|
|
185
|
+
limit=MAX_LIMIT_SIZE,
|
|
186
|
+
)
|
|
187
|
+
df = pd.DataFrame()
|
|
188
|
+
df_sql = pd.DataFrame(
|
|
189
|
+
{
|
|
190
|
+
"id": [doc["id"] for doc in sql_data],
|
|
191
|
+
"question": [doc["text"] for doc in sql_data],
|
|
192
|
+
"content": [doc["sql"] for doc in sql_data],
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
df = pd.concat([df, df_sql])
|
|
196
|
+
|
|
197
|
+
ddl_data = self.milvus_client.query(
|
|
198
|
+
collection_name="vannaddl",
|
|
199
|
+
output_fields=["*"],
|
|
200
|
+
limit=MAX_LIMIT_SIZE,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
df_ddl = pd.DataFrame(
|
|
204
|
+
{
|
|
205
|
+
"id": [doc["id"] for doc in ddl_data],
|
|
206
|
+
"question": [None for doc in ddl_data],
|
|
207
|
+
"content": [doc["ddl"] for doc in ddl_data],
|
|
208
|
+
}
|
|
209
|
+
)
|
|
210
|
+
df = pd.concat([df, df_ddl])
|
|
211
|
+
|
|
212
|
+
doc_data = self.milvus_client.query(
|
|
213
|
+
collection_name="vannadoc",
|
|
214
|
+
output_fields=["*"],
|
|
215
|
+
limit=MAX_LIMIT_SIZE,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
df_doc = pd.DataFrame(
|
|
219
|
+
{
|
|
220
|
+
"id": [doc["id"] for doc in doc_data],
|
|
221
|
+
"question": [None for doc in doc_data],
|
|
222
|
+
"content": [doc["doc"] for doc in doc_data],
|
|
223
|
+
}
|
|
224
|
+
)
|
|
225
|
+
df = pd.concat([df, df_doc])
|
|
226
|
+
return df
|
|
227
|
+
|
|
228
|
+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
229
|
+
search_params = {
|
|
230
|
+
"metric_type": "L2",
|
|
231
|
+
"params": {"nprobe": 128},
|
|
232
|
+
}
|
|
233
|
+
embeddings = self.embedding_function.encode_queries([question])
|
|
234
|
+
res = self.milvus_client.search(
|
|
235
|
+
collection_name="vannasql",
|
|
236
|
+
anns_field="vector",
|
|
237
|
+
data=embeddings,
|
|
238
|
+
limit=self.n_results,
|
|
239
|
+
output_fields=["text", "sql"],
|
|
240
|
+
search_params=search_params
|
|
241
|
+
)
|
|
242
|
+
res = res[0]
|
|
243
|
+
|
|
244
|
+
list_sql = []
|
|
245
|
+
for doc in res:
|
|
246
|
+
dict = {}
|
|
247
|
+
dict["question"] = doc["entity"]["text"]
|
|
248
|
+
dict["sql"] = doc["entity"]["sql"]
|
|
249
|
+
list_sql.append(dict)
|
|
250
|
+
return list_sql
|
|
251
|
+
|
|
252
|
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
253
|
+
search_params = {
|
|
254
|
+
"metric_type": "L2",
|
|
255
|
+
"params": {"nprobe": 128},
|
|
256
|
+
}
|
|
257
|
+
embeddings = self.embedding_function.encode_queries([question])
|
|
258
|
+
res = self.milvus_client.search(
|
|
259
|
+
collection_name="vannaddl",
|
|
260
|
+
anns_field="vector",
|
|
261
|
+
data=embeddings,
|
|
262
|
+
limit=self.n_results,
|
|
263
|
+
output_fields=["ddl"],
|
|
264
|
+
search_params=search_params
|
|
265
|
+
)
|
|
266
|
+
res = res[0]
|
|
267
|
+
|
|
268
|
+
list_ddl = []
|
|
269
|
+
for doc in res:
|
|
270
|
+
list_ddl.append(doc["entity"]["ddl"])
|
|
271
|
+
return list_ddl
|
|
272
|
+
|
|
273
|
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
274
|
+
search_params = {
|
|
275
|
+
"metric_type": "L2",
|
|
276
|
+
"params": {"nprobe": 128},
|
|
277
|
+
}
|
|
278
|
+
embeddings = self.embedding_function.encode_queries([question])
|
|
279
|
+
res = self.milvus_client.search(
|
|
280
|
+
collection_name="vannadoc",
|
|
281
|
+
anns_field="vector",
|
|
282
|
+
data=embeddings,
|
|
283
|
+
limit=self.n_results,
|
|
284
|
+
output_fields=["doc"],
|
|
285
|
+
search_params=search_params
|
|
286
|
+
)
|
|
287
|
+
res = res[0]
|
|
288
|
+
|
|
289
|
+
list_doc = []
|
|
290
|
+
for doc in res:
|
|
291
|
+
list_doc.append(doc["entity"]["doc"])
|
|
292
|
+
return list_doc
|
|
293
|
+
|
|
294
|
+
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
295
|
+
if id.endswith("-sql"):
|
|
296
|
+
self.milvus_client.delete(collection_name="vannasql", ids=[id])
|
|
297
|
+
return True
|
|
298
|
+
elif id.endswith("-ddl"):
|
|
299
|
+
self.milvus_client.delete(collection_name="vannaddl", ids=[id])
|
|
300
|
+
return True
|
|
301
|
+
elif id.endswith("-doc"):
|
|
302
|
+
self.milvus_client.delete(collection_name="vannadoc", ids=[id])
|
|
303
|
+
return True
|
|
304
|
+
else:
|
|
305
|
+
return False
|
|
@@ -39,16 +39,6 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
39
39
|
TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
|
-
documentation_collection_name = "documentation"
|
|
43
|
-
ddl_collection_name = "ddl"
|
|
44
|
-
sql_collection_name = "sql"
|
|
45
|
-
|
|
46
|
-
id_suffixes = {
|
|
47
|
-
ddl_collection_name: "ddl",
|
|
48
|
-
documentation_collection_name: "doc",
|
|
49
|
-
sql_collection_name: "sql",
|
|
50
|
-
}
|
|
51
|
-
|
|
52
42
|
def __init__(
|
|
53
43
|
self,
|
|
54
44
|
config={},
|
|
@@ -80,15 +70,21 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
80
70
|
self.collection_params = config.get("collection_params", {})
|
|
81
71
|
self.distance_metric = config.get("distance_metric", models.Distance.COSINE)
|
|
82
72
|
self.documentation_collection_name = config.get(
|
|
83
|
-
"documentation_collection_name",
|
|
73
|
+
"documentation_collection_name", "documentation"
|
|
84
74
|
)
|
|
85
75
|
self.ddl_collection_name = config.get(
|
|
86
|
-
"ddl_collection_name",
|
|
76
|
+
"ddl_collection_name", "ddl"
|
|
87
77
|
)
|
|
88
78
|
self.sql_collection_name = config.get(
|
|
89
|
-
"sql_collection_name",
|
|
79
|
+
"sql_collection_name", "sql"
|
|
90
80
|
)
|
|
91
81
|
|
|
82
|
+
self.id_suffixes = {
|
|
83
|
+
self.ddl_collection_name: "ddl",
|
|
84
|
+
self.documentation_collection_name: "doc",
|
|
85
|
+
self.sql_collection_name: "sql",
|
|
86
|
+
}
|
|
87
|
+
|
|
92
88
|
self._setup_collections()
|
|
93
89
|
|
|
94
90
|
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
@@ -212,7 +208,7 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
212
208
|
try:
|
|
213
209
|
id, collection_name = self._parse_point_id(id)
|
|
214
210
|
res = self._client.delete(collection_name, points_selector=[id])
|
|
215
|
-
|
|
211
|
+
return True
|
|
216
212
|
except ValueError:
|
|
217
213
|
return False
|
|
218
214
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|