vanna 0.7.9__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +461 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +224 -217
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0.dist-info/METADATA +485 -0
- vanna-2.0.0.dist-info/RECORD +289 -0
- vanna-2.0.0.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.9.dist-info/METADATA +0 -408
- vanna-0.7.9.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,461 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Qdrant vector database implementation of AgentMemory.
|
|
3
|
+
|
|
4
|
+
This implementation uses Qdrant for vector storage of tool usage patterns.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import uuid
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
import asyncio
|
|
12
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from qdrant_client import QdrantClient
|
|
16
|
+
from qdrant_client.models import (
|
|
17
|
+
Distance,
|
|
18
|
+
VectorParams,
|
|
19
|
+
PointStruct,
|
|
20
|
+
Filter,
|
|
21
|
+
FieldCondition,
|
|
22
|
+
MatchValue,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
QDRANT_AVAILABLE = True
|
|
26
|
+
except ImportError:
|
|
27
|
+
QDRANT_AVAILABLE = False
|
|
28
|
+
|
|
29
|
+
from vanna.capabilities.agent_memory import (
|
|
30
|
+
AgentMemory,
|
|
31
|
+
TextMemory,
|
|
32
|
+
TextMemorySearchResult,
|
|
33
|
+
ToolMemory,
|
|
34
|
+
ToolMemorySearchResult,
|
|
35
|
+
)
|
|
36
|
+
from vanna.core.tool import ToolContext
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class QdrantAgentMemory(AgentMemory):
|
|
40
|
+
"""Qdrant-based implementation of AgentMemory."""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
collection_name: str = "tool_memories",
|
|
45
|
+
url: Optional[str] = None,
|
|
46
|
+
path: Optional[str] = None,
|
|
47
|
+
api_key: Optional[str] = None,
|
|
48
|
+
dimension: int = 384,
|
|
49
|
+
):
|
|
50
|
+
if not QDRANT_AVAILABLE:
|
|
51
|
+
raise ImportError(
|
|
52
|
+
"Qdrant is required for QdrantAgentMemory. Install with: pip install qdrant-client"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
self.collection_name = collection_name
|
|
56
|
+
self.url = url
|
|
57
|
+
self.path = path
|
|
58
|
+
self.api_key = api_key
|
|
59
|
+
self.dimension = dimension
|
|
60
|
+
self._client = None
|
|
61
|
+
self._executor = ThreadPoolExecutor(max_workers=2)
|
|
62
|
+
|
|
63
|
+
def _get_client(self):
|
|
64
|
+
"""Get or create Qdrant client."""
|
|
65
|
+
if self._client is None:
|
|
66
|
+
if self.url:
|
|
67
|
+
self._client = QdrantClient(url=self.url, api_key=self.api_key)
|
|
68
|
+
else:
|
|
69
|
+
self._client = QdrantClient(path=self.path or ":memory:")
|
|
70
|
+
|
|
71
|
+
# Create collection if it doesn't exist
|
|
72
|
+
collections = self._client.get_collections().collections
|
|
73
|
+
if not any(c.name == self.collection_name for c in collections):
|
|
74
|
+
self._client.create_collection(
|
|
75
|
+
collection_name=self.collection_name,
|
|
76
|
+
vectors_config=VectorParams(
|
|
77
|
+
size=self.dimension, distance=Distance.COSINE
|
|
78
|
+
),
|
|
79
|
+
)
|
|
80
|
+
return self._client
|
|
81
|
+
|
|
82
|
+
def _create_embedding(self, text: str) -> List[float]:
|
|
83
|
+
"""Create a simple embedding from text (placeholder)."""
|
|
84
|
+
import hashlib
|
|
85
|
+
|
|
86
|
+
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
|
|
87
|
+
return [(hash_val >> i) % 100 / 100.0 for i in range(self.dimension)]
|
|
88
|
+
|
|
89
|
+
async def save_tool_usage(
|
|
90
|
+
self,
|
|
91
|
+
question: str,
|
|
92
|
+
tool_name: str,
|
|
93
|
+
args: Dict[str, Any],
|
|
94
|
+
context: ToolContext,
|
|
95
|
+
success: bool = True,
|
|
96
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
97
|
+
) -> None:
|
|
98
|
+
"""Save a tool usage pattern."""
|
|
99
|
+
|
|
100
|
+
def _save():
|
|
101
|
+
client = self._get_client()
|
|
102
|
+
|
|
103
|
+
memory_id = str(uuid.uuid4())
|
|
104
|
+
timestamp = datetime.now().isoformat()
|
|
105
|
+
embedding = self._create_embedding(question)
|
|
106
|
+
|
|
107
|
+
payload = {
|
|
108
|
+
"question": question,
|
|
109
|
+
"tool_name": tool_name,
|
|
110
|
+
"args": args,
|
|
111
|
+
"timestamp": timestamp,
|
|
112
|
+
"success": success,
|
|
113
|
+
"metadata": metadata or {},
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
point = PointStruct(id=memory_id, vector=embedding, payload=payload)
|
|
117
|
+
|
|
118
|
+
client.upsert(collection_name=self.collection_name, points=[point])
|
|
119
|
+
|
|
120
|
+
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
|
|
121
|
+
|
|
122
|
+
async def search_similar_usage(
|
|
123
|
+
self,
|
|
124
|
+
question: str,
|
|
125
|
+
context: ToolContext,
|
|
126
|
+
*,
|
|
127
|
+
limit: int = 10,
|
|
128
|
+
similarity_threshold: float = 0.7,
|
|
129
|
+
tool_name_filter: Optional[str] = None,
|
|
130
|
+
) -> List[ToolMemorySearchResult]:
|
|
131
|
+
"""Search for similar tool usage patterns."""
|
|
132
|
+
|
|
133
|
+
def _search():
|
|
134
|
+
client = self._get_client()
|
|
135
|
+
|
|
136
|
+
embedding = self._create_embedding(question)
|
|
137
|
+
|
|
138
|
+
# Build filter
|
|
139
|
+
query_filter = None
|
|
140
|
+
conditions = [FieldCondition(key="success", match=MatchValue(value=True))]
|
|
141
|
+
if tool_name_filter:
|
|
142
|
+
conditions.append(
|
|
143
|
+
FieldCondition(
|
|
144
|
+
key="tool_name", match=MatchValue(value=tool_name_filter)
|
|
145
|
+
)
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if conditions:
|
|
149
|
+
query_filter = Filter(must=conditions)
|
|
150
|
+
|
|
151
|
+
# Use query_points for newer qdrant-client (1.8.0+) or search for older versions
|
|
152
|
+
if hasattr(client, "query_points"):
|
|
153
|
+
results = client.query_points(
|
|
154
|
+
collection_name=self.collection_name,
|
|
155
|
+
query=embedding,
|
|
156
|
+
limit=limit,
|
|
157
|
+
query_filter=query_filter,
|
|
158
|
+
score_threshold=similarity_threshold,
|
|
159
|
+
).points
|
|
160
|
+
else:
|
|
161
|
+
# Fallback to search method for older qdrant-client versions
|
|
162
|
+
results = client.search(
|
|
163
|
+
collection_name=self.collection_name,
|
|
164
|
+
query_vector=embedding,
|
|
165
|
+
limit=limit,
|
|
166
|
+
query_filter=query_filter,
|
|
167
|
+
score_threshold=similarity_threshold,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
search_results = []
|
|
171
|
+
for i, hit in enumerate(results):
|
|
172
|
+
payload = hit.payload
|
|
173
|
+
|
|
174
|
+
memory = ToolMemory(
|
|
175
|
+
memory_id=str(hit.id),
|
|
176
|
+
question=payload["question"],
|
|
177
|
+
tool_name=payload["tool_name"],
|
|
178
|
+
args=payload["args"],
|
|
179
|
+
timestamp=payload.get("timestamp"),
|
|
180
|
+
success=payload.get("success", True),
|
|
181
|
+
metadata=payload.get("metadata", {}),
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
search_results.append(
|
|
185
|
+
ToolMemorySearchResult(
|
|
186
|
+
memory=memory, similarity_score=hit.score, rank=i + 1
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
return search_results
|
|
191
|
+
|
|
192
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
|
|
193
|
+
|
|
194
|
+
async def get_recent_memories(
|
|
195
|
+
self, context: ToolContext, limit: int = 10
|
|
196
|
+
) -> List[ToolMemory]:
|
|
197
|
+
"""Get recently added memories."""
|
|
198
|
+
|
|
199
|
+
def _get_recent():
|
|
200
|
+
client = self._get_client()
|
|
201
|
+
|
|
202
|
+
# Scroll through all points and sort by timestamp
|
|
203
|
+
points, _ = client.scroll(
|
|
204
|
+
collection_name=self.collection_name,
|
|
205
|
+
limit=1000, # Get more than we need to sort
|
|
206
|
+
with_payload=True,
|
|
207
|
+
with_vectors=False,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Sort by timestamp
|
|
211
|
+
sorted_points = sorted(
|
|
212
|
+
points, key=lambda p: p.payload.get("timestamp", ""), reverse=True
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
memories = []
|
|
216
|
+
for point in sorted_points[:limit]:
|
|
217
|
+
payload = point.payload
|
|
218
|
+
memory = ToolMemory(
|
|
219
|
+
memory_id=str(point.id),
|
|
220
|
+
question=payload["question"],
|
|
221
|
+
tool_name=payload["tool_name"],
|
|
222
|
+
args=payload["args"],
|
|
223
|
+
timestamp=payload.get("timestamp"),
|
|
224
|
+
success=payload.get("success", True),
|
|
225
|
+
metadata=payload.get("metadata", {}),
|
|
226
|
+
)
|
|
227
|
+
memories.append(memory)
|
|
228
|
+
|
|
229
|
+
return memories
|
|
230
|
+
|
|
231
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
232
|
+
self._executor, _get_recent
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
|
|
236
|
+
"""Delete a memory by its ID. Returns True if deleted, False if not found."""
|
|
237
|
+
|
|
238
|
+
def _delete():
|
|
239
|
+
client = self._get_client()
|
|
240
|
+
|
|
241
|
+
try:
|
|
242
|
+
# Check if the point exists before attempting to delete
|
|
243
|
+
points = client.retrieve(
|
|
244
|
+
collection_name=self.collection_name,
|
|
245
|
+
ids=[memory_id],
|
|
246
|
+
with_payload=False,
|
|
247
|
+
with_vectors=False,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
if points and len(points) > 0:
|
|
251
|
+
client.delete(
|
|
252
|
+
collection_name=self.collection_name,
|
|
253
|
+
points_selector=[memory_id],
|
|
254
|
+
)
|
|
255
|
+
return True
|
|
256
|
+
return False
|
|
257
|
+
except Exception:
|
|
258
|
+
return False
|
|
259
|
+
|
|
260
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
|
|
261
|
+
|
|
262
|
+
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
|
|
263
|
+
"""Save a text memory."""
|
|
264
|
+
|
|
265
|
+
def _save():
|
|
266
|
+
client = self._get_client()
|
|
267
|
+
|
|
268
|
+
memory_id = str(uuid.uuid4())
|
|
269
|
+
timestamp = datetime.now().isoformat()
|
|
270
|
+
embedding = self._create_embedding(content)
|
|
271
|
+
|
|
272
|
+
payload = {
|
|
273
|
+
"content": content,
|
|
274
|
+
"timestamp": timestamp,
|
|
275
|
+
"is_text_memory": True,
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
point = PointStruct(id=memory_id, vector=embedding, payload=payload)
|
|
279
|
+
|
|
280
|
+
client.upsert(collection_name=self.collection_name, points=[point])
|
|
281
|
+
|
|
282
|
+
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
|
|
283
|
+
|
|
284
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
|
|
285
|
+
|
|
286
|
+
async def search_text_memories(
|
|
287
|
+
self,
|
|
288
|
+
query: str,
|
|
289
|
+
context: ToolContext,
|
|
290
|
+
*,
|
|
291
|
+
limit: int = 10,
|
|
292
|
+
similarity_threshold: float = 0.7,
|
|
293
|
+
) -> List[TextMemorySearchResult]:
|
|
294
|
+
"""Search for similar text memories."""
|
|
295
|
+
|
|
296
|
+
def _search():
|
|
297
|
+
client = self._get_client()
|
|
298
|
+
|
|
299
|
+
embedding = self._create_embedding(query)
|
|
300
|
+
|
|
301
|
+
query_filter = Filter(
|
|
302
|
+
must=[
|
|
303
|
+
FieldCondition(key="is_text_memory", match=MatchValue(value=True))
|
|
304
|
+
]
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Use query_points for newer qdrant-client (1.8.0+) or search for older versions
|
|
308
|
+
if hasattr(client, "query_points"):
|
|
309
|
+
results = client.query_points(
|
|
310
|
+
collection_name=self.collection_name,
|
|
311
|
+
query=embedding,
|
|
312
|
+
limit=limit,
|
|
313
|
+
query_filter=query_filter,
|
|
314
|
+
score_threshold=similarity_threshold,
|
|
315
|
+
).points
|
|
316
|
+
else:
|
|
317
|
+
# Fallback to search method for older qdrant-client versions
|
|
318
|
+
results = client.search(
|
|
319
|
+
collection_name=self.collection_name,
|
|
320
|
+
query_vector=embedding,
|
|
321
|
+
limit=limit,
|
|
322
|
+
query_filter=query_filter,
|
|
323
|
+
score_threshold=similarity_threshold,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
search_results = []
|
|
327
|
+
for i, hit in enumerate(results):
|
|
328
|
+
payload = hit.payload
|
|
329
|
+
|
|
330
|
+
memory = TextMemory(
|
|
331
|
+
memory_id=str(hit.id),
|
|
332
|
+
content=payload.get("content", ""),
|
|
333
|
+
timestamp=payload.get("timestamp"),
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
search_results.append(
|
|
337
|
+
TextMemorySearchResult(
|
|
338
|
+
memory=memory, similarity_score=hit.score, rank=i + 1
|
|
339
|
+
)
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
return search_results
|
|
343
|
+
|
|
344
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
|
|
345
|
+
|
|
346
|
+
async def get_recent_text_memories(
|
|
347
|
+
self, context: ToolContext, limit: int = 10
|
|
348
|
+
) -> List[TextMemory]:
|
|
349
|
+
"""Get recently added text memories."""
|
|
350
|
+
|
|
351
|
+
def _get_recent():
|
|
352
|
+
client = self._get_client()
|
|
353
|
+
|
|
354
|
+
# Scroll through text memory points and sort by timestamp
|
|
355
|
+
points, _ = client.scroll(
|
|
356
|
+
collection_name=self.collection_name,
|
|
357
|
+
scroll_filter=Filter(
|
|
358
|
+
must=[
|
|
359
|
+
FieldCondition(
|
|
360
|
+
key="is_text_memory", match=MatchValue(value=True)
|
|
361
|
+
)
|
|
362
|
+
]
|
|
363
|
+
),
|
|
364
|
+
limit=1000,
|
|
365
|
+
with_payload=True,
|
|
366
|
+
with_vectors=False,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# Sort by timestamp
|
|
370
|
+
sorted_points = sorted(
|
|
371
|
+
points, key=lambda p: p.payload.get("timestamp", ""), reverse=True
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
memories = []
|
|
375
|
+
for point in sorted_points[:limit]:
|
|
376
|
+
payload = point.payload
|
|
377
|
+
memory = TextMemory(
|
|
378
|
+
memory_id=str(point.id),
|
|
379
|
+
content=payload.get("content", ""),
|
|
380
|
+
timestamp=payload.get("timestamp"),
|
|
381
|
+
)
|
|
382
|
+
memories.append(memory)
|
|
383
|
+
|
|
384
|
+
return memories
|
|
385
|
+
|
|
386
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
387
|
+
self._executor, _get_recent
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
|
|
391
|
+
"""Delete a text memory by its ID."""
|
|
392
|
+
|
|
393
|
+
def _delete():
|
|
394
|
+
client = self._get_client()
|
|
395
|
+
|
|
396
|
+
try:
|
|
397
|
+
# Check if the point exists before attempting to delete
|
|
398
|
+
points = client.retrieve(
|
|
399
|
+
collection_name=self.collection_name,
|
|
400
|
+
ids=[memory_id],
|
|
401
|
+
with_payload=False,
|
|
402
|
+
with_vectors=False,
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
if points and len(points) > 0:
|
|
406
|
+
client.delete(
|
|
407
|
+
collection_name=self.collection_name,
|
|
408
|
+
points_selector=[memory_id],
|
|
409
|
+
)
|
|
410
|
+
return True
|
|
411
|
+
return False
|
|
412
|
+
except Exception:
|
|
413
|
+
return False
|
|
414
|
+
|
|
415
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
|
|
416
|
+
|
|
417
|
+
async def clear_memories(
|
|
418
|
+
self,
|
|
419
|
+
context: ToolContext,
|
|
420
|
+
tool_name: Optional[str] = None,
|
|
421
|
+
before_date: Optional[str] = None,
|
|
422
|
+
) -> int:
|
|
423
|
+
"""Clear stored memories."""
|
|
424
|
+
|
|
425
|
+
def _clear():
|
|
426
|
+
client = self._get_client()
|
|
427
|
+
|
|
428
|
+
# Build filter
|
|
429
|
+
conditions = []
|
|
430
|
+
if tool_name:
|
|
431
|
+
conditions.append(
|
|
432
|
+
FieldCondition(key="tool_name", match=MatchValue(value=tool_name))
|
|
433
|
+
)
|
|
434
|
+
if before_date:
|
|
435
|
+
conditions.append(
|
|
436
|
+
FieldCondition(key="timestamp", match=MatchValue(value=before_date))
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
if conditions or (tool_name is None and before_date is None):
|
|
440
|
+
# Delete with filter or delete all
|
|
441
|
+
query_filter = Filter(must=conditions) if conditions else None
|
|
442
|
+
|
|
443
|
+
if query_filter:
|
|
444
|
+
client.delete(
|
|
445
|
+
collection_name=self.collection_name,
|
|
446
|
+
points_selector=query_filter,
|
|
447
|
+
)
|
|
448
|
+
else:
|
|
449
|
+
# Delete all points
|
|
450
|
+
client.delete_collection(collection_name=self.collection_name)
|
|
451
|
+
# Recreate empty collection
|
|
452
|
+
client.create_collection(
|
|
453
|
+
collection_name=self.collection_name,
|
|
454
|
+
vectors_config=VectorParams(
|
|
455
|
+
size=self.dimension, distance=Distance.COSINE
|
|
456
|
+
),
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
return 0 # Qdrant doesn't return count
|
|
460
|
+
|
|
461
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""Snowflake implementation of SqlRunner interface."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
import os
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
|
|
8
|
+
from vanna.core.tool import ToolContext
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SnowflakeRunner(SqlRunner):
|
|
12
|
+
"""Snowflake implementation of the SqlRunner interface."""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
account: str,
|
|
17
|
+
username: str,
|
|
18
|
+
password: Optional[str] = None,
|
|
19
|
+
database: str = "",
|
|
20
|
+
role: Optional[str] = None,
|
|
21
|
+
warehouse: Optional[str] = None,
|
|
22
|
+
private_key_path: Optional[str] = None,
|
|
23
|
+
private_key_passphrase: Optional[str] = None,
|
|
24
|
+
private_key_content: Optional[bytes] = None,
|
|
25
|
+
**kwargs,
|
|
26
|
+
):
|
|
27
|
+
"""Initialize with Snowflake connection parameters.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
account: Snowflake account identifier
|
|
31
|
+
username: Database user
|
|
32
|
+
password: Database password (optional if using key-pair auth)
|
|
33
|
+
database: Database name
|
|
34
|
+
role: Snowflake role to use (optional)
|
|
35
|
+
warehouse: Snowflake warehouse to use (optional)
|
|
36
|
+
private_key_path: Path to private key file for RSA key-pair authentication (optional)
|
|
37
|
+
private_key_passphrase: Passphrase for encrypted private key (optional)
|
|
38
|
+
private_key_content: Private key content as bytes (optional, alternative to private_key_path)
|
|
39
|
+
**kwargs: Additional snowflake.connector connection parameters
|
|
40
|
+
|
|
41
|
+
Note:
|
|
42
|
+
Either password OR private_key_path/private_key_content must be provided.
|
|
43
|
+
RSA key-pair authentication is recommended for production systems as Snowflake
|
|
44
|
+
is deprecating user/password authentication.
|
|
45
|
+
"""
|
|
46
|
+
try:
|
|
47
|
+
import snowflake.connector
|
|
48
|
+
|
|
49
|
+
self.snowflake = snowflake.connector
|
|
50
|
+
except ImportError as e:
|
|
51
|
+
raise ImportError(
|
|
52
|
+
"snowflake-connector-python package is required. "
|
|
53
|
+
"Install with: pip install 'vanna[snowflake]'"
|
|
54
|
+
) from e
|
|
55
|
+
|
|
56
|
+
# Validate that at least one authentication method is provided
|
|
57
|
+
if not password and not private_key_path and not private_key_content:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
"Either password or private_key_path/private_key_content must be provided for authentication"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Validate private key path exists if provided
|
|
63
|
+
if private_key_path and not os.path.isfile(private_key_path):
|
|
64
|
+
raise FileNotFoundError(f"Private key file not found: {private_key_path}")
|
|
65
|
+
|
|
66
|
+
self.account = account
|
|
67
|
+
self.username = username
|
|
68
|
+
self.password = password
|
|
69
|
+
self.database = database
|
|
70
|
+
self.role = role
|
|
71
|
+
self.warehouse = warehouse
|
|
72
|
+
self.private_key_path = private_key_path
|
|
73
|
+
self.private_key_passphrase = private_key_passphrase
|
|
74
|
+
self.private_key_content = private_key_content
|
|
75
|
+
self.kwargs = kwargs
|
|
76
|
+
|
|
77
|
+
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
|
|
78
|
+
"""Execute SQL query against Snowflake database and return results as DataFrame.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
args: SQL query arguments
|
|
82
|
+
context: Tool execution context
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
DataFrame with query results
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
snowflake.connector.Error: If query execution fails
|
|
89
|
+
"""
|
|
90
|
+
# Build connection parameters
|
|
91
|
+
conn_params = {
|
|
92
|
+
"user": self.username,
|
|
93
|
+
"account": self.account,
|
|
94
|
+
"client_session_keep_alive": True,
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
# Add database if specified
|
|
98
|
+
if self.database:
|
|
99
|
+
conn_params["database"] = self.database
|
|
100
|
+
|
|
101
|
+
# Configure authentication method
|
|
102
|
+
if self.private_key_path or self.private_key_content:
|
|
103
|
+
# Use RSA key-pair authentication
|
|
104
|
+
if self.private_key_path:
|
|
105
|
+
conn_params["private_key_path"] = self.private_key_path
|
|
106
|
+
else:
|
|
107
|
+
conn_params["private_key_content"] = self.private_key_content
|
|
108
|
+
|
|
109
|
+
# Add passphrase if provided
|
|
110
|
+
if self.private_key_passphrase:
|
|
111
|
+
conn_params["private_key_passphrase"] = self.private_key_passphrase
|
|
112
|
+
else:
|
|
113
|
+
# Use password authentication (fallback)
|
|
114
|
+
conn_params["password"] = self.password
|
|
115
|
+
|
|
116
|
+
# Add any additional kwargs
|
|
117
|
+
conn_params.update(self.kwargs)
|
|
118
|
+
|
|
119
|
+
# Connect to the database
|
|
120
|
+
conn = self.snowflake.connect(**conn_params)
|
|
121
|
+
|
|
122
|
+
cursor = conn.cursor()
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
# Set role if specified
|
|
126
|
+
if self.role:
|
|
127
|
+
cursor.execute(f"USE ROLE {self.role}")
|
|
128
|
+
|
|
129
|
+
# Set warehouse if specified
|
|
130
|
+
if self.warehouse:
|
|
131
|
+
cursor.execute(f"USE WAREHOUSE {self.warehouse}")
|
|
132
|
+
|
|
133
|
+
# Use the specified database if provided
|
|
134
|
+
if self.database:
|
|
135
|
+
cursor.execute(f"USE DATABASE {self.database}")
|
|
136
|
+
|
|
137
|
+
# Execute the query
|
|
138
|
+
cursor.execute(args.sql)
|
|
139
|
+
results = cursor.fetchall()
|
|
140
|
+
|
|
141
|
+
# Create a pandas dataframe from the results
|
|
142
|
+
df = pd.DataFrame(results, columns=[desc[0] for desc in cursor.description])
|
|
143
|
+
return df
|
|
144
|
+
|
|
145
|
+
finally:
|
|
146
|
+
cursor.close()
|
|
147
|
+
conn.close()
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""SQLite implementation of SqlRunner interface."""
|
|
2
|
+
|
|
3
|
+
import sqlite3
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
|
|
7
|
+
from vanna.core.tool import ToolContext
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SqliteRunner(SqlRunner):
|
|
11
|
+
"""SQLite implementation of the SqlRunner interface."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, database_path: str):
|
|
14
|
+
"""Initialize with a SQLite database path.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
database_path: Path to the SQLite database file
|
|
18
|
+
"""
|
|
19
|
+
self.database_path = database_path
|
|
20
|
+
|
|
21
|
+
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
|
|
22
|
+
"""Execute SQL query against SQLite database and return results as DataFrame.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
args: SQL query arguments
|
|
26
|
+
context: Tool execution context
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
DataFrame with query results
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
sqlite3.Error: If query execution fails
|
|
33
|
+
"""
|
|
34
|
+
# Connect to the database
|
|
35
|
+
conn = sqlite3.connect(self.database_path)
|
|
36
|
+
conn.row_factory = sqlite3.Row # Enable column access by name
|
|
37
|
+
cursor = conn.cursor()
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
# Execute the query
|
|
41
|
+
cursor.execute(args.sql)
|
|
42
|
+
|
|
43
|
+
# Determine if this is a SELECT query or modification query
|
|
44
|
+
query_type = args.sql.strip().upper().split()[0]
|
|
45
|
+
|
|
46
|
+
if query_type == "SELECT":
|
|
47
|
+
# Fetch results for SELECT queries
|
|
48
|
+
rows = cursor.fetchall()
|
|
49
|
+
if not rows:
|
|
50
|
+
# Return empty DataFrame
|
|
51
|
+
return pd.DataFrame()
|
|
52
|
+
|
|
53
|
+
# Convert rows to list of dictionaries
|
|
54
|
+
results_data = [dict(row) for row in rows]
|
|
55
|
+
return pd.DataFrame(results_data)
|
|
56
|
+
else:
|
|
57
|
+
# For non-SELECT queries (INSERT, UPDATE, DELETE, etc.)
|
|
58
|
+
conn.commit()
|
|
59
|
+
rows_affected = cursor.rowcount
|
|
60
|
+
# Return a DataFrame indicating rows affected
|
|
61
|
+
return pd.DataFrame({"rows_affected": [rows_affected]})
|
|
62
|
+
|
|
63
|
+
finally:
|
|
64
|
+
cursor.close()
|
|
65
|
+
conn.close()
|