vanna 0.7.8__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +461 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +247 -223
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0.dist-info/METADATA +485 -0
- vanna-2.0.0.dist-info/RECORD +289 -0
- vanna-2.0.0.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.8.dist-info/METADATA +0 -408
- vanna-0.7.8.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
- {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Azure AI Search implementation of AgentMemory.
|
|
3
|
+
|
|
4
|
+
This implementation uses Azure Cognitive Search for vector storage of tool usage patterns.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import uuid
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
import asyncio
|
|
12
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from azure.search.documents import SearchClient
|
|
16
|
+
from azure.search.documents.indexes import SearchIndexClient
|
|
17
|
+
from azure.search.documents.indexes.models import (
|
|
18
|
+
SearchIndex,
|
|
19
|
+
SearchField,
|
|
20
|
+
SearchFieldDataType,
|
|
21
|
+
VectorSearch,
|
|
22
|
+
VectorSearchAlgorithmConfiguration,
|
|
23
|
+
)
|
|
24
|
+
from azure.core.credentials import AzureKeyCredential
|
|
25
|
+
|
|
26
|
+
AZURE_SEARCH_AVAILABLE = True
|
|
27
|
+
except ImportError:
|
|
28
|
+
AZURE_SEARCH_AVAILABLE = False
|
|
29
|
+
|
|
30
|
+
from vanna.capabilities.agent_memory import (
|
|
31
|
+
AgentMemory,
|
|
32
|
+
TextMemory,
|
|
33
|
+
TextMemorySearchResult,
|
|
34
|
+
ToolMemory,
|
|
35
|
+
ToolMemorySearchResult,
|
|
36
|
+
)
|
|
37
|
+
from vanna.core.tool import ToolContext
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class AzureAISearchAgentMemory(AgentMemory):
|
|
41
|
+
"""Azure AI Search-based implementation of AgentMemory."""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
endpoint: str,
|
|
46
|
+
api_key: str,
|
|
47
|
+
index_name: str = "tool-memories",
|
|
48
|
+
dimension: int = 384,
|
|
49
|
+
):
|
|
50
|
+
if not AZURE_SEARCH_AVAILABLE:
|
|
51
|
+
raise ImportError(
|
|
52
|
+
"Azure Search is required for AzureAISearchAgentMemory. "
|
|
53
|
+
"Install with: pip install azure-search-documents"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
self.endpoint = endpoint
|
|
57
|
+
self.api_key = api_key
|
|
58
|
+
self.index_name = index_name
|
|
59
|
+
self.dimension = dimension
|
|
60
|
+
self._credential = AzureKeyCredential(api_key)
|
|
61
|
+
self._search_client = None
|
|
62
|
+
self._index_client = None
|
|
63
|
+
self._executor = ThreadPoolExecutor(max_workers=2)
|
|
64
|
+
|
|
65
|
+
def _get_index_client(self):
|
|
66
|
+
"""Get or create index client."""
|
|
67
|
+
if self._index_client is None:
|
|
68
|
+
self._index_client = SearchIndexClient(
|
|
69
|
+
endpoint=self.endpoint, credential=self._credential
|
|
70
|
+
)
|
|
71
|
+
self._ensure_index_exists()
|
|
72
|
+
return self._index_client
|
|
73
|
+
|
|
74
|
+
def _get_search_client(self):
|
|
75
|
+
"""Get or create search client."""
|
|
76
|
+
if self._search_client is None:
|
|
77
|
+
self._get_index_client() # Ensure index exists
|
|
78
|
+
self._search_client = SearchClient(
|
|
79
|
+
endpoint=self.endpoint,
|
|
80
|
+
index_name=self.index_name,
|
|
81
|
+
credential=self._credential,
|
|
82
|
+
)
|
|
83
|
+
return self._search_client
|
|
84
|
+
|
|
85
|
+
def _ensure_index_exists(self):
|
|
86
|
+
"""Create index if it doesn't exist."""
|
|
87
|
+
try:
|
|
88
|
+
self._index_client.get_index(self.index_name)
|
|
89
|
+
except Exception:
|
|
90
|
+
# Create index with vector search configuration
|
|
91
|
+
fields = [
|
|
92
|
+
SearchField(
|
|
93
|
+
name="memory_id", type=SearchFieldDataType.String, key=True
|
|
94
|
+
),
|
|
95
|
+
SearchField(
|
|
96
|
+
name="question", type=SearchFieldDataType.String, searchable=True
|
|
97
|
+
),
|
|
98
|
+
SearchField(
|
|
99
|
+
name="tool_name", type=SearchFieldDataType.String, filterable=True
|
|
100
|
+
),
|
|
101
|
+
SearchField(name="args_json", type=SearchFieldDataType.String),
|
|
102
|
+
SearchField(
|
|
103
|
+
name="timestamp",
|
|
104
|
+
type=SearchFieldDataType.String,
|
|
105
|
+
sortable=True,
|
|
106
|
+
filterable=True,
|
|
107
|
+
),
|
|
108
|
+
SearchField(
|
|
109
|
+
name="success", type=SearchFieldDataType.Boolean, filterable=True
|
|
110
|
+
),
|
|
111
|
+
SearchField(name="metadata_json", type=SearchFieldDataType.String),
|
|
112
|
+
SearchField(
|
|
113
|
+
name="embedding",
|
|
114
|
+
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
|
|
115
|
+
searchable=True,
|
|
116
|
+
vector_search_dimensions=self.dimension,
|
|
117
|
+
vector_search_configuration="vector-config",
|
|
118
|
+
),
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
vector_search = VectorSearch(
|
|
122
|
+
algorithm_configurations=[
|
|
123
|
+
VectorSearchAlgorithmConfiguration(name="vector-config")
|
|
124
|
+
]
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
index = SearchIndex(
|
|
128
|
+
name=self.index_name, fields=fields, vector_search=vector_search
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
self._index_client.create_index(index)
|
|
132
|
+
|
|
133
|
+
def _create_embedding(self, text: str) -> List[float]:
|
|
134
|
+
"""Create a simple embedding from text (placeholder)."""
|
|
135
|
+
import hashlib
|
|
136
|
+
|
|
137
|
+
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
|
|
138
|
+
return [(hash_val >> i) % 100 / 100.0 for i in range(self.dimension)]
|
|
139
|
+
|
|
140
|
+
async def save_tool_usage(
|
|
141
|
+
self,
|
|
142
|
+
question: str,
|
|
143
|
+
tool_name: str,
|
|
144
|
+
args: Dict[str, Any],
|
|
145
|
+
context: ToolContext,
|
|
146
|
+
success: bool = True,
|
|
147
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
148
|
+
) -> None:
|
|
149
|
+
"""Save a tool usage pattern."""
|
|
150
|
+
|
|
151
|
+
def _save():
|
|
152
|
+
client = self._get_search_client()
|
|
153
|
+
|
|
154
|
+
memory_id = str(uuid.uuid4())
|
|
155
|
+
timestamp = datetime.now().isoformat()
|
|
156
|
+
embedding = self._create_embedding(question)
|
|
157
|
+
|
|
158
|
+
document = {
|
|
159
|
+
"memory_id": memory_id,
|
|
160
|
+
"question": question,
|
|
161
|
+
"tool_name": tool_name,
|
|
162
|
+
"args_json": json.dumps(args),
|
|
163
|
+
"timestamp": timestamp,
|
|
164
|
+
"success": success,
|
|
165
|
+
"metadata_json": json.dumps(metadata or {}),
|
|
166
|
+
"embedding": embedding,
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
client.upload_documents(documents=[document])
|
|
170
|
+
|
|
171
|
+
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
|
|
172
|
+
|
|
173
|
+
async def search_similar_usage(
|
|
174
|
+
self,
|
|
175
|
+
question: str,
|
|
176
|
+
context: ToolContext,
|
|
177
|
+
*,
|
|
178
|
+
limit: int = 10,
|
|
179
|
+
similarity_threshold: float = 0.7,
|
|
180
|
+
tool_name_filter: Optional[str] = None,
|
|
181
|
+
) -> List[ToolMemorySearchResult]:
|
|
182
|
+
"""Search for similar tool usage patterns."""
|
|
183
|
+
|
|
184
|
+
def _search():
|
|
185
|
+
client = self._get_search_client()
|
|
186
|
+
|
|
187
|
+
embedding = self._create_embedding(question)
|
|
188
|
+
|
|
189
|
+
# Build filter
|
|
190
|
+
filter_expr = "success eq true"
|
|
191
|
+
if tool_name_filter:
|
|
192
|
+
filter_expr += f" and tool_name eq '{tool_name_filter}'"
|
|
193
|
+
|
|
194
|
+
results = client.search(
|
|
195
|
+
search_text=None, vector=embedding, top_k=limit, filter=filter_expr
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
search_results = []
|
|
199
|
+
for i, doc in enumerate(results):
|
|
200
|
+
# Azure returns similarity score in @search.score
|
|
201
|
+
similarity_score = doc.get("@search.score", 0)
|
|
202
|
+
|
|
203
|
+
if similarity_score >= similarity_threshold:
|
|
204
|
+
args = json.loads(doc.get("args_json", "{}"))
|
|
205
|
+
metadata_dict = json.loads(doc.get("metadata_json", "{}"))
|
|
206
|
+
|
|
207
|
+
memory = ToolMemory(
|
|
208
|
+
memory_id=doc["memory_id"],
|
|
209
|
+
question=doc["question"],
|
|
210
|
+
tool_name=doc["tool_name"],
|
|
211
|
+
args=args,
|
|
212
|
+
timestamp=doc.get("timestamp"),
|
|
213
|
+
success=doc.get("success", True),
|
|
214
|
+
metadata=metadata_dict,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
search_results.append(
|
|
218
|
+
ToolMemorySearchResult(
|
|
219
|
+
memory=memory, similarity_score=similarity_score, rank=i + 1
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return search_results
|
|
224
|
+
|
|
225
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
|
|
226
|
+
|
|
227
|
+
async def get_recent_memories(
|
|
228
|
+
self, context: ToolContext, limit: int = 10
|
|
229
|
+
) -> List[ToolMemory]:
|
|
230
|
+
"""Get recently added memories."""
|
|
231
|
+
|
|
232
|
+
def _get_recent():
|
|
233
|
+
client = self._get_search_client()
|
|
234
|
+
|
|
235
|
+
results = client.search(
|
|
236
|
+
search_text="*", top=limit, order_by=["timestamp desc"]
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
memories = []
|
|
240
|
+
for doc in results:
|
|
241
|
+
args = json.loads(doc.get("args_json", "{}"))
|
|
242
|
+
metadata_dict = json.loads(doc.get("metadata_json", "{}"))
|
|
243
|
+
|
|
244
|
+
memory = ToolMemory(
|
|
245
|
+
memory_id=doc["memory_id"],
|
|
246
|
+
question=doc["question"],
|
|
247
|
+
tool_name=doc["tool_name"],
|
|
248
|
+
args=args,
|
|
249
|
+
timestamp=doc.get("timestamp"),
|
|
250
|
+
success=doc.get("success", True),
|
|
251
|
+
metadata=metadata_dict,
|
|
252
|
+
)
|
|
253
|
+
memories.append(memory)
|
|
254
|
+
|
|
255
|
+
return memories
|
|
256
|
+
|
|
257
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
258
|
+
self._executor, _get_recent
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
|
|
262
|
+
"""Delete a memory by its ID."""
|
|
263
|
+
|
|
264
|
+
def _delete():
|
|
265
|
+
client = self._get_search_client()
|
|
266
|
+
|
|
267
|
+
try:
|
|
268
|
+
client.delete_documents(documents=[{"memory_id": memory_id}])
|
|
269
|
+
return True
|
|
270
|
+
except Exception:
|
|
271
|
+
return False
|
|
272
|
+
|
|
273
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
|
|
274
|
+
|
|
275
|
+
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
|
|
276
|
+
"""Save a text memory."""
|
|
277
|
+
|
|
278
|
+
def _save():
|
|
279
|
+
client = self._get_search_client()
|
|
280
|
+
|
|
281
|
+
memory_id = str(uuid.uuid4())
|
|
282
|
+
timestamp = datetime.now().isoformat()
|
|
283
|
+
embedding = self._create_embedding(content)
|
|
284
|
+
|
|
285
|
+
document = {
|
|
286
|
+
"memory_id": memory_id,
|
|
287
|
+
"content": content,
|
|
288
|
+
"timestamp": timestamp,
|
|
289
|
+
"embedding": embedding,
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
client.upload_documents(documents=[document])
|
|
293
|
+
|
|
294
|
+
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
|
|
295
|
+
|
|
296
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
|
|
297
|
+
|
|
298
|
+
async def search_text_memories(
|
|
299
|
+
self,
|
|
300
|
+
query: str,
|
|
301
|
+
context: ToolContext,
|
|
302
|
+
*,
|
|
303
|
+
limit: int = 10,
|
|
304
|
+
similarity_threshold: float = 0.7,
|
|
305
|
+
) -> List[TextMemorySearchResult]:
|
|
306
|
+
"""Search for similar text memories."""
|
|
307
|
+
|
|
308
|
+
def _search():
|
|
309
|
+
client = self._get_search_client()
|
|
310
|
+
|
|
311
|
+
embedding = self._create_embedding(query)
|
|
312
|
+
|
|
313
|
+
results = client.search(search_text=None, vector=embedding, top_k=limit)
|
|
314
|
+
|
|
315
|
+
search_results = []
|
|
316
|
+
for i, doc in enumerate(results):
|
|
317
|
+
similarity_score = doc.get("@search.score", 0)
|
|
318
|
+
|
|
319
|
+
if similarity_score >= similarity_threshold:
|
|
320
|
+
memory = TextMemory(
|
|
321
|
+
memory_id=doc["memory_id"],
|
|
322
|
+
content=doc.get("content", ""),
|
|
323
|
+
timestamp=doc.get("timestamp"),
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
search_results.append(
|
|
327
|
+
TextMemorySearchResult(
|
|
328
|
+
memory=memory, similarity_score=similarity_score, rank=i + 1
|
|
329
|
+
)
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
return search_results
|
|
333
|
+
|
|
334
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
|
|
335
|
+
|
|
336
|
+
async def get_recent_text_memories(
|
|
337
|
+
self, context: ToolContext, limit: int = 10
|
|
338
|
+
) -> List[TextMemory]:
|
|
339
|
+
"""Get recently added text memories."""
|
|
340
|
+
|
|
341
|
+
def _get_recent():
|
|
342
|
+
client = self._get_search_client()
|
|
343
|
+
|
|
344
|
+
results = client.search(
|
|
345
|
+
search_text="*", top=limit, order_by=["timestamp desc"]
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
memories = []
|
|
349
|
+
for doc in results:
|
|
350
|
+
# Skip if this is a tool memory (has tool_name field)
|
|
351
|
+
if "tool_name" in doc:
|
|
352
|
+
continue
|
|
353
|
+
|
|
354
|
+
memory = TextMemory(
|
|
355
|
+
memory_id=doc["memory_id"],
|
|
356
|
+
content=doc.get("content", ""),
|
|
357
|
+
timestamp=doc.get("timestamp"),
|
|
358
|
+
)
|
|
359
|
+
memories.append(memory)
|
|
360
|
+
|
|
361
|
+
return memories[:limit]
|
|
362
|
+
|
|
363
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
364
|
+
self._executor, _get_recent
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
|
|
368
|
+
"""Delete a text memory by its ID."""
|
|
369
|
+
|
|
370
|
+
def _delete():
|
|
371
|
+
client = self._get_search_client()
|
|
372
|
+
|
|
373
|
+
try:
|
|
374
|
+
client.delete_documents(documents=[{"memory_id": memory_id}])
|
|
375
|
+
return True
|
|
376
|
+
except Exception:
|
|
377
|
+
return False
|
|
378
|
+
|
|
379
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
|
|
380
|
+
|
|
381
|
+
async def clear_memories(
|
|
382
|
+
self,
|
|
383
|
+
context: ToolContext,
|
|
384
|
+
tool_name: Optional[str] = None,
|
|
385
|
+
before_date: Optional[str] = None,
|
|
386
|
+
) -> int:
|
|
387
|
+
"""Clear stored memories."""
|
|
388
|
+
|
|
389
|
+
def _clear():
|
|
390
|
+
client = self._get_search_client()
|
|
391
|
+
|
|
392
|
+
# Build filter
|
|
393
|
+
filter_parts = []
|
|
394
|
+
if tool_name:
|
|
395
|
+
filter_parts.append(f"tool_name eq '{tool_name}'")
|
|
396
|
+
if before_date:
|
|
397
|
+
filter_parts.append(f"timestamp lt '{before_date}'")
|
|
398
|
+
|
|
399
|
+
filter_expr = " and ".join(filter_parts) if filter_parts else None
|
|
400
|
+
|
|
401
|
+
# Search for documents to delete
|
|
402
|
+
results = client.search(
|
|
403
|
+
search_text="*", filter=filter_expr, select=["memory_id"]
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
docs_to_delete = [{"memory_id": doc["memory_id"]} for doc in results]
|
|
407
|
+
|
|
408
|
+
if docs_to_delete:
|
|
409
|
+
client.delete_documents(documents=docs_to_delete)
|
|
410
|
+
|
|
411
|
+
return len(docs_to_delete)
|
|
412
|
+
|
|
413
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""BigQuery implementation of SqlRunner interface."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
|
|
7
|
+
from vanna.core.tool import ToolContext
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BigQueryRunner(SqlRunner):
|
|
11
|
+
"""BigQuery implementation of the SqlRunner interface."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, project_id: str, cred_file_path: Optional[str] = None, **kwargs):
|
|
14
|
+
"""Initialize with BigQuery connection parameters.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
project_id: Google Cloud Project ID
|
|
18
|
+
cred_file_path: Path to Google Cloud credentials JSON file (optional)
|
|
19
|
+
**kwargs: Additional google.cloud.bigquery.Client parameters
|
|
20
|
+
"""
|
|
21
|
+
try:
|
|
22
|
+
from google.cloud import bigquery
|
|
23
|
+
from google.oauth2 import service_account
|
|
24
|
+
|
|
25
|
+
self.bigquery = bigquery
|
|
26
|
+
self.service_account = service_account
|
|
27
|
+
except ImportError as e:
|
|
28
|
+
raise ImportError(
|
|
29
|
+
"google-cloud-bigquery package is required. "
|
|
30
|
+
"Install with: pip install 'vanna[bigquery]'"
|
|
31
|
+
) from e
|
|
32
|
+
|
|
33
|
+
self.project_id = project_id
|
|
34
|
+
self.cred_file_path = cred_file_path
|
|
35
|
+
self.kwargs = kwargs
|
|
36
|
+
self._client = None
|
|
37
|
+
|
|
38
|
+
def _get_client(self):
|
|
39
|
+
"""Get or create BigQuery client."""
|
|
40
|
+
if self._client is not None:
|
|
41
|
+
return self._client
|
|
42
|
+
|
|
43
|
+
if self.cred_file_path:
|
|
44
|
+
import json
|
|
45
|
+
|
|
46
|
+
with open(self.cred_file_path, "r") as f:
|
|
47
|
+
credentials = (
|
|
48
|
+
self.service_account.Credentials.from_service_account_info(
|
|
49
|
+
json.loads(f.read()),
|
|
50
|
+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
self._client = self.bigquery.Client(
|
|
54
|
+
project=self.project_id, credentials=credentials, **self.kwargs
|
|
55
|
+
)
|
|
56
|
+
else:
|
|
57
|
+
# Use default credentials
|
|
58
|
+
self._client = self.bigquery.Client(project=self.project_id, **self.kwargs)
|
|
59
|
+
|
|
60
|
+
return self._client
|
|
61
|
+
|
|
62
|
+
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
|
|
63
|
+
"""Execute SQL query against BigQuery database and return results as DataFrame.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
args: SQL query arguments
|
|
67
|
+
context: Tool execution context
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
DataFrame with query results
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
google.api_core.exceptions.GoogleAPIError: If query execution fails
|
|
74
|
+
"""
|
|
75
|
+
client = self._get_client()
|
|
76
|
+
|
|
77
|
+
# Execute the query
|
|
78
|
+
job = client.query(args.sql)
|
|
79
|
+
df = job.result().to_dataframe()
|
|
80
|
+
|
|
81
|
+
return df
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ChromaDB integration for Vanna Agents.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .agent_memory import ChromaAgentMemory
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_device() -> str:
|
|
9
|
+
"""Detect the best available device for embeddings.
|
|
10
|
+
|
|
11
|
+
This function checks for GPU availability and returns the appropriate device string
|
|
12
|
+
for use with embedding models. It prioritizes hardware acceleration when available.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
str: Device string - 'cuda' if NVIDIA GPU available, 'mps' if Apple Silicon,
|
|
16
|
+
'cpu' otherwise.
|
|
17
|
+
|
|
18
|
+
Examples:
|
|
19
|
+
>>> device = get_device()
|
|
20
|
+
>>> print(f"Using device: {device}")
|
|
21
|
+
Using device: cuda
|
|
22
|
+
|
|
23
|
+
# Use with ChromaDB SentenceTransformer embeddings
|
|
24
|
+
>>> from chromadb.utils import embedding_functions
|
|
25
|
+
>>> ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
26
|
+
... model_name="sentence-transformers/all-MiniLM-L6-v2",
|
|
27
|
+
... device=get_device()
|
|
28
|
+
... )
|
|
29
|
+
>>> memory = ChromaAgentMemory(embedding_function=ef)
|
|
30
|
+
"""
|
|
31
|
+
try:
|
|
32
|
+
import torch
|
|
33
|
+
|
|
34
|
+
# Check for CUDA (NVIDIA GPUs)
|
|
35
|
+
if torch.cuda.is_available():
|
|
36
|
+
return "cuda"
|
|
37
|
+
|
|
38
|
+
# Check for MPS (Apple Silicon GPUs)
|
|
39
|
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
40
|
+
return "mps"
|
|
41
|
+
|
|
42
|
+
except ImportError:
|
|
43
|
+
# PyTorch not installed, fall back to CPU
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
return "cpu"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def create_sentence_transformer_embedding_function(
|
|
50
|
+
model_name: str = "sentence-transformers/all-MiniLM-L6-v2", device: str = None
|
|
51
|
+
):
|
|
52
|
+
"""Create a SentenceTransformer embedding function with automatic device detection.
|
|
53
|
+
|
|
54
|
+
This convenience function creates a ChromaDB-compatible SentenceTransformer embedding
|
|
55
|
+
function with intelligent device selection. If no device is specified, it automatically
|
|
56
|
+
detects and uses the best available hardware (CUDA, MPS, or CPU).
|
|
57
|
+
|
|
58
|
+
Note: This requires the 'sentence-transformers' package to be installed.
|
|
59
|
+
Install with: pip install sentence-transformers
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
model_name: The name of the sentence-transformer model to use.
|
|
63
|
+
Defaults to "sentence-transformers/all-MiniLM-L6-v2".
|
|
64
|
+
device: Optional device string ('cuda', 'mps', or 'cpu'). If None,
|
|
65
|
+
automatically detects the best available device.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
A ChromaDB SentenceTransformer embedding function configured for the
|
|
69
|
+
specified/detected device.
|
|
70
|
+
|
|
71
|
+
Examples:
|
|
72
|
+
# Automatic device detection (uses CUDA/MPS if available)
|
|
73
|
+
>>> from vanna.integrations.chromadb import ChromaAgentMemory, create_sentence_transformer_embedding_function
|
|
74
|
+
>>> ef = create_sentence_transformer_embedding_function()
|
|
75
|
+
>>> memory = ChromaAgentMemory(embedding_function=ef)
|
|
76
|
+
|
|
77
|
+
# Explicitly use CUDA
|
|
78
|
+
>>> ef_cuda = create_sentence_transformer_embedding_function(device="cuda")
|
|
79
|
+
>>> memory = ChromaAgentMemory(embedding_function=ef_cuda)
|
|
80
|
+
|
|
81
|
+
# Use a different model
|
|
82
|
+
>>> ef_large = create_sentence_transformer_embedding_function(
|
|
83
|
+
... model_name="sentence-transformers/all-mpnet-base-v2"
|
|
84
|
+
... )
|
|
85
|
+
>>> memory = ChromaAgentMemory(embedding_function=ef_large)
|
|
86
|
+
"""
|
|
87
|
+
try:
|
|
88
|
+
from chromadb.utils import embedding_functions
|
|
89
|
+
except ImportError:
|
|
90
|
+
raise ImportError("ChromaDB is required. Install with: pip install chromadb")
|
|
91
|
+
|
|
92
|
+
if device is None:
|
|
93
|
+
device = get_device()
|
|
94
|
+
|
|
95
|
+
return embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
96
|
+
model_name=model_name, device=device
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
__all__ = [
|
|
101
|
+
"ChromaAgentMemory",
|
|
102
|
+
"get_device",
|
|
103
|
+
"create_sentence_transformer_embedding_function",
|
|
104
|
+
]
|