vanna 0.7.8__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +461 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +247 -223
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0.dist-info/METADATA +485 -0
- vanna-2.0.0.dist-info/RECORD +289 -0
- vanna-2.0.0.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.8.dist-info/METADATA +0 -408
- vanna-0.7.8.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
- {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,416 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local vector database implementation of AgentMemory.
|
|
3
|
+
|
|
4
|
+
This implementation uses ChromaDB for local vector storage of tool usage patterns.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import hashlib
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
import asyncio
|
|
12
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import chromadb
|
|
16
|
+
from chromadb.config import Settings
|
|
17
|
+
from chromadb.utils import embedding_functions
|
|
18
|
+
|
|
19
|
+
CHROMADB_AVAILABLE = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
CHROMADB_AVAILABLE = False
|
|
22
|
+
|
|
23
|
+
from vanna.capabilities.agent_memory import (
|
|
24
|
+
AgentMemory,
|
|
25
|
+
TextMemory,
|
|
26
|
+
TextMemorySearchResult,
|
|
27
|
+
ToolMemory,
|
|
28
|
+
ToolMemorySearchResult,
|
|
29
|
+
)
|
|
30
|
+
from vanna.core.tool import ToolContext
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ChromaAgentMemory(AgentMemory):
|
|
34
|
+
"""ChromaDB-based implementation of AgentMemory."""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
persist_directory: str = "./chroma_memory",
|
|
39
|
+
collection_name: str = "tool_memories",
|
|
40
|
+
embedding_function=None,
|
|
41
|
+
):
|
|
42
|
+
if not CHROMADB_AVAILABLE:
|
|
43
|
+
raise ImportError(
|
|
44
|
+
"ChromaDB is required for ChromaAgentMemory. Install with: pip install chromadb"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
self.persist_directory = persist_directory
|
|
48
|
+
self.collection_name = collection_name
|
|
49
|
+
self._client = None
|
|
50
|
+
self._collection = None
|
|
51
|
+
self._executor = ThreadPoolExecutor(max_workers=2)
|
|
52
|
+
self._embedding_function = embedding_function
|
|
53
|
+
|
|
54
|
+
def _get_client(self):
|
|
55
|
+
"""Get or create ChromaDB client."""
|
|
56
|
+
if self._client is None:
|
|
57
|
+
self._client = chromadb.PersistentClient(
|
|
58
|
+
path=self.persist_directory,
|
|
59
|
+
settings=Settings(anonymized_telemetry=False, allow_reset=True),
|
|
60
|
+
)
|
|
61
|
+
return self._client
|
|
62
|
+
|
|
63
|
+
def _get_embedding_function(self):
|
|
64
|
+
"""Get or create the embedding function.
|
|
65
|
+
|
|
66
|
+
If no embedding function was provided during initialization,
|
|
67
|
+
uses ChromaDB's default embedding function.
|
|
68
|
+
"""
|
|
69
|
+
if self._embedding_function is None:
|
|
70
|
+
# Use ChromaDB's default embedding function
|
|
71
|
+
# This avoids requiring sentence-transformers as a hard dependency
|
|
72
|
+
self._embedding_function = embedding_functions.DefaultEmbeddingFunction()
|
|
73
|
+
return self._embedding_function
|
|
74
|
+
|
|
75
|
+
def _get_collection(self):
|
|
76
|
+
"""Get or create ChromaDB collection."""
|
|
77
|
+
if self._collection is None:
|
|
78
|
+
client = self._get_client()
|
|
79
|
+
embedding_func = self._get_embedding_function()
|
|
80
|
+
try:
|
|
81
|
+
self._collection = client.get_collection(
|
|
82
|
+
name=self.collection_name, embedding_function=embedding_func
|
|
83
|
+
)
|
|
84
|
+
except Exception:
|
|
85
|
+
self._collection = client.create_collection(
|
|
86
|
+
name=self.collection_name,
|
|
87
|
+
embedding_function=embedding_func,
|
|
88
|
+
metadata={"description": "Tool usage memories for learning"},
|
|
89
|
+
)
|
|
90
|
+
return self._collection
|
|
91
|
+
|
|
92
|
+
def _create_memory_id(self) -> str:
|
|
93
|
+
"""Create a unique ID for a memory."""
|
|
94
|
+
import uuid
|
|
95
|
+
|
|
96
|
+
return str(uuid.uuid4())
|
|
97
|
+
|
|
98
|
+
async def save_tool_usage(
|
|
99
|
+
self,
|
|
100
|
+
question: str,
|
|
101
|
+
tool_name: str,
|
|
102
|
+
args: Dict[str, Any],
|
|
103
|
+
context: ToolContext,
|
|
104
|
+
success: bool = True,
|
|
105
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
106
|
+
) -> None:
|
|
107
|
+
"""Save a tool usage pattern."""
|
|
108
|
+
|
|
109
|
+
def _save():
|
|
110
|
+
collection = self._get_collection()
|
|
111
|
+
|
|
112
|
+
memory_id = self._create_memory_id()
|
|
113
|
+
timestamp = datetime.now().isoformat()
|
|
114
|
+
|
|
115
|
+
# ChromaDB only accepts primitive types in metadata
|
|
116
|
+
# Serialize complex objects to JSON strings
|
|
117
|
+
memory_data = {
|
|
118
|
+
"question": question,
|
|
119
|
+
"tool_name": tool_name,
|
|
120
|
+
"args_json": json.dumps(args), # Serialize to JSON string
|
|
121
|
+
"timestamp": timestamp,
|
|
122
|
+
"success": success,
|
|
123
|
+
"metadata_json": json.dumps(metadata or {}), # Serialize metadata too
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
# Use question as document text for embedding
|
|
127
|
+
collection.upsert(
|
|
128
|
+
ids=[memory_id], documents=[question], metadatas=[memory_data]
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
|
|
132
|
+
|
|
133
|
+
async def search_similar_usage(
|
|
134
|
+
self,
|
|
135
|
+
question: str,
|
|
136
|
+
context: ToolContext,
|
|
137
|
+
*,
|
|
138
|
+
limit: int = 10,
|
|
139
|
+
similarity_threshold: float = 0.7,
|
|
140
|
+
tool_name_filter: Optional[str] = None,
|
|
141
|
+
) -> List[ToolMemorySearchResult]:
|
|
142
|
+
"""Search for similar tool usage patterns."""
|
|
143
|
+
|
|
144
|
+
def _search():
|
|
145
|
+
collection = self._get_collection()
|
|
146
|
+
|
|
147
|
+
# Prepare where filter - ChromaDB requires $and for multiple conditions
|
|
148
|
+
if tool_name_filter:
|
|
149
|
+
where_filter = {
|
|
150
|
+
"$and": [{"success": True}, {"tool_name": tool_name_filter}]
|
|
151
|
+
}
|
|
152
|
+
else:
|
|
153
|
+
where_filter = {"success": True}
|
|
154
|
+
|
|
155
|
+
results = collection.query(
|
|
156
|
+
query_texts=[question], n_results=limit, where=where_filter
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
search_results = []
|
|
160
|
+
if results["ids"] and len(results["ids"][0]) > 0:
|
|
161
|
+
for i, (id_, distance, metadata) in enumerate(
|
|
162
|
+
zip(
|
|
163
|
+
results["ids"][0],
|
|
164
|
+
results["distances"][0],
|
|
165
|
+
results["metadatas"][0],
|
|
166
|
+
)
|
|
167
|
+
):
|
|
168
|
+
# Convert distance to similarity score (ChromaDB uses L2 distance)
|
|
169
|
+
similarity_score = max(0, 1 - distance)
|
|
170
|
+
|
|
171
|
+
if similarity_score >= similarity_threshold:
|
|
172
|
+
# Deserialize JSON fields
|
|
173
|
+
args = json.loads(metadata.get("args_json", "{}"))
|
|
174
|
+
metadata_dict = json.loads(metadata.get("metadata_json", "{}"))
|
|
175
|
+
|
|
176
|
+
# Use the ChromaDB document ID as the memory ID
|
|
177
|
+
memory = ToolMemory(
|
|
178
|
+
memory_id=id_,
|
|
179
|
+
question=metadata["question"],
|
|
180
|
+
tool_name=metadata["tool_name"],
|
|
181
|
+
args=args,
|
|
182
|
+
timestamp=metadata.get("timestamp"),
|
|
183
|
+
success=metadata.get("success", True),
|
|
184
|
+
metadata=metadata_dict,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
search_results.append(
|
|
188
|
+
ToolMemorySearchResult(
|
|
189
|
+
memory=memory,
|
|
190
|
+
similarity_score=similarity_score,
|
|
191
|
+
rank=i + 1,
|
|
192
|
+
)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
return search_results
|
|
196
|
+
|
|
197
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
|
|
198
|
+
|
|
199
|
+
async def get_recent_memories(
|
|
200
|
+
self, context: ToolContext, limit: int = 10
|
|
201
|
+
) -> List[ToolMemory]:
|
|
202
|
+
"""Get recently added memories. Returns most recent memories first."""
|
|
203
|
+
|
|
204
|
+
def _get_recent():
|
|
205
|
+
collection = self._get_collection()
|
|
206
|
+
|
|
207
|
+
# Get all memories and sort by timestamp
|
|
208
|
+
results = collection.get()
|
|
209
|
+
|
|
210
|
+
if not results["metadatas"] or not results["ids"]:
|
|
211
|
+
return []
|
|
212
|
+
|
|
213
|
+
# Parse and sort by timestamp
|
|
214
|
+
memories_with_time = []
|
|
215
|
+
for i, (doc_id, metadata) in enumerate(
|
|
216
|
+
zip(results["ids"], results["metadatas"])
|
|
217
|
+
):
|
|
218
|
+
args = json.loads(metadata.get("args_json", "{}"))
|
|
219
|
+
metadata_dict = json.loads(metadata.get("metadata_json", "{}"))
|
|
220
|
+
|
|
221
|
+
# Use the ChromaDB document ID as the memory ID
|
|
222
|
+
memory = ToolMemory(
|
|
223
|
+
memory_id=doc_id,
|
|
224
|
+
question=metadata["question"],
|
|
225
|
+
tool_name=metadata["tool_name"],
|
|
226
|
+
args=args,
|
|
227
|
+
timestamp=metadata.get("timestamp"),
|
|
228
|
+
success=metadata.get("success", True),
|
|
229
|
+
metadata=metadata_dict,
|
|
230
|
+
)
|
|
231
|
+
memories_with_time.append((memory, metadata.get("timestamp", "")))
|
|
232
|
+
|
|
233
|
+
# Sort by timestamp descending (most recent first)
|
|
234
|
+
memories_with_time.sort(key=lambda x: x[1], reverse=True)
|
|
235
|
+
|
|
236
|
+
# Return only the memory objects, limited to the requested amount
|
|
237
|
+
return [m[0] for m in memories_with_time[:limit]]
|
|
238
|
+
|
|
239
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
240
|
+
self._executor, _get_recent
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
|
|
244
|
+
"""Delete a memory by its ID. Returns True if deleted, False if not found."""
|
|
245
|
+
|
|
246
|
+
def _delete():
|
|
247
|
+
collection = self._get_collection()
|
|
248
|
+
|
|
249
|
+
# Check if the ID exists
|
|
250
|
+
try:
|
|
251
|
+
results = collection.get(ids=[memory_id])
|
|
252
|
+
if results["ids"] and len(results["ids"]) > 0:
|
|
253
|
+
collection.delete(ids=[memory_id])
|
|
254
|
+
return True
|
|
255
|
+
return False
|
|
256
|
+
except Exception:
|
|
257
|
+
return False
|
|
258
|
+
|
|
259
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
|
|
260
|
+
|
|
261
|
+
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
|
|
262
|
+
"""Save a text memory."""
|
|
263
|
+
|
|
264
|
+
def _save():
|
|
265
|
+
collection = self._get_collection()
|
|
266
|
+
|
|
267
|
+
memory_id = self._create_memory_id()
|
|
268
|
+
timestamp = datetime.now().isoformat()
|
|
269
|
+
|
|
270
|
+
memory_data = {
|
|
271
|
+
"content": content,
|
|
272
|
+
"timestamp": timestamp,
|
|
273
|
+
"is_text_memory": True,
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
collection.upsert(
|
|
277
|
+
ids=[memory_id], documents=[content], metadatas=[memory_data]
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
|
|
281
|
+
|
|
282
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
|
|
283
|
+
|
|
284
|
+
async def search_text_memories(
|
|
285
|
+
self,
|
|
286
|
+
query: str,
|
|
287
|
+
context: ToolContext,
|
|
288
|
+
*,
|
|
289
|
+
limit: int = 10,
|
|
290
|
+
similarity_threshold: float = 0.7,
|
|
291
|
+
) -> List[TextMemorySearchResult]:
|
|
292
|
+
"""Search for similar text memories."""
|
|
293
|
+
|
|
294
|
+
def _search():
|
|
295
|
+
collection = self._get_collection()
|
|
296
|
+
|
|
297
|
+
where_filter = {"is_text_memory": True}
|
|
298
|
+
|
|
299
|
+
results = collection.query(
|
|
300
|
+
query_texts=[query], n_results=limit, where=where_filter
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
search_results = []
|
|
304
|
+
if results["ids"] and len(results["ids"][0]) > 0:
|
|
305
|
+
for i, (id_, distance, metadata) in enumerate(
|
|
306
|
+
zip(
|
|
307
|
+
results["ids"][0],
|
|
308
|
+
results["distances"][0],
|
|
309
|
+
results["metadatas"][0],
|
|
310
|
+
)
|
|
311
|
+
):
|
|
312
|
+
similarity_score = max(0, 1 - distance)
|
|
313
|
+
|
|
314
|
+
if similarity_score >= similarity_threshold:
|
|
315
|
+
memory = TextMemory(
|
|
316
|
+
memory_id=id_,
|
|
317
|
+
content=metadata.get("content", ""),
|
|
318
|
+
timestamp=metadata.get("timestamp"),
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
search_results.append(
|
|
322
|
+
TextMemorySearchResult(
|
|
323
|
+
memory=memory,
|
|
324
|
+
similarity_score=similarity_score,
|
|
325
|
+
rank=i + 1,
|
|
326
|
+
)
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
return search_results
|
|
330
|
+
|
|
331
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
|
|
332
|
+
|
|
333
|
+
async def get_recent_text_memories(
|
|
334
|
+
self, context: ToolContext, limit: int = 10
|
|
335
|
+
) -> List[TextMemory]:
|
|
336
|
+
"""Get recently added text memories."""
|
|
337
|
+
|
|
338
|
+
def _get_recent():
|
|
339
|
+
collection = self._get_collection()
|
|
340
|
+
|
|
341
|
+
results = collection.get(where={"is_text_memory": True})
|
|
342
|
+
|
|
343
|
+
if not results["metadatas"] or not results["ids"]:
|
|
344
|
+
return []
|
|
345
|
+
|
|
346
|
+
memories_with_time = []
|
|
347
|
+
for doc_id, metadata in zip(results["ids"], results["metadatas"]):
|
|
348
|
+
memory = TextMemory(
|
|
349
|
+
memory_id=doc_id,
|
|
350
|
+
content=metadata.get("content", ""),
|
|
351
|
+
timestamp=metadata.get("timestamp"),
|
|
352
|
+
)
|
|
353
|
+
memories_with_time.append((memory, metadata.get("timestamp", "")))
|
|
354
|
+
|
|
355
|
+
memories_with_time.sort(key=lambda x: x[1], reverse=True)
|
|
356
|
+
|
|
357
|
+
return [m[0] for m in memories_with_time[:limit]]
|
|
358
|
+
|
|
359
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
360
|
+
self._executor, _get_recent
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
|
|
364
|
+
"""Delete a text memory by its ID."""
|
|
365
|
+
|
|
366
|
+
def _delete():
|
|
367
|
+
collection = self._get_collection()
|
|
368
|
+
|
|
369
|
+
try:
|
|
370
|
+
results = collection.get(ids=[memory_id])
|
|
371
|
+
if results["ids"] and len(results["ids"]) > 0:
|
|
372
|
+
collection.delete(ids=[memory_id])
|
|
373
|
+
return True
|
|
374
|
+
return False
|
|
375
|
+
except Exception:
|
|
376
|
+
return False
|
|
377
|
+
|
|
378
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
|
|
379
|
+
|
|
380
|
+
async def clear_memories(
|
|
381
|
+
self,
|
|
382
|
+
context: ToolContext,
|
|
383
|
+
tool_name: Optional[str] = None,
|
|
384
|
+
before_date: Optional[str] = None,
|
|
385
|
+
) -> int:
|
|
386
|
+
"""Clear stored memories."""
|
|
387
|
+
|
|
388
|
+
def _clear():
|
|
389
|
+
collection = self._get_collection()
|
|
390
|
+
|
|
391
|
+
# Build where filter
|
|
392
|
+
where_filter = {}
|
|
393
|
+
if tool_name:
|
|
394
|
+
where_filter["tool_name"] = tool_name
|
|
395
|
+
|
|
396
|
+
# Get memories to delete
|
|
397
|
+
results = collection.get(where=where_filter if where_filter else None)
|
|
398
|
+
|
|
399
|
+
if not results["ids"]:
|
|
400
|
+
return 0
|
|
401
|
+
|
|
402
|
+
ids_to_delete = []
|
|
403
|
+
for i, metadata in enumerate(results["metadatas"]):
|
|
404
|
+
if before_date:
|
|
405
|
+
memory_date = metadata.get("timestamp", "")
|
|
406
|
+
if memory_date and memory_date < before_date:
|
|
407
|
+
ids_to_delete.append(results["ids"][i])
|
|
408
|
+
else:
|
|
409
|
+
ids_to_delete.append(results["ids"][i])
|
|
410
|
+
|
|
411
|
+
if ids_to_delete:
|
|
412
|
+
collection.delete(ids=ids_to_delete)
|
|
413
|
+
|
|
414
|
+
return len(ids_to_delete)
|
|
415
|
+
|
|
416
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""ClickHouse implementation of SqlRunner interface."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
|
|
7
|
+
from vanna.core.tool import ToolContext
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ClickHouseRunner(SqlRunner):
|
|
11
|
+
"""ClickHouse implementation of the SqlRunner interface."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
host: str,
|
|
16
|
+
database: str,
|
|
17
|
+
user: str,
|
|
18
|
+
password: str,
|
|
19
|
+
port: int = 8123,
|
|
20
|
+
**kwargs,
|
|
21
|
+
):
|
|
22
|
+
"""Initialize with ClickHouse connection parameters.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
host: Database host address
|
|
26
|
+
database: Database name
|
|
27
|
+
user: Database user
|
|
28
|
+
password: Database password
|
|
29
|
+
port: Database port (default: 8123)
|
|
30
|
+
**kwargs: Additional clickhouse_connect connection parameters
|
|
31
|
+
"""
|
|
32
|
+
try:
|
|
33
|
+
import clickhouse_connect
|
|
34
|
+
|
|
35
|
+
self.clickhouse_connect = clickhouse_connect
|
|
36
|
+
except ImportError as e:
|
|
37
|
+
raise ImportError(
|
|
38
|
+
"clickhouse-connect package is required. "
|
|
39
|
+
"Install with: pip install 'vanna[clickhouse]'"
|
|
40
|
+
) from e
|
|
41
|
+
|
|
42
|
+
self.host = host
|
|
43
|
+
self.port = port
|
|
44
|
+
self.user = user
|
|
45
|
+
self.password = password
|
|
46
|
+
self.database = database
|
|
47
|
+
self.kwargs = kwargs
|
|
48
|
+
|
|
49
|
+
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
|
|
50
|
+
"""Execute SQL query against ClickHouse database and return results as DataFrame.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
args: SQL query arguments
|
|
54
|
+
context: Tool execution context
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
DataFrame with query results
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
Exception: If query execution fails
|
|
61
|
+
"""
|
|
62
|
+
# Connect to the database
|
|
63
|
+
client = self.clickhouse_connect.get_client(
|
|
64
|
+
host=self.host,
|
|
65
|
+
port=self.port,
|
|
66
|
+
username=self.user,
|
|
67
|
+
password=self.password,
|
|
68
|
+
database=self.database,
|
|
69
|
+
**self.kwargs,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
# Execute the query
|
|
74
|
+
result = client.query(args.sql)
|
|
75
|
+
results = result.result_rows
|
|
76
|
+
|
|
77
|
+
# Create a pandas dataframe from the results
|
|
78
|
+
df = pd.DataFrame(results, columns=result.column_names)
|
|
79
|
+
return df
|
|
80
|
+
|
|
81
|
+
finally:
|
|
82
|
+
client.close()
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""DuckDB implementation of SqlRunner interface."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
|
|
7
|
+
from vanna.core.tool import ToolContext
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DuckDBRunner(SqlRunner):
|
|
11
|
+
"""DuckDB implementation of the SqlRunner interface."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self, database_path: str = ":memory:", init_sql: Optional[str] = None, **kwargs
|
|
15
|
+
):
|
|
16
|
+
"""Initialize with DuckDB connection parameters.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
database_path: Path to the DuckDB database file.
|
|
20
|
+
Use ":memory:" for in-memory database (default).
|
|
21
|
+
Use "md:" or "motherduck:" for MotherDuck database.
|
|
22
|
+
init_sql: Optional SQL to run when connecting to the database
|
|
23
|
+
**kwargs: Additional duckdb connection parameters
|
|
24
|
+
"""
|
|
25
|
+
try:
|
|
26
|
+
import duckdb
|
|
27
|
+
|
|
28
|
+
self.duckdb = duckdb
|
|
29
|
+
except ImportError as e:
|
|
30
|
+
raise ImportError(
|
|
31
|
+
"duckdb package is required. Install with: pip install 'vanna[duckdb]'"
|
|
32
|
+
) from e
|
|
33
|
+
|
|
34
|
+
self.database_path = database_path
|
|
35
|
+
self.init_sql = init_sql
|
|
36
|
+
self.kwargs = kwargs
|
|
37
|
+
self._conn = None
|
|
38
|
+
|
|
39
|
+
def _get_connection(self):
|
|
40
|
+
"""Get or create DuckDB connection."""
|
|
41
|
+
if self._conn is None:
|
|
42
|
+
self._conn = self.duckdb.connect(self.database_path, **self.kwargs)
|
|
43
|
+
if self.init_sql:
|
|
44
|
+
self._conn.query(self.init_sql)
|
|
45
|
+
return self._conn
|
|
46
|
+
|
|
47
|
+
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
|
|
48
|
+
"""Execute SQL query against DuckDB database and return results as DataFrame.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
args: SQL query arguments
|
|
52
|
+
context: Tool execution context
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
DataFrame with query results
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
duckdb.Error: If query execution fails
|
|
59
|
+
"""
|
|
60
|
+
conn = self._get_connection()
|
|
61
|
+
|
|
62
|
+
# Execute the query and convert to DataFrame
|
|
63
|
+
df = conn.query(args.sql).to_df()
|
|
64
|
+
|
|
65
|
+
return df
|