vanna 0.7.9__py3-none-any.whl → 2.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +439 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +224 -217
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0rc1.dist-info/METADATA +868 -0
- vanna-2.0.0rc1.dist-info/RECORD +289 -0
- vanna-2.0.0rc1.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.9.dist-info/METADATA +0 -408
- vanna-0.7.9.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/WHEEL +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/licenses/LICENSE +0 -0
|
@@ -17,7 +17,8 @@ class PG_VectorStore(VannaBase):
|
|
|
17
17
|
def __init__(self, config=None):
|
|
18
18
|
if not config or "connection_string" not in config:
|
|
19
19
|
raise ValueError(
|
|
20
|
-
"A valid 'config' dictionary with a 'connection_string' is required."
|
|
20
|
+
"A valid 'config' dictionary with a 'connection_string' is required."
|
|
21
|
+
)
|
|
21
22
|
|
|
22
23
|
VannaBase.__init__(self, config=config)
|
|
23
24
|
|
|
@@ -29,7 +30,10 @@ class PG_VectorStore(VannaBase):
|
|
|
29
30
|
self.embedding_function = config.get("embedding_function")
|
|
30
31
|
else:
|
|
31
32
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
32
|
-
|
|
33
|
+
|
|
34
|
+
self.embedding_function = HuggingFaceEmbeddings(
|
|
35
|
+
model_name="all-MiniLM-L6-v2"
|
|
36
|
+
)
|
|
33
37
|
|
|
34
38
|
self.sql_collection = PGVector(
|
|
35
39
|
embeddings=self.embedding_function,
|
|
@@ -95,15 +99,21 @@ class PG_VectorStore(VannaBase):
|
|
|
95
99
|
raise ValueError("Specified collection does not exist.")
|
|
96
100
|
|
|
97
101
|
def get_similar_question_sql(self, question: str) -> list:
|
|
98
|
-
documents = self.sql_collection.similarity_search(
|
|
102
|
+
documents = self.sql_collection.similarity_search(
|
|
103
|
+
query=question, k=self.n_results
|
|
104
|
+
)
|
|
99
105
|
return [ast.literal_eval(document.page_content) for document in documents]
|
|
100
106
|
|
|
101
107
|
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
102
|
-
documents = self.ddl_collection.similarity_search(
|
|
108
|
+
documents = self.ddl_collection.similarity_search(
|
|
109
|
+
query=question, k=self.n_results
|
|
110
|
+
)
|
|
103
111
|
return [document.page_content for document in documents]
|
|
104
112
|
|
|
105
113
|
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
106
|
-
documents = self.documentation_collection.similarity_search(
|
|
114
|
+
documents = self.documentation_collection.similarity_search(
|
|
115
|
+
query=question, k=self.n_results
|
|
116
|
+
)
|
|
107
117
|
return [document.page_content for document in documents]
|
|
108
118
|
|
|
109
119
|
def train(
|
|
@@ -123,7 +133,9 @@ class PG_VectorStore(VannaBase):
|
|
|
123
133
|
return self.add_documentation(documentation)
|
|
124
134
|
|
|
125
135
|
if sql and question:
|
|
126
|
-
return self.add_question_sql(
|
|
136
|
+
return self.add_question_sql(
|
|
137
|
+
question=question, sql=sql, createdat=createdat
|
|
138
|
+
)
|
|
127
139
|
|
|
128
140
|
if ddl:
|
|
129
141
|
logging.info(f"Adding ddl: {ddl}")
|
|
@@ -135,7 +147,9 @@ class PG_VectorStore(VannaBase):
|
|
|
135
147
|
self.add_ddl(item.item_value)
|
|
136
148
|
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
|
|
137
149
|
self.add_documentation(item.item_value)
|
|
138
|
-
elif
|
|
150
|
+
elif (
|
|
151
|
+
item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name
|
|
152
|
+
):
|
|
139
153
|
self.add_question_sql(question=item.item_name, sql=item.item_value)
|
|
140
154
|
|
|
141
155
|
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
@@ -153,7 +167,9 @@ class PG_VectorStore(VannaBase):
|
|
|
153
167
|
for _, row in df_embedding.iterrows():
|
|
154
168
|
custom_id = row["cmetadata"]["id"]
|
|
155
169
|
document = row["document"]
|
|
156
|
-
training_data_type =
|
|
170
|
+
training_data_type = (
|
|
171
|
+
"documentation" if custom_id[-3:] == "doc" else custom_id[-3:]
|
|
172
|
+
)
|
|
157
173
|
|
|
158
174
|
if training_data_type == "sql":
|
|
159
175
|
# Convert the document string to a dictionary
|
|
@@ -162,19 +178,28 @@ class PG_VectorStore(VannaBase):
|
|
|
162
178
|
question = doc_dict.get("question")
|
|
163
179
|
content = doc_dict.get("sql")
|
|
164
180
|
except (ValueError, SyntaxError):
|
|
165
|
-
logging.info(
|
|
181
|
+
logging.info(
|
|
182
|
+
f"Skipping row with custom_id {custom_id} due to parsing error."
|
|
183
|
+
)
|
|
166
184
|
continue
|
|
167
185
|
elif training_data_type in ["documentation", "ddl"]:
|
|
168
186
|
question = None # Default value for question
|
|
169
187
|
content = document
|
|
170
188
|
else:
|
|
171
189
|
# If the suffix is not recognized, skip this row
|
|
172
|
-
logging.info(
|
|
190
|
+
logging.info(
|
|
191
|
+
f"Skipping row with custom_id {custom_id} due to unrecognized training data type."
|
|
192
|
+
)
|
|
173
193
|
continue
|
|
174
194
|
|
|
175
195
|
# Append the processed data to the list
|
|
176
196
|
processed_rows.append(
|
|
177
|
-
{
|
|
197
|
+
{
|
|
198
|
+
"id": custom_id,
|
|
199
|
+
"question": question,
|
|
200
|
+
"content": content,
|
|
201
|
+
"training_data_type": training_data_type,
|
|
202
|
+
}
|
|
178
203
|
)
|
|
179
204
|
|
|
180
205
|
# Create a DataFrame from the list of processed rows
|
|
@@ -218,7 +243,9 @@ class PG_VectorStore(VannaBase):
|
|
|
218
243
|
suffix = suffix_map.get(collection_name)
|
|
219
244
|
|
|
220
245
|
if not suffix:
|
|
221
|
-
logging.info(
|
|
246
|
+
logging.info(
|
|
247
|
+
"Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'."
|
|
248
|
+
)
|
|
222
249
|
return False
|
|
223
250
|
|
|
224
251
|
# SQL query to delete rows based on the condition
|
|
@@ -242,7 +269,9 @@ class PG_VectorStore(VannaBase):
|
|
|
242
269
|
)
|
|
243
270
|
return True
|
|
244
271
|
else:
|
|
245
|
-
logging.info(
|
|
272
|
+
logging.info(
|
|
273
|
+
f"No rows deleted for collection {collection_name}."
|
|
274
|
+
)
|
|
246
275
|
return False
|
|
247
276
|
except Exception as e:
|
|
248
277
|
logging.error(f"An error occurred: {e}")
|
|
@@ -71,12 +71,8 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
71
71
|
self.documentation_collection_name = config.get(
|
|
72
72
|
"documentation_collection_name", "documentation"
|
|
73
73
|
)
|
|
74
|
-
self.ddl_collection_name = config.get(
|
|
75
|
-
|
|
76
|
-
)
|
|
77
|
-
self.sql_collection_name = config.get(
|
|
78
|
-
"sql_collection_name", "sql"
|
|
79
|
-
)
|
|
74
|
+
self.ddl_collection_name = config.get("ddl_collection_name", "ddl")
|
|
75
|
+
self.sql_collection_name = config.get("sql_collection_name", "sql")
|
|
80
76
|
|
|
81
77
|
self.id_suffixes = {
|
|
82
78
|
self.ddl_collection_name: "ddl",
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import qianfan
|
|
2
|
+
|
|
3
|
+
from ..base import VannaBase
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Qianfan_Chat(VannaBase):
|
|
7
|
+
def __init__(self, client=None, config=None):
|
|
8
|
+
VannaBase.__init__(self, config=config)
|
|
9
|
+
|
|
10
|
+
if "api_key" not in config:
|
|
11
|
+
raise Exception("Missing api_key in config")
|
|
12
|
+
self.api_key = config["api_key"]
|
|
13
|
+
|
|
14
|
+
if "secret_key" not in config:
|
|
15
|
+
raise Exception("Missing secret_key in config")
|
|
16
|
+
self.secret_key = config["secret_key"]
|
|
17
|
+
|
|
18
|
+
# default parameters - can be overrided using config
|
|
19
|
+
self.temperature = 0.9
|
|
20
|
+
self.max_tokens = 1024
|
|
21
|
+
|
|
22
|
+
if "temperature" in config:
|
|
23
|
+
self.temperature = config["temperature"]
|
|
24
|
+
|
|
25
|
+
if "max_tokens" in config:
|
|
26
|
+
self.max_tokens = config["max_tokens"]
|
|
27
|
+
|
|
28
|
+
self.model = config["model"] if "model" in config else "ERNIE-Speed"
|
|
29
|
+
|
|
30
|
+
if client is not None:
|
|
31
|
+
self.client = client
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
self.client = qianfan.ChatCompletion(ak=self.api_key, sk=self.secret_key)
|
|
35
|
+
|
|
36
|
+
def system_message(self, message: str) -> any:
|
|
37
|
+
return {"role": "system", "content": message}
|
|
38
|
+
|
|
39
|
+
def user_message(self, message: str) -> any:
|
|
40
|
+
return {"role": "user", "content": message}
|
|
41
|
+
|
|
42
|
+
def assistant_message(self, message: str) -> any:
|
|
43
|
+
return {"role": "assistant", "content": message}
|
|
44
|
+
|
|
45
|
+
def get_sql_prompt(
|
|
46
|
+
self,
|
|
47
|
+
initial_prompt: str,
|
|
48
|
+
question: str,
|
|
49
|
+
question_sql_list: list,
|
|
50
|
+
ddl_list: list,
|
|
51
|
+
doc_list: list,
|
|
52
|
+
**kwargs,
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Example:
|
|
56
|
+
```python
|
|
57
|
+
vn.get_sql_prompt(
|
|
58
|
+
question="What are the top 10 customers by sales?",
|
|
59
|
+
question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
|
|
60
|
+
ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
|
|
61
|
+
doc_list=["The customers table contains information about customers and their sales."],
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
This method is used to generate a prompt for the LLM to generate SQL.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
question (str): The question to generate SQL for.
|
|
70
|
+
question_sql_list (list): A list of questions and their corresponding SQL statements.
|
|
71
|
+
ddl_list (list): A list of DDL statements.
|
|
72
|
+
doc_list (list): A list of documentation.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
any: The prompt for the LLM to generate SQL.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
if initial_prompt is None:
|
|
79
|
+
initial_prompt = (
|
|
80
|
+
f"You are a {self.dialect} expert. "
|
|
81
|
+
+ "Please help to generate a SQL to answer the question based on some context.Please don't give any explanation for your answer. Just only generate a SQL \n"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
initial_prompt = self.add_ddl_to_prompt(
|
|
85
|
+
initial_prompt, ddl_list, max_tokens=self.max_tokens
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if self.static_documentation != "":
|
|
89
|
+
doc_list.append(self.static_documentation)
|
|
90
|
+
|
|
91
|
+
initial_prompt = self.add_documentation_to_prompt(
|
|
92
|
+
initial_prompt, doc_list, max_tokens=self.max_tokens
|
|
93
|
+
)
|
|
94
|
+
message_log = []
|
|
95
|
+
|
|
96
|
+
if question_sql_list is None or len(question_sql_list) == 0:
|
|
97
|
+
initial_prompt = initial_prompt + f"question: {question}"
|
|
98
|
+
message_log.append(self.user_message(initial_prompt))
|
|
99
|
+
else:
|
|
100
|
+
for i, example in question_sql_list:
|
|
101
|
+
if example is None:
|
|
102
|
+
print("example is None")
|
|
103
|
+
else:
|
|
104
|
+
if (
|
|
105
|
+
example is not None
|
|
106
|
+
and "question" in example
|
|
107
|
+
and "sql" in example
|
|
108
|
+
):
|
|
109
|
+
if i == 0:
|
|
110
|
+
initial_prompt = (
|
|
111
|
+
initial_prompt + f"question: {example['question']}"
|
|
112
|
+
)
|
|
113
|
+
message_log.append(self.user_message(initial_prompt))
|
|
114
|
+
else:
|
|
115
|
+
message_log.append(self.user_message(example["question"]))
|
|
116
|
+
message_log.append(self.assistant_message(example["sql"]))
|
|
117
|
+
|
|
118
|
+
message_log.append(self.user_message(question))
|
|
119
|
+
return message_log
|
|
120
|
+
|
|
121
|
+
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
122
|
+
if prompt is None:
|
|
123
|
+
raise Exception("Prompt is None")
|
|
124
|
+
|
|
125
|
+
if len(prompt) == 0:
|
|
126
|
+
raise Exception("Prompt is empty")
|
|
127
|
+
|
|
128
|
+
# Count the number of tokens in the message log
|
|
129
|
+
# Use 4 as an approximation for the number of characters per token
|
|
130
|
+
num_tokens = 0
|
|
131
|
+
for message in prompt:
|
|
132
|
+
num_tokens += len(message["content"]) / 4
|
|
133
|
+
|
|
134
|
+
if kwargs.get("model", None) is not None:
|
|
135
|
+
model = kwargs.get("model", None)
|
|
136
|
+
print(f"Using model {model} for {num_tokens} tokens (approx)")
|
|
137
|
+
response = self.client.do(
|
|
138
|
+
model=self.model,
|
|
139
|
+
messages=prompt,
|
|
140
|
+
max_output_tokens=self.max_tokens,
|
|
141
|
+
stop=None,
|
|
142
|
+
temperature=self.temperature,
|
|
143
|
+
)
|
|
144
|
+
elif self.config is not None and "model" in self.config:
|
|
145
|
+
print(
|
|
146
|
+
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
|
|
147
|
+
)
|
|
148
|
+
response = self.client.do(
|
|
149
|
+
model=self.config.get("model"),
|
|
150
|
+
messages=prompt,
|
|
151
|
+
max_output_tokens=self.max_tokens,
|
|
152
|
+
stop=None,
|
|
153
|
+
temperature=self.temperature,
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
if num_tokens > 3500:
|
|
157
|
+
model = "ERNIE-Speed-128K"
|
|
158
|
+
else:
|
|
159
|
+
model = "ERNIE-Speed-8K"
|
|
160
|
+
|
|
161
|
+
print(f"Using model {model} for {num_tokens} tokens (approx)")
|
|
162
|
+
response = self.client.do(
|
|
163
|
+
model=model,
|
|
164
|
+
messages=prompt,
|
|
165
|
+
max_output_tokens=self.max_tokens,
|
|
166
|
+
stop=None,
|
|
167
|
+
temperature=self.temperature,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return response.body.get("result")
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import qianfan
|
|
2
|
+
|
|
3
|
+
from ..base import VannaBase
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Qianfan_Embeddings(VannaBase):
|
|
7
|
+
def __init__(self, client=None, config=None):
|
|
8
|
+
VannaBase.__init__(self, config=config)
|
|
9
|
+
|
|
10
|
+
if client is not None:
|
|
11
|
+
self.client = client
|
|
12
|
+
return
|
|
13
|
+
|
|
14
|
+
if "api_key" not in config:
|
|
15
|
+
raise Exception("Missing api_key in config")
|
|
16
|
+
self.api_key = config["api_key"]
|
|
17
|
+
|
|
18
|
+
if "secret_key" not in config:
|
|
19
|
+
raise Exception("Missing secret_key in config")
|
|
20
|
+
self.secret_key = config["secret_key"]
|
|
21
|
+
|
|
22
|
+
self.client = qianfan.Embedding(ak=self.api_key, sk=self.secret_key)
|
|
23
|
+
|
|
24
|
+
def generate_embedding(self, data: str, **kwargs) -> list[float]:
|
|
25
|
+
if self.config is not None and "model" in self.config:
|
|
26
|
+
embedding = self.client.do(
|
|
27
|
+
model=self.config["model"],
|
|
28
|
+
input=[data],
|
|
29
|
+
)
|
|
30
|
+
else:
|
|
31
|
+
embedding = self.client.do(
|
|
32
|
+
model="bge-large-zh",
|
|
33
|
+
input=[data],
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
return embedding.get("data")[0]["embedding"]
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from openai import OpenAI
|
|
4
|
+
|
|
5
|
+
from ..base import VannaBase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class QianWenAI_Chat(VannaBase):
|
|
9
|
+
def __init__(self, client=None, config=None):
|
|
10
|
+
VannaBase.__init__(self, config=config)
|
|
11
|
+
|
|
12
|
+
# default parameters - can be overrided using config
|
|
13
|
+
self.temperature = 0.7
|
|
14
|
+
|
|
15
|
+
if "temperature" in config:
|
|
16
|
+
self.temperature = config["temperature"]
|
|
17
|
+
|
|
18
|
+
if "api_type" in config:
|
|
19
|
+
raise Exception(
|
|
20
|
+
"Passing api_type is now deprecated. Please pass an OpenAI client instead."
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
if "api_base" in config:
|
|
24
|
+
raise Exception(
|
|
25
|
+
"Passing api_base is now deprecated. Please pass an OpenAI client instead."
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
if "api_version" in config:
|
|
29
|
+
raise Exception(
|
|
30
|
+
"Passing api_version is now deprecated. Please pass an OpenAI client instead."
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
if client is not None:
|
|
34
|
+
self.client = client
|
|
35
|
+
return
|
|
36
|
+
|
|
37
|
+
if config is None and client is None:
|
|
38
|
+
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
if "api_key" in config:
|
|
42
|
+
if "base_url" not in config:
|
|
43
|
+
self.client = OpenAI(
|
|
44
|
+
api_key=config["api_key"],
|
|
45
|
+
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
self.client = OpenAI(
|
|
49
|
+
api_key=config["api_key"], base_url=config["base_url"]
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def system_message(self, message: str) -> any:
|
|
53
|
+
return {"role": "system", "content": message}
|
|
54
|
+
|
|
55
|
+
def user_message(self, message: str) -> any:
|
|
56
|
+
return {"role": "user", "content": message}
|
|
57
|
+
|
|
58
|
+
def assistant_message(self, message: str) -> any:
|
|
59
|
+
return {"role": "assistant", "content": message}
|
|
60
|
+
|
|
61
|
+
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
62
|
+
if prompt is None:
|
|
63
|
+
raise Exception("Prompt is None")
|
|
64
|
+
|
|
65
|
+
if len(prompt) == 0:
|
|
66
|
+
raise Exception("Prompt is empty")
|
|
67
|
+
|
|
68
|
+
# Count the number of tokens in the message log
|
|
69
|
+
# Use 4 as an approximation for the number of characters per token
|
|
70
|
+
num_tokens = 0
|
|
71
|
+
for message in prompt:
|
|
72
|
+
num_tokens += len(message["content"]) / 4
|
|
73
|
+
|
|
74
|
+
if kwargs.get("model", None) is not None:
|
|
75
|
+
model = kwargs.get("model", None)
|
|
76
|
+
print(f"Using model {model} for {num_tokens} tokens (approx)")
|
|
77
|
+
response = self.client.chat.completions.create(
|
|
78
|
+
model=model,
|
|
79
|
+
messages=prompt,
|
|
80
|
+
stop=None,
|
|
81
|
+
temperature=self.temperature,
|
|
82
|
+
)
|
|
83
|
+
elif kwargs.get("engine", None) is not None:
|
|
84
|
+
engine = kwargs.get("engine", None)
|
|
85
|
+
print(f"Using model {engine} for {num_tokens} tokens (approx)")
|
|
86
|
+
response = self.client.chat.completions.create(
|
|
87
|
+
engine=engine,
|
|
88
|
+
messages=prompt,
|
|
89
|
+
stop=None,
|
|
90
|
+
temperature=self.temperature,
|
|
91
|
+
)
|
|
92
|
+
elif self.config is not None and "engine" in self.config:
|
|
93
|
+
print(
|
|
94
|
+
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
|
|
95
|
+
)
|
|
96
|
+
response = self.client.chat.completions.create(
|
|
97
|
+
engine=self.config["engine"],
|
|
98
|
+
messages=prompt,
|
|
99
|
+
stop=None,
|
|
100
|
+
temperature=self.temperature,
|
|
101
|
+
)
|
|
102
|
+
elif self.config is not None and "model" in self.config:
|
|
103
|
+
print(
|
|
104
|
+
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
|
|
105
|
+
)
|
|
106
|
+
response = self.client.chat.completions.create(
|
|
107
|
+
model=self.config["model"],
|
|
108
|
+
messages=prompt,
|
|
109
|
+
stop=None,
|
|
110
|
+
temperature=self.temperature,
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
if num_tokens > 3500:
|
|
114
|
+
model = "qwen-long"
|
|
115
|
+
else:
|
|
116
|
+
model = "qwen-plus"
|
|
117
|
+
|
|
118
|
+
print(f"Using model {model} for {num_tokens} tokens (approx)")
|
|
119
|
+
response = self.client.chat.completions.create(
|
|
120
|
+
model=model,
|
|
121
|
+
messages=prompt,
|
|
122
|
+
stop=None,
|
|
123
|
+
temperature=self.temperature,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Find the first response from the chatbot that has text in it (some responses may not have text)
|
|
127
|
+
for choice in response.choices:
|
|
128
|
+
if "text" in choice:
|
|
129
|
+
return choice.text
|
|
130
|
+
|
|
131
|
+
# If no response with text is found, return the first response's content (which may be empty)
|
|
132
|
+
return response.choices[0].message.content
|
|
@@ -8,31 +8,31 @@ import requests
|
|
|
8
8
|
|
|
9
9
|
from .base import VannaBase
|
|
10
10
|
from .types import (
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
11
|
+
AccuracyStats,
|
|
12
|
+
ApiKey,
|
|
13
|
+
DataFrameJSON,
|
|
14
|
+
DataResult,
|
|
15
|
+
Explanation,
|
|
16
|
+
FullQuestionDocument,
|
|
17
|
+
NewOrganization,
|
|
18
|
+
NewOrganizationMember,
|
|
19
|
+
Organization,
|
|
20
|
+
OrganizationList,
|
|
21
|
+
PlotlyResult,
|
|
22
|
+
Question,
|
|
23
|
+
QuestionCategory,
|
|
24
|
+
QuestionId,
|
|
25
|
+
QuestionList,
|
|
26
|
+
QuestionSQLPair,
|
|
27
|
+
QuestionStringList,
|
|
28
|
+
SQLAnswer,
|
|
29
|
+
Status,
|
|
30
|
+
StatusWithId,
|
|
31
|
+
StringData,
|
|
32
|
+
TrainingData,
|
|
33
|
+
UserEmail,
|
|
34
|
+
UserOTP,
|
|
35
|
+
Visibility,
|
|
36
36
|
)
|
|
37
37
|
from .vannadb import VannaDB_VectorStore
|
|
38
38
|
|
|
@@ -40,7 +40,9 @@ from .vannadb import VannaDB_VectorStore
|
|
|
40
40
|
class VannaDefault(VannaDB_VectorStore):
|
|
41
41
|
def __init__(self, model: str, api_key: str, config=None):
|
|
42
42
|
VannaBase.__init__(self, config=config)
|
|
43
|
-
VannaDB_VectorStore.__init__(
|
|
43
|
+
VannaDB_VectorStore.__init__(
|
|
44
|
+
self, vanna_model=model, vanna_api_key=api_key, config=config
|
|
45
|
+
)
|
|
44
46
|
|
|
45
47
|
self._model = model
|
|
46
48
|
self._api_key = api_key
|
|
@@ -9,18 +9,14 @@ from .exceptions import ImproperlyConfigured, ValidationError
|
|
|
9
9
|
|
|
10
10
|
def validate_config_path(path):
|
|
11
11
|
if not os.path.exists(path):
|
|
12
|
-
raise ImproperlyConfigured(
|
|
13
|
-
f'No such configuration file: {path}'
|
|
14
|
-
)
|
|
12
|
+
raise ImproperlyConfigured(f"No such configuration file: {path}")
|
|
15
13
|
|
|
16
14
|
if not os.path.isfile(path):
|
|
17
|
-
raise ImproperlyConfigured(
|
|
18
|
-
f'Config should be a file: {path}'
|
|
19
|
-
)
|
|
15
|
+
raise ImproperlyConfigured(f"Config should be a file: {path}")
|
|
20
16
|
|
|
21
17
|
if not os.access(path, os.R_OK):
|
|
22
18
|
raise ImproperlyConfigured(
|
|
23
|
-
f
|
|
19
|
+
f"Cannot read the config file. Please grant read privileges: {path}"
|
|
24
20
|
)
|
|
25
21
|
|
|
26
22
|
|
|
@@ -31,13 +27,12 @@ def sanitize_model_name(model_name):
|
|
|
31
27
|
# Replace spaces with a hyphen
|
|
32
28
|
model_name = model_name.replace(" ", "-")
|
|
33
29
|
|
|
34
|
-
if
|
|
35
|
-
|
|
30
|
+
if "-" in model_name:
|
|
36
31
|
# remove double hyphones
|
|
37
32
|
model_name = re.sub(r"-+", "-", model_name)
|
|
38
|
-
if
|
|
33
|
+
if "_" in model_name:
|
|
39
34
|
# If name contains both underscores and hyphen replace all underscores with hyphens
|
|
40
|
-
model_name = re.sub(r
|
|
35
|
+
model_name = re.sub(r"_", "-", model_name)
|
|
41
36
|
|
|
42
37
|
# Remove special characters only allow underscore
|
|
43
38
|
model_name = re.sub(r"[^a-zA-Z0-9-_]", "", model_name)
|