vanna 0.7.9__py3-none-any.whl → 2.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +439 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +224 -217
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0rc1.dist-info/METADATA +868 -0
- vanna-2.0.0rc1.dist-info/RECORD +289 -0
- vanna-2.0.0rc1.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.9.dist-info/METADATA +0 -408
- vanna-0.7.9.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/WHEEL +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from vanna.core.llm import LlmService, LlmRequest, LlmResponse, LlmStreamChunk
|
|
8
|
+
from vanna.core.tool import ToolCall, ToolSchema
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from openai.types.responses import Response
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OpenAIResponsesService(LlmService):
|
|
15
|
+
def __init__(
|
|
16
|
+
self, api_key: Optional[str] = None, model: Optional[str] = None
|
|
17
|
+
) -> None:
|
|
18
|
+
try:
|
|
19
|
+
from openai import AsyncOpenAI
|
|
20
|
+
from openai.types.responses import Response
|
|
21
|
+
except Exception as e: # pragma: no cover
|
|
22
|
+
raise ImportError(
|
|
23
|
+
"openai package is required. Install with: pip install 'vanna[openai]'"
|
|
24
|
+
) from e
|
|
25
|
+
|
|
26
|
+
self.client = AsyncOpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY"))
|
|
27
|
+
self.model = model or os.getenv("OPENAI_MODEL", "gpt-5")
|
|
28
|
+
|
|
29
|
+
async def send_request(self, request: LlmRequest) -> LlmResponse:
|
|
30
|
+
payload = self._payload(request)
|
|
31
|
+
resp: Response = await self.client.responses.create(**payload)
|
|
32
|
+
self._debug_print("response", resp)
|
|
33
|
+
text, tools, status, usage = self._extract(resp)
|
|
34
|
+
return LlmResponse(
|
|
35
|
+
content=text,
|
|
36
|
+
tool_calls=tools or None,
|
|
37
|
+
finish_reason=status,
|
|
38
|
+
usage=usage or None,
|
|
39
|
+
metadata={"request_id": getattr(resp, "id", None)},
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
async def stream_request(
|
|
43
|
+
self, request: LlmRequest
|
|
44
|
+
) -> AsyncGenerator[LlmStreamChunk, None]:
|
|
45
|
+
payload = self._payload(request)
|
|
46
|
+
async with self.client.responses.stream(**payload) as stream:
|
|
47
|
+
async for event in stream:
|
|
48
|
+
self._debug_print("stream_event", event)
|
|
49
|
+
event_type = getattr(event, "type", None)
|
|
50
|
+
if event_type == "response.output_text.delta":
|
|
51
|
+
delta = getattr(event, "delta", None)
|
|
52
|
+
if delta:
|
|
53
|
+
yield LlmStreamChunk(content=delta)
|
|
54
|
+
final: Response = await stream.get_final_response()
|
|
55
|
+
self._debug_print("final_response", final)
|
|
56
|
+
|
|
57
|
+
_text, tools, status, _usage = self._extract(final)
|
|
58
|
+
yield LlmStreamChunk(tool_calls=tools or None, finish_reason=status)
|
|
59
|
+
|
|
60
|
+
async def validate_tools(self, tools: List[Any]) -> List[str]:
|
|
61
|
+
return [] # minimal: accept whatever's passed through
|
|
62
|
+
|
|
63
|
+
# ---- helpers ----
|
|
64
|
+
|
|
65
|
+
def _payload(self, request: LlmRequest) -> Dict[str, Any]:
|
|
66
|
+
msgs = [{"role": m.role, "content": m.content} for m in request.messages]
|
|
67
|
+
p: Dict[str, Any] = {"model": self.model, "input": msgs}
|
|
68
|
+
if request.system_prompt:
|
|
69
|
+
p["instructions"] = request.system_prompt
|
|
70
|
+
if request.max_tokens:
|
|
71
|
+
p["max_output_tokens"] = request.max_tokens
|
|
72
|
+
if request.tools:
|
|
73
|
+
p["tools"] = [self._serialize_tool(t) for t in request.tools]
|
|
74
|
+
return p
|
|
75
|
+
|
|
76
|
+
def _debug_print(self, label: str, obj: Any) -> None:
|
|
77
|
+
try:
|
|
78
|
+
payload = obj.model_dump()
|
|
79
|
+
except AttributeError:
|
|
80
|
+
try:
|
|
81
|
+
payload = obj.dict()
|
|
82
|
+
except AttributeError:
|
|
83
|
+
payload = obj
|
|
84
|
+
print(f"[OpenAIResponsesService] {label}: {payload}")
|
|
85
|
+
|
|
86
|
+
def _extract(
|
|
87
|
+
self, resp: Response
|
|
88
|
+
) -> Tuple[
|
|
89
|
+
Optional[str], Optional[List[ToolCall]], Optional[str], Optional[Dict[str, int]]
|
|
90
|
+
]:
|
|
91
|
+
text = getattr(resp, "output_text", None)
|
|
92
|
+
|
|
93
|
+
tool_calls: List[ToolCall] = []
|
|
94
|
+
for oc in getattr(resp, "output", []) or []:
|
|
95
|
+
for item in getattr(oc, "content", []) or []:
|
|
96
|
+
if getattr(item, "type", None) == "tool_call":
|
|
97
|
+
tc = getattr(item, "tool_call", None)
|
|
98
|
+
if tc and getattr(tc, "type", None) == "function":
|
|
99
|
+
fn = getattr(tc, "function", None)
|
|
100
|
+
if fn:
|
|
101
|
+
name = getattr(fn, "name", None)
|
|
102
|
+
args = getattr(fn, "arguments", None)
|
|
103
|
+
if not isinstance(args, (dict, list)):
|
|
104
|
+
try:
|
|
105
|
+
args = json.loads(args) if args else {}
|
|
106
|
+
except Exception:
|
|
107
|
+
args = {"_raw": args}
|
|
108
|
+
tool_calls.append(ToolCall(name=name, arguments=args))
|
|
109
|
+
|
|
110
|
+
usage = None
|
|
111
|
+
if getattr(resp, "usage", None):
|
|
112
|
+
usage = {
|
|
113
|
+
"input_tokens": getattr(resp.usage, "input_tokens", 0) or 0,
|
|
114
|
+
"output_tokens": getattr(resp.usage, "output_tokens", 0) or 0,
|
|
115
|
+
"total_tokens": getattr(resp.usage, "total_tokens", None)
|
|
116
|
+
or (
|
|
117
|
+
(getattr(resp.usage, "input_tokens", 0) or 0)
|
|
118
|
+
+ (getattr(resp.usage, "output_tokens", 0) or 0)
|
|
119
|
+
),
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
status = getattr(resp, "status", None) # e.g. "completed"
|
|
123
|
+
return text, (tool_calls or None), status, usage
|
|
124
|
+
|
|
125
|
+
def _serialize_tool(self, tool: Any) -> Dict[str, Any]:
|
|
126
|
+
"""Convert a tool schema into the dict format expected by OpenAI Responses."""
|
|
127
|
+
|
|
128
|
+
if isinstance(tool, ToolSchema):
|
|
129
|
+
return {
|
|
130
|
+
"type": "function",
|
|
131
|
+
"name": tool.name,
|
|
132
|
+
"description": tool.description,
|
|
133
|
+
"parameters": tool.parameters,
|
|
134
|
+
"strict": False,
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
# Support generic pydantic/BaseModel style objects without importing pydantic here.
|
|
138
|
+
if hasattr(tool, "model_dump"):
|
|
139
|
+
data = tool.model_dump()
|
|
140
|
+
if all(key in data for key in ("name", "description", "parameters")):
|
|
141
|
+
return {
|
|
142
|
+
"type": "function",
|
|
143
|
+
"name": data["name"],
|
|
144
|
+
"description": data["description"],
|
|
145
|
+
"parameters": data["parameters"],
|
|
146
|
+
"strict": data.get("strict", False),
|
|
147
|
+
}
|
|
148
|
+
return data
|
|
149
|
+
|
|
150
|
+
if isinstance(tool, dict):
|
|
151
|
+
if "type" in tool:
|
|
152
|
+
return tool
|
|
153
|
+
if all(k in tool for k in ("name", "description", "parameters")):
|
|
154
|
+
return {
|
|
155
|
+
"type": "function",
|
|
156
|
+
"name": tool["name"],
|
|
157
|
+
"description": tool["description"],
|
|
158
|
+
"parameters": tool["parameters"],
|
|
159
|
+
"strict": tool.get("strict", False),
|
|
160
|
+
}
|
|
161
|
+
return tool
|
|
162
|
+
|
|
163
|
+
raise TypeError(f"Unsupported tool schema type: {type(tool)!r}")
|
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OpenSearch vector database implementation of AgentMemory.
|
|
3
|
+
|
|
4
|
+
This implementation uses OpenSearch for distributed search and storage of tool usage patterns.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import uuid
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
import asyncio
|
|
12
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from opensearchpy import OpenSearch, helpers
|
|
16
|
+
|
|
17
|
+
OPENSEARCH_AVAILABLE = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
OPENSEARCH_AVAILABLE = False
|
|
20
|
+
|
|
21
|
+
from vanna.capabilities.agent_memory import (
|
|
22
|
+
AgentMemory,
|
|
23
|
+
TextMemory,
|
|
24
|
+
TextMemorySearchResult,
|
|
25
|
+
ToolMemory,
|
|
26
|
+
ToolMemorySearchResult,
|
|
27
|
+
)
|
|
28
|
+
from vanna.core.tool import ToolContext
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class OpenSearchAgentMemory(AgentMemory):
|
|
32
|
+
"""OpenSearch-based implementation of AgentMemory."""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
index_name: str = "tool_memories",
|
|
37
|
+
hosts: Optional[List[str]] = None,
|
|
38
|
+
http_auth: Optional[tuple] = None,
|
|
39
|
+
use_ssl: bool = False,
|
|
40
|
+
verify_certs: bool = False,
|
|
41
|
+
dimension: int = 384,
|
|
42
|
+
):
|
|
43
|
+
if not OPENSEARCH_AVAILABLE:
|
|
44
|
+
raise ImportError(
|
|
45
|
+
"OpenSearch is required for OpenSearchAgentMemory. Install with: pip install opensearch-py"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
self.index_name = index_name
|
|
49
|
+
self.hosts = hosts or ["localhost:9200"]
|
|
50
|
+
self.http_auth = http_auth
|
|
51
|
+
self.use_ssl = use_ssl
|
|
52
|
+
self.verify_certs = verify_certs
|
|
53
|
+
self.dimension = dimension
|
|
54
|
+
self._client = None
|
|
55
|
+
self._executor = ThreadPoolExecutor(max_workers=2)
|
|
56
|
+
|
|
57
|
+
def _get_client(self):
|
|
58
|
+
"""Get or create OpenSearch client."""
|
|
59
|
+
if self._client is None:
|
|
60
|
+
self._client = OpenSearch(
|
|
61
|
+
hosts=self.hosts,
|
|
62
|
+
http_auth=self.http_auth,
|
|
63
|
+
use_ssl=self.use_ssl,
|
|
64
|
+
verify_certs=self.verify_certs,
|
|
65
|
+
ssl_show_warn=False,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Create index if it doesn't exist
|
|
69
|
+
if not self._client.indices.exists(index=self.index_name):
|
|
70
|
+
index_body = {
|
|
71
|
+
"settings": {
|
|
72
|
+
"index": {"knn": True, "knn.algo_param.ef_search": 100}
|
|
73
|
+
},
|
|
74
|
+
"mappings": {
|
|
75
|
+
"properties": {
|
|
76
|
+
"memory_id": {"type": "keyword"},
|
|
77
|
+
"question": {"type": "text"},
|
|
78
|
+
"tool_name": {"type": "keyword"},
|
|
79
|
+
"args": {"type": "object", "enabled": False},
|
|
80
|
+
"timestamp": {"type": "date"},
|
|
81
|
+
"success": {"type": "boolean"},
|
|
82
|
+
"metadata": {"type": "object", "enabled": False},
|
|
83
|
+
"embedding": {
|
|
84
|
+
"type": "knn_vector",
|
|
85
|
+
"dimension": self.dimension,
|
|
86
|
+
"method": {
|
|
87
|
+
"name": "hnsw",
|
|
88
|
+
"space_type": "cosinesimil",
|
|
89
|
+
"engine": "nmslib",
|
|
90
|
+
},
|
|
91
|
+
},
|
|
92
|
+
}
|
|
93
|
+
},
|
|
94
|
+
}
|
|
95
|
+
self._client.indices.create(index=self.index_name, body=index_body)
|
|
96
|
+
|
|
97
|
+
return self._client
|
|
98
|
+
|
|
99
|
+
def _create_embedding(self, text: str) -> List[float]:
|
|
100
|
+
"""Create a simple embedding from text (placeholder)."""
|
|
101
|
+
import hashlib
|
|
102
|
+
|
|
103
|
+
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
|
|
104
|
+
return [(hash_val >> i) % 100 / 100.0 for i in range(self.dimension)]
|
|
105
|
+
|
|
106
|
+
async def save_tool_usage(
|
|
107
|
+
self,
|
|
108
|
+
question: str,
|
|
109
|
+
tool_name: str,
|
|
110
|
+
args: Dict[str, Any],
|
|
111
|
+
context: ToolContext,
|
|
112
|
+
success: bool = True,
|
|
113
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
114
|
+
) -> None:
|
|
115
|
+
"""Save a tool usage pattern."""
|
|
116
|
+
|
|
117
|
+
def _save():
|
|
118
|
+
client = self._get_client()
|
|
119
|
+
|
|
120
|
+
memory_id = str(uuid.uuid4())
|
|
121
|
+
timestamp = datetime.now().isoformat()
|
|
122
|
+
embedding = self._create_embedding(question)
|
|
123
|
+
|
|
124
|
+
document = {
|
|
125
|
+
"memory_id": memory_id,
|
|
126
|
+
"question": question,
|
|
127
|
+
"tool_name": tool_name,
|
|
128
|
+
"args": args,
|
|
129
|
+
"timestamp": timestamp,
|
|
130
|
+
"success": success,
|
|
131
|
+
"metadata": metadata or {},
|
|
132
|
+
"embedding": embedding,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
client.index(
|
|
136
|
+
index=self.index_name, body=document, id=memory_id, refresh=True
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
|
|
140
|
+
|
|
141
|
+
async def search_similar_usage(
|
|
142
|
+
self,
|
|
143
|
+
question: str,
|
|
144
|
+
context: ToolContext,
|
|
145
|
+
*,
|
|
146
|
+
limit: int = 10,
|
|
147
|
+
similarity_threshold: float = 0.7,
|
|
148
|
+
tool_name_filter: Optional[str] = None,
|
|
149
|
+
) -> List[ToolMemorySearchResult]:
|
|
150
|
+
"""Search for similar tool usage patterns."""
|
|
151
|
+
|
|
152
|
+
def _search():
|
|
153
|
+
client = self._get_client()
|
|
154
|
+
|
|
155
|
+
embedding = self._create_embedding(question)
|
|
156
|
+
|
|
157
|
+
# Build query
|
|
158
|
+
must_conditions = [{"term": {"success": True}}]
|
|
159
|
+
if tool_name_filter:
|
|
160
|
+
must_conditions.append({"term": {"tool_name": tool_name_filter}})
|
|
161
|
+
|
|
162
|
+
query = {
|
|
163
|
+
"size": limit,
|
|
164
|
+
"query": {
|
|
165
|
+
"bool": {
|
|
166
|
+
"must": must_conditions,
|
|
167
|
+
"filter": {
|
|
168
|
+
"knn": {"embedding": {"vector": embedding, "k": limit}}
|
|
169
|
+
},
|
|
170
|
+
}
|
|
171
|
+
},
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
response = client.search(index=self.index_name, body=query)
|
|
175
|
+
|
|
176
|
+
search_results = []
|
|
177
|
+
for i, hit in enumerate(response["hits"]["hits"]):
|
|
178
|
+
source = hit["_source"]
|
|
179
|
+
score = hit["_score"]
|
|
180
|
+
|
|
181
|
+
# Normalize score to 0-1 range (OpenSearch scores can vary)
|
|
182
|
+
similarity_score = min(score / 10.0, 1.0)
|
|
183
|
+
|
|
184
|
+
if similarity_score >= similarity_threshold:
|
|
185
|
+
memory = ToolMemory(
|
|
186
|
+
memory_id=source["memory_id"],
|
|
187
|
+
question=source["question"],
|
|
188
|
+
tool_name=source["tool_name"],
|
|
189
|
+
args=source["args"],
|
|
190
|
+
timestamp=source.get("timestamp"),
|
|
191
|
+
success=source.get("success", True),
|
|
192
|
+
metadata=source.get("metadata", {}),
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
search_results.append(
|
|
196
|
+
ToolMemorySearchResult(
|
|
197
|
+
memory=memory, similarity_score=similarity_score, rank=i + 1
|
|
198
|
+
)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return search_results
|
|
202
|
+
|
|
203
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
|
|
204
|
+
|
|
205
|
+
async def get_recent_memories(
|
|
206
|
+
self, context: ToolContext, limit: int = 10
|
|
207
|
+
) -> List[ToolMemory]:
|
|
208
|
+
"""Get recently added memories."""
|
|
209
|
+
|
|
210
|
+
def _get_recent():
|
|
211
|
+
client = self._get_client()
|
|
212
|
+
|
|
213
|
+
query = {
|
|
214
|
+
"size": limit,
|
|
215
|
+
"query": {"match_all": {}},
|
|
216
|
+
"sort": [{"timestamp": {"order": "desc"}}],
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
response = client.search(index=self.index_name, body=query)
|
|
220
|
+
|
|
221
|
+
memories = []
|
|
222
|
+
for hit in response["hits"]["hits"]:
|
|
223
|
+
source = hit["_source"]
|
|
224
|
+
|
|
225
|
+
memory = ToolMemory(
|
|
226
|
+
memory_id=source["memory_id"],
|
|
227
|
+
question=source["question"],
|
|
228
|
+
tool_name=source["tool_name"],
|
|
229
|
+
args=source["args"],
|
|
230
|
+
timestamp=source.get("timestamp"),
|
|
231
|
+
success=source.get("success", True),
|
|
232
|
+
metadata=source.get("metadata", {}),
|
|
233
|
+
)
|
|
234
|
+
memories.append(memory)
|
|
235
|
+
|
|
236
|
+
return memories
|
|
237
|
+
|
|
238
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
239
|
+
self._executor, _get_recent
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
|
|
243
|
+
"""Delete a memory by its ID."""
|
|
244
|
+
|
|
245
|
+
def _delete():
|
|
246
|
+
client = self._get_client()
|
|
247
|
+
|
|
248
|
+
try:
|
|
249
|
+
client.delete(index=self.index_name, id=memory_id, refresh=True)
|
|
250
|
+
return True
|
|
251
|
+
except Exception:
|
|
252
|
+
return False
|
|
253
|
+
|
|
254
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
|
|
255
|
+
|
|
256
|
+
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
|
|
257
|
+
"""Save a text memory."""
|
|
258
|
+
|
|
259
|
+
def _save():
|
|
260
|
+
client = self._get_client()
|
|
261
|
+
|
|
262
|
+
memory_id = str(uuid.uuid4())
|
|
263
|
+
timestamp = datetime.now().isoformat()
|
|
264
|
+
embedding = self._create_embedding(content)
|
|
265
|
+
|
|
266
|
+
document = {
|
|
267
|
+
"memory_id": memory_id,
|
|
268
|
+
"content": content,
|
|
269
|
+
"timestamp": timestamp,
|
|
270
|
+
"is_text_memory": True,
|
|
271
|
+
"embedding": embedding,
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
client.index(
|
|
275
|
+
index=self.index_name, body=document, id=memory_id, refresh=True
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
|
|
279
|
+
|
|
280
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
|
|
281
|
+
|
|
282
|
+
async def search_text_memories(
|
|
283
|
+
self,
|
|
284
|
+
query: str,
|
|
285
|
+
context: ToolContext,
|
|
286
|
+
*,
|
|
287
|
+
limit: int = 10,
|
|
288
|
+
similarity_threshold: float = 0.7,
|
|
289
|
+
) -> List[TextMemorySearchResult]:
|
|
290
|
+
"""Search for similar text memories."""
|
|
291
|
+
|
|
292
|
+
def _search():
|
|
293
|
+
client = self._get_client()
|
|
294
|
+
|
|
295
|
+
embedding = self._create_embedding(query)
|
|
296
|
+
|
|
297
|
+
query_body = {
|
|
298
|
+
"size": limit,
|
|
299
|
+
"query": {
|
|
300
|
+
"bool": {
|
|
301
|
+
"must": [{"term": {"is_text_memory": True}}],
|
|
302
|
+
"filter": {
|
|
303
|
+
"knn": {"embedding": {"vector": embedding, "k": limit}}
|
|
304
|
+
},
|
|
305
|
+
}
|
|
306
|
+
},
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
response = client.search(index=self.index_name, body=query_body)
|
|
310
|
+
|
|
311
|
+
search_results = []
|
|
312
|
+
for i, hit in enumerate(response["hits"]["hits"]):
|
|
313
|
+
source = hit["_source"]
|
|
314
|
+
score = hit["_score"]
|
|
315
|
+
|
|
316
|
+
similarity_score = min(score / 10.0, 1.0)
|
|
317
|
+
|
|
318
|
+
if similarity_score >= similarity_threshold:
|
|
319
|
+
memory = TextMemory(
|
|
320
|
+
memory_id=source["memory_id"],
|
|
321
|
+
content=source.get("content", ""),
|
|
322
|
+
timestamp=source.get("timestamp"),
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
search_results.append(
|
|
326
|
+
TextMemorySearchResult(
|
|
327
|
+
memory=memory, similarity_score=similarity_score, rank=i + 1
|
|
328
|
+
)
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
return search_results
|
|
332
|
+
|
|
333
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
|
|
334
|
+
|
|
335
|
+
async def get_recent_text_memories(
|
|
336
|
+
self, context: ToolContext, limit: int = 10
|
|
337
|
+
) -> List[TextMemory]:
|
|
338
|
+
"""Get recently added text memories."""
|
|
339
|
+
|
|
340
|
+
def _get_recent():
|
|
341
|
+
client = self._get_client()
|
|
342
|
+
|
|
343
|
+
query = {
|
|
344
|
+
"size": limit,
|
|
345
|
+
"query": {"term": {"is_text_memory": True}},
|
|
346
|
+
"sort": [{"timestamp": {"order": "desc"}}],
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
response = client.search(index=self.index_name, body=query)
|
|
350
|
+
|
|
351
|
+
memories = []
|
|
352
|
+
for hit in response["hits"]["hits"]:
|
|
353
|
+
source = hit["_source"]
|
|
354
|
+
|
|
355
|
+
memory = TextMemory(
|
|
356
|
+
memory_id=source["memory_id"],
|
|
357
|
+
content=source.get("content", ""),
|
|
358
|
+
timestamp=source.get("timestamp"),
|
|
359
|
+
)
|
|
360
|
+
memories.append(memory)
|
|
361
|
+
|
|
362
|
+
return memories
|
|
363
|
+
|
|
364
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
365
|
+
self._executor, _get_recent
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
|
|
369
|
+
"""Delete a text memory by its ID."""
|
|
370
|
+
|
|
371
|
+
def _delete():
|
|
372
|
+
client = self._get_client()
|
|
373
|
+
|
|
374
|
+
try:
|
|
375
|
+
client.delete(index=self.index_name, id=memory_id, refresh=True)
|
|
376
|
+
return True
|
|
377
|
+
except Exception:
|
|
378
|
+
return False
|
|
379
|
+
|
|
380
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
|
|
381
|
+
|
|
382
|
+
async def clear_memories(
|
|
383
|
+
self,
|
|
384
|
+
context: ToolContext,
|
|
385
|
+
tool_name: Optional[str] = None,
|
|
386
|
+
before_date: Optional[str] = None,
|
|
387
|
+
) -> int:
|
|
388
|
+
"""Clear stored memories."""
|
|
389
|
+
|
|
390
|
+
def _clear():
|
|
391
|
+
client = self._get_client()
|
|
392
|
+
|
|
393
|
+
# Build query
|
|
394
|
+
must_conditions = []
|
|
395
|
+
if tool_name:
|
|
396
|
+
must_conditions.append({"term": {"tool_name": tool_name}})
|
|
397
|
+
if before_date:
|
|
398
|
+
must_conditions.append({"range": {"timestamp": {"lt": before_date}}})
|
|
399
|
+
|
|
400
|
+
if must_conditions:
|
|
401
|
+
query = {"query": {"bool": {"must": must_conditions}}}
|
|
402
|
+
else:
|
|
403
|
+
query = {"query": {"match_all": {}}}
|
|
404
|
+
|
|
405
|
+
response = client.delete_by_query(
|
|
406
|
+
index=self.index_name, body=query, refresh=True
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
return response.get("deleted", 0)
|
|
410
|
+
|
|
411
|
+
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""Oracle implementation of SqlRunner interface."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
|
|
7
|
+
from vanna.core.tool import ToolContext
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OracleRunner(SqlRunner):
|
|
11
|
+
"""Oracle implementation of the SqlRunner interface."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, user: str, password: str, dsn: str, **kwargs):
|
|
14
|
+
"""Initialize with Oracle connection parameters.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
user: Oracle database user name
|
|
18
|
+
password: Oracle database user password
|
|
19
|
+
dsn: Oracle database host - format: host:port/sid
|
|
20
|
+
**kwargs: Additional oracledb connection parameters
|
|
21
|
+
"""
|
|
22
|
+
try:
|
|
23
|
+
import oracledb
|
|
24
|
+
|
|
25
|
+
self.oracledb = oracledb
|
|
26
|
+
except ImportError as e:
|
|
27
|
+
raise ImportError(
|
|
28
|
+
"oracledb package is required. Install with: pip install 'vanna[oracle]'"
|
|
29
|
+
) from e
|
|
30
|
+
|
|
31
|
+
self.user = user
|
|
32
|
+
self.password = password
|
|
33
|
+
self.dsn = dsn
|
|
34
|
+
self.kwargs = kwargs
|
|
35
|
+
|
|
36
|
+
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
|
|
37
|
+
"""Execute SQL query against Oracle database and return results as DataFrame.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
args: SQL query arguments
|
|
41
|
+
context: Tool execution context
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
DataFrame with query results
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
oracledb.Error: If query execution fails
|
|
48
|
+
"""
|
|
49
|
+
# Connect to the database
|
|
50
|
+
conn = self.oracledb.connect(
|
|
51
|
+
user=self.user, password=self.password, dsn=self.dsn, **self.kwargs
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
cursor = conn.cursor()
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
# Strip and remove trailing semicolons (Oracle doesn't like them)
|
|
58
|
+
sql = args.sql.rstrip()
|
|
59
|
+
if sql.endswith(";"):
|
|
60
|
+
sql = sql[:-1]
|
|
61
|
+
|
|
62
|
+
# Execute the query
|
|
63
|
+
cursor.execute(sql)
|
|
64
|
+
results = cursor.fetchall()
|
|
65
|
+
|
|
66
|
+
# Create a pandas dataframe from the results
|
|
67
|
+
df = pd.DataFrame(results, columns=[desc[0] for desc in cursor.description])
|
|
68
|
+
return df
|
|
69
|
+
|
|
70
|
+
except self.oracledb.Error:
|
|
71
|
+
conn.rollback()
|
|
72
|
+
raise
|
|
73
|
+
finally:
|
|
74
|
+
cursor.close()
|
|
75
|
+
conn.close()
|