vanna 0.7.8__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +461 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +247 -223
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0.dist-info/METADATA +485 -0
- vanna-2.0.0.dist-info/RECORD +289 -0
- vanna-2.0.0.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.8.dist-info/METADATA +0 -408
- vanna-0.7.8.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
- {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FAISS vector database implementation of AgentMemory.
|
|
3
|
+
|
|
4
|
+
This implementation uses FAISS for local vector storage of tool usage patterns.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import uuid
|
|
9
|
+
import pickle
|
|
10
|
+
import os
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from typing import Any, Dict, List, Optional
|
|
13
|
+
import asyncio
|
|
14
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import faiss
|
|
19
|
+
|
|
20
|
+
FAISS_AVAILABLE = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
FAISS_AVAILABLE = False
|
|
23
|
+
|
|
24
|
+
from vanna.capabilities.agent_memory import (
|
|
25
|
+
AgentMemory,
|
|
26
|
+
TextMemory,
|
|
27
|
+
TextMemorySearchResult,
|
|
28
|
+
ToolMemory,
|
|
29
|
+
ToolMemorySearchResult,
|
|
30
|
+
)
|
|
31
|
+
from vanna.core.tool import ToolContext
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class FAISSAgentMemory(AgentMemory):
|
|
35
|
+
"""FAISS-based implementation of AgentMemory."""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
index_path: Optional[str] = None,
|
|
40
|
+
persist_path: Optional[str] = None,
|
|
41
|
+
dimension: int = 384,
|
|
42
|
+
metric: str = "cosine",
|
|
43
|
+
):
|
|
44
|
+
if not FAISS_AVAILABLE:
|
|
45
|
+
raise ImportError(
|
|
46
|
+
"FAISS is required for FAISSAgentMemory. Install with: pip install faiss-cpu"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Accept either index_path or persist_path for backward compatibility
|
|
50
|
+
self.index_path = persist_path or index_path or "./faiss_index"
|
|
51
|
+
self.dimension = dimension
|
|
52
|
+
self.metric = metric
|
|
53
|
+
self._index = None
|
|
54
|
+
self._metadata = {}
|
|
55
|
+
self._executor = ThreadPoolExecutor(max_workers=2)
|
|
56
|
+
self._load_index()
|
|
57
|
+
|
|
58
|
+
def _load_index(self):
|
|
59
|
+
"""Load or create FAISS index."""
|
|
60
|
+
index_file = os.path.join(self.index_path, "index.faiss")
|
|
61
|
+
metadata_file = os.path.join(self.index_path, "metadata.pkl")
|
|
62
|
+
|
|
63
|
+
if os.path.exists(index_file) and os.path.exists(metadata_file):
|
|
64
|
+
# Load existing index
|
|
65
|
+
self._index = faiss.read_index(index_file)
|
|
66
|
+
with open(metadata_file, "rb") as f:
|
|
67
|
+
self._metadata = pickle.load(f)
|
|
68
|
+
else:
|
|
69
|
+
# Create new index
|
|
70
|
+
os.makedirs(self.index_path, exist_ok=True)
|
|
71
|
+
if self.metric == "cosine":
|
|
72
|
+
self._index = faiss.IndexFlatIP(self.dimension)
|
|
73
|
+
else:
|
|
74
|
+
self._index = faiss.IndexFlatL2(self.dimension)
|
|
75
|
+
self._metadata = {}
|
|
76
|
+
|
|
77
|
+
def _save_index(self):
|
|
78
|
+
"""Save FAISS index to disk."""
|
|
79
|
+
index_file = os.path.join(self.index_path, "index.faiss")
|
|
80
|
+
metadata_file = os.path.join(self.index_path, "metadata.pkl")
|
|
81
|
+
|
|
82
|
+
faiss.write_index(self._index, index_file)
|
|
83
|
+
with open(metadata_file, "wb") as f:
|
|
84
|
+
pickle.dump(self._metadata, f)
|
|
85
|
+
|
|
86
|
+
def _create_embedding(self, text: str) -> np.ndarray:
|
|
87
|
+
"""Create a simple embedding from text (placeholder)."""
|
|
88
|
+
import hashlib
|
|
89
|
+
|
|
90
|
+
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
|
|
91
|
+
embedding = np.array(
|
|
92
|
+
[(hash_val >> i) % 100 / 100.0 for i in range(self.dimension)],
|
|
93
|
+
dtype=np.float32,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Normalize for cosine similarity
|
|
97
|
+
if self.metric == "cosine":
|
|
98
|
+
norm = np.linalg.norm(embedding)
|
|
99
|
+
if norm > 0:
|
|
100
|
+
embedding = embedding / norm
|
|
101
|
+
|
|
102
|
+
return embedding
|
|
103
|
+
|
|
104
|
+
async def save_tool_usage(
|
|
105
|
+
self,
|
|
106
|
+
question: str,
|
|
107
|
+
tool_name: str,
|
|
108
|
+
args: Dict[str, Any],
|
|
109
|
+
context: ToolContext,
|
|
110
|
+
success: bool = True,
|
|
111
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
112
|
+
) -> None:
|
|
113
|
+
"""Save a tool usage pattern."""
|
|
114
|
+
|
|
115
|
+
def _save():
|
|
116
|
+
memory_id = str(uuid.uuid4())
|
|
117
|
+
timestamp = datetime.now().isoformat()
|
|
118
|
+
embedding = self._create_embedding(question)
|
|
119
|
+
|
|
120
|
+
# Add to FAISS index
|
|
121
|
+
self._index.add(np.array([embedding]))
|
|
122
|
+
|
|
123
|
+
# Store metadata
|
|
124
|
+
idx = self._index.ntotal - 1
|
|
125
|
+
self._metadata[idx] = {
|
|
126
|
+
"memory_id": memory_id,
|
|
127
|
+
"question": question,
|
|
128
|
+
"tool_name": tool_name,
|
|
129
|
+
"args": args,
|
|
130
|
+
"timestamp": timestamp,
|
|
131
|
+
"success": success,
|
|
132
|
+
"metadata": metadata or {},
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
self._save_index()
|
|
136
|
+
|
|
137
|
+
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
|
|
138
|
+
|
|
139
|
+
async def search_similar_usage(
|
|
140
|
+
self,
|
|
141
|
+
question: str,
|
|
142
|
+
context: ToolContext,
|
|
143
|
+
*,
|
|
144
|
+
limit: int = 10,
|
|
145
|
+
similarity_threshold: float = 0.7,
|
|
146
|
+
tool_name_filter: Optional[str] = None,
|
|
147
|
+
) -> List[ToolMemorySearchResult]:
|
|
148
|
+
"""Search for similar tool usage patterns."""
|
|
149
|
+
|
|
150
|
+
def _search():
|
|
151
|
+
embedding = self._create_embedding(question)
|
|
152
|
+
|
|
153
|
+
# Search in FAISS
|
|
154
|
+
k = min(limit * 3, self._index.ntotal) if self._index.ntotal > 0 else 1
|
|
155
|
+
if k == 0:
|
|
156
|
+
return []
|
|
157
|
+
|
|
158
|
+
distances, indices = self._index.search(np.array([embedding]), k)
|
|
159
|
+
|
|
160
|
+
search_results = []
|
|
161
|
+
rank = 1
|
|
162
|
+
for i, (dist, idx) in enumerate(zip(distances[0], indices[0])):
|
|
163
|
+
if idx == -1 or idx not in self._metadata:
|
|
164
|
+
continue
|
|
165
|
+
|
|
166
|
+
metadata = self._metadata[idx]
|
|
167
|
+
|
|
168
|
+
# Filter by success
|
|
169
|
+
if not metadata.get("success", True):
|
|
170
|
+
continue
|
|
171
|
+
|
|
172
|
+
# Filter by tool name
|
|
173
|
+
if tool_name_filter and metadata.get("tool_name") != tool_name_filter:
|
|
174
|
+
continue
|
|
175
|
+
|
|
176
|
+
# Convert distance to similarity score
|
|
177
|
+
if self.metric == "cosine":
|
|
178
|
+
similarity_score = float(dist)
|
|
179
|
+
else:
|
|
180
|
+
similarity_score = 1.0 / (1.0 + float(dist))
|
|
181
|
+
|
|
182
|
+
if similarity_score >= similarity_threshold:
|
|
183
|
+
memory = ToolMemory(
|
|
184
|
+
memory_id=metadata["memory_id"],
|
|
185
|
+
question=metadata["question"],
|
|
186
|
+
tool_name=metadata["tool_name"],
|
|
187
|
+
args=metadata["args"],
|
|
188
|
+
timestamp=metadata.get("timestamp"),
|
|
189
|
+
success=metadata.get("success", True),
|
|
190
|
+
metadata=metadata.get("metadata", {}),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
search_results.append(
|
|
194
|
+
ToolMemorySearchResult(
|
|
195
|
+
memory=memory, similarity_score=similarity_score, rank=rank
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
rank += 1
|
|
199
|
+
|
|
200
|
+
if len(search_results) >= limit:
|
|
201
|
+
break
|
|
202
|
+
|
|
203
|
+
return search_results
|
|
204
|
+
|
|
205
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
|
|
206
|
+
|
|
207
|
+
async def get_recent_memories(
|
|
208
|
+
self, context: ToolContext, limit: int = 10
|
|
209
|
+
) -> List[ToolMemory]:
|
|
210
|
+
"""Get recently added memories."""
|
|
211
|
+
|
|
212
|
+
def _get_recent():
|
|
213
|
+
# Get all metadata entries and sort by timestamp
|
|
214
|
+
all_entries = list(self._metadata.values())
|
|
215
|
+
sorted_entries = sorted(
|
|
216
|
+
all_entries, key=lambda m: m.get("timestamp", ""), reverse=True
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
memories = []
|
|
220
|
+
for entry in sorted_entries[:limit]:
|
|
221
|
+
memory = ToolMemory(
|
|
222
|
+
memory_id=entry["memory_id"],
|
|
223
|
+
question=entry["question"],
|
|
224
|
+
tool_name=entry["tool_name"],
|
|
225
|
+
args=entry["args"],
|
|
226
|
+
timestamp=entry.get("timestamp"),
|
|
227
|
+
success=entry.get("success", True),
|
|
228
|
+
metadata=entry.get("metadata", {}),
|
|
229
|
+
)
|
|
230
|
+
memories.append(memory)
|
|
231
|
+
|
|
232
|
+
return memories
|
|
233
|
+
|
|
234
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
235
|
+
self._executor, _get_recent
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
|
|
239
|
+
"""Delete a memory by its ID."""
|
|
240
|
+
|
|
241
|
+
def _delete():
|
|
242
|
+
# Find and remove from metadata
|
|
243
|
+
idx_to_remove = None
|
|
244
|
+
for idx, metadata in self._metadata.items():
|
|
245
|
+
if metadata["memory_id"] == memory_id:
|
|
246
|
+
idx_to_remove = idx
|
|
247
|
+
break
|
|
248
|
+
|
|
249
|
+
if idx_to_remove is not None:
|
|
250
|
+
del self._metadata[idx_to_remove]
|
|
251
|
+
self._save_index()
|
|
252
|
+
return True
|
|
253
|
+
|
|
254
|
+
return False
|
|
255
|
+
|
|
256
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
|
|
257
|
+
|
|
258
|
+
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
|
|
259
|
+
"""Save a text memory."""
|
|
260
|
+
|
|
261
|
+
def _save():
|
|
262
|
+
memory_id = str(uuid.uuid4())
|
|
263
|
+
timestamp = datetime.now().isoformat()
|
|
264
|
+
embedding = self._create_embedding(content)
|
|
265
|
+
|
|
266
|
+
# Add to FAISS index
|
|
267
|
+
self._index.add(np.array([embedding]))
|
|
268
|
+
|
|
269
|
+
# Store metadata
|
|
270
|
+
idx = self._index.ntotal - 1
|
|
271
|
+
self._metadata[idx] = {
|
|
272
|
+
"memory_id": memory_id,
|
|
273
|
+
"content": content,
|
|
274
|
+
"timestamp": timestamp,
|
|
275
|
+
"is_text_memory": True,
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
self._save_index()
|
|
279
|
+
|
|
280
|
+
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
|
|
281
|
+
|
|
282
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
|
|
283
|
+
|
|
284
|
+
async def search_text_memories(
|
|
285
|
+
self,
|
|
286
|
+
query: str,
|
|
287
|
+
context: ToolContext,
|
|
288
|
+
*,
|
|
289
|
+
limit: int = 10,
|
|
290
|
+
similarity_threshold: float = 0.7,
|
|
291
|
+
) -> List[TextMemorySearchResult]:
|
|
292
|
+
"""Search for similar text memories."""
|
|
293
|
+
|
|
294
|
+
def _search():
|
|
295
|
+
embedding = self._create_embedding(query)
|
|
296
|
+
|
|
297
|
+
k = min(limit * 3, self._index.ntotal) if self._index.ntotal > 0 else 1
|
|
298
|
+
if k == 0:
|
|
299
|
+
return []
|
|
300
|
+
|
|
301
|
+
distances, indices = self._index.search(np.array([embedding]), k)
|
|
302
|
+
|
|
303
|
+
search_results = []
|
|
304
|
+
rank = 1
|
|
305
|
+
for dist, idx in zip(distances[0], indices[0]):
|
|
306
|
+
if idx == -1 or idx not in self._metadata:
|
|
307
|
+
continue
|
|
308
|
+
|
|
309
|
+
metadata = self._metadata[idx]
|
|
310
|
+
|
|
311
|
+
# Filter for text memories only
|
|
312
|
+
if not metadata.get("is_text_memory", False):
|
|
313
|
+
continue
|
|
314
|
+
|
|
315
|
+
# Convert distance to similarity score
|
|
316
|
+
if self.metric == "cosine":
|
|
317
|
+
similarity_score = float(dist)
|
|
318
|
+
else:
|
|
319
|
+
similarity_score = 1.0 / (1.0 + float(dist))
|
|
320
|
+
|
|
321
|
+
if similarity_score >= similarity_threshold:
|
|
322
|
+
memory = TextMemory(
|
|
323
|
+
memory_id=metadata["memory_id"],
|
|
324
|
+
content=metadata["content"],
|
|
325
|
+
timestamp=metadata.get("timestamp"),
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
search_results.append(
|
|
329
|
+
TextMemorySearchResult(
|
|
330
|
+
memory=memory, similarity_score=similarity_score, rank=rank
|
|
331
|
+
)
|
|
332
|
+
)
|
|
333
|
+
rank += 1
|
|
334
|
+
|
|
335
|
+
if len(search_results) >= limit:
|
|
336
|
+
break
|
|
337
|
+
|
|
338
|
+
return search_results
|
|
339
|
+
|
|
340
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
|
|
341
|
+
|
|
342
|
+
async def get_recent_text_memories(
|
|
343
|
+
self, context: ToolContext, limit: int = 10
|
|
344
|
+
) -> List[TextMemory]:
|
|
345
|
+
"""Get recently added text memories."""
|
|
346
|
+
|
|
347
|
+
def _get_recent():
|
|
348
|
+
# Get all text memory entries and sort by timestamp
|
|
349
|
+
text_entries = [
|
|
350
|
+
entry
|
|
351
|
+
for entry in self._metadata.values()
|
|
352
|
+
if entry.get("is_text_memory", False)
|
|
353
|
+
]
|
|
354
|
+
sorted_entries = sorted(
|
|
355
|
+
text_entries, key=lambda m: m.get("timestamp", ""), reverse=True
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
memories = []
|
|
359
|
+
for entry in sorted_entries[:limit]:
|
|
360
|
+
memory = TextMemory(
|
|
361
|
+
memory_id=entry["memory_id"],
|
|
362
|
+
content=entry["content"],
|
|
363
|
+
timestamp=entry.get("timestamp"),
|
|
364
|
+
)
|
|
365
|
+
memories.append(memory)
|
|
366
|
+
|
|
367
|
+
return memories
|
|
368
|
+
|
|
369
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
370
|
+
self._executor, _get_recent
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
|
|
374
|
+
"""Delete a text memory by its ID."""
|
|
375
|
+
|
|
376
|
+
def _delete():
|
|
377
|
+
# Find and remove from metadata
|
|
378
|
+
idx_to_remove = None
|
|
379
|
+
for idx, metadata in self._metadata.items():
|
|
380
|
+
if metadata["memory_id"] == memory_id:
|
|
381
|
+
idx_to_remove = idx
|
|
382
|
+
break
|
|
383
|
+
|
|
384
|
+
if idx_to_remove is not None:
|
|
385
|
+
del self._metadata[idx_to_remove]
|
|
386
|
+
self._save_index()
|
|
387
|
+
return True
|
|
388
|
+
|
|
389
|
+
return False
|
|
390
|
+
|
|
391
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
|
|
392
|
+
|
|
393
|
+
async def clear_memories(
|
|
394
|
+
self,
|
|
395
|
+
context: ToolContext,
|
|
396
|
+
tool_name: Optional[str] = None,
|
|
397
|
+
before_date: Optional[str] = None,
|
|
398
|
+
) -> int:
|
|
399
|
+
"""Clear stored memories."""
|
|
400
|
+
|
|
401
|
+
def _clear():
|
|
402
|
+
indices_to_remove = []
|
|
403
|
+
|
|
404
|
+
for idx, metadata in self._metadata.items():
|
|
405
|
+
should_remove = True
|
|
406
|
+
|
|
407
|
+
if tool_name and metadata.get("tool_name") != tool_name:
|
|
408
|
+
should_remove = False
|
|
409
|
+
|
|
410
|
+
if before_date and metadata.get("timestamp", "") >= before_date:
|
|
411
|
+
should_remove = False
|
|
412
|
+
|
|
413
|
+
if should_remove:
|
|
414
|
+
indices_to_remove.append(idx)
|
|
415
|
+
|
|
416
|
+
# Remove from metadata
|
|
417
|
+
for idx in indices_to_remove:
|
|
418
|
+
del self._metadata[idx]
|
|
419
|
+
|
|
420
|
+
# If clearing all, recreate index
|
|
421
|
+
if not tool_name and not before_date:
|
|
422
|
+
if self.metric == "cosine":
|
|
423
|
+
self._index = faiss.IndexFlatIP(self.dimension)
|
|
424
|
+
else:
|
|
425
|
+
self._index = faiss.IndexFlatL2(self.dimension)
|
|
426
|
+
self._metadata = {}
|
|
427
|
+
|
|
428
|
+
self._save_index()
|
|
429
|
+
return len(indices_to_remove)
|
|
430
|
+
|
|
431
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
|