vanna 0.7.3__tar.gz → 0.7.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {vanna-0.7.3 → vanna-0.7.5}/PKG-INFO +16 -1
- {vanna-0.7.3 → vanna-0.7.5}/pyproject.toml +6 -2
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/google/bigquery_vector.py +34 -15
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/google/gemini_chat.py +2 -2
- vanna-0.7.5/src/vanna/pgvector/__init__.py +1 -0
- vanna-0.7.5/src/vanna/pgvector/pgvector.py +253 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/vllm/vllm.py +7 -0
- vanna-0.7.5/src/vanna/xinference/__init__.py +1 -0
- vanna-0.7.5/src/vanna/xinference/xinference.py +53 -0
- {vanna-0.7.3 → vanna-0.7.5}/README.md +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/ZhipuAI/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/advanced/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/anthropic/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/anthropic/anthropic_chat.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/azuresearch/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/azuresearch/azuresearch_vector.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/base/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/base/base.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/bedrock/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/bedrock/bedrock_converse.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/chromadb/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/chromadb/chromadb_vector.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/exceptions/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/faiss/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/faiss/faiss.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/flask/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/flask/assets.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/flask/auth.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/google/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/hf/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/hf/hf.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/local.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/marqo/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/marqo/marqo.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/milvus/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/milvus/milvus_vector.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/mistral/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/mistral/mistral.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/mock/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/mock/embedding.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/mock/llm.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/mock/vectordb.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/ollama/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/ollama/ollama.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/openai/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/openai/openai_chat.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/openai/openai_embeddings.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/opensearch/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/opensearch/opensearch_vector.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/pinecone/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/pinecone/pinecone_vector.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qdrant/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qdrant/qdrant.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qianfan/Qianfan_Chat.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qianfan/Qianfan_embeddings.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qianfan/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qianwen/QianwenAI_chat.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qianwen/QianwenAI_embeddings.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qianwen/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/remote.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/types/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/utils.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/vannadb/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/vannadb/vannadb_vector.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/vllm/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/weaviate/__init__.py +0 -0
- {vanna-0.7.3 → vanna-0.7.5}/src/vanna/weaviate/weaviate_vector.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vanna
|
|
3
|
-
Version: 0.7.
|
|
3
|
+
Version: 0.7.5
|
|
4
4
|
Summary: Generate SQL queries from natural language
|
|
5
5
|
Author-email: Zain Hoda <zain@vanna.ai>
|
|
6
6
|
Requires-Python: >=3.9
|
|
@@ -46,6 +46,13 @@ Requires-Dist: weaviate-client ; extra == "all"
|
|
|
46
46
|
Requires-Dist: azure-search-documents ; extra == "all"
|
|
47
47
|
Requires-Dist: azure-identity ; extra == "all"
|
|
48
48
|
Requires-Dist: azure-common ; extra == "all"
|
|
49
|
+
Requires-Dist: faiss-cpu ; extra == "all"
|
|
50
|
+
Requires-Dist: boto ; extra == "all"
|
|
51
|
+
Requires-Dist: boto3 ; extra == "all"
|
|
52
|
+
Requires-Dist: botocore ; extra == "all"
|
|
53
|
+
Requires-Dist: langchain_core ; extra == "all"
|
|
54
|
+
Requires-Dist: langchain_postgres ; extra == "all"
|
|
55
|
+
Requires-Dist: xinference-client ; extra == "all"
|
|
49
56
|
Requires-Dist: anthropic ; extra == "anthropic"
|
|
50
57
|
Requires-Dist: azure-search-documents ; extra == "azuresearch"
|
|
51
58
|
Requires-Dist: azure-identity ; extra == "azuresearch"
|
|
@@ -57,6 +64,8 @@ Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
|
|
|
57
64
|
Requires-Dist: chromadb ; extra == "chromadb"
|
|
58
65
|
Requires-Dist: clickhouse_connect ; extra == "clickhouse"
|
|
59
66
|
Requires-Dist: duckdb ; extra == "duckdb"
|
|
67
|
+
Requires-Dist: faiss-cpu ; extra == "faiss-cpu"
|
|
68
|
+
Requires-Dist: faiss-gpu ; extra == "faiss-gpu"
|
|
60
69
|
Requires-Dist: google-generativeai ; extra == "gemini"
|
|
61
70
|
Requires-Dist: google-generativeai ; extra == "google"
|
|
62
71
|
Requires-Dist: google-cloud-aiplatform ; extra == "google"
|
|
@@ -70,6 +79,7 @@ Requires-Dist: httpx ; extra == "ollama"
|
|
|
70
79
|
Requires-Dist: openai ; extra == "openai"
|
|
71
80
|
Requires-Dist: opensearch-py ; extra == "opensearch"
|
|
72
81
|
Requires-Dist: opensearch-dsl ; extra == "opensearch"
|
|
82
|
+
Requires-Dist: langchain-postgres>=0.0.12 ; extra == "pgvector"
|
|
73
83
|
Requires-Dist: pinecone-client ; extra == "pinecone"
|
|
74
84
|
Requires-Dist: fastembed ; extra == "pinecone"
|
|
75
85
|
Requires-Dist: psycopg2-binary ; extra == "postgres"
|
|
@@ -81,6 +91,7 @@ Requires-Dist: snowflake-connector-python ; extra == "snowflake"
|
|
|
81
91
|
Requires-Dist: tox ; extra == "test"
|
|
82
92
|
Requires-Dist: vllm ; extra == "vllm"
|
|
83
93
|
Requires-Dist: weaviate-client ; extra == "weaviate"
|
|
94
|
+
Requires-Dist: xinference-client ; extra == "xinference-client"
|
|
84
95
|
Requires-Dist: zhipuai ; extra == "zhipuai"
|
|
85
96
|
Project-URL: Bug Tracker, https://github.com/vanna-ai/vanna/issues
|
|
86
97
|
Project-URL: Homepage, https://github.com/vanna-ai/vanna
|
|
@@ -92,6 +103,8 @@ Provides-Extra: bigquery
|
|
|
92
103
|
Provides-Extra: chromadb
|
|
93
104
|
Provides-Extra: clickhouse
|
|
94
105
|
Provides-Extra: duckdb
|
|
106
|
+
Provides-Extra: faiss-cpu
|
|
107
|
+
Provides-Extra: faiss-gpu
|
|
95
108
|
Provides-Extra: gemini
|
|
96
109
|
Provides-Extra: google
|
|
97
110
|
Provides-Extra: hf
|
|
@@ -102,6 +115,7 @@ Provides-Extra: mysql
|
|
|
102
115
|
Provides-Extra: ollama
|
|
103
116
|
Provides-Extra: openai
|
|
104
117
|
Provides-Extra: opensearch
|
|
118
|
+
Provides-Extra: pgvector
|
|
105
119
|
Provides-Extra: pinecone
|
|
106
120
|
Provides-Extra: postgres
|
|
107
121
|
Provides-Extra: qdrant
|
|
@@ -110,6 +124,7 @@ Provides-Extra: snowflake
|
|
|
110
124
|
Provides-Extra: test
|
|
111
125
|
Provides-Extra: vllm
|
|
112
126
|
Provides-Extra: weaviate
|
|
127
|
+
Provides-Extra: xinference-client
|
|
113
128
|
Provides-Extra: zhipuai
|
|
114
129
|
|
|
115
130
|
|
|
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "vanna"
|
|
7
|
-
version = "0.7.
|
|
7
|
+
version = "0.7.5"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name="Zain Hoda", email="zain@vanna.ai" },
|
|
10
10
|
]
|
|
@@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
|
|
|
33
33
|
snowflake = ["snowflake-connector-python"]
|
|
34
34
|
duckdb = ["duckdb"]
|
|
35
35
|
google = ["google-generativeai", "google-cloud-aiplatform"]
|
|
36
|
-
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common"]
|
|
36
|
+
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres", "xinference-client"]
|
|
37
37
|
test = ["tox"]
|
|
38
38
|
chromadb = ["chromadb"]
|
|
39
39
|
openai = ["openai"]
|
|
@@ -53,3 +53,7 @@ milvus = ["pymilvus[model]"]
|
|
|
53
53
|
bedrock = ["boto3", "botocore"]
|
|
54
54
|
weaviate = ["weaviate-client"]
|
|
55
55
|
azuresearch = ["azure-search-documents", "azure-identity", "azure-common", "fastembed"]
|
|
56
|
+
pgvector = ["langchain-postgres>=0.0.12"]
|
|
57
|
+
faiss-cpu = ["faiss-cpu"]
|
|
58
|
+
faiss-gpu = ["faiss-gpu"]
|
|
59
|
+
xinference-client = ["xinference-client"]
|
|
@@ -2,6 +2,10 @@ import datetime
|
|
|
2
2
|
import os
|
|
3
3
|
import uuid
|
|
4
4
|
from typing import List, Optional
|
|
5
|
+
from vertexai.language_models import (
|
|
6
|
+
TextEmbeddingInput,
|
|
7
|
+
TextEmbeddingModel
|
|
8
|
+
)
|
|
5
9
|
|
|
6
10
|
import pandas as pd
|
|
7
11
|
from google.cloud import bigquery
|
|
@@ -23,17 +27,15 @@ class BigQuery_VectorStore(VannaBase):
|
|
|
23
27
|
or set as an environment variable, assign it.
|
|
24
28
|
"""
|
|
25
29
|
print("Configuring genai")
|
|
30
|
+
self.type = "GEMINI"
|
|
26
31
|
import google.generativeai as genai
|
|
27
32
|
|
|
28
33
|
genai.configure(api_key=config["api_key"])
|
|
29
34
|
|
|
30
35
|
self.genai = genai
|
|
31
36
|
else:
|
|
37
|
+
self.type = "VERTEX_AI"
|
|
32
38
|
# Authenticate using VertexAI
|
|
33
|
-
from vertexai.language_models import (
|
|
34
|
-
TextEmbeddingInput,
|
|
35
|
-
TextEmbeddingModel,
|
|
36
|
-
)
|
|
37
39
|
|
|
38
40
|
if self.config.get("project_id"):
|
|
39
41
|
self.project_id = self.config.get("project_id")
|
|
@@ -139,25 +141,42 @@ class BigQuery_VectorStore(VannaBase):
|
|
|
139
141
|
results = self.conn.query(query).result().to_dataframe()
|
|
140
142
|
return results
|
|
141
143
|
|
|
142
|
-
def
|
|
143
|
-
|
|
144
|
+
def get_embeddings(self, data: str, task: str) -> List[float]:
|
|
145
|
+
embeddings = None
|
|
146
|
+
|
|
147
|
+
if self.type == "VERTEX_AI":
|
|
148
|
+
input = [TextEmbeddingInput(data, task)]
|
|
149
|
+
model = TextEmbeddingModel.from_pretrained("text-embedding-004")
|
|
150
|
+
|
|
151
|
+
result = model.get_embeddings(input)
|
|
152
|
+
|
|
153
|
+
if len(result) > 0:
|
|
154
|
+
embeddings = result[0].values
|
|
155
|
+
else:
|
|
156
|
+
# Use Gemini Consumer API
|
|
157
|
+
result = self.genai.embed_content(
|
|
144
158
|
model="models/text-embedding-004",
|
|
145
159
|
content=data,
|
|
146
|
-
task_type=
|
|
160
|
+
task_type=task)
|
|
147
161
|
|
|
148
|
-
|
|
149
|
-
|
|
162
|
+
if 'embedding' in result:
|
|
163
|
+
embeddings = result['embedding']
|
|
164
|
+
|
|
165
|
+
return embeddings
|
|
166
|
+
|
|
167
|
+
def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
|
|
168
|
+
result = self.get_embeddings(data, "RETRIEVAL_QUERY")
|
|
169
|
+
|
|
170
|
+
if result != None:
|
|
171
|
+
return result
|
|
150
172
|
else:
|
|
151
173
|
raise ValueError("No embeddings returned")
|
|
152
174
|
|
|
153
175
|
def generate_storage_embedding(self, data: str, **kwargs) -> List[float]:
|
|
154
|
-
result = self.
|
|
155
|
-
model="models/text-embedding-004",
|
|
156
|
-
content=data,
|
|
157
|
-
task_type="retrieval_document")
|
|
176
|
+
result = self.get_embeddings(data, "RETRIEVAL_DOCUMENT")
|
|
158
177
|
|
|
159
|
-
if
|
|
160
|
-
return result
|
|
178
|
+
if result != None:
|
|
179
|
+
return result
|
|
161
180
|
else:
|
|
162
181
|
raise ValueError("No embeddings returned")
|
|
163
182
|
|
|
@@ -15,7 +15,7 @@ class GoogleGeminiChat(VannaBase):
|
|
|
15
15
|
if "model_name" in config:
|
|
16
16
|
model_name = config["model_name"]
|
|
17
17
|
else:
|
|
18
|
-
model_name = "gemini-1.
|
|
18
|
+
model_name = "gemini-1.5-pro"
|
|
19
19
|
|
|
20
20
|
self.google_api_key = None
|
|
21
21
|
|
|
@@ -30,7 +30,7 @@ class GoogleGeminiChat(VannaBase):
|
|
|
30
30
|
self.chat_model = genai.GenerativeModel(model_name)
|
|
31
31
|
else:
|
|
32
32
|
# Authenticate using VertexAI
|
|
33
|
-
from vertexai.
|
|
33
|
+
from vertexai.generative_models import GenerativeModel
|
|
34
34
|
self.chat_model = GenerativeModel(model_name)
|
|
35
35
|
|
|
36
36
|
def system_message(self, message: str) -> any:
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .pgvector import PG_VectorStore
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import uuid
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from langchain_core.documents import Document
|
|
8
|
+
from langchain_postgres.vectorstores import PGVector
|
|
9
|
+
from sqlalchemy import create_engine, text
|
|
10
|
+
|
|
11
|
+
from .. import ValidationError
|
|
12
|
+
from ..base import VannaBase
|
|
13
|
+
from ..types import TrainingPlan, TrainingPlanItem
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PG_VectorStore(VannaBase):
|
|
17
|
+
def __init__(self, config=None):
|
|
18
|
+
if not config or "connection_string" not in config:
|
|
19
|
+
raise ValueError(
|
|
20
|
+
"A valid 'config' dictionary with a 'connection_string' is required.")
|
|
21
|
+
|
|
22
|
+
VannaBase.__init__(self, config=config)
|
|
23
|
+
|
|
24
|
+
if config and "connection_string" in config:
|
|
25
|
+
self.connection_string = config.get("connection_string")
|
|
26
|
+
self.n_results = config.get("n_results", 10)
|
|
27
|
+
|
|
28
|
+
if config and "embedding_function" in config:
|
|
29
|
+
self.embedding_function = config.get("embedding_function")
|
|
30
|
+
else:
|
|
31
|
+
from langchain_huggingface import HuggingFaceEmbeddings
|
|
32
|
+
self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
|
33
|
+
|
|
34
|
+
self.sql_collection = PGVector(
|
|
35
|
+
embeddings=self.embedding_function,
|
|
36
|
+
collection_name="sql",
|
|
37
|
+
connection=self.connection_string,
|
|
38
|
+
)
|
|
39
|
+
self.ddl_collection = PGVector(
|
|
40
|
+
embeddings=self.embedding_function,
|
|
41
|
+
collection_name="ddl",
|
|
42
|
+
connection=self.connection_string,
|
|
43
|
+
)
|
|
44
|
+
self.documentation_collection = PGVector(
|
|
45
|
+
embeddings=self.embedding_function,
|
|
46
|
+
collection_name="documentation",
|
|
47
|
+
connection=self.connection_string,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
51
|
+
question_sql_json = json.dumps(
|
|
52
|
+
{
|
|
53
|
+
"question": question,
|
|
54
|
+
"sql": sql,
|
|
55
|
+
},
|
|
56
|
+
ensure_ascii=False,
|
|
57
|
+
)
|
|
58
|
+
id = str(uuid.uuid4()) + "-sql"
|
|
59
|
+
createdat = kwargs.get("createdat")
|
|
60
|
+
doc = Document(
|
|
61
|
+
page_content=question_sql_json,
|
|
62
|
+
metadata={"id": id, "createdat": createdat},
|
|
63
|
+
)
|
|
64
|
+
self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
|
|
65
|
+
|
|
66
|
+
return id
|
|
67
|
+
|
|
68
|
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
69
|
+
_id = str(uuid.uuid4()) + "-ddl"
|
|
70
|
+
doc = Document(
|
|
71
|
+
page_content=ddl,
|
|
72
|
+
metadata={"id": _id},
|
|
73
|
+
)
|
|
74
|
+
self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]])
|
|
75
|
+
return _id
|
|
76
|
+
|
|
77
|
+
def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
78
|
+
_id = str(uuid.uuid4()) + "-doc"
|
|
79
|
+
doc = Document(
|
|
80
|
+
page_content=documentation,
|
|
81
|
+
metadata={"id": _id},
|
|
82
|
+
)
|
|
83
|
+
self.documentation_collection.add_documents([doc], ids=[doc.metadata["id"]])
|
|
84
|
+
return _id
|
|
85
|
+
|
|
86
|
+
def get_collection(self, collection_name):
|
|
87
|
+
match collection_name:
|
|
88
|
+
case "sql":
|
|
89
|
+
return self.sql_collection
|
|
90
|
+
case "ddl":
|
|
91
|
+
return self.ddl_collection
|
|
92
|
+
case "documentation":
|
|
93
|
+
return self.documentation_collection
|
|
94
|
+
case _:
|
|
95
|
+
raise ValueError("Specified collection does not exist.")
|
|
96
|
+
|
|
97
|
+
def get_similar_question_sql(self, question: str) -> list:
|
|
98
|
+
documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
|
|
99
|
+
return [ast.literal_eval(document.page_content) for document in documents]
|
|
100
|
+
|
|
101
|
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
102
|
+
documents = self.ddl_collection.similarity_search(query=question, k=self.n_results)
|
|
103
|
+
return [document.page_content for document in documents]
|
|
104
|
+
|
|
105
|
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
106
|
+
documents = self.documentation_collection.similarity_search(query=question, k=self.n_results)
|
|
107
|
+
return [document.page_content for document in documents]
|
|
108
|
+
|
|
109
|
+
def train(
|
|
110
|
+
self,
|
|
111
|
+
question: str | None = None,
|
|
112
|
+
sql: str | None = None,
|
|
113
|
+
ddl: str | None = None,
|
|
114
|
+
documentation: str | None = None,
|
|
115
|
+
plan: TrainingPlan | None = None,
|
|
116
|
+
createdat: str | None = None,
|
|
117
|
+
):
|
|
118
|
+
if question and not sql:
|
|
119
|
+
raise ValidationError("Please provide a SQL query.")
|
|
120
|
+
|
|
121
|
+
if documentation:
|
|
122
|
+
logging.info(f"Adding documentation: {documentation}")
|
|
123
|
+
return self.add_documentation(documentation)
|
|
124
|
+
|
|
125
|
+
if sql and question:
|
|
126
|
+
return self.add_question_sql(question=question, sql=sql, createdat=createdat)
|
|
127
|
+
|
|
128
|
+
if ddl:
|
|
129
|
+
logging.info(f"Adding ddl: {ddl}")
|
|
130
|
+
return self.add_ddl(ddl)
|
|
131
|
+
|
|
132
|
+
if plan:
|
|
133
|
+
for item in plan._plan:
|
|
134
|
+
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
|
|
135
|
+
self.add_ddl(item.item_value)
|
|
136
|
+
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
|
|
137
|
+
self.add_documentation(item.item_value)
|
|
138
|
+
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name:
|
|
139
|
+
self.add_question_sql(question=item.item_name, sql=item.item_value)
|
|
140
|
+
|
|
141
|
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
142
|
+
# Establishing the connection
|
|
143
|
+
engine = create_engine(self.connection_string)
|
|
144
|
+
|
|
145
|
+
# Querying the 'langchain_pg_embedding' table
|
|
146
|
+
query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding"
|
|
147
|
+
df_embedding = pd.read_sql(query_embedding, engine)
|
|
148
|
+
|
|
149
|
+
# List to accumulate the processed rows
|
|
150
|
+
processed_rows = []
|
|
151
|
+
|
|
152
|
+
# Process each row in the DataFrame
|
|
153
|
+
for _, row in df_embedding.iterrows():
|
|
154
|
+
custom_id = row["cmetadata"]["id"]
|
|
155
|
+
document = row["document"]
|
|
156
|
+
training_data_type = "documentation" if custom_id[-3:] == "doc" else custom_id[-3:]
|
|
157
|
+
|
|
158
|
+
if training_data_type == "sql":
|
|
159
|
+
# Convert the document string to a dictionary
|
|
160
|
+
try:
|
|
161
|
+
doc_dict = ast.literal_eval(document)
|
|
162
|
+
question = doc_dict.get("question")
|
|
163
|
+
content = doc_dict.get("sql")
|
|
164
|
+
except (ValueError, SyntaxError):
|
|
165
|
+
logging.info(f"Skipping row with custom_id {custom_id} due to parsing error.")
|
|
166
|
+
continue
|
|
167
|
+
elif training_data_type in ["documentation", "ddl"]:
|
|
168
|
+
question = None # Default value for question
|
|
169
|
+
content = document
|
|
170
|
+
else:
|
|
171
|
+
# If the suffix is not recognized, skip this row
|
|
172
|
+
logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
|
|
173
|
+
continue
|
|
174
|
+
|
|
175
|
+
# Append the processed data to the list
|
|
176
|
+
processed_rows.append(
|
|
177
|
+
{"id": custom_id, "question": question, "content": content, "training_data_type": training_data_type}
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Create a DataFrame from the list of processed rows
|
|
181
|
+
df_processed = pd.DataFrame(processed_rows)
|
|
182
|
+
|
|
183
|
+
return df_processed
|
|
184
|
+
|
|
185
|
+
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
186
|
+
# Create the database engine
|
|
187
|
+
engine = create_engine(self.connection_string)
|
|
188
|
+
|
|
189
|
+
# SQL DELETE statement
|
|
190
|
+
delete_statement = text(
|
|
191
|
+
"""
|
|
192
|
+
DELETE FROM langchain_pg_embedding
|
|
193
|
+
WHERE cmetadata ->> 'id' = :id
|
|
194
|
+
"""
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Connect to the database and execute the delete statement
|
|
198
|
+
with engine.connect() as connection:
|
|
199
|
+
# Start a transaction
|
|
200
|
+
with connection.begin() as transaction:
|
|
201
|
+
try:
|
|
202
|
+
result = connection.execute(delete_statement, {"id": id})
|
|
203
|
+
# Commit the transaction if the delete was successful
|
|
204
|
+
transaction.commit()
|
|
205
|
+
# Check if any row was deleted and return True or False accordingly
|
|
206
|
+
return result.rowcount > 0
|
|
207
|
+
except Exception as e:
|
|
208
|
+
# Rollback the transaction in case of error
|
|
209
|
+
logging.error(f"An error occurred: {e}")
|
|
210
|
+
transaction.rollback()
|
|
211
|
+
return False
|
|
212
|
+
|
|
213
|
+
def remove_collection(self, collection_name: str) -> bool:
|
|
214
|
+
engine = create_engine(self.connection_string)
|
|
215
|
+
|
|
216
|
+
# Determine the suffix to look for based on the collection name
|
|
217
|
+
suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"}
|
|
218
|
+
suffix = suffix_map.get(collection_name)
|
|
219
|
+
|
|
220
|
+
if not suffix:
|
|
221
|
+
logging.info("Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.")
|
|
222
|
+
return False
|
|
223
|
+
|
|
224
|
+
# SQL query to delete rows based on the condition
|
|
225
|
+
query = text(
|
|
226
|
+
f"""
|
|
227
|
+
DELETE FROM langchain_pg_embedding
|
|
228
|
+
WHERE cmetadata->>'id' LIKE '%{suffix}'
|
|
229
|
+
"""
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Execute the deletion within a transaction block
|
|
233
|
+
with engine.connect() as connection:
|
|
234
|
+
with connection.begin() as transaction:
|
|
235
|
+
try:
|
|
236
|
+
result = connection.execute(query)
|
|
237
|
+
transaction.commit() # Explicitly commit the transaction
|
|
238
|
+
if result.rowcount > 0:
|
|
239
|
+
logging.info(
|
|
240
|
+
f"Deleted {result.rowcount} rows from "
|
|
241
|
+
f"langchain_pg_embedding where collection is {collection_name}."
|
|
242
|
+
)
|
|
243
|
+
return True
|
|
244
|
+
else:
|
|
245
|
+
logging.info(f"No rows deleted for collection {collection_name}.")
|
|
246
|
+
return False
|
|
247
|
+
except Exception as e:
|
|
248
|
+
logging.error(f"An error occurred: {e}")
|
|
249
|
+
transaction.rollback() # Rollback in case of error
|
|
250
|
+
return False
|
|
251
|
+
|
|
252
|
+
def generate_embedding(self, *args, **kwargs):
|
|
253
|
+
pass
|
|
@@ -22,6 +22,12 @@ class Vllm(VannaBase):
|
|
|
22
22
|
else:
|
|
23
23
|
self.auth_key = None
|
|
24
24
|
|
|
25
|
+
if "temperature" in config:
|
|
26
|
+
self.temperature = config["temperature"]
|
|
27
|
+
else:
|
|
28
|
+
# default temperature - can be overrided using config
|
|
29
|
+
self.temperature = 0.7
|
|
30
|
+
|
|
25
31
|
def system_message(self, message: str) -> any:
|
|
26
32
|
return {"role": "system", "content": message}
|
|
27
33
|
|
|
@@ -68,6 +74,7 @@ class Vllm(VannaBase):
|
|
|
68
74
|
url = f"{self.host}/v1/chat/completions"
|
|
69
75
|
data = {
|
|
70
76
|
"model": self.model,
|
|
77
|
+
"temperature": self.temperature,
|
|
71
78
|
"stream": False,
|
|
72
79
|
"messages": prompt,
|
|
73
80
|
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .xinference import Xinference
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from xinference_client.client.restful.restful_client import (
|
|
2
|
+
Client,
|
|
3
|
+
RESTfulChatModelHandle,
|
|
4
|
+
)
|
|
5
|
+
|
|
6
|
+
from ..base import VannaBase
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Xinference(VannaBase):
|
|
10
|
+
def __init__(self, config=None):
|
|
11
|
+
VannaBase.__init__(self, config=config)
|
|
12
|
+
|
|
13
|
+
if not config or "base_url" not in config:
|
|
14
|
+
raise ValueError("config must contain at least Xinference base_url")
|
|
15
|
+
|
|
16
|
+
base_url = config["base_url"]
|
|
17
|
+
api_key = config.get("api_key", "not empty")
|
|
18
|
+
self.xinference_client = Client(base_url=base_url, api_key=api_key)
|
|
19
|
+
|
|
20
|
+
def system_message(self, message: str) -> any:
|
|
21
|
+
return {"role": "system", "content": message}
|
|
22
|
+
|
|
23
|
+
def user_message(self, message: str) -> any:
|
|
24
|
+
return {"role": "user", "content": message}
|
|
25
|
+
|
|
26
|
+
def assistant_message(self, message: str) -> any:
|
|
27
|
+
return {"role": "assistant", "content": message}
|
|
28
|
+
|
|
29
|
+
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
30
|
+
if prompt is None:
|
|
31
|
+
raise Exception("Prompt is None")
|
|
32
|
+
|
|
33
|
+
if len(prompt) == 0:
|
|
34
|
+
raise Exception("Prompt is empty")
|
|
35
|
+
|
|
36
|
+
num_tokens = 0
|
|
37
|
+
for message in prompt:
|
|
38
|
+
num_tokens += len(message["content"]) / 4
|
|
39
|
+
|
|
40
|
+
model_uid = kwargs.get("model_uid") or self.config.get("model_uid", None)
|
|
41
|
+
if model_uid is None:
|
|
42
|
+
raise ValueError("model_uid is required")
|
|
43
|
+
|
|
44
|
+
xinference_model = self.xinference_client.get_model(model_uid)
|
|
45
|
+
if isinstance(xinference_model, RESTfulChatModelHandle):
|
|
46
|
+
print(
|
|
47
|
+
f"Using model_uid {model_uid} for {num_tokens} tokens (approx)"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
response = xinference_model.chat(prompt)
|
|
51
|
+
return response["choices"][0]["message"]["content"]
|
|
52
|
+
else:
|
|
53
|
+
raise NotImplementedError(f"Xinference model handle type {type(xinference_model)} is not supported, required RESTfulChatModelHandle")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|