vanna 0.7.8__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +461 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +247 -223
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0.dist-info/METADATA +485 -0
- vanna-2.0.0.dist-info/RECORD +289 -0
- vanna-2.0.0.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.8.dist-info/METADATA +0 -408
- vanna-0.7.8.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
- {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""PostgreSQL implementation of SqlRunner interface."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
|
|
7
|
+
from vanna.core.tool import ToolContext
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PostgresRunner(SqlRunner):
|
|
11
|
+
"""PostgreSQL implementation of the SqlRunner interface."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
connection_string: Optional[str] = None,
|
|
16
|
+
host: Optional[str] = None,
|
|
17
|
+
port: Optional[int] = 5432,
|
|
18
|
+
database: Optional[str] = None,
|
|
19
|
+
user: Optional[str] = None,
|
|
20
|
+
password: Optional[str] = None,
|
|
21
|
+
**kwargs,
|
|
22
|
+
):
|
|
23
|
+
"""Initialize with PostgreSQL connection parameters.
|
|
24
|
+
|
|
25
|
+
You can either provide a connection_string OR individual parameters (host, database, etc.).
|
|
26
|
+
If connection_string is provided, it takes precedence.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
connection_string: PostgreSQL connection string (e.g., "postgresql://user:password@host:port/database")
|
|
30
|
+
host: Database host address
|
|
31
|
+
port: Database port (default: 5432)
|
|
32
|
+
database: Database name
|
|
33
|
+
user: Database user
|
|
34
|
+
password: Database password
|
|
35
|
+
**kwargs: Additional psycopg2 connection parameters (sslmode, connect_timeout, etc.)
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
import psycopg2
|
|
39
|
+
import psycopg2.extras
|
|
40
|
+
|
|
41
|
+
self.psycopg2 = psycopg2
|
|
42
|
+
except Exception as e:
|
|
43
|
+
raise ImportError(
|
|
44
|
+
"psycopg2 package is required. Install with: pip install 'vanna[postgres]'"
|
|
45
|
+
) from e
|
|
46
|
+
|
|
47
|
+
if connection_string:
|
|
48
|
+
self.connection_string = connection_string
|
|
49
|
+
self.connection_params = None
|
|
50
|
+
elif host and database and user:
|
|
51
|
+
self.connection_string = None
|
|
52
|
+
self.connection_params = {
|
|
53
|
+
"host": host,
|
|
54
|
+
"port": port,
|
|
55
|
+
"database": database,
|
|
56
|
+
"user": user,
|
|
57
|
+
"password": password,
|
|
58
|
+
**kwargs,
|
|
59
|
+
}
|
|
60
|
+
else:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
"Either provide connection_string OR (host, database, and user) parameters"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
|
|
66
|
+
"""Execute SQL query against PostgreSQL database and return results as DataFrame.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
args: SQL query arguments
|
|
70
|
+
context: Tool execution context
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
DataFrame with query results
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
psycopg2.Error: If query execution fails
|
|
77
|
+
"""
|
|
78
|
+
# Connect to the database using either connection string or parameters
|
|
79
|
+
if self.connection_string:
|
|
80
|
+
conn = self.psycopg2.connect(self.connection_string)
|
|
81
|
+
else:
|
|
82
|
+
conn = self.psycopg2.connect(**self.connection_params)
|
|
83
|
+
|
|
84
|
+
cursor = conn.cursor(cursor_factory=self.psycopg2.extras.RealDictCursor)
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
# Execute the query
|
|
88
|
+
cursor.execute(args.sql)
|
|
89
|
+
|
|
90
|
+
# Determine if this is a SELECT query or modification query
|
|
91
|
+
query_type = args.sql.strip().upper().split()[0]
|
|
92
|
+
|
|
93
|
+
if query_type == "SELECT":
|
|
94
|
+
# Fetch results for SELECT queries
|
|
95
|
+
rows = cursor.fetchall()
|
|
96
|
+
if not rows:
|
|
97
|
+
# Return empty DataFrame
|
|
98
|
+
return pd.DataFrame()
|
|
99
|
+
|
|
100
|
+
# Convert rows to list of dictionaries
|
|
101
|
+
results_data = [dict(row) for row in rows]
|
|
102
|
+
return pd.DataFrame(results_data)
|
|
103
|
+
else:
|
|
104
|
+
# For non-SELECT queries (INSERT, UPDATE, DELETE, etc.)
|
|
105
|
+
conn.commit()
|
|
106
|
+
rows_affected = cursor.rowcount
|
|
107
|
+
# Return a DataFrame indicating rows affected
|
|
108
|
+
return pd.DataFrame({"rows_affected": [rows_affected]})
|
|
109
|
+
|
|
110
|
+
finally:
|
|
111
|
+
cursor.close()
|
|
112
|
+
conn.close()
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cloud-based implementation of AgentMemory.
|
|
3
|
+
|
|
4
|
+
This implementation uses Vanna's premium cloud service for storing and searching
|
|
5
|
+
tool usage patterns with advanced similarity search and analytics.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
import httpx
|
|
12
|
+
|
|
13
|
+
from vanna.capabilities.agent_memory import (
|
|
14
|
+
AgentMemory,
|
|
15
|
+
TextMemory,
|
|
16
|
+
TextMemorySearchResult,
|
|
17
|
+
ToolMemory,
|
|
18
|
+
ToolMemorySearchResult,
|
|
19
|
+
)
|
|
20
|
+
from vanna.core.tool import ToolContext
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CloudAgentMemory(AgentMemory):
|
|
24
|
+
"""Cloud-based implementation of AgentMemory."""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
api_base_url: str = "https://api.vanna.ai",
|
|
29
|
+
api_key: Optional[str] = None,
|
|
30
|
+
organization_id: Optional[str] = None,
|
|
31
|
+
):
|
|
32
|
+
self.api_base_url = api_base_url.rstrip("/")
|
|
33
|
+
self.api_key = api_key
|
|
34
|
+
self.organization_id = organization_id
|
|
35
|
+
self._client = httpx.AsyncClient(base_url=self.api_base_url, timeout=30.0)
|
|
36
|
+
|
|
37
|
+
def _get_headers(self) -> Dict[str, str]:
|
|
38
|
+
"""Get request headers with authentication."""
|
|
39
|
+
headers = {"Content-Type": "application/json"}
|
|
40
|
+
if self.api_key:
|
|
41
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
42
|
+
if self.organization_id:
|
|
43
|
+
headers["X-Organization-ID"] = self.organization_id
|
|
44
|
+
return headers
|
|
45
|
+
|
|
46
|
+
async def save_tool_usage(
|
|
47
|
+
self,
|
|
48
|
+
question: str,
|
|
49
|
+
tool_name: str,
|
|
50
|
+
args: Dict[str, Any],
|
|
51
|
+
context: ToolContext,
|
|
52
|
+
success: bool = True,
|
|
53
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
54
|
+
) -> None:
|
|
55
|
+
"""Save a tool usage pattern to premium cloud storage."""
|
|
56
|
+
import uuid
|
|
57
|
+
|
|
58
|
+
payload = {
|
|
59
|
+
"id": str(uuid.uuid4()),
|
|
60
|
+
"question": question,
|
|
61
|
+
"tool_name": tool_name,
|
|
62
|
+
"args": args,
|
|
63
|
+
"success": success,
|
|
64
|
+
"metadata": metadata or {},
|
|
65
|
+
"timestamp": datetime.now().isoformat(),
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
response = await self._client.post(
|
|
69
|
+
"/memory/tool-usage", json=payload, headers=self._get_headers()
|
|
70
|
+
)
|
|
71
|
+
response.raise_for_status()
|
|
72
|
+
|
|
73
|
+
async def search_similar_usage(
|
|
74
|
+
self,
|
|
75
|
+
question: str,
|
|
76
|
+
context: ToolContext,
|
|
77
|
+
*,
|
|
78
|
+
limit: int = 10,
|
|
79
|
+
similarity_threshold: float = 0.7,
|
|
80
|
+
tool_name_filter: Optional[str] = None,
|
|
81
|
+
) -> List[ToolMemorySearchResult]:
|
|
82
|
+
"""Search for similar tool usage patterns in premium cloud storage."""
|
|
83
|
+
params = {
|
|
84
|
+
"question": question,
|
|
85
|
+
"limit": limit,
|
|
86
|
+
"similarity_threshold": similarity_threshold,
|
|
87
|
+
}
|
|
88
|
+
if tool_name_filter:
|
|
89
|
+
params["tool_name_filter"] = tool_name_filter
|
|
90
|
+
|
|
91
|
+
response = await self._client.get(
|
|
92
|
+
"/memory/search-similar", params=params, headers=self._get_headers()
|
|
93
|
+
)
|
|
94
|
+
response.raise_for_status()
|
|
95
|
+
|
|
96
|
+
data = response.json()
|
|
97
|
+
results = []
|
|
98
|
+
|
|
99
|
+
for item in data.get("results", []):
|
|
100
|
+
memory = ToolMemory(**item["memory"])
|
|
101
|
+
result = ToolMemorySearchResult(
|
|
102
|
+
memory=memory,
|
|
103
|
+
similarity_score=item["similarity_score"],
|
|
104
|
+
rank=item["rank"],
|
|
105
|
+
)
|
|
106
|
+
results.append(result)
|
|
107
|
+
|
|
108
|
+
return results
|
|
109
|
+
|
|
110
|
+
async def get_recent_memories(
|
|
111
|
+
self, context: ToolContext, limit: int = 10
|
|
112
|
+
) -> List[ToolMemory]:
|
|
113
|
+
"""Get recently added memories from premium cloud storage."""
|
|
114
|
+
params = {"limit": limit}
|
|
115
|
+
|
|
116
|
+
response = await self._client.get(
|
|
117
|
+
"/memory/recent", params=params, headers=self._get_headers()
|
|
118
|
+
)
|
|
119
|
+
response.raise_for_status()
|
|
120
|
+
|
|
121
|
+
data = response.json()
|
|
122
|
+
memories = []
|
|
123
|
+
|
|
124
|
+
for item in data.get("memories", []):
|
|
125
|
+
memory = ToolMemory(**item)
|
|
126
|
+
memories.append(memory)
|
|
127
|
+
|
|
128
|
+
return memories
|
|
129
|
+
|
|
130
|
+
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
|
|
131
|
+
"""Delete a memory by its ID from premium cloud storage."""
|
|
132
|
+
response = await self._client.delete(
|
|
133
|
+
f"/memory/{memory_id}", headers=self._get_headers()
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
if response.status_code == 404:
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
response.raise_for_status()
|
|
140
|
+
return True
|
|
141
|
+
|
|
142
|
+
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
|
|
143
|
+
"""Cloud implementation does not yet support text memories."""
|
|
144
|
+
raise NotImplementedError("CloudAgentMemory does not support text memories.")
|
|
145
|
+
|
|
146
|
+
async def search_text_memories(
|
|
147
|
+
self,
|
|
148
|
+
query: str,
|
|
149
|
+
context: ToolContext,
|
|
150
|
+
*,
|
|
151
|
+
limit: int = 10,
|
|
152
|
+
similarity_threshold: float = 0.7,
|
|
153
|
+
) -> List[TextMemorySearchResult]:
|
|
154
|
+
"""Cloud implementation does not yet support text memories."""
|
|
155
|
+
return []
|
|
156
|
+
|
|
157
|
+
async def get_recent_text_memories(
|
|
158
|
+
self, context: ToolContext, limit: int = 10
|
|
159
|
+
) -> List[TextMemory]:
|
|
160
|
+
"""Cloud implementation does not yet support text memories."""
|
|
161
|
+
return []
|
|
162
|
+
|
|
163
|
+
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
|
|
164
|
+
"""Cloud implementation does not yet support text memories."""
|
|
165
|
+
return False
|
|
166
|
+
|
|
167
|
+
async def clear_memories(
|
|
168
|
+
self,
|
|
169
|
+
context: ToolContext,
|
|
170
|
+
tool_name: Optional[str] = None,
|
|
171
|
+
before_date: Optional[str] = None,
|
|
172
|
+
) -> int:
|
|
173
|
+
"""Clear stored memories from premium cloud storage."""
|
|
174
|
+
payload = {}
|
|
175
|
+
if tool_name:
|
|
176
|
+
payload["tool_name"] = tool_name
|
|
177
|
+
if before_date:
|
|
178
|
+
payload["before_date"] = before_date
|
|
179
|
+
|
|
180
|
+
response = await self._client.delete(
|
|
181
|
+
"/memory/clear", json=payload, headers=self._get_headers()
|
|
182
|
+
)
|
|
183
|
+
response.raise_for_status()
|
|
184
|
+
|
|
185
|
+
data = response.json()
|
|
186
|
+
return data.get("deleted_count", 0)
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Presto implementation of SqlRunner interface."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
|
|
7
|
+
from vanna.core.tool import ToolContext
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PrestoRunner(SqlRunner):
|
|
11
|
+
"""Presto implementation of the SqlRunner interface."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
host: str,
|
|
16
|
+
catalog: str = "hive",
|
|
17
|
+
schema: str = "default",
|
|
18
|
+
user: Optional[str] = None,
|
|
19
|
+
password: Optional[str] = None,
|
|
20
|
+
port: int = 443,
|
|
21
|
+
combined_pem_path: Optional[str] = None,
|
|
22
|
+
protocol: str = "https",
|
|
23
|
+
requests_kwargs: Optional[dict] = None,
|
|
24
|
+
**kwargs,
|
|
25
|
+
):
|
|
26
|
+
"""Initialize with Presto connection parameters.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
host: The host address of the Presto database
|
|
30
|
+
catalog: The catalog to use in the Presto environment (default: 'hive')
|
|
31
|
+
schema: The schema to use in the Presto environment (default: 'default')
|
|
32
|
+
user: The username for authentication
|
|
33
|
+
password: The password for authentication
|
|
34
|
+
port: The port number for the Presto connection (default: 443)
|
|
35
|
+
combined_pem_path: The path to the combined pem file for SSL connection
|
|
36
|
+
protocol: The protocol to use for the connection (default: 'https')
|
|
37
|
+
requests_kwargs: Additional keyword arguments for requests
|
|
38
|
+
**kwargs: Additional pyhive connection parameters
|
|
39
|
+
"""
|
|
40
|
+
try:
|
|
41
|
+
from pyhive import presto
|
|
42
|
+
|
|
43
|
+
self.presto = presto
|
|
44
|
+
except ImportError as e:
|
|
45
|
+
raise ImportError(
|
|
46
|
+
"pyhive package is required. Install with: pip install pyhive"
|
|
47
|
+
) from e
|
|
48
|
+
|
|
49
|
+
self.host = host
|
|
50
|
+
self.catalog = catalog
|
|
51
|
+
self.schema = schema
|
|
52
|
+
self.user = user
|
|
53
|
+
self.password = password
|
|
54
|
+
self.port = port
|
|
55
|
+
self.protocol = protocol
|
|
56
|
+
self.kwargs = kwargs
|
|
57
|
+
|
|
58
|
+
# Set up requests_kwargs for SSL if combined_pem_path is provided
|
|
59
|
+
if requests_kwargs is None and combined_pem_path is not None:
|
|
60
|
+
self.requests_kwargs = {"verify": combined_pem_path}
|
|
61
|
+
else:
|
|
62
|
+
self.requests_kwargs = requests_kwargs
|
|
63
|
+
|
|
64
|
+
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
|
|
65
|
+
"""Execute SQL query against Presto database and return results as DataFrame.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
args: SQL query arguments
|
|
69
|
+
context: Tool execution context
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
DataFrame with query results
|
|
73
|
+
|
|
74
|
+
Raises:
|
|
75
|
+
presto.Error: If query execution fails
|
|
76
|
+
"""
|
|
77
|
+
# Connect to the database
|
|
78
|
+
conn = self.presto.Connection(
|
|
79
|
+
host=self.host,
|
|
80
|
+
username=self.user,
|
|
81
|
+
password=self.password,
|
|
82
|
+
catalog=self.catalog,
|
|
83
|
+
schema=self.schema,
|
|
84
|
+
port=self.port,
|
|
85
|
+
protocol=self.protocol,
|
|
86
|
+
requests_kwargs=self.requests_kwargs,
|
|
87
|
+
**self.kwargs,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
# Strip and remove trailing semicolons (Presto doesn't like them)
|
|
92
|
+
sql = args.sql.rstrip()
|
|
93
|
+
if sql.endswith(";"):
|
|
94
|
+
sql = sql[:-1]
|
|
95
|
+
|
|
96
|
+
cursor = conn.cursor()
|
|
97
|
+
cursor.execute(sql)
|
|
98
|
+
results = cursor.fetchall()
|
|
99
|
+
|
|
100
|
+
# Create a pandas dataframe from the results
|
|
101
|
+
df = pd.DataFrame(results, columns=[desc[0] for desc in cursor.description])
|
|
102
|
+
|
|
103
|
+
cursor.close()
|
|
104
|
+
return df
|
|
105
|
+
|
|
106
|
+
finally:
|
|
107
|
+
conn.close()
|