vanna 0.7.9__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +461 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +224 -217
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0.dist-info/METADATA +485 -0
- vanna-2.0.0.dist-info/RECORD +289 -0
- vanna-2.0.0.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.9.dist-info/METADATA +0 -408
- vanna-0.7.9.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,458 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Milvus vector database implementation of AgentMemory.
|
|
3
|
+
|
|
4
|
+
This implementation uses Milvus for distributed vector storage of tool usage patterns.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import uuid
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
import asyncio
|
|
12
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from pymilvus import (
|
|
16
|
+
connections,
|
|
17
|
+
Collection,
|
|
18
|
+
CollectionSchema,
|
|
19
|
+
FieldSchema,
|
|
20
|
+
DataType,
|
|
21
|
+
utility,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
MILVUS_AVAILABLE = True
|
|
25
|
+
except ImportError:
|
|
26
|
+
MILVUS_AVAILABLE = False
|
|
27
|
+
|
|
28
|
+
from vanna.capabilities.agent_memory import (
|
|
29
|
+
AgentMemory,
|
|
30
|
+
TextMemory,
|
|
31
|
+
TextMemorySearchResult,
|
|
32
|
+
ToolMemory,
|
|
33
|
+
ToolMemorySearchResult,
|
|
34
|
+
)
|
|
35
|
+
from vanna.core.tool import ToolContext
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class MilvusAgentMemory(AgentMemory):
|
|
39
|
+
"""Milvus-based implementation of AgentMemory."""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
collection_name: str = "tool_memories",
|
|
44
|
+
host: str = "localhost",
|
|
45
|
+
port: int = 19530,
|
|
46
|
+
alias: str = "default",
|
|
47
|
+
dimension: int = 384,
|
|
48
|
+
):
|
|
49
|
+
if not MILVUS_AVAILABLE:
|
|
50
|
+
raise ImportError(
|
|
51
|
+
"Milvus is required for MilvusAgentMemory. Install with: pip install pymilvus"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
self.collection_name = collection_name
|
|
55
|
+
self.host = host
|
|
56
|
+
self.port = port
|
|
57
|
+
self.alias = alias
|
|
58
|
+
self.dimension = dimension
|
|
59
|
+
self._collection = None
|
|
60
|
+
self._executor = ThreadPoolExecutor(max_workers=2)
|
|
61
|
+
|
|
62
|
+
def _get_collection(self):
|
|
63
|
+
"""Get or create Milvus collection."""
|
|
64
|
+
if self._collection is None:
|
|
65
|
+
# Connect to Milvus
|
|
66
|
+
connections.connect(alias=self.alias, host=self.host, port=self.port)
|
|
67
|
+
|
|
68
|
+
# Create collection if it doesn't exist
|
|
69
|
+
if not utility.has_collection(self.collection_name):
|
|
70
|
+
fields = [
|
|
71
|
+
FieldSchema(
|
|
72
|
+
name="id",
|
|
73
|
+
dtype=DataType.VARCHAR,
|
|
74
|
+
is_primary=True,
|
|
75
|
+
max_length=100,
|
|
76
|
+
),
|
|
77
|
+
FieldSchema(
|
|
78
|
+
name="embedding",
|
|
79
|
+
dtype=DataType.FLOAT_VECTOR,
|
|
80
|
+
dim=self.dimension,
|
|
81
|
+
),
|
|
82
|
+
FieldSchema(
|
|
83
|
+
name="question", dtype=DataType.VARCHAR, max_length=2000
|
|
84
|
+
),
|
|
85
|
+
FieldSchema(
|
|
86
|
+
name="tool_name", dtype=DataType.VARCHAR, max_length=200
|
|
87
|
+
),
|
|
88
|
+
FieldSchema(
|
|
89
|
+
name="args_json", dtype=DataType.VARCHAR, max_length=5000
|
|
90
|
+
),
|
|
91
|
+
FieldSchema(
|
|
92
|
+
name="timestamp", dtype=DataType.VARCHAR, max_length=50
|
|
93
|
+
),
|
|
94
|
+
FieldSchema(name="success", dtype=DataType.BOOL),
|
|
95
|
+
FieldSchema(
|
|
96
|
+
name="metadata_json", dtype=DataType.VARCHAR, max_length=5000
|
|
97
|
+
),
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
schema = CollectionSchema(
|
|
101
|
+
fields=fields, description="Tool usage memories"
|
|
102
|
+
)
|
|
103
|
+
collection = Collection(name=self.collection_name, schema=schema)
|
|
104
|
+
|
|
105
|
+
# Create index for vector field
|
|
106
|
+
index_params = {
|
|
107
|
+
"index_type": "IVF_FLAT",
|
|
108
|
+
"metric_type": "IP",
|
|
109
|
+
"params": {"nlist": 128},
|
|
110
|
+
}
|
|
111
|
+
collection.create_index(
|
|
112
|
+
field_name="embedding", index_params=index_params
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self._collection = Collection(self.collection_name)
|
|
116
|
+
self._collection.load()
|
|
117
|
+
|
|
118
|
+
return self._collection
|
|
119
|
+
|
|
120
|
+
def _create_embedding(self, text: str) -> List[float]:
|
|
121
|
+
"""Create a simple embedding from text (placeholder)."""
|
|
122
|
+
import hashlib
|
|
123
|
+
|
|
124
|
+
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
|
|
125
|
+
return [(hash_val >> i) % 100 / 100.0 for i in range(self.dimension)]
|
|
126
|
+
|
|
127
|
+
async def save_tool_usage(
|
|
128
|
+
self,
|
|
129
|
+
question: str,
|
|
130
|
+
tool_name: str,
|
|
131
|
+
args: Dict[str, Any],
|
|
132
|
+
context: ToolContext,
|
|
133
|
+
success: bool = True,
|
|
134
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
135
|
+
) -> None:
|
|
136
|
+
"""Save a tool usage pattern."""
|
|
137
|
+
|
|
138
|
+
def _save():
|
|
139
|
+
collection = self._get_collection()
|
|
140
|
+
|
|
141
|
+
memory_id = str(uuid.uuid4())
|
|
142
|
+
timestamp = datetime.now().isoformat()
|
|
143
|
+
embedding = self._create_embedding(question)
|
|
144
|
+
|
|
145
|
+
entities = [
|
|
146
|
+
[memory_id],
|
|
147
|
+
[embedding],
|
|
148
|
+
[question],
|
|
149
|
+
[tool_name],
|
|
150
|
+
[json.dumps(args)],
|
|
151
|
+
[timestamp],
|
|
152
|
+
[success],
|
|
153
|
+
[json.dumps(metadata or {})],
|
|
154
|
+
]
|
|
155
|
+
|
|
156
|
+
collection.insert(entities)
|
|
157
|
+
collection.flush()
|
|
158
|
+
|
|
159
|
+
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
|
|
160
|
+
|
|
161
|
+
async def search_similar_usage(
|
|
162
|
+
self,
|
|
163
|
+
question: str,
|
|
164
|
+
context: ToolContext,
|
|
165
|
+
*,
|
|
166
|
+
limit: int = 10,
|
|
167
|
+
similarity_threshold: float = 0.7,
|
|
168
|
+
tool_name_filter: Optional[str] = None,
|
|
169
|
+
) -> List[ToolMemorySearchResult]:
|
|
170
|
+
"""Search for similar tool usage patterns."""
|
|
171
|
+
|
|
172
|
+
def _search():
|
|
173
|
+
collection = self._get_collection()
|
|
174
|
+
|
|
175
|
+
embedding = self._create_embedding(question)
|
|
176
|
+
|
|
177
|
+
# Build filter expression
|
|
178
|
+
expr = "success == true"
|
|
179
|
+
if tool_name_filter:
|
|
180
|
+
expr += f' && tool_name == "{tool_name_filter}"'
|
|
181
|
+
|
|
182
|
+
search_params = {"metric_type": "IP", "params": {"nprobe": 10}}
|
|
183
|
+
|
|
184
|
+
results = collection.search(
|
|
185
|
+
data=[embedding],
|
|
186
|
+
anns_field="embedding",
|
|
187
|
+
param=search_params,
|
|
188
|
+
limit=limit,
|
|
189
|
+
expr=expr,
|
|
190
|
+
output_fields=[
|
|
191
|
+
"id",
|
|
192
|
+
"question",
|
|
193
|
+
"tool_name",
|
|
194
|
+
"args_json",
|
|
195
|
+
"timestamp",
|
|
196
|
+
"success",
|
|
197
|
+
"metadata_json",
|
|
198
|
+
],
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
search_results = []
|
|
202
|
+
for i, hits in enumerate(results):
|
|
203
|
+
for j, hit in enumerate(hits):
|
|
204
|
+
similarity_score = hit.distance
|
|
205
|
+
|
|
206
|
+
if similarity_score >= similarity_threshold:
|
|
207
|
+
args = json.loads(hit.entity.get("args_json", "{}"))
|
|
208
|
+
metadata_dict = json.loads(
|
|
209
|
+
hit.entity.get("metadata_json", "{}")
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
memory = ToolMemory(
|
|
213
|
+
memory_id=hit.entity.get("id"),
|
|
214
|
+
question=hit.entity.get("question"),
|
|
215
|
+
tool_name=hit.entity.get("tool_name"),
|
|
216
|
+
args=args,
|
|
217
|
+
timestamp=hit.entity.get("timestamp"),
|
|
218
|
+
success=hit.entity.get("success", True),
|
|
219
|
+
metadata=metadata_dict,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
search_results.append(
|
|
223
|
+
ToolMemorySearchResult(
|
|
224
|
+
memory=memory,
|
|
225
|
+
similarity_score=similarity_score,
|
|
226
|
+
rank=j + 1,
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return search_results
|
|
231
|
+
|
|
232
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
|
|
233
|
+
|
|
234
|
+
async def get_recent_memories(
|
|
235
|
+
self, context: ToolContext, limit: int = 10
|
|
236
|
+
) -> List[ToolMemory]:
|
|
237
|
+
"""Get recently added memories."""
|
|
238
|
+
|
|
239
|
+
def _get_recent():
|
|
240
|
+
collection = self._get_collection()
|
|
241
|
+
|
|
242
|
+
# Query all entries and sort by timestamp
|
|
243
|
+
results = collection.query(
|
|
244
|
+
expr="id != ''",
|
|
245
|
+
output_fields=[
|
|
246
|
+
"id",
|
|
247
|
+
"question",
|
|
248
|
+
"tool_name",
|
|
249
|
+
"args_json",
|
|
250
|
+
"timestamp",
|
|
251
|
+
"success",
|
|
252
|
+
"metadata_json",
|
|
253
|
+
],
|
|
254
|
+
limit=1000,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Sort by timestamp
|
|
258
|
+
sorted_results = sorted(
|
|
259
|
+
results, key=lambda r: r.get("timestamp", ""), reverse=True
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
memories = []
|
|
263
|
+
for result in sorted_results[:limit]:
|
|
264
|
+
args = json.loads(result.get("args_json", "{}"))
|
|
265
|
+
metadata_dict = json.loads(result.get("metadata_json", "{}"))
|
|
266
|
+
|
|
267
|
+
memory = ToolMemory(
|
|
268
|
+
memory_id=result.get("id"),
|
|
269
|
+
question=result.get("question"),
|
|
270
|
+
tool_name=result.get("tool_name"),
|
|
271
|
+
args=args,
|
|
272
|
+
timestamp=result.get("timestamp"),
|
|
273
|
+
success=result.get("success", True),
|
|
274
|
+
metadata=metadata_dict,
|
|
275
|
+
)
|
|
276
|
+
memories.append(memory)
|
|
277
|
+
|
|
278
|
+
return memories
|
|
279
|
+
|
|
280
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
281
|
+
self._executor, _get_recent
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
|
|
285
|
+
"""Delete a memory by its ID."""
|
|
286
|
+
|
|
287
|
+
def _delete():
|
|
288
|
+
collection = self._get_collection()
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
expr = f'id == "{memory_id}"'
|
|
292
|
+
collection.delete(expr)
|
|
293
|
+
return True
|
|
294
|
+
except Exception:
|
|
295
|
+
return False
|
|
296
|
+
|
|
297
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
|
|
298
|
+
|
|
299
|
+
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
|
|
300
|
+
"""Save a text memory."""
|
|
301
|
+
|
|
302
|
+
def _save():
|
|
303
|
+
collection = self._get_collection()
|
|
304
|
+
|
|
305
|
+
memory_id = str(uuid.uuid4())
|
|
306
|
+
timestamp = datetime.now().isoformat()
|
|
307
|
+
embedding = self._create_embedding(content)
|
|
308
|
+
|
|
309
|
+
entities = [
|
|
310
|
+
[memory_id],
|
|
311
|
+
[embedding],
|
|
312
|
+
[content],
|
|
313
|
+
[""], # tool_name (empty for text memories)
|
|
314
|
+
[""], # args_json (empty for text memories)
|
|
315
|
+
[timestamp],
|
|
316
|
+
[True], # success (always true for text memories)
|
|
317
|
+
[json.dumps({"is_text_memory": True})], # metadata_json
|
|
318
|
+
]
|
|
319
|
+
|
|
320
|
+
collection.insert(entities)
|
|
321
|
+
collection.flush()
|
|
322
|
+
|
|
323
|
+
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
|
|
324
|
+
|
|
325
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
|
|
326
|
+
|
|
327
|
+
async def search_text_memories(
|
|
328
|
+
self,
|
|
329
|
+
query: str,
|
|
330
|
+
context: ToolContext,
|
|
331
|
+
*,
|
|
332
|
+
limit: int = 10,
|
|
333
|
+
similarity_threshold: float = 0.7,
|
|
334
|
+
) -> List[TextMemorySearchResult]:
|
|
335
|
+
"""Search for similar text memories."""
|
|
336
|
+
|
|
337
|
+
def _search():
|
|
338
|
+
collection = self._get_collection()
|
|
339
|
+
|
|
340
|
+
embedding = self._create_embedding(query)
|
|
341
|
+
|
|
342
|
+
# Build filter expression for text memories
|
|
343
|
+
expr = 'tool_name == ""'
|
|
344
|
+
|
|
345
|
+
search_params = {"metric_type": "IP", "params": {"nprobe": 10}}
|
|
346
|
+
|
|
347
|
+
results = collection.search(
|
|
348
|
+
data=[embedding],
|
|
349
|
+
anns_field="embedding",
|
|
350
|
+
param=search_params,
|
|
351
|
+
limit=limit,
|
|
352
|
+
expr=expr,
|
|
353
|
+
output_fields=["id", "question", "timestamp", "metadata_json"],
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
search_results = []
|
|
357
|
+
for i, hits in enumerate(results):
|
|
358
|
+
for j, hit in enumerate(hits):
|
|
359
|
+
similarity_score = hit.distance
|
|
360
|
+
|
|
361
|
+
if similarity_score >= similarity_threshold:
|
|
362
|
+
content = hit.entity.get("question", "")
|
|
363
|
+
|
|
364
|
+
memory = TextMemory(
|
|
365
|
+
memory_id=hit.entity.get("id"),
|
|
366
|
+
content=content,
|
|
367
|
+
timestamp=hit.entity.get("timestamp"),
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
search_results.append(
|
|
371
|
+
TextMemorySearchResult(
|
|
372
|
+
memory=memory,
|
|
373
|
+
similarity_score=similarity_score,
|
|
374
|
+
rank=j + 1,
|
|
375
|
+
)
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
return search_results
|
|
379
|
+
|
|
380
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
|
|
381
|
+
|
|
382
|
+
async def get_recent_text_memories(
|
|
383
|
+
self, context: ToolContext, limit: int = 10
|
|
384
|
+
) -> List[TextMemory]:
|
|
385
|
+
"""Get recently added text memories."""
|
|
386
|
+
|
|
387
|
+
def _get_recent():
|
|
388
|
+
collection = self._get_collection()
|
|
389
|
+
|
|
390
|
+
# Query text memory entries
|
|
391
|
+
results = collection.query(
|
|
392
|
+
expr='tool_name == ""',
|
|
393
|
+
output_fields=["id", "question", "timestamp"],
|
|
394
|
+
limit=1000,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
# Sort by timestamp
|
|
398
|
+
sorted_results = sorted(
|
|
399
|
+
results, key=lambda r: r.get("timestamp", ""), reverse=True
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
memories = []
|
|
403
|
+
for result in sorted_results[:limit]:
|
|
404
|
+
memory = TextMemory(
|
|
405
|
+
memory_id=result.get("id"),
|
|
406
|
+
content=result.get("question", ""),
|
|
407
|
+
timestamp=result.get("timestamp"),
|
|
408
|
+
)
|
|
409
|
+
memories.append(memory)
|
|
410
|
+
|
|
411
|
+
return memories
|
|
412
|
+
|
|
413
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
414
|
+
self._executor, _get_recent
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
|
|
418
|
+
"""Delete a text memory by its ID."""
|
|
419
|
+
|
|
420
|
+
def _delete():
|
|
421
|
+
collection = self._get_collection()
|
|
422
|
+
|
|
423
|
+
try:
|
|
424
|
+
expr = f'id == "{memory_id}"'
|
|
425
|
+
collection.delete(expr)
|
|
426
|
+
return True
|
|
427
|
+
except Exception:
|
|
428
|
+
return False
|
|
429
|
+
|
|
430
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
|
|
431
|
+
|
|
432
|
+
async def clear_memories(
|
|
433
|
+
self,
|
|
434
|
+
context: ToolContext,
|
|
435
|
+
tool_name: Optional[str] = None,
|
|
436
|
+
before_date: Optional[str] = None,
|
|
437
|
+
) -> int:
|
|
438
|
+
"""Clear stored memories."""
|
|
439
|
+
|
|
440
|
+
def _clear():
|
|
441
|
+
collection = self._get_collection()
|
|
442
|
+
|
|
443
|
+
# Build filter expression
|
|
444
|
+
expr_parts = []
|
|
445
|
+
if tool_name:
|
|
446
|
+
expr_parts.append(f'tool_name == "{tool_name}"')
|
|
447
|
+
if before_date:
|
|
448
|
+
expr_parts.append(f'timestamp < "{before_date}"')
|
|
449
|
+
|
|
450
|
+
if expr_parts:
|
|
451
|
+
expr = " && ".join(expr_parts)
|
|
452
|
+
else:
|
|
453
|
+
expr = "id != ''"
|
|
454
|
+
|
|
455
|
+
collection.delete(expr)
|
|
456
|
+
return 0
|
|
457
|
+
|
|
458
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Mock LLM service implementation for testing.
|
|
3
|
+
|
|
4
|
+
This module provides a simple mock implementation of the LlmService interface,
|
|
5
|
+
useful for testing and development without requiring actual LLM API calls.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
from typing import AsyncGenerator, List
|
|
10
|
+
|
|
11
|
+
from vanna.core.llm import LlmService, LlmRequest, LlmResponse, LlmStreamChunk
|
|
12
|
+
from vanna.core.tool import ToolSchema
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MockLlmService(LlmService):
|
|
16
|
+
"""Mock LLM service that returns predefined responses."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, response_content: str = "Hello! This is a mock response."):
|
|
19
|
+
self.response_content = response_content
|
|
20
|
+
self.call_count = 0
|
|
21
|
+
|
|
22
|
+
async def send_request(self, request: LlmRequest) -> LlmResponse:
|
|
23
|
+
"""Send a request to the mock LLM."""
|
|
24
|
+
self.call_count += 1
|
|
25
|
+
|
|
26
|
+
# Simulate processing delay
|
|
27
|
+
await asyncio.sleep(0.1)
|
|
28
|
+
|
|
29
|
+
# Return a simple response
|
|
30
|
+
return LlmResponse(
|
|
31
|
+
content=f"{self.response_content} (Request #{self.call_count})",
|
|
32
|
+
finish_reason="stop",
|
|
33
|
+
usage={"prompt_tokens": 50, "completion_tokens": 20, "total_tokens": 70},
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
async def stream_request(
|
|
37
|
+
self, request: LlmRequest
|
|
38
|
+
) -> AsyncGenerator[LlmStreamChunk, None]:
|
|
39
|
+
"""Stream a request to the mock LLM."""
|
|
40
|
+
self.call_count += 1
|
|
41
|
+
|
|
42
|
+
# Split response into chunks
|
|
43
|
+
words = f"{self.response_content} (Streamed #{self.call_count})".split()
|
|
44
|
+
|
|
45
|
+
for i, word in enumerate(words):
|
|
46
|
+
await asyncio.sleep(0.05) # Simulate streaming delay
|
|
47
|
+
|
|
48
|
+
chunk_content = word + (" " if i < len(words) - 1 else "")
|
|
49
|
+
yield LlmStreamChunk(
|
|
50
|
+
content=chunk_content,
|
|
51
|
+
finish_reason="stop" if i == len(words) - 1 else None,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
|
|
55
|
+
"""Validate tool schemas and return any errors."""
|
|
56
|
+
# Mock validation - no errors
|
|
57
|
+
return []
|
|
58
|
+
|
|
59
|
+
def set_response(self, content: str) -> None:
|
|
60
|
+
"""Set the response content for testing."""
|
|
61
|
+
self.response_content = content
|
|
62
|
+
|
|
63
|
+
def reset_call_count(self) -> None:
|
|
64
|
+
"""Reset the call counter."""
|
|
65
|
+
self.call_count = 0
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""Microsoft SQL Server implementation of SqlRunner interface."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
|
|
7
|
+
from vanna.core.tool import ToolContext
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MSSQLRunner(SqlRunner):
|
|
11
|
+
"""Microsoft SQL Server implementation of the SqlRunner interface."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, odbc_conn_str: str, **kwargs):
|
|
14
|
+
"""Initialize with MSSQL connection parameters.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
odbc_conn_str: The ODBC connection string for SQL Server
|
|
18
|
+
**kwargs: Additional SQLAlchemy engine parameters
|
|
19
|
+
"""
|
|
20
|
+
try:
|
|
21
|
+
import pyodbc
|
|
22
|
+
|
|
23
|
+
self.pyodbc = pyodbc
|
|
24
|
+
except ImportError as e:
|
|
25
|
+
raise ImportError(
|
|
26
|
+
"pyodbc package is required. Install with: pip install pyodbc"
|
|
27
|
+
) from e
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
import sqlalchemy as sa
|
|
31
|
+
from sqlalchemy.engine import URL
|
|
32
|
+
from sqlalchemy import create_engine
|
|
33
|
+
|
|
34
|
+
self.sa = sa
|
|
35
|
+
self.URL = URL
|
|
36
|
+
self.create_engine = create_engine
|
|
37
|
+
except ImportError as e:
|
|
38
|
+
raise ImportError(
|
|
39
|
+
"sqlalchemy package is required. Install with: pip install sqlalchemy"
|
|
40
|
+
) from e
|
|
41
|
+
|
|
42
|
+
# Create the connection URL
|
|
43
|
+
connection_url = self.URL.create(
|
|
44
|
+
"mssql+pyodbc", query={"odbc_connect": odbc_conn_str}
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Create the engine
|
|
48
|
+
self.engine = self.create_engine(connection_url, **kwargs)
|
|
49
|
+
|
|
50
|
+
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
|
|
51
|
+
"""Execute SQL query against MSSQL database and return results as DataFrame.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
args: SQL query arguments
|
|
55
|
+
context: Tool execution context
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
DataFrame with query results
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
sqlalchemy.exc.SQLAlchemyError: If query execution fails
|
|
62
|
+
"""
|
|
63
|
+
# Execute the SQL statement and return the result as a pandas DataFrame
|
|
64
|
+
with self.engine.begin() as conn:
|
|
65
|
+
df = pd.read_sql_query(self.sa.text(args.sql), conn)
|
|
66
|
+
return df
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""MySQL implementation of SqlRunner interface."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
|
|
7
|
+
from vanna.core.tool import ToolContext
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MySQLRunner(SqlRunner):
|
|
11
|
+
"""MySQL implementation of the SqlRunner interface."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
host: str,
|
|
16
|
+
database: str,
|
|
17
|
+
user: str,
|
|
18
|
+
password: str,
|
|
19
|
+
port: int = 3306,
|
|
20
|
+
**kwargs,
|
|
21
|
+
):
|
|
22
|
+
"""Initialize with MySQL connection parameters.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
host: Database host address
|
|
26
|
+
database: Database name
|
|
27
|
+
user: Database user
|
|
28
|
+
password: Database password
|
|
29
|
+
port: Database port (default: 3306)
|
|
30
|
+
**kwargs: Additional PyMySQL connection parameters
|
|
31
|
+
"""
|
|
32
|
+
try:
|
|
33
|
+
import pymysql.cursors
|
|
34
|
+
|
|
35
|
+
self.pymysql = pymysql
|
|
36
|
+
except ImportError as e:
|
|
37
|
+
raise ImportError(
|
|
38
|
+
"PyMySQL package is required. Install with: pip install 'vanna[mysql]'"
|
|
39
|
+
) from e
|
|
40
|
+
|
|
41
|
+
self.host = host
|
|
42
|
+
self.database = database
|
|
43
|
+
self.user = user
|
|
44
|
+
self.password = password
|
|
45
|
+
self.port = port
|
|
46
|
+
self.kwargs = kwargs
|
|
47
|
+
|
|
48
|
+
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
|
|
49
|
+
"""Execute SQL query against MySQL database and return results as DataFrame.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
args: SQL query arguments
|
|
53
|
+
context: Tool execution context
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
DataFrame with query results
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
pymysql.Error: If query execution fails
|
|
60
|
+
"""
|
|
61
|
+
# Connect to the database
|
|
62
|
+
conn = self.pymysql.connect(
|
|
63
|
+
host=self.host,
|
|
64
|
+
user=self.user,
|
|
65
|
+
password=self.password,
|
|
66
|
+
database=self.database,
|
|
67
|
+
port=self.port,
|
|
68
|
+
cursorclass=self.pymysql.cursors.DictCursor,
|
|
69
|
+
**self.kwargs,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
# Ping to ensure connection is alive
|
|
74
|
+
conn.ping(reconnect=True)
|
|
75
|
+
|
|
76
|
+
cursor = conn.cursor()
|
|
77
|
+
cursor.execute(args.sql)
|
|
78
|
+
results = cursor.fetchall()
|
|
79
|
+
|
|
80
|
+
# Create a pandas dataframe from the results
|
|
81
|
+
df = pd.DataFrame(
|
|
82
|
+
results,
|
|
83
|
+
columns=[desc[0] for desc in cursor.description]
|
|
84
|
+
if cursor.description
|
|
85
|
+
else [],
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
cursor.close()
|
|
89
|
+
return df
|
|
90
|
+
|
|
91
|
+
finally:
|
|
92
|
+
conn.close()
|