vanna 0.7.8__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +461 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +247 -223
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0.dist-info/METADATA +485 -0
- vanna-2.0.0.dist-info/RECORD +289 -0
- vanna-2.0.0.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.8.dist-info/METADATA +0 -408
- vanna-0.7.8.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
- {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,428 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Weaviate vector database implementation of AgentMemory.
|
|
3
|
+
|
|
4
|
+
This implementation uses Weaviate for semantic search and storage of tool usage patterns.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import uuid
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
import asyncio
|
|
12
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import weaviate
|
|
16
|
+
from weaviate.classes.config import (
|
|
17
|
+
Configure,
|
|
18
|
+
Property,
|
|
19
|
+
DataType as WeaviateDataType,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
WEAVIATE_AVAILABLE = True
|
|
23
|
+
except ImportError:
|
|
24
|
+
WEAVIATE_AVAILABLE = False
|
|
25
|
+
|
|
26
|
+
from vanna.capabilities.agent_memory import (
|
|
27
|
+
AgentMemory,
|
|
28
|
+
TextMemory,
|
|
29
|
+
TextMemorySearchResult,
|
|
30
|
+
ToolMemory,
|
|
31
|
+
ToolMemorySearchResult,
|
|
32
|
+
)
|
|
33
|
+
from vanna.core.tool import ToolContext
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class WeaviateAgentMemory(AgentMemory):
|
|
37
|
+
"""Weaviate-based implementation of AgentMemory."""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
collection_name: str = "ToolMemory",
|
|
42
|
+
url: str = "http://localhost:8080",
|
|
43
|
+
api_key: Optional[str] = None,
|
|
44
|
+
dimension: int = 384,
|
|
45
|
+
):
|
|
46
|
+
if not WEAVIATE_AVAILABLE:
|
|
47
|
+
raise ImportError(
|
|
48
|
+
"Weaviate is required for WeaviateAgentMemory. Install with: pip install weaviate-client"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
self.collection_name = collection_name
|
|
52
|
+
self.url = url
|
|
53
|
+
self.api_key = api_key
|
|
54
|
+
self.dimension = dimension
|
|
55
|
+
self._client = None
|
|
56
|
+
self._executor = ThreadPoolExecutor(max_workers=2)
|
|
57
|
+
|
|
58
|
+
def _get_client(self):
|
|
59
|
+
"""Get or create Weaviate client."""
|
|
60
|
+
if self._client is None:
|
|
61
|
+
if self.api_key:
|
|
62
|
+
self._client = weaviate.connect_to_weaviate_cloud(
|
|
63
|
+
cluster_url=self.url,
|
|
64
|
+
auth_credentials=weaviate.auth.AuthApiKey(self.api_key),
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
self._client = weaviate.connect_to_local(
|
|
68
|
+
host=self.url.replace("http://", "").replace("https://", "")
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Create collection if it doesn't exist
|
|
72
|
+
if not self._client.collections.exists(self.collection_name):
|
|
73
|
+
self._client.collections.create(
|
|
74
|
+
name=self.collection_name,
|
|
75
|
+
vectorizer_config=Configure.Vectorizer.none(),
|
|
76
|
+
properties=[
|
|
77
|
+
Property(name="question", data_type=WeaviateDataType.TEXT),
|
|
78
|
+
Property(name="tool_name", data_type=WeaviateDataType.TEXT),
|
|
79
|
+
Property(name="args_json", data_type=WeaviateDataType.TEXT),
|
|
80
|
+
Property(name="timestamp", data_type=WeaviateDataType.TEXT),
|
|
81
|
+
Property(name="success", data_type=WeaviateDataType.BOOL),
|
|
82
|
+
Property(name="metadata_json", data_type=WeaviateDataType.TEXT),
|
|
83
|
+
],
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return self._client
|
|
87
|
+
|
|
88
|
+
def _create_embedding(self, text: str) -> List[float]:
|
|
89
|
+
"""Create a simple embedding from text (placeholder)."""
|
|
90
|
+
import hashlib
|
|
91
|
+
|
|
92
|
+
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
|
|
93
|
+
return [(hash_val >> i) % 100 / 100.0 for i in range(self.dimension)]
|
|
94
|
+
|
|
95
|
+
async def save_tool_usage(
|
|
96
|
+
self,
|
|
97
|
+
question: str,
|
|
98
|
+
tool_name: str,
|
|
99
|
+
args: Dict[str, Any],
|
|
100
|
+
context: ToolContext,
|
|
101
|
+
success: bool = True,
|
|
102
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
103
|
+
) -> None:
|
|
104
|
+
"""Save a tool usage pattern."""
|
|
105
|
+
|
|
106
|
+
def _save():
|
|
107
|
+
client = self._get_client()
|
|
108
|
+
collection = client.collections.get(self.collection_name)
|
|
109
|
+
|
|
110
|
+
memory_id = str(uuid.uuid4())
|
|
111
|
+
timestamp = datetime.now().isoformat()
|
|
112
|
+
embedding = self._create_embedding(question)
|
|
113
|
+
|
|
114
|
+
properties = {
|
|
115
|
+
"question": question,
|
|
116
|
+
"tool_name": tool_name,
|
|
117
|
+
"args_json": json.dumps(args),
|
|
118
|
+
"timestamp": timestamp,
|
|
119
|
+
"success": success,
|
|
120
|
+
"metadata_json": json.dumps(metadata or {}),
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
collection.data.insert(
|
|
124
|
+
properties=properties, vector=embedding, uuid=memory_id
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
|
|
128
|
+
|
|
129
|
+
async def search_similar_usage(
|
|
130
|
+
self,
|
|
131
|
+
question: str,
|
|
132
|
+
context: ToolContext,
|
|
133
|
+
*,
|
|
134
|
+
limit: int = 10,
|
|
135
|
+
similarity_threshold: float = 0.7,
|
|
136
|
+
tool_name_filter: Optional[str] = None,
|
|
137
|
+
) -> List[ToolMemorySearchResult]:
|
|
138
|
+
"""Search for similar tool usage patterns."""
|
|
139
|
+
|
|
140
|
+
def _search():
|
|
141
|
+
client = self._get_client()
|
|
142
|
+
collection = client.collections.get(self.collection_name)
|
|
143
|
+
|
|
144
|
+
embedding = self._create_embedding(question)
|
|
145
|
+
|
|
146
|
+
# Build filter
|
|
147
|
+
filters = weaviate.classes.query.Filter.by_property("success").equal(True)
|
|
148
|
+
if tool_name_filter:
|
|
149
|
+
filters = filters & weaviate.classes.query.Filter.by_property(
|
|
150
|
+
"tool_name"
|
|
151
|
+
).equal(tool_name_filter)
|
|
152
|
+
|
|
153
|
+
response = collection.query.near_vector(
|
|
154
|
+
near_vector=embedding,
|
|
155
|
+
limit=limit,
|
|
156
|
+
filters=filters,
|
|
157
|
+
return_metadata=weaviate.classes.query.MetadataQuery(distance=True),
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
search_results = []
|
|
161
|
+
for i, obj in enumerate(response.objects):
|
|
162
|
+
# Weaviate returns distance, convert to similarity
|
|
163
|
+
distance = obj.metadata.distance if obj.metadata else 1.0
|
|
164
|
+
similarity_score = 1 - distance
|
|
165
|
+
|
|
166
|
+
if similarity_score >= similarity_threshold:
|
|
167
|
+
properties = obj.properties
|
|
168
|
+
args = json.loads(properties.get("args_json", "{}"))
|
|
169
|
+
metadata_dict = json.loads(properties.get("metadata_json", "{}"))
|
|
170
|
+
|
|
171
|
+
memory = ToolMemory(
|
|
172
|
+
memory_id=str(obj.uuid),
|
|
173
|
+
question=properties.get("question"),
|
|
174
|
+
tool_name=properties.get("tool_name"),
|
|
175
|
+
args=args,
|
|
176
|
+
timestamp=properties.get("timestamp"),
|
|
177
|
+
success=properties.get("success", True),
|
|
178
|
+
metadata=metadata_dict,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
search_results.append(
|
|
182
|
+
ToolMemorySearchResult(
|
|
183
|
+
memory=memory, similarity_score=similarity_score, rank=i + 1
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return search_results
|
|
188
|
+
|
|
189
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
|
|
190
|
+
|
|
191
|
+
async def get_recent_memories(
|
|
192
|
+
self, context: ToolContext, limit: int = 10
|
|
193
|
+
) -> List[ToolMemory]:
|
|
194
|
+
"""Get recently added memories."""
|
|
195
|
+
|
|
196
|
+
def _get_recent():
|
|
197
|
+
client = self._get_client()
|
|
198
|
+
collection = client.collections.get(self.collection_name)
|
|
199
|
+
|
|
200
|
+
# Query and sort by timestamp
|
|
201
|
+
response = collection.query.fetch_objects(limit=1000)
|
|
202
|
+
|
|
203
|
+
# Convert to list and sort
|
|
204
|
+
objects_list = list(response.objects)
|
|
205
|
+
sorted_objects = sorted(
|
|
206
|
+
objects_list,
|
|
207
|
+
key=lambda o: o.properties.get("timestamp", ""),
|
|
208
|
+
reverse=True,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
memories = []
|
|
212
|
+
for obj in sorted_objects[:limit]:
|
|
213
|
+
properties = obj.properties
|
|
214
|
+
args = json.loads(properties.get("args_json", "{}"))
|
|
215
|
+
metadata_dict = json.loads(properties.get("metadata_json", "{}"))
|
|
216
|
+
|
|
217
|
+
memory = ToolMemory(
|
|
218
|
+
memory_id=str(obj.uuid),
|
|
219
|
+
question=properties.get("question"),
|
|
220
|
+
tool_name=properties.get("tool_name"),
|
|
221
|
+
args=args,
|
|
222
|
+
timestamp=properties.get("timestamp"),
|
|
223
|
+
success=properties.get("success", True),
|
|
224
|
+
metadata=metadata_dict,
|
|
225
|
+
)
|
|
226
|
+
memories.append(memory)
|
|
227
|
+
|
|
228
|
+
return memories
|
|
229
|
+
|
|
230
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
231
|
+
self._executor, _get_recent
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
|
|
235
|
+
"""Delete a memory by its ID."""
|
|
236
|
+
|
|
237
|
+
def _delete():
|
|
238
|
+
client = self._get_client()
|
|
239
|
+
collection = client.collections.get(self.collection_name)
|
|
240
|
+
|
|
241
|
+
try:
|
|
242
|
+
collection.data.delete_by_id(uuid=memory_id)
|
|
243
|
+
return True
|
|
244
|
+
except Exception:
|
|
245
|
+
return False
|
|
246
|
+
|
|
247
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
|
|
248
|
+
|
|
249
|
+
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
|
|
250
|
+
"""Save a text memory."""
|
|
251
|
+
|
|
252
|
+
def _save():
|
|
253
|
+
client = self._get_client()
|
|
254
|
+
collection = client.collections.get(self.collection_name)
|
|
255
|
+
|
|
256
|
+
memory_id = str(uuid.uuid4())
|
|
257
|
+
timestamp = datetime.now().isoformat()
|
|
258
|
+
embedding = self._create_embedding(content)
|
|
259
|
+
|
|
260
|
+
properties = {
|
|
261
|
+
"question": content, # Using question field for content
|
|
262
|
+
"tool_name": "", # Empty for text memories
|
|
263
|
+
"args_json": "",
|
|
264
|
+
"timestamp": timestamp,
|
|
265
|
+
"success": True,
|
|
266
|
+
"metadata_json": json.dumps({"is_text_memory": True}),
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
collection.data.insert(
|
|
270
|
+
properties=properties, vector=embedding, uuid=memory_id
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
|
|
274
|
+
|
|
275
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
|
|
276
|
+
|
|
277
|
+
async def search_text_memories(
|
|
278
|
+
self,
|
|
279
|
+
query: str,
|
|
280
|
+
context: ToolContext,
|
|
281
|
+
*,
|
|
282
|
+
limit: int = 10,
|
|
283
|
+
similarity_threshold: float = 0.7,
|
|
284
|
+
) -> List[TextMemorySearchResult]:
|
|
285
|
+
"""Search for similar text memories."""
|
|
286
|
+
|
|
287
|
+
def _search():
|
|
288
|
+
client = self._get_client()
|
|
289
|
+
collection = client.collections.get(self.collection_name)
|
|
290
|
+
|
|
291
|
+
embedding = self._create_embedding(query)
|
|
292
|
+
|
|
293
|
+
# Build filter for text memories (empty tool_name)
|
|
294
|
+
filters = weaviate.classes.query.Filter.by_property("tool_name").equal("")
|
|
295
|
+
|
|
296
|
+
response = collection.query.near_vector(
|
|
297
|
+
near_vector=embedding,
|
|
298
|
+
limit=limit,
|
|
299
|
+
filters=filters,
|
|
300
|
+
return_metadata=weaviate.classes.query.MetadataQuery(distance=True),
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
search_results = []
|
|
304
|
+
for i, obj in enumerate(response.objects):
|
|
305
|
+
distance = obj.metadata.distance if obj.metadata else 1.0
|
|
306
|
+
similarity_score = 1 - distance
|
|
307
|
+
|
|
308
|
+
if similarity_score >= similarity_threshold:
|
|
309
|
+
properties = obj.properties
|
|
310
|
+
content = properties.get("question", "")
|
|
311
|
+
|
|
312
|
+
memory = TextMemory(
|
|
313
|
+
memory_id=str(obj.uuid),
|
|
314
|
+
content=content,
|
|
315
|
+
timestamp=properties.get("timestamp"),
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
search_results.append(
|
|
319
|
+
TextMemorySearchResult(
|
|
320
|
+
memory=memory, similarity_score=similarity_score, rank=i + 1
|
|
321
|
+
)
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
return search_results
|
|
325
|
+
|
|
326
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
|
|
327
|
+
|
|
328
|
+
async def get_recent_text_memories(
|
|
329
|
+
self, context: ToolContext, limit: int = 10
|
|
330
|
+
) -> List[TextMemory]:
|
|
331
|
+
"""Get recently added text memories."""
|
|
332
|
+
|
|
333
|
+
def _get_recent():
|
|
334
|
+
client = self._get_client()
|
|
335
|
+
collection = client.collections.get(self.collection_name)
|
|
336
|
+
|
|
337
|
+
# Query text memories (empty tool_name) and sort by timestamp
|
|
338
|
+
response = collection.query.fetch_objects(
|
|
339
|
+
filters=weaviate.classes.query.Filter.by_property("tool_name").equal(
|
|
340
|
+
""
|
|
341
|
+
),
|
|
342
|
+
limit=1000,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Convert to list and sort
|
|
346
|
+
objects_list = list(response.objects)
|
|
347
|
+
sorted_objects = sorted(
|
|
348
|
+
objects_list,
|
|
349
|
+
key=lambda o: o.properties.get("timestamp", ""),
|
|
350
|
+
reverse=True,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
memories = []
|
|
354
|
+
for obj in sorted_objects[:limit]:
|
|
355
|
+
properties = obj.properties
|
|
356
|
+
content = properties.get("question", "")
|
|
357
|
+
|
|
358
|
+
memory = TextMemory(
|
|
359
|
+
memory_id=str(obj.uuid),
|
|
360
|
+
content=content,
|
|
361
|
+
timestamp=properties.get("timestamp"),
|
|
362
|
+
)
|
|
363
|
+
memories.append(memory)
|
|
364
|
+
|
|
365
|
+
return memories
|
|
366
|
+
|
|
367
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
368
|
+
self._executor, _get_recent
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
|
|
372
|
+
"""Delete a text memory by its ID."""
|
|
373
|
+
|
|
374
|
+
def _delete():
|
|
375
|
+
client = self._get_client()
|
|
376
|
+
collection = client.collections.get(self.collection_name)
|
|
377
|
+
|
|
378
|
+
try:
|
|
379
|
+
collection.data.delete_by_id(uuid=memory_id)
|
|
380
|
+
return True
|
|
381
|
+
except Exception:
|
|
382
|
+
return False
|
|
383
|
+
|
|
384
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
|
|
385
|
+
|
|
386
|
+
async def clear_memories(
|
|
387
|
+
self,
|
|
388
|
+
context: ToolContext,
|
|
389
|
+
tool_name: Optional[str] = None,
|
|
390
|
+
before_date: Optional[str] = None,
|
|
391
|
+
) -> int:
|
|
392
|
+
"""Clear stored memories."""
|
|
393
|
+
|
|
394
|
+
def _clear():
|
|
395
|
+
client = self._get_client()
|
|
396
|
+
collection = client.collections.get(self.collection_name)
|
|
397
|
+
|
|
398
|
+
# Build filter
|
|
399
|
+
if tool_name and before_date:
|
|
400
|
+
filters = weaviate.classes.query.Filter.by_property("tool_name").equal(
|
|
401
|
+
tool_name
|
|
402
|
+
) & weaviate.classes.query.Filter.by_property("timestamp").less_than(
|
|
403
|
+
before_date
|
|
404
|
+
)
|
|
405
|
+
elif tool_name:
|
|
406
|
+
filters = weaviate.classes.query.Filter.by_property("tool_name").equal(
|
|
407
|
+
tool_name
|
|
408
|
+
)
|
|
409
|
+
elif before_date:
|
|
410
|
+
filters = weaviate.classes.query.Filter.by_property(
|
|
411
|
+
"timestamp"
|
|
412
|
+
).less_than(before_date)
|
|
413
|
+
else:
|
|
414
|
+
filters = None
|
|
415
|
+
|
|
416
|
+
if filters:
|
|
417
|
+
collection.data.delete_many(where=filters)
|
|
418
|
+
else:
|
|
419
|
+
# Delete all
|
|
420
|
+
collection.data.delete_many(
|
|
421
|
+
where=weaviate.classes.query.Filter.by_property(
|
|
422
|
+
"success"
|
|
423
|
+
).contains_any([True, False])
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
return 0
|
|
427
|
+
|
|
428
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
|
|
@@ -3,6 +3,7 @@ from zhipuai import ZhipuAI
|
|
|
3
3
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
|
4
4
|
from ..base import VannaBase
|
|
5
5
|
|
|
6
|
+
|
|
6
7
|
class ZhipuAI_Embeddings(VannaBase):
|
|
7
8
|
"""
|
|
8
9
|
[future functionality] This function is used to generate embeddings from ZhipuAI.
|
|
@@ -10,6 +11,7 @@ class ZhipuAI_Embeddings(VannaBase):
|
|
|
10
11
|
Args:
|
|
11
12
|
VannaBase (_type_): _description_
|
|
12
13
|
"""
|
|
14
|
+
|
|
13
15
|
def __init__(self, config=None):
|
|
14
16
|
VannaBase.__init__(self, config=config)
|
|
15
17
|
if "api_key" not in config:
|
|
@@ -18,39 +20,38 @@ class ZhipuAI_Embeddings(VannaBase):
|
|
|
18
20
|
self.client = ZhipuAI(api_key=self.api_key)
|
|
19
21
|
|
|
20
22
|
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
21
|
-
|
|
22
23
|
embedding = self.client.embeddings.create(
|
|
23
24
|
model="embedding-2",
|
|
24
25
|
input=data,
|
|
25
26
|
)
|
|
26
27
|
|
|
27
28
|
return embedding.data[0].embedding
|
|
28
|
-
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
class ZhipuAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
|
32
32
|
"""
|
|
33
33
|
A embeddingFunction that uses ZhipuAI to generate embeddings which can use in chromadb.
|
|
34
|
-
usage:
|
|
34
|
+
usage:
|
|
35
35
|
class MyVanna(ChromaDB_VectorStore, ZhipuAI_Chat):
|
|
36
36
|
def __init__(self, config=None):
|
|
37
37
|
ChromaDB_VectorStore.__init__(self, config=config)
|
|
38
38
|
ZhipuAI_Chat.__init__(self, config=config)
|
|
39
|
-
|
|
39
|
+
|
|
40
40
|
config={'api_key': 'xxx'}
|
|
41
41
|
zhipu_embedding_function = ZhipuAIEmbeddingFunction(config=config)
|
|
42
42
|
config = {"api_key": "xxx", "model": "glm-4","path":"xy","embedding_function":zhipu_embedding_function}
|
|
43
|
-
|
|
43
|
+
|
|
44
44
|
vn = MyVanna(config)
|
|
45
|
-
|
|
45
|
+
|
|
46
46
|
"""
|
|
47
|
+
|
|
47
48
|
def __init__(self, config=None):
|
|
48
49
|
if config is None or "api_key" not in config:
|
|
49
50
|
raise ValueError("Missing 'api_key' in config")
|
|
50
|
-
|
|
51
|
+
|
|
51
52
|
self.api_key = config["api_key"]
|
|
52
53
|
self.model_name = config.get("model_name", "embedding-2")
|
|
53
|
-
|
|
54
|
+
|
|
54
55
|
try:
|
|
55
56
|
self.client = ZhipuAI(api_key=self.api_key)
|
|
56
57
|
except Exception as e:
|
|
@@ -66,8 +67,7 @@ class ZhipuAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
|
|
66
67
|
for document in input:
|
|
67
68
|
try:
|
|
68
69
|
response = self.client.embeddings.create(
|
|
69
|
-
model=self.model_name,
|
|
70
|
-
input=document
|
|
70
|
+
model=self.model_name, input=document
|
|
71
71
|
)
|
|
72
72
|
# print(response)
|
|
73
73
|
embedding = response.data[0].embedding
|
|
@@ -76,4 +76,4 @@ class ZhipuAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
|
|
76
76
|
except Exception as e:
|
|
77
77
|
raise ValueError(f"Error generating embedding for document: {e}")
|
|
78
78
|
|
|
79
|
-
return all_embeddings
|
|
79
|
+
return all_embeddings
|