vanna 0.7.9__py3-none-any.whl → 2.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +439 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +224 -217
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0rc1.dist-info/METADATA +868 -0
- vanna-2.0.0rc1.dist-info/RECORD +289 -0
- vanna-2.0.0rc1.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.9.dist-info/METADATA +0 -408
- vanna-0.7.9.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/WHEEL +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .bedrock_converse import Bedrock_Converse
|
|
@@ -6,6 +6,7 @@ try:
|
|
|
6
6
|
except ImportError:
|
|
7
7
|
raise ImportError("Please install boto3 and botocore to use Amazon Bedrock models")
|
|
8
8
|
|
|
9
|
+
|
|
9
10
|
class Bedrock_Converse(VannaBase):
|
|
10
11
|
def __init__(self, client=None, config=None):
|
|
11
12
|
VannaBase.__init__(self, config=config)
|
|
@@ -13,29 +14,27 @@ class Bedrock_Converse(VannaBase):
|
|
|
13
14
|
# default parameters
|
|
14
15
|
self.temperature = 0.0
|
|
15
16
|
self.max_tokens = 500
|
|
16
|
-
|
|
17
|
+
|
|
17
18
|
if client is None:
|
|
18
19
|
raise ValueError(
|
|
19
20
|
"A valid Bedrock runtime client must be provided to invoke Bedrock models"
|
|
20
21
|
)
|
|
21
22
|
else:
|
|
22
23
|
self.client = client
|
|
23
|
-
|
|
24
|
+
|
|
24
25
|
if config is None:
|
|
25
26
|
raise ValueError(
|
|
26
27
|
"Config is required with model_id and inference parameters"
|
|
27
28
|
)
|
|
28
|
-
|
|
29
|
+
|
|
29
30
|
if "modelId" not in config:
|
|
30
|
-
raise ValueError(
|
|
31
|
-
"config must contain a modelId to invoke"
|
|
32
|
-
)
|
|
31
|
+
raise ValueError("config must contain a modelId to invoke")
|
|
33
32
|
else:
|
|
34
33
|
self.model = config["modelId"]
|
|
35
|
-
|
|
34
|
+
|
|
36
35
|
if "temperature" in config:
|
|
37
36
|
self.temperature = config["temperature"]
|
|
38
|
-
|
|
37
|
+
|
|
39
38
|
if "max_tokens" in config:
|
|
40
39
|
self.max_tokens = config["max_tokens"]
|
|
41
40
|
|
|
@@ -51,7 +50,7 @@ class Bedrock_Converse(VannaBase):
|
|
|
51
50
|
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
52
51
|
inference_config = {
|
|
53
52
|
"temperature": self.temperature,
|
|
54
|
-
"maxTokens": self.max_tokens
|
|
53
|
+
"maxTokens": self.max_tokens,
|
|
55
54
|
}
|
|
56
55
|
additional_model_fields = {
|
|
57
56
|
"top_p": 1, # setting top_p value for nucleus sampling
|
|
@@ -64,13 +63,15 @@ class Bedrock_Converse(VannaBase):
|
|
|
64
63
|
if role == "system":
|
|
65
64
|
system_message = prompt_message["content"]
|
|
66
65
|
else:
|
|
67
|
-
no_system_prompt.append(
|
|
66
|
+
no_system_prompt.append(
|
|
67
|
+
{"role": role, "content": [{"text": prompt_message["content"]}]}
|
|
68
|
+
)
|
|
68
69
|
|
|
69
70
|
converse_api_params = {
|
|
70
71
|
"modelId": self.model,
|
|
71
72
|
"messages": no_system_prompt,
|
|
72
73
|
"inferenceConfig": inference_config,
|
|
73
|
-
"additionalModelRequestFields": additional_model_fields
|
|
74
|
+
"additionalModelRequestFields": additional_model_fields,
|
|
74
75
|
}
|
|
75
76
|
|
|
76
77
|
if system_message:
|
|
@@ -82,4 +83,4 @@ class Bedrock_Converse(VannaBase):
|
|
|
82
83
|
return text_content
|
|
83
84
|
except ClientError as err:
|
|
84
85
|
message = err.response["Error"]["Message"]
|
|
85
|
-
raise Exception(f"A Bedrock client error occurred: {message}")
|
|
86
|
+
raise Exception(f"A Bedrock client error occurred: {message}")
|
|
@@ -23,7 +23,9 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
23
23
|
curr_client = config.get("client", "persistent")
|
|
24
24
|
collection_metadata = config.get("collection_metadata", None)
|
|
25
25
|
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
|
|
26
|
-
self.n_results_documentation = config.get(
|
|
26
|
+
self.n_results_documentation = config.get(
|
|
27
|
+
"n_results_documentation", config.get("n_results", 10)
|
|
28
|
+
)
|
|
27
29
|
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
|
|
28
30
|
|
|
29
31
|
if curr_client == "persistent":
|
|
@@ -25,15 +25,17 @@ class Cohere_Chat(VannaBase):
|
|
|
25
25
|
|
|
26
26
|
# Check for API key in environment variable
|
|
27
27
|
api_key = os.getenv("COHERE_API_KEY")
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
# Check for API key in config
|
|
30
30
|
if config is not None and "api_key" in config:
|
|
31
31
|
api_key = config["api_key"]
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
# Validate API key
|
|
34
34
|
if not api_key:
|
|
35
|
-
raise ValueError(
|
|
36
|
-
|
|
35
|
+
raise ValueError(
|
|
36
|
+
"Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable."
|
|
37
|
+
)
|
|
38
|
+
|
|
37
39
|
# Initialize client with validated API key
|
|
38
40
|
self.client = OpenAI(
|
|
39
41
|
base_url="https://api.cohere.ai/compatibility/v1",
|
|
@@ -41,7 +43,10 @@ class Cohere_Chat(VannaBase):
|
|
|
41
43
|
)
|
|
42
44
|
|
|
43
45
|
def system_message(self, message: str) -> any:
|
|
44
|
-
return {
|
|
46
|
+
return {
|
|
47
|
+
"role": "developer",
|
|
48
|
+
"content": message,
|
|
49
|
+
} # Cohere uses 'developer' for system role
|
|
45
50
|
|
|
46
51
|
def user_message(self, message: str) -> any:
|
|
47
52
|
return {"role": "user", "content": message}
|
|
@@ -74,21 +79,21 @@ class Cohere_Chat(VannaBase):
|
|
|
74
79
|
messages=prompt,
|
|
75
80
|
temperature=self.temperature,
|
|
76
81
|
)
|
|
77
|
-
|
|
82
|
+
|
|
78
83
|
# Check if response has expected structure
|
|
79
|
-
if not response or not hasattr(response,
|
|
84
|
+
if not response or not hasattr(response, "choices") or not response.choices:
|
|
80
85
|
raise ValueError("Received empty or malformed response from API")
|
|
81
|
-
|
|
82
|
-
if not response.choices[0] or not hasattr(response.choices[0],
|
|
86
|
+
|
|
87
|
+
if not response.choices[0] or not hasattr(response.choices[0], "message"):
|
|
83
88
|
raise ValueError("Response is missing expected 'message' field")
|
|
84
|
-
|
|
85
|
-
if not hasattr(response.choices[0].message,
|
|
89
|
+
|
|
90
|
+
if not hasattr(response.choices[0].message, "content"):
|
|
86
91
|
raise ValueError("Response message is missing expected 'content' field")
|
|
87
|
-
|
|
92
|
+
|
|
88
93
|
return response.choices[0].message.content
|
|
89
|
-
|
|
94
|
+
|
|
90
95
|
except Exception as e:
|
|
91
96
|
# Log the error and raise a more informative exception
|
|
92
97
|
error_msg = f"Error processing Cohere chat response: {str(e)}"
|
|
93
98
|
print(error_msg)
|
|
94
|
-
raise Exception(error_msg)
|
|
99
|
+
raise Exception(error_msg)
|
|
@@ -8,10 +8,10 @@ from ..base import VannaBase
|
|
|
8
8
|
class Cohere_Embeddings(VannaBase):
|
|
9
9
|
def __init__(self, client=None, config=None):
|
|
10
10
|
VannaBase.__init__(self, config=config)
|
|
11
|
-
|
|
11
|
+
|
|
12
12
|
# Default embedding model
|
|
13
13
|
self.model = "embed-multilingual-v3.0"
|
|
14
|
-
|
|
14
|
+
|
|
15
15
|
if config is not None and "model" in config:
|
|
16
16
|
self.model = config["model"]
|
|
17
17
|
|
|
@@ -21,15 +21,17 @@ class Cohere_Embeddings(VannaBase):
|
|
|
21
21
|
|
|
22
22
|
# Check for API key in environment variable
|
|
23
23
|
api_key = os.getenv("COHERE_API_KEY")
|
|
24
|
-
|
|
24
|
+
|
|
25
25
|
# Check for API key in config
|
|
26
26
|
if config is not None and "api_key" in config:
|
|
27
27
|
api_key = config["api_key"]
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
# Validate API key
|
|
30
30
|
if not api_key:
|
|
31
|
-
raise ValueError(
|
|
32
|
-
|
|
31
|
+
raise ValueError(
|
|
32
|
+
"Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable."
|
|
33
|
+
)
|
|
34
|
+
|
|
33
35
|
# Initialize client with validated API key
|
|
34
36
|
self.client = OpenAI(
|
|
35
37
|
base_url="https://api.cohere.ai/compatibility/v1",
|
|
@@ -39,33 +41,37 @@ class Cohere_Embeddings(VannaBase):
|
|
|
39
41
|
def generate_embedding(self, data: str, **kwargs) -> list[float]:
|
|
40
42
|
if not data:
|
|
41
43
|
raise ValueError("Cannot generate embedding for empty input data")
|
|
42
|
-
|
|
44
|
+
|
|
43
45
|
# Use model from kwargs, config, or default
|
|
44
46
|
model = kwargs.get("model", self.model)
|
|
45
47
|
if self.config is not None and "model" in self.config and model == self.model:
|
|
46
48
|
model = self.config["model"]
|
|
47
|
-
|
|
48
|
-
try:
|
|
49
|
+
|
|
50
|
+
try:
|
|
49
51
|
embedding = self.client.embeddings.create(
|
|
50
52
|
model=model,
|
|
51
53
|
input=data,
|
|
52
54
|
encoding_format="float", # Ensure we get float values
|
|
53
55
|
)
|
|
54
|
-
|
|
56
|
+
|
|
55
57
|
# Check if response has expected structure
|
|
56
|
-
if not embedding or not hasattr(embedding,
|
|
57
|
-
raise ValueError(
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
58
|
+
if not embedding or not hasattr(embedding, "data") or not embedding.data:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"Received empty or malformed embedding response from API"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if not embedding.data[0] or not hasattr(embedding.data[0], "embedding"):
|
|
64
|
+
raise ValueError(
|
|
65
|
+
"Embedding response is missing expected 'embedding' field"
|
|
66
|
+
)
|
|
67
|
+
|
|
62
68
|
if not embedding.data[0].embedding:
|
|
63
69
|
raise ValueError("Received empty embedding vector")
|
|
64
|
-
|
|
70
|
+
|
|
65
71
|
return embedding.data[0].embedding
|
|
66
|
-
|
|
72
|
+
|
|
67
73
|
except Exception as e:
|
|
68
74
|
# Log the error and raise a more informative exception
|
|
69
75
|
error_msg = f"Error generating embedding with Cohere: {str(e)}"
|
|
70
76
|
print(error_msg)
|
|
71
|
-
raise Exception(error_msg)
|
|
77
|
+
raise Exception(error_msg)
|
|
@@ -5,7 +5,6 @@ from openai import OpenAI
|
|
|
5
5
|
from ..base import VannaBase
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
|
|
9
8
|
# from vanna.chromadb import ChromaDB_VectorStore
|
|
10
9
|
|
|
11
10
|
# class DeepSeekVanna(ChromaDB_VectorStore, DeepSeekChat):
|
|
@@ -27,12 +26,12 @@ class DeepSeekChat(VannaBase):
|
|
|
27
26
|
|
|
28
27
|
if "model" not in config:
|
|
29
28
|
raise ValueError("config must contain a DeepSeek model")
|
|
30
|
-
|
|
29
|
+
|
|
31
30
|
api_key = config["api_key"]
|
|
32
31
|
model = config["model"]
|
|
33
32
|
self.model = model
|
|
34
33
|
self.client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com/v1")
|
|
35
|
-
|
|
34
|
+
|
|
36
35
|
def system_message(self, message: str) -> any:
|
|
37
36
|
return {"role": "system", "content": message}
|
|
38
37
|
|
|
@@ -45,10 +44,10 @@ class DeepSeekChat(VannaBase):
|
|
|
45
44
|
def generate_sql(self, question: str, **kwargs) -> str:
|
|
46
45
|
# 使用父类的 generate_sql
|
|
47
46
|
sql = super().generate_sql(question, **kwargs)
|
|
48
|
-
|
|
47
|
+
|
|
49
48
|
# 替换 "\_" 为 "_"
|
|
50
49
|
sql = sql.replace("\\_", "_")
|
|
51
|
-
|
|
50
|
+
|
|
52
51
|
return sql
|
|
53
52
|
|
|
54
53
|
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
@@ -56,5 +55,5 @@ class DeepSeekChat(VannaBase):
|
|
|
56
55
|
model=self.model,
|
|
57
56
|
messages=prompt,
|
|
58
57
|
)
|
|
59
|
-
|
|
58
|
+
|
|
60
59
|
return chat_response.choices[0].message.content
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .faiss import FAISS
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import os
|
|
1
|
+
import os
|
|
2
2
|
import json
|
|
3
3
|
import uuid
|
|
4
4
|
from typing import List, Dict, Any
|
|
@@ -10,13 +10,14 @@ import pandas as pd
|
|
|
10
10
|
from ..base import VannaBase
|
|
11
11
|
from ..exceptions import DependencyError
|
|
12
12
|
|
|
13
|
+
|
|
13
14
|
class FAISS(VannaBase):
|
|
14
15
|
def __init__(self, config=None):
|
|
15
16
|
if config is None:
|
|
16
17
|
config = {}
|
|
17
|
-
|
|
18
|
+
|
|
18
19
|
VannaBase.__init__(self, config=config)
|
|
19
|
-
|
|
20
|
+
|
|
20
21
|
try:
|
|
21
22
|
import faiss
|
|
22
23
|
except ImportError:
|
|
@@ -30,34 +31,48 @@ class FAISS(VannaBase):
|
|
|
30
31
|
raise DependencyError(
|
|
31
32
|
"SentenceTransformer is not installed. Please install it with 'pip install sentence-transformers'."
|
|
32
33
|
)
|
|
33
|
-
|
|
34
|
+
|
|
34
35
|
self.path = config.get("path", ".")
|
|
35
|
-
self.embedding_dim = config.get(
|
|
36
|
-
self.n_results_sql = config.get(
|
|
37
|
-
self.n_results_ddl = config.get(
|
|
38
|
-
self.n_results_documentation = config.get(
|
|
36
|
+
self.embedding_dim = config.get("embedding_dim", 384)
|
|
37
|
+
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
|
|
38
|
+
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
|
|
39
|
+
self.n_results_documentation = config.get(
|
|
40
|
+
"n_results_documentation", config.get("n_results", 10)
|
|
41
|
+
)
|
|
39
42
|
self.curr_client = config.get("client", "persistent")
|
|
40
43
|
|
|
41
|
-
if self.curr_client ==
|
|
42
|
-
self.sql_index = self._load_or_create_index(
|
|
43
|
-
self.ddl_index = self._load_or_create_index(
|
|
44
|
-
self.doc_index = self._load_or_create_index(
|
|
45
|
-
elif self.curr_client ==
|
|
44
|
+
if self.curr_client == "persistent":
|
|
45
|
+
self.sql_index = self._load_or_create_index("sql_index.faiss")
|
|
46
|
+
self.ddl_index = self._load_or_create_index("ddl_index.faiss")
|
|
47
|
+
self.doc_index = self._load_or_create_index("doc_index.faiss")
|
|
48
|
+
elif self.curr_client == "in-memory":
|
|
46
49
|
self.sql_index = faiss.IndexFlatL2(self.embedding_dim)
|
|
47
50
|
self.ddl_index = faiss.IndexFlatL2(self.embedding_dim)
|
|
48
51
|
self.doc_index = faiss.IndexFlatL2(self.embedding_dim)
|
|
49
|
-
elif
|
|
52
|
+
elif (
|
|
53
|
+
isinstance(self.curr_client, list)
|
|
54
|
+
and len(self.curr_client) == 3
|
|
55
|
+
and all(isinstance(idx, faiss.Index) for idx in self.curr_client)
|
|
56
|
+
):
|
|
50
57
|
self.sql_index = self.curr_client[0]
|
|
51
58
|
self.ddl_index = self.curr_client[1]
|
|
52
59
|
self.doc_index = self.curr_client[2]
|
|
53
60
|
else:
|
|
54
|
-
raise ValueError(
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
self.ddl_metadata: List[Dict[str, str]] = self._load_or_create_metadata('ddl_metadata.json')
|
|
58
|
-
self.doc_metadata: List[Dict[str, str]] = self._load_or_create_metadata('doc_metadata.json')
|
|
61
|
+
raise ValueError(
|
|
62
|
+
f"Unsupported storage type was set in config: {self.curr_client}"
|
|
63
|
+
)
|
|
59
64
|
|
|
60
|
-
|
|
65
|
+
self.sql_metadata: List[Dict[str, Any]] = self._load_or_create_metadata(
|
|
66
|
+
"sql_metadata.json"
|
|
67
|
+
)
|
|
68
|
+
self.ddl_metadata: List[Dict[str, str]] = self._load_or_create_metadata(
|
|
69
|
+
"ddl_metadata.json"
|
|
70
|
+
)
|
|
71
|
+
self.doc_metadata: List[Dict[str, str]] = self._load_or_create_metadata(
|
|
72
|
+
"doc_metadata.json"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
model_name = config.get("embedding_model", "all-MiniLM-L6-v2")
|
|
61
76
|
self.embedding_model = SentenceTransformer(model_name)
|
|
62
77
|
|
|
63
78
|
def _load_or_create_index(self, filename):
|
|
@@ -69,25 +84,26 @@ class FAISS(VannaBase):
|
|
|
69
84
|
def _load_or_create_metadata(self, filename):
|
|
70
85
|
filepath = os.path.join(self.path, filename)
|
|
71
86
|
if os.path.exists(filepath):
|
|
72
|
-
with open(filepath,
|
|
87
|
+
with open(filepath, "r") as f:
|
|
73
88
|
return json.load(f)
|
|
74
89
|
return []
|
|
75
90
|
|
|
76
91
|
def _save_index(self, index, filename):
|
|
77
|
-
if self.curr_client ==
|
|
92
|
+
if self.curr_client == "persistent":
|
|
78
93
|
filepath = os.path.join(self.path, filename)
|
|
79
94
|
faiss.write_index(index, filepath)
|
|
80
95
|
|
|
81
96
|
def _save_metadata(self, metadata, filename):
|
|
82
|
-
if self.curr_client ==
|
|
97
|
+
if self.curr_client == "persistent":
|
|
83
98
|
filepath = os.path.join(self.path, filename)
|
|
84
|
-
with open(filepath,
|
|
99
|
+
with open(filepath, "w") as f:
|
|
85
100
|
json.dump(metadata, f)
|
|
86
101
|
|
|
87
102
|
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
88
103
|
embedding = self.embedding_model.encode(data)
|
|
89
|
-
assert embedding.shape[0] == self.embedding_dim,
|
|
104
|
+
assert embedding.shape[0] == self.embedding_dim, (
|
|
90
105
|
f"Embedding dimension mismatch: expected {self.embedding_dim}, got {embedding.shape[0]}"
|
|
106
|
+
)
|
|
91
107
|
return embedding.tolist()
|
|
92
108
|
|
|
93
109
|
def _add_to_index(self, index, metadata_list, text, extra_metadata=None) -> str:
|
|
@@ -96,81 +112,119 @@ class FAISS(VannaBase):
|
|
|
96
112
|
entry_id = str(uuid.uuid4())
|
|
97
113
|
metadata_list.append({"id": entry_id, **(extra_metadata or {})})
|
|
98
114
|
return entry_id
|
|
99
|
-
|
|
115
|
+
|
|
100
116
|
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
101
|
-
entry_id = self._add_to_index(
|
|
102
|
-
|
|
103
|
-
|
|
117
|
+
entry_id = self._add_to_index(
|
|
118
|
+
self.sql_index,
|
|
119
|
+
self.sql_metadata,
|
|
120
|
+
question + " " + sql,
|
|
121
|
+
{"question": question, "sql": sql},
|
|
122
|
+
)
|
|
123
|
+
self._save_index(self.sql_index, "sql_index.faiss")
|
|
124
|
+
self._save_metadata(self.sql_metadata, "sql_metadata.json")
|
|
104
125
|
return entry_id
|
|
105
126
|
|
|
106
127
|
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
107
|
-
entry_id = self._add_to_index(
|
|
108
|
-
|
|
109
|
-
|
|
128
|
+
entry_id = self._add_to_index(
|
|
129
|
+
self.ddl_index, self.ddl_metadata, ddl, {"ddl": ddl}
|
|
130
|
+
)
|
|
131
|
+
self._save_index(self.ddl_index, "ddl_index.faiss")
|
|
132
|
+
self._save_metadata(self.ddl_metadata, "ddl_metadata.json")
|
|
110
133
|
return entry_id
|
|
111
134
|
|
|
112
135
|
def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
113
|
-
entry_id = self._add_to_index(
|
|
114
|
-
|
|
115
|
-
|
|
136
|
+
entry_id = self._add_to_index(
|
|
137
|
+
self.doc_index,
|
|
138
|
+
self.doc_metadata,
|
|
139
|
+
documentation,
|
|
140
|
+
{"documentation": documentation},
|
|
141
|
+
)
|
|
142
|
+
self._save_index(self.doc_index, "doc_index.faiss")
|
|
143
|
+
self._save_metadata(self.doc_metadata, "doc_metadata.json")
|
|
116
144
|
return entry_id
|
|
117
145
|
|
|
118
146
|
def _get_similar(self, index, metadata_list, text, n_results) -> list:
|
|
119
147
|
embedding = self.generate_embedding(text)
|
|
120
148
|
D, I = index.search(np.array([embedding], dtype=np.float32), k=n_results)
|
|
121
|
-
return
|
|
149
|
+
return (
|
|
150
|
+
[] if len(I[0]) == 0 or I[0][0] == -1 else [metadata_list[i] for i in I[0]]
|
|
151
|
+
)
|
|
122
152
|
|
|
123
153
|
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
124
|
-
return self._get_similar(
|
|
125
|
-
|
|
154
|
+
return self._get_similar(
|
|
155
|
+
self.sql_index, self.sql_metadata, question, self.n_results_sql
|
|
156
|
+
)
|
|
157
|
+
|
|
126
158
|
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
127
|
-
return [
|
|
159
|
+
return [
|
|
160
|
+
metadata["ddl"]
|
|
161
|
+
for metadata in self._get_similar(
|
|
162
|
+
self.ddl_index, self.ddl_metadata, question, self.n_results_ddl
|
|
163
|
+
)
|
|
164
|
+
]
|
|
128
165
|
|
|
129
166
|
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
130
|
-
return [
|
|
167
|
+
return [
|
|
168
|
+
metadata["documentation"]
|
|
169
|
+
for metadata in self._get_similar(
|
|
170
|
+
self.doc_index,
|
|
171
|
+
self.doc_metadata,
|
|
172
|
+
question,
|
|
173
|
+
self.n_results_documentation,
|
|
174
|
+
)
|
|
175
|
+
]
|
|
131
176
|
|
|
132
177
|
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
133
178
|
sql_data = pd.DataFrame(self.sql_metadata)
|
|
134
|
-
sql_data[
|
|
179
|
+
sql_data["training_data_type"] = "sql"
|
|
135
180
|
|
|
136
181
|
ddl_data = pd.DataFrame(self.ddl_metadata)
|
|
137
|
-
ddl_data[
|
|
182
|
+
ddl_data["training_data_type"] = "ddl"
|
|
138
183
|
|
|
139
184
|
doc_data = pd.DataFrame(self.doc_metadata)
|
|
140
|
-
doc_data[
|
|
185
|
+
doc_data["training_data_type"] = "documentation"
|
|
141
186
|
|
|
142
187
|
return pd.concat([sql_data, ddl_data, doc_data], ignore_index=True)
|
|
143
188
|
|
|
144
189
|
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
145
190
|
for metadata_list, index, index_name in [
|
|
146
|
-
(self.sql_metadata, self.sql_index,
|
|
147
|
-
(self.ddl_metadata, self.ddl_index,
|
|
148
|
-
(self.doc_metadata, self.doc_index,
|
|
191
|
+
(self.sql_metadata, self.sql_index, "sql_index.faiss"),
|
|
192
|
+
(self.ddl_metadata, self.ddl_index, "ddl_index.faiss"),
|
|
193
|
+
(self.doc_metadata, self.doc_index, "doc_index.faiss"),
|
|
149
194
|
]:
|
|
150
195
|
for i, item in enumerate(metadata_list):
|
|
151
|
-
if item[
|
|
196
|
+
if item["id"] == id:
|
|
152
197
|
del metadata_list[i]
|
|
153
198
|
new_index = faiss.IndexFlatL2(self.embedding_dim)
|
|
154
|
-
embeddings = [
|
|
199
|
+
embeddings = [
|
|
200
|
+
self.generate_embedding(json.dumps(m)) for m in metadata_list
|
|
201
|
+
]
|
|
155
202
|
if embeddings:
|
|
156
203
|
new_index.add(np.array(embeddings, dtype=np.float32))
|
|
157
|
-
setattr(self, index_name.split(
|
|
158
|
-
|
|
159
|
-
if self.curr_client ==
|
|
204
|
+
setattr(self, index_name.split(".")[0], new_index)
|
|
205
|
+
|
|
206
|
+
if self.curr_client == "persistent":
|
|
160
207
|
self._save_index(new_index, index_name)
|
|
161
|
-
self._save_metadata(
|
|
162
|
-
|
|
208
|
+
self._save_metadata(
|
|
209
|
+
metadata_list, f"{index_name.split('.')[0]}_metadata.json"
|
|
210
|
+
)
|
|
211
|
+
|
|
163
212
|
return True
|
|
164
213
|
return False
|
|
165
214
|
|
|
166
215
|
def remove_collection(self, collection_name: str) -> bool:
|
|
167
216
|
if collection_name in ["sql", "ddl", "documentation"]:
|
|
168
|
-
setattr(
|
|
217
|
+
setattr(
|
|
218
|
+
self, f"{collection_name}_index", faiss.IndexFlatL2(self.embedding_dim)
|
|
219
|
+
)
|
|
169
220
|
setattr(self, f"{collection_name}_metadata", [])
|
|
170
|
-
|
|
171
|
-
if self.curr_client ==
|
|
172
|
-
self._save_index(
|
|
221
|
+
|
|
222
|
+
if self.curr_client == "persistent":
|
|
223
|
+
self._save_index(
|
|
224
|
+
getattr(self, f"{collection_name}_index"),
|
|
225
|
+
f"{collection_name}_index.faiss",
|
|
226
|
+
)
|
|
173
227
|
self._save_metadata([], f"{collection_name}_metadata.json")
|
|
174
|
-
|
|
228
|
+
|
|
175
229
|
return True
|
|
176
|
-
return False
|
|
230
|
+
return False
|