vanna 0.7.3__tar.gz → 0.7.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {vanna-0.7.3 → vanna-0.7.5}/PKG-INFO +16 -1
  2. {vanna-0.7.3 → vanna-0.7.5}/pyproject.toml +6 -2
  3. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/google/bigquery_vector.py +34 -15
  4. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/google/gemini_chat.py +2 -2
  5. vanna-0.7.5/src/vanna/pgvector/__init__.py +1 -0
  6. vanna-0.7.5/src/vanna/pgvector/pgvector.py +253 -0
  7. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/vllm/vllm.py +7 -0
  8. vanna-0.7.5/src/vanna/xinference/__init__.py +1 -0
  9. vanna-0.7.5/src/vanna/xinference/xinference.py +53 -0
  10. {vanna-0.7.3 → vanna-0.7.5}/README.md +0 -0
  11. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +0 -0
  12. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
  13. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/ZhipuAI/__init__.py +0 -0
  14. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/__init__.py +0 -0
  15. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/advanced/__init__.py +0 -0
  16. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/anthropic/__init__.py +0 -0
  17. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/anthropic/anthropic_chat.py +0 -0
  18. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/azuresearch/__init__.py +0 -0
  19. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/azuresearch/azuresearch_vector.py +0 -0
  20. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/base/__init__.py +0 -0
  21. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/base/base.py +0 -0
  22. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/bedrock/__init__.py +0 -0
  23. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/bedrock/bedrock_converse.py +0 -0
  24. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/chromadb/__init__.py +0 -0
  25. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/chromadb/chromadb_vector.py +0 -0
  26. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/exceptions/__init__.py +0 -0
  27. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/faiss/__init__.py +0 -0
  28. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/faiss/faiss.py +0 -0
  29. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/flask/__init__.py +0 -0
  30. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/flask/assets.py +0 -0
  31. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/flask/auth.py +0 -0
  32. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/google/__init__.py +0 -0
  33. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/hf/__init__.py +0 -0
  34. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/hf/hf.py +0 -0
  35. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/local.py +0 -0
  36. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/marqo/__init__.py +0 -0
  37. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/marqo/marqo.py +0 -0
  38. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/milvus/__init__.py +0 -0
  39. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/milvus/milvus_vector.py +0 -0
  40. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/mistral/__init__.py +0 -0
  41. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/mistral/mistral.py +0 -0
  42. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/mock/__init__.py +0 -0
  43. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/mock/embedding.py +0 -0
  44. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/mock/llm.py +0 -0
  45. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/mock/vectordb.py +0 -0
  46. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/ollama/__init__.py +0 -0
  47. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/ollama/ollama.py +0 -0
  48. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/openai/__init__.py +0 -0
  49. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/openai/openai_chat.py +0 -0
  50. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/openai/openai_embeddings.py +0 -0
  51. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/opensearch/__init__.py +0 -0
  52. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/opensearch/opensearch_vector.py +0 -0
  53. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/pinecone/__init__.py +0 -0
  54. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/pinecone/pinecone_vector.py +0 -0
  55. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qdrant/__init__.py +0 -0
  56. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qdrant/qdrant.py +0 -0
  57. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qianfan/Qianfan_Chat.py +0 -0
  58. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qianfan/Qianfan_embeddings.py +0 -0
  59. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qianfan/__init__.py +0 -0
  60. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qianwen/QianwenAI_chat.py +0 -0
  61. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qianwen/QianwenAI_embeddings.py +0 -0
  62. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/qianwen/__init__.py +0 -0
  63. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/remote.py +0 -0
  64. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/types/__init__.py +0 -0
  65. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/utils.py +0 -0
  66. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/vannadb/__init__.py +0 -0
  67. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/vannadb/vannadb_vector.py +0 -0
  68. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/vllm/__init__.py +0 -0
  69. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/weaviate/__init__.py +0 -0
  70. {vanna-0.7.3 → vanna-0.7.5}/src/vanna/weaviate/weaviate_vector.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.7.3
3
+ Version: 0.7.5
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -46,6 +46,13 @@ Requires-Dist: weaviate-client ; extra == "all"
46
46
  Requires-Dist: azure-search-documents ; extra == "all"
47
47
  Requires-Dist: azure-identity ; extra == "all"
48
48
  Requires-Dist: azure-common ; extra == "all"
49
+ Requires-Dist: faiss-cpu ; extra == "all"
50
+ Requires-Dist: boto ; extra == "all"
51
+ Requires-Dist: boto3 ; extra == "all"
52
+ Requires-Dist: botocore ; extra == "all"
53
+ Requires-Dist: langchain_core ; extra == "all"
54
+ Requires-Dist: langchain_postgres ; extra == "all"
55
+ Requires-Dist: xinference-client ; extra == "all"
49
56
  Requires-Dist: anthropic ; extra == "anthropic"
50
57
  Requires-Dist: azure-search-documents ; extra == "azuresearch"
51
58
  Requires-Dist: azure-identity ; extra == "azuresearch"
@@ -57,6 +64,8 @@ Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
57
64
  Requires-Dist: chromadb ; extra == "chromadb"
58
65
  Requires-Dist: clickhouse_connect ; extra == "clickhouse"
59
66
  Requires-Dist: duckdb ; extra == "duckdb"
67
+ Requires-Dist: faiss-cpu ; extra == "faiss-cpu"
68
+ Requires-Dist: faiss-gpu ; extra == "faiss-gpu"
60
69
  Requires-Dist: google-generativeai ; extra == "gemini"
61
70
  Requires-Dist: google-generativeai ; extra == "google"
62
71
  Requires-Dist: google-cloud-aiplatform ; extra == "google"
@@ -70,6 +79,7 @@ Requires-Dist: httpx ; extra == "ollama"
70
79
  Requires-Dist: openai ; extra == "openai"
71
80
  Requires-Dist: opensearch-py ; extra == "opensearch"
72
81
  Requires-Dist: opensearch-dsl ; extra == "opensearch"
82
+ Requires-Dist: langchain-postgres>=0.0.12 ; extra == "pgvector"
73
83
  Requires-Dist: pinecone-client ; extra == "pinecone"
74
84
  Requires-Dist: fastembed ; extra == "pinecone"
75
85
  Requires-Dist: psycopg2-binary ; extra == "postgres"
@@ -81,6 +91,7 @@ Requires-Dist: snowflake-connector-python ; extra == "snowflake"
81
91
  Requires-Dist: tox ; extra == "test"
82
92
  Requires-Dist: vllm ; extra == "vllm"
83
93
  Requires-Dist: weaviate-client ; extra == "weaviate"
94
+ Requires-Dist: xinference-client ; extra == "xinference-client"
84
95
  Requires-Dist: zhipuai ; extra == "zhipuai"
85
96
  Project-URL: Bug Tracker, https://github.com/vanna-ai/vanna/issues
86
97
  Project-URL: Homepage, https://github.com/vanna-ai/vanna
@@ -92,6 +103,8 @@ Provides-Extra: bigquery
92
103
  Provides-Extra: chromadb
93
104
  Provides-Extra: clickhouse
94
105
  Provides-Extra: duckdb
106
+ Provides-Extra: faiss-cpu
107
+ Provides-Extra: faiss-gpu
95
108
  Provides-Extra: gemini
96
109
  Provides-Extra: google
97
110
  Provides-Extra: hf
@@ -102,6 +115,7 @@ Provides-Extra: mysql
102
115
  Provides-Extra: ollama
103
116
  Provides-Extra: openai
104
117
  Provides-Extra: opensearch
118
+ Provides-Extra: pgvector
105
119
  Provides-Extra: pinecone
106
120
  Provides-Extra: postgres
107
121
  Provides-Extra: qdrant
@@ -110,6 +124,7 @@ Provides-Extra: snowflake
110
124
  Provides-Extra: test
111
125
  Provides-Extra: vllm
112
126
  Provides-Extra: weaviate
127
+ Provides-Extra: xinference-client
113
128
  Provides-Extra: zhipuai
114
129
 
115
130
 
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
4
4
 
5
5
  [project]
6
6
  name = "vanna"
7
- version = "0.7.3"
7
+ version = "0.7.5"
8
8
  authors = [
9
9
  { name="Zain Hoda", email="zain@vanna.ai" },
10
10
  ]
@@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
33
33
  snowflake = ["snowflake-connector-python"]
34
34
  duckdb = ["duckdb"]
35
35
  google = ["google-generativeai", "google-cloud-aiplatform"]
36
- all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common"]
36
+ all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres", "xinference-client"]
37
37
  test = ["tox"]
38
38
  chromadb = ["chromadb"]
39
39
  openai = ["openai"]
@@ -53,3 +53,7 @@ milvus = ["pymilvus[model]"]
53
53
  bedrock = ["boto3", "botocore"]
54
54
  weaviate = ["weaviate-client"]
55
55
  azuresearch = ["azure-search-documents", "azure-identity", "azure-common", "fastembed"]
56
+ pgvector = ["langchain-postgres>=0.0.12"]
57
+ faiss-cpu = ["faiss-cpu"]
58
+ faiss-gpu = ["faiss-gpu"]
59
+ xinference-client = ["xinference-client"]
@@ -2,6 +2,10 @@ import datetime
2
2
  import os
3
3
  import uuid
4
4
  from typing import List, Optional
5
+ from vertexai.language_models import (
6
+ TextEmbeddingInput,
7
+ TextEmbeddingModel
8
+ )
5
9
 
6
10
  import pandas as pd
7
11
  from google.cloud import bigquery
@@ -23,17 +27,15 @@ class BigQuery_VectorStore(VannaBase):
23
27
  or set as an environment variable, assign it.
24
28
  """
25
29
  print("Configuring genai")
30
+ self.type = "GEMINI"
26
31
  import google.generativeai as genai
27
32
 
28
33
  genai.configure(api_key=config["api_key"])
29
34
 
30
35
  self.genai = genai
31
36
  else:
37
+ self.type = "VERTEX_AI"
32
38
  # Authenticate using VertexAI
33
- from vertexai.language_models import (
34
- TextEmbeddingInput,
35
- TextEmbeddingModel,
36
- )
37
39
 
38
40
  if self.config.get("project_id"):
39
41
  self.project_id = self.config.get("project_id")
@@ -139,25 +141,42 @@ class BigQuery_VectorStore(VannaBase):
139
141
  results = self.conn.query(query).result().to_dataframe()
140
142
  return results
141
143
 
142
- def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
143
- result = self.genai.embed_content(
144
+ def get_embeddings(self, data: str, task: str) -> List[float]:
145
+ embeddings = None
146
+
147
+ if self.type == "VERTEX_AI":
148
+ input = [TextEmbeddingInput(data, task)]
149
+ model = TextEmbeddingModel.from_pretrained("text-embedding-004")
150
+
151
+ result = model.get_embeddings(input)
152
+
153
+ if len(result) > 0:
154
+ embeddings = result[0].values
155
+ else:
156
+ # Use Gemini Consumer API
157
+ result = self.genai.embed_content(
144
158
  model="models/text-embedding-004",
145
159
  content=data,
146
- task_type="retrieval_query")
160
+ task_type=task)
147
161
 
148
- if 'embedding' in result:
149
- return result['embedding']
162
+ if 'embedding' in result:
163
+ embeddings = result['embedding']
164
+
165
+ return embeddings
166
+
167
+ def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
168
+ result = self.get_embeddings(data, "RETRIEVAL_QUERY")
169
+
170
+ if result != None:
171
+ return result
150
172
  else:
151
173
  raise ValueError("No embeddings returned")
152
174
 
153
175
  def generate_storage_embedding(self, data: str, **kwargs) -> List[float]:
154
- result = self.genai.embed_content(
155
- model="models/text-embedding-004",
156
- content=data,
157
- task_type="retrieval_document")
176
+ result = self.get_embeddings(data, "RETRIEVAL_DOCUMENT")
158
177
 
159
- if 'embedding' in result:
160
- return result['embedding']
178
+ if result != None:
179
+ return result
161
180
  else:
162
181
  raise ValueError("No embeddings returned")
163
182
 
@@ -15,7 +15,7 @@ class GoogleGeminiChat(VannaBase):
15
15
  if "model_name" in config:
16
16
  model_name = config["model_name"]
17
17
  else:
18
- model_name = "gemini-1.0-pro"
18
+ model_name = "gemini-1.5-pro"
19
19
 
20
20
  self.google_api_key = None
21
21
 
@@ -30,7 +30,7 @@ class GoogleGeminiChat(VannaBase):
30
30
  self.chat_model = genai.GenerativeModel(model_name)
31
31
  else:
32
32
  # Authenticate using VertexAI
33
- from vertexai.preview.generative_models import GenerativeModel
33
+ from vertexai.generative_models import GenerativeModel
34
34
  self.chat_model = GenerativeModel(model_name)
35
35
 
36
36
  def system_message(self, message: str) -> any:
@@ -0,0 +1 @@
1
+ from .pgvector import PG_VectorStore
@@ -0,0 +1,253 @@
1
+ import ast
2
+ import json
3
+ import logging
4
+ import uuid
5
+
6
+ import pandas as pd
7
+ from langchain_core.documents import Document
8
+ from langchain_postgres.vectorstores import PGVector
9
+ from sqlalchemy import create_engine, text
10
+
11
+ from .. import ValidationError
12
+ from ..base import VannaBase
13
+ from ..types import TrainingPlan, TrainingPlanItem
14
+
15
+
16
+ class PG_VectorStore(VannaBase):
17
+ def __init__(self, config=None):
18
+ if not config or "connection_string" not in config:
19
+ raise ValueError(
20
+ "A valid 'config' dictionary with a 'connection_string' is required.")
21
+
22
+ VannaBase.__init__(self, config=config)
23
+
24
+ if config and "connection_string" in config:
25
+ self.connection_string = config.get("connection_string")
26
+ self.n_results = config.get("n_results", 10)
27
+
28
+ if config and "embedding_function" in config:
29
+ self.embedding_function = config.get("embedding_function")
30
+ else:
31
+ from langchain_huggingface import HuggingFaceEmbeddings
32
+ self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
33
+
34
+ self.sql_collection = PGVector(
35
+ embeddings=self.embedding_function,
36
+ collection_name="sql",
37
+ connection=self.connection_string,
38
+ )
39
+ self.ddl_collection = PGVector(
40
+ embeddings=self.embedding_function,
41
+ collection_name="ddl",
42
+ connection=self.connection_string,
43
+ )
44
+ self.documentation_collection = PGVector(
45
+ embeddings=self.embedding_function,
46
+ collection_name="documentation",
47
+ connection=self.connection_string,
48
+ )
49
+
50
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
51
+ question_sql_json = json.dumps(
52
+ {
53
+ "question": question,
54
+ "sql": sql,
55
+ },
56
+ ensure_ascii=False,
57
+ )
58
+ id = str(uuid.uuid4()) + "-sql"
59
+ createdat = kwargs.get("createdat")
60
+ doc = Document(
61
+ page_content=question_sql_json,
62
+ metadata={"id": id, "createdat": createdat},
63
+ )
64
+ self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
65
+
66
+ return id
67
+
68
+ def add_ddl(self, ddl: str, **kwargs) -> str:
69
+ _id = str(uuid.uuid4()) + "-ddl"
70
+ doc = Document(
71
+ page_content=ddl,
72
+ metadata={"id": _id},
73
+ )
74
+ self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]])
75
+ return _id
76
+
77
+ def add_documentation(self, documentation: str, **kwargs) -> str:
78
+ _id = str(uuid.uuid4()) + "-doc"
79
+ doc = Document(
80
+ page_content=documentation,
81
+ metadata={"id": _id},
82
+ )
83
+ self.documentation_collection.add_documents([doc], ids=[doc.metadata["id"]])
84
+ return _id
85
+
86
+ def get_collection(self, collection_name):
87
+ match collection_name:
88
+ case "sql":
89
+ return self.sql_collection
90
+ case "ddl":
91
+ return self.ddl_collection
92
+ case "documentation":
93
+ return self.documentation_collection
94
+ case _:
95
+ raise ValueError("Specified collection does not exist.")
96
+
97
+ def get_similar_question_sql(self, question: str) -> list:
98
+ documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
99
+ return [ast.literal_eval(document.page_content) for document in documents]
100
+
101
+ def get_related_ddl(self, question: str, **kwargs) -> list:
102
+ documents = self.ddl_collection.similarity_search(query=question, k=self.n_results)
103
+ return [document.page_content for document in documents]
104
+
105
+ def get_related_documentation(self, question: str, **kwargs) -> list:
106
+ documents = self.documentation_collection.similarity_search(query=question, k=self.n_results)
107
+ return [document.page_content for document in documents]
108
+
109
+ def train(
110
+ self,
111
+ question: str | None = None,
112
+ sql: str | None = None,
113
+ ddl: str | None = None,
114
+ documentation: str | None = None,
115
+ plan: TrainingPlan | None = None,
116
+ createdat: str | None = None,
117
+ ):
118
+ if question and not sql:
119
+ raise ValidationError("Please provide a SQL query.")
120
+
121
+ if documentation:
122
+ logging.info(f"Adding documentation: {documentation}")
123
+ return self.add_documentation(documentation)
124
+
125
+ if sql and question:
126
+ return self.add_question_sql(question=question, sql=sql, createdat=createdat)
127
+
128
+ if ddl:
129
+ logging.info(f"Adding ddl: {ddl}")
130
+ return self.add_ddl(ddl)
131
+
132
+ if plan:
133
+ for item in plan._plan:
134
+ if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
135
+ self.add_ddl(item.item_value)
136
+ elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
137
+ self.add_documentation(item.item_value)
138
+ elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name:
139
+ self.add_question_sql(question=item.item_name, sql=item.item_value)
140
+
141
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
142
+ # Establishing the connection
143
+ engine = create_engine(self.connection_string)
144
+
145
+ # Querying the 'langchain_pg_embedding' table
146
+ query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding"
147
+ df_embedding = pd.read_sql(query_embedding, engine)
148
+
149
+ # List to accumulate the processed rows
150
+ processed_rows = []
151
+
152
+ # Process each row in the DataFrame
153
+ for _, row in df_embedding.iterrows():
154
+ custom_id = row["cmetadata"]["id"]
155
+ document = row["document"]
156
+ training_data_type = "documentation" if custom_id[-3:] == "doc" else custom_id[-3:]
157
+
158
+ if training_data_type == "sql":
159
+ # Convert the document string to a dictionary
160
+ try:
161
+ doc_dict = ast.literal_eval(document)
162
+ question = doc_dict.get("question")
163
+ content = doc_dict.get("sql")
164
+ except (ValueError, SyntaxError):
165
+ logging.info(f"Skipping row with custom_id {custom_id} due to parsing error.")
166
+ continue
167
+ elif training_data_type in ["documentation", "ddl"]:
168
+ question = None # Default value for question
169
+ content = document
170
+ else:
171
+ # If the suffix is not recognized, skip this row
172
+ logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
173
+ continue
174
+
175
+ # Append the processed data to the list
176
+ processed_rows.append(
177
+ {"id": custom_id, "question": question, "content": content, "training_data_type": training_data_type}
178
+ )
179
+
180
+ # Create a DataFrame from the list of processed rows
181
+ df_processed = pd.DataFrame(processed_rows)
182
+
183
+ return df_processed
184
+
185
+ def remove_training_data(self, id: str, **kwargs) -> bool:
186
+ # Create the database engine
187
+ engine = create_engine(self.connection_string)
188
+
189
+ # SQL DELETE statement
190
+ delete_statement = text(
191
+ """
192
+ DELETE FROM langchain_pg_embedding
193
+ WHERE cmetadata ->> 'id' = :id
194
+ """
195
+ )
196
+
197
+ # Connect to the database and execute the delete statement
198
+ with engine.connect() as connection:
199
+ # Start a transaction
200
+ with connection.begin() as transaction:
201
+ try:
202
+ result = connection.execute(delete_statement, {"id": id})
203
+ # Commit the transaction if the delete was successful
204
+ transaction.commit()
205
+ # Check if any row was deleted and return True or False accordingly
206
+ return result.rowcount > 0
207
+ except Exception as e:
208
+ # Rollback the transaction in case of error
209
+ logging.error(f"An error occurred: {e}")
210
+ transaction.rollback()
211
+ return False
212
+
213
+ def remove_collection(self, collection_name: str) -> bool:
214
+ engine = create_engine(self.connection_string)
215
+
216
+ # Determine the suffix to look for based on the collection name
217
+ suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"}
218
+ suffix = suffix_map.get(collection_name)
219
+
220
+ if not suffix:
221
+ logging.info("Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.")
222
+ return False
223
+
224
+ # SQL query to delete rows based on the condition
225
+ query = text(
226
+ f"""
227
+ DELETE FROM langchain_pg_embedding
228
+ WHERE cmetadata->>'id' LIKE '%{suffix}'
229
+ """
230
+ )
231
+
232
+ # Execute the deletion within a transaction block
233
+ with engine.connect() as connection:
234
+ with connection.begin() as transaction:
235
+ try:
236
+ result = connection.execute(query)
237
+ transaction.commit() # Explicitly commit the transaction
238
+ if result.rowcount > 0:
239
+ logging.info(
240
+ f"Deleted {result.rowcount} rows from "
241
+ f"langchain_pg_embedding where collection is {collection_name}."
242
+ )
243
+ return True
244
+ else:
245
+ logging.info(f"No rows deleted for collection {collection_name}.")
246
+ return False
247
+ except Exception as e:
248
+ logging.error(f"An error occurred: {e}")
249
+ transaction.rollback() # Rollback in case of error
250
+ return False
251
+
252
+ def generate_embedding(self, *args, **kwargs):
253
+ pass
@@ -22,6 +22,12 @@ class Vllm(VannaBase):
22
22
  else:
23
23
  self.auth_key = None
24
24
 
25
+ if "temperature" in config:
26
+ self.temperature = config["temperature"]
27
+ else:
28
+ # default temperature - can be overrided using config
29
+ self.temperature = 0.7
30
+
25
31
  def system_message(self, message: str) -> any:
26
32
  return {"role": "system", "content": message}
27
33
 
@@ -68,6 +74,7 @@ class Vllm(VannaBase):
68
74
  url = f"{self.host}/v1/chat/completions"
69
75
  data = {
70
76
  "model": self.model,
77
+ "temperature": self.temperature,
71
78
  "stream": False,
72
79
  "messages": prompt,
73
80
  }
@@ -0,0 +1 @@
1
+ from .xinference import Xinference
@@ -0,0 +1,53 @@
1
+ from xinference_client.client.restful.restful_client import (
2
+ Client,
3
+ RESTfulChatModelHandle,
4
+ )
5
+
6
+ from ..base import VannaBase
7
+
8
+
9
+ class Xinference(VannaBase):
10
+ def __init__(self, config=None):
11
+ VannaBase.__init__(self, config=config)
12
+
13
+ if not config or "base_url" not in config:
14
+ raise ValueError("config must contain at least Xinference base_url")
15
+
16
+ base_url = config["base_url"]
17
+ api_key = config.get("api_key", "not empty")
18
+ self.xinference_client = Client(base_url=base_url, api_key=api_key)
19
+
20
+ def system_message(self, message: str) -> any:
21
+ return {"role": "system", "content": message}
22
+
23
+ def user_message(self, message: str) -> any:
24
+ return {"role": "user", "content": message}
25
+
26
+ def assistant_message(self, message: str) -> any:
27
+ return {"role": "assistant", "content": message}
28
+
29
+ def submit_prompt(self, prompt, **kwargs) -> str:
30
+ if prompt is None:
31
+ raise Exception("Prompt is None")
32
+
33
+ if len(prompt) == 0:
34
+ raise Exception("Prompt is empty")
35
+
36
+ num_tokens = 0
37
+ for message in prompt:
38
+ num_tokens += len(message["content"]) / 4
39
+
40
+ model_uid = kwargs.get("model_uid") or self.config.get("model_uid", None)
41
+ if model_uid is None:
42
+ raise ValueError("model_uid is required")
43
+
44
+ xinference_model = self.xinference_client.get_model(model_uid)
45
+ if isinstance(xinference_model, RESTfulChatModelHandle):
46
+ print(
47
+ f"Using model_uid {model_uid} for {num_tokens} tokens (approx)"
48
+ )
49
+
50
+ response = xinference_model.chat(prompt)
51
+ return response["choices"][0]["message"]["content"]
52
+ else:
53
+ raise NotImplementedError(f"Xinference model handle type {type(xinference_model)} is not supported, required RESTfulChatModelHandle")
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes