vanna 0.7.9__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +461 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +224 -217
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0.dist-info/METADATA +485 -0
- vanna-2.0.0.dist-info/RECORD +289 -0
- vanna-2.0.0.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.9.dist-info/METADATA +0 -408
- vanna-0.7.9.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Google Gemini LLM service implementation.
|
|
3
|
+
|
|
4
|
+
Implements the LlmService interface using Google's Gen AI SDK
|
|
5
|
+
(google-genai). Supports non-streaming and streaming text output,
|
|
6
|
+
as well as function calling (tool use).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
from vanna.core.llm import (
|
|
19
|
+
LlmService,
|
|
20
|
+
LlmRequest,
|
|
21
|
+
LlmResponse,
|
|
22
|
+
LlmStreamChunk,
|
|
23
|
+
)
|
|
24
|
+
from vanna.core.tool import ToolCall, ToolSchema
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GeminiLlmService(LlmService):
|
|
28
|
+
"""Google Gemini-backed LLM service.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model: Gemini model name (e.g., "gemini-2.5-pro", "gemini-2.5-flash").
|
|
32
|
+
Defaults to "gemini-2.5-pro". Can also be set via GEMINI_MODEL env var.
|
|
33
|
+
api_key: API key; falls back to env `GOOGLE_API_KEY` or `GEMINI_API_KEY`.
|
|
34
|
+
GOOGLE_API_KEY takes precedence if both are set.
|
|
35
|
+
temperature: Temperature for generation (0.0-2.0). Default 0.7.
|
|
36
|
+
extra_config: Extra kwargs forwarded to GenerateContentConfig.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
model: Optional[str] = None,
|
|
42
|
+
api_key: Optional[str] = None,
|
|
43
|
+
temperature: float = 0.7,
|
|
44
|
+
**extra_config: Any,
|
|
45
|
+
) -> None:
|
|
46
|
+
try:
|
|
47
|
+
from google import genai
|
|
48
|
+
from google.genai import types
|
|
49
|
+
except Exception as e: # pragma: no cover
|
|
50
|
+
raise ImportError(
|
|
51
|
+
"google-genai package is required. "
|
|
52
|
+
"Install with: pip install 'vanna[gemini]'"
|
|
53
|
+
) from e
|
|
54
|
+
|
|
55
|
+
self.model_name = model or os.getenv("GEMINI_MODEL", "gemini-2.5-pro")
|
|
56
|
+
# Check GOOGLE_API_KEY first (takes precedence), then GEMINI_API_KEY
|
|
57
|
+
api_key = api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
|
58
|
+
|
|
59
|
+
if not api_key:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"Google API key is required. Set GOOGLE_API_KEY or GEMINI_API_KEY "
|
|
62
|
+
"environment variable, or pass api_key parameter."
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Store modules for use in methods
|
|
66
|
+
self._genai = genai
|
|
67
|
+
self._types = types
|
|
68
|
+
|
|
69
|
+
# Create client
|
|
70
|
+
self._client = genai.Client(api_key=api_key)
|
|
71
|
+
|
|
72
|
+
# Store generation config
|
|
73
|
+
self.temperature = temperature
|
|
74
|
+
self.extra_config = extra_config
|
|
75
|
+
|
|
76
|
+
async def send_request(self, request: LlmRequest) -> LlmResponse:
|
|
77
|
+
"""Send a non-streaming request to Gemini and return the response."""
|
|
78
|
+
contents, config = self._build_payload(request)
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
# Generate content
|
|
82
|
+
response = self._client.models.generate_content(
|
|
83
|
+
model=self.model_name,
|
|
84
|
+
contents=contents,
|
|
85
|
+
config=config,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
logger.info(f"Gemini response: {response}")
|
|
89
|
+
|
|
90
|
+
# Parse response
|
|
91
|
+
text_content, tool_calls = self._parse_response(response)
|
|
92
|
+
|
|
93
|
+
# Extract usage information
|
|
94
|
+
usage: Dict[str, int] = {}
|
|
95
|
+
if hasattr(response, "usage_metadata"):
|
|
96
|
+
try:
|
|
97
|
+
usage = {
|
|
98
|
+
"prompt_tokens": int(
|
|
99
|
+
response.usage_metadata.prompt_token_count
|
|
100
|
+
),
|
|
101
|
+
"completion_tokens": int(
|
|
102
|
+
response.usage_metadata.candidates_token_count
|
|
103
|
+
),
|
|
104
|
+
"total_tokens": int(response.usage_metadata.total_token_count),
|
|
105
|
+
}
|
|
106
|
+
except Exception:
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
# Get finish reason
|
|
110
|
+
finish_reason = None
|
|
111
|
+
if response.candidates:
|
|
112
|
+
finish_reason = str(response.candidates[0].finish_reason).lower()
|
|
113
|
+
|
|
114
|
+
return LlmResponse(
|
|
115
|
+
content=text_content or None,
|
|
116
|
+
tool_calls=tool_calls or None,
|
|
117
|
+
finish_reason=finish_reason,
|
|
118
|
+
usage=usage or None,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
except Exception as e:
|
|
122
|
+
logger.error(f"Error calling Gemini API: {e}")
|
|
123
|
+
raise
|
|
124
|
+
|
|
125
|
+
async def stream_request(
|
|
126
|
+
self, request: LlmRequest
|
|
127
|
+
) -> AsyncGenerator[LlmStreamChunk, None]:
|
|
128
|
+
"""Stream a request to Gemini.
|
|
129
|
+
|
|
130
|
+
Yields text chunks as they arrive. Emits tool calls at the end.
|
|
131
|
+
"""
|
|
132
|
+
contents, config = self._build_payload(request)
|
|
133
|
+
|
|
134
|
+
logger.info(f"Gemini streaming request with model: {self.model_name}")
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
# Stream content
|
|
138
|
+
stream = self._client.models.generate_content_stream(
|
|
139
|
+
model=self.model_name,
|
|
140
|
+
contents=contents,
|
|
141
|
+
config=config,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Accumulate chunks for tool calls
|
|
145
|
+
accumulated_chunks = []
|
|
146
|
+
|
|
147
|
+
for chunk in stream:
|
|
148
|
+
accumulated_chunks.append(chunk)
|
|
149
|
+
|
|
150
|
+
# Yield text content as it arrives
|
|
151
|
+
if hasattr(chunk, "text") and chunk.text:
|
|
152
|
+
yield LlmStreamChunk(content=chunk.text)
|
|
153
|
+
|
|
154
|
+
# After stream completes, check for tool calls in accumulated response
|
|
155
|
+
if accumulated_chunks:
|
|
156
|
+
final_chunk = accumulated_chunks[-1]
|
|
157
|
+
_, tool_calls = self._parse_response_chunk(final_chunk)
|
|
158
|
+
|
|
159
|
+
finish_reason = None
|
|
160
|
+
if final_chunk.candidates:
|
|
161
|
+
finish_reason = str(final_chunk.candidates[0].finish_reason).lower()
|
|
162
|
+
|
|
163
|
+
if tool_calls:
|
|
164
|
+
yield LlmStreamChunk(
|
|
165
|
+
tool_calls=tool_calls,
|
|
166
|
+
finish_reason=finish_reason,
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
yield LlmStreamChunk(finish_reason=finish_reason or "stop")
|
|
170
|
+
|
|
171
|
+
except Exception as e:
|
|
172
|
+
logger.error(f"Error streaming from Gemini API: {e}")
|
|
173
|
+
raise
|
|
174
|
+
|
|
175
|
+
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
|
|
176
|
+
"""Basic validation of tool schemas for Gemini."""
|
|
177
|
+
errors: List[str] = []
|
|
178
|
+
for t in tools:
|
|
179
|
+
if not t.name:
|
|
180
|
+
errors.append("Tool name is required")
|
|
181
|
+
if not t.description:
|
|
182
|
+
errors.append(f"Tool {t.name}: description is required")
|
|
183
|
+
return errors
|
|
184
|
+
|
|
185
|
+
# Internal helpers
|
|
186
|
+
def _build_payload(self, request: LlmRequest) -> tuple[List[Any], Any]:
|
|
187
|
+
"""Build the payload for Gemini API.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Tuple of (contents, config)
|
|
191
|
+
"""
|
|
192
|
+
# Build contents (messages) for Gemini
|
|
193
|
+
contents = []
|
|
194
|
+
|
|
195
|
+
# System prompt handling - Gemini supports system instructions in config
|
|
196
|
+
system_instruction = None
|
|
197
|
+
if request.system_prompt:
|
|
198
|
+
system_instruction = request.system_prompt
|
|
199
|
+
|
|
200
|
+
for m in request.messages:
|
|
201
|
+
# Map roles: user -> user, assistant -> model, tool -> function
|
|
202
|
+
if m.role == "user":
|
|
203
|
+
contents.append(
|
|
204
|
+
self._types.Content(
|
|
205
|
+
role="user", parts=[self._types.Part(text=m.content)]
|
|
206
|
+
)
|
|
207
|
+
)
|
|
208
|
+
elif m.role == "assistant":
|
|
209
|
+
parts = []
|
|
210
|
+
|
|
211
|
+
# Add text content if present
|
|
212
|
+
if m.content and m.content.strip():
|
|
213
|
+
parts.append(self._types.Part(text=m.content))
|
|
214
|
+
|
|
215
|
+
# Add tool calls if present
|
|
216
|
+
if m.tool_calls:
|
|
217
|
+
for tc in m.tool_calls:
|
|
218
|
+
parts.append(
|
|
219
|
+
self._types.Part(
|
|
220
|
+
function_call=self._types.FunctionCall(
|
|
221
|
+
name=tc.name, args=tc.arguments
|
|
222
|
+
)
|
|
223
|
+
)
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
if parts:
|
|
227
|
+
contents.append(self._types.Content(role="model", parts=parts))
|
|
228
|
+
|
|
229
|
+
elif m.role == "tool":
|
|
230
|
+
# Tool results in Gemini format
|
|
231
|
+
if m.tool_call_id:
|
|
232
|
+
# Parse the content as JSON if possible
|
|
233
|
+
try:
|
|
234
|
+
response_content = json.loads(m.content)
|
|
235
|
+
except (json.JSONDecodeError, TypeError):
|
|
236
|
+
response_content = {"result": m.content}
|
|
237
|
+
|
|
238
|
+
# Extract function name from tool_call_id or use a default
|
|
239
|
+
function_name = m.tool_call_id.replace("call_", "")
|
|
240
|
+
|
|
241
|
+
contents.append(
|
|
242
|
+
self._types.Content(
|
|
243
|
+
role="function",
|
|
244
|
+
parts=[
|
|
245
|
+
self._types.Part(
|
|
246
|
+
function_response=self._types.FunctionResponse(
|
|
247
|
+
name=function_name, response=response_content
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
],
|
|
251
|
+
)
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Build tools configuration if tools are provided
|
|
255
|
+
tools = None
|
|
256
|
+
if request.tools:
|
|
257
|
+
function_declarations = []
|
|
258
|
+
for tool in request.tools:
|
|
259
|
+
# Clean schema to remove unsupported fields
|
|
260
|
+
cleaned_parameters = self._clean_schema_for_gemini(tool.parameters)
|
|
261
|
+
|
|
262
|
+
function_declarations.append(
|
|
263
|
+
{
|
|
264
|
+
"name": tool.name,
|
|
265
|
+
"description": tool.description,
|
|
266
|
+
"parameters": cleaned_parameters,
|
|
267
|
+
}
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
if function_declarations:
|
|
271
|
+
tools = [self._types.Tool(function_declarations=function_declarations)]
|
|
272
|
+
|
|
273
|
+
# Build generation config
|
|
274
|
+
config_dict = {
|
|
275
|
+
"temperature": request.temperature,
|
|
276
|
+
**self.extra_config,
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
if request.max_tokens is not None:
|
|
280
|
+
config_dict["max_output_tokens"] = request.max_tokens
|
|
281
|
+
|
|
282
|
+
if tools:
|
|
283
|
+
config_dict["tools"] = tools
|
|
284
|
+
|
|
285
|
+
if system_instruction:
|
|
286
|
+
config_dict["system_instruction"] = system_instruction
|
|
287
|
+
|
|
288
|
+
config = self._types.GenerateContentConfig(**config_dict)
|
|
289
|
+
|
|
290
|
+
return contents, config
|
|
291
|
+
|
|
292
|
+
def _parse_response(self, response: Any) -> tuple[str, List[ToolCall]]:
|
|
293
|
+
"""Parse a Gemini response into text and tool calls."""
|
|
294
|
+
text_parts: List[str] = []
|
|
295
|
+
tool_calls: List[ToolCall] = []
|
|
296
|
+
|
|
297
|
+
if not response.candidates:
|
|
298
|
+
return "", []
|
|
299
|
+
|
|
300
|
+
candidate = response.candidates[0]
|
|
301
|
+
|
|
302
|
+
if (
|
|
303
|
+
hasattr(candidate, "content")
|
|
304
|
+
and candidate.content
|
|
305
|
+
and hasattr(candidate.content, "parts")
|
|
306
|
+
and candidate.content.parts
|
|
307
|
+
):
|
|
308
|
+
for part in candidate.content.parts:
|
|
309
|
+
# Check for text content
|
|
310
|
+
if hasattr(part, "text") and part.text:
|
|
311
|
+
text_parts.append(part.text)
|
|
312
|
+
|
|
313
|
+
# Check for function calls
|
|
314
|
+
if hasattr(part, "function_call") and part.function_call:
|
|
315
|
+
fc = part.function_call
|
|
316
|
+
# Convert function call to ToolCall
|
|
317
|
+
tool_calls.append(
|
|
318
|
+
ToolCall(
|
|
319
|
+
id=f"call_{fc.name}", # Generate an ID
|
|
320
|
+
name=fc.name,
|
|
321
|
+
arguments=dict(fc.args) if hasattr(fc, "args") else {},
|
|
322
|
+
)
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
text_content = "".join(text_parts)
|
|
326
|
+
return text_content, tool_calls
|
|
327
|
+
|
|
328
|
+
def _parse_response_chunk(self, chunk: Any) -> tuple[str, List[ToolCall]]:
|
|
329
|
+
"""Parse a streaming chunk (same logic as _parse_response)."""
|
|
330
|
+
return self._parse_response(chunk)
|
|
331
|
+
|
|
332
|
+
def _clean_schema_for_gemini(self, schema: Dict[str, Any]) -> Dict[str, Any]:
|
|
333
|
+
"""Clean JSON Schema to only include fields supported by Gemini.
|
|
334
|
+
|
|
335
|
+
Gemini only supports a subset of OpenAPI schema. This removes unsupported
|
|
336
|
+
fields like 'title', 'default', '$schema', etc.
|
|
337
|
+
|
|
338
|
+
Supported fields:
|
|
339
|
+
- type, description, enum
|
|
340
|
+
- properties, required, items (for objects/arrays)
|
|
341
|
+
"""
|
|
342
|
+
if not isinstance(schema, dict):
|
|
343
|
+
return schema
|
|
344
|
+
|
|
345
|
+
# Fields that Gemini supports
|
|
346
|
+
allowed_fields = {
|
|
347
|
+
"type",
|
|
348
|
+
"description",
|
|
349
|
+
"enum",
|
|
350
|
+
"properties",
|
|
351
|
+
"required",
|
|
352
|
+
"items",
|
|
353
|
+
"format",
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
cleaned = {}
|
|
357
|
+
for key, value in schema.items():
|
|
358
|
+
if key in allowed_fields:
|
|
359
|
+
# Recursively clean nested schemas
|
|
360
|
+
if key == "properties" and isinstance(value, dict):
|
|
361
|
+
cleaned[key] = {
|
|
362
|
+
prop_name: self._clean_schema_for_gemini(prop_schema)
|
|
363
|
+
for prop_name, prop_schema in value.items()
|
|
364
|
+
}
|
|
365
|
+
elif key == "items" and isinstance(value, dict):
|
|
366
|
+
cleaned[key] = self._clean_schema_for_gemini(value)
|
|
367
|
+
else:
|
|
368
|
+
cleaned[key] = value
|
|
369
|
+
|
|
370
|
+
return cleaned
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Hive implementation of SqlRunner interface."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
|
|
7
|
+
from vanna.core.tool import ToolContext
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class HiveRunner(SqlRunner):
|
|
11
|
+
"""Hive implementation of the SqlRunner interface."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
host: str,
|
|
16
|
+
database: str = "default",
|
|
17
|
+
user: Optional[str] = None,
|
|
18
|
+
password: Optional[str] = None,
|
|
19
|
+
port: int = 10000,
|
|
20
|
+
auth: str = "CUSTOM",
|
|
21
|
+
**kwargs,
|
|
22
|
+
):
|
|
23
|
+
"""Initialize with Hive connection parameters.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
host: The host of the Hive database
|
|
27
|
+
database: The name of the database to connect to (default: 'default')
|
|
28
|
+
user: The username to use for authentication
|
|
29
|
+
password: The password to use for authentication
|
|
30
|
+
port: The port to use for the connection (default: 10000)
|
|
31
|
+
auth: The authentication method to use (default: 'CUSTOM')
|
|
32
|
+
**kwargs: Additional pyhive connection parameters
|
|
33
|
+
"""
|
|
34
|
+
try:
|
|
35
|
+
from pyhive import hive
|
|
36
|
+
|
|
37
|
+
self.hive = hive
|
|
38
|
+
except ImportError as e:
|
|
39
|
+
raise ImportError(
|
|
40
|
+
"pyhive package is required. Install with: pip install pyhive"
|
|
41
|
+
) from e
|
|
42
|
+
|
|
43
|
+
self.host = host
|
|
44
|
+
self.database = database
|
|
45
|
+
self.user = user
|
|
46
|
+
self.password = password
|
|
47
|
+
self.port = port
|
|
48
|
+
self.auth = auth
|
|
49
|
+
self.kwargs = kwargs
|
|
50
|
+
|
|
51
|
+
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
|
|
52
|
+
"""Execute SQL query against Hive database and return results as DataFrame.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
args: SQL query arguments
|
|
56
|
+
context: Tool execution context
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
DataFrame with query results
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
hive.Error: If query execution fails
|
|
63
|
+
"""
|
|
64
|
+
# Connect to the database
|
|
65
|
+
conn = self.hive.Connection(
|
|
66
|
+
host=self.host,
|
|
67
|
+
username=self.user,
|
|
68
|
+
password=self.password,
|
|
69
|
+
database=self.database,
|
|
70
|
+
port=self.port,
|
|
71
|
+
auth=self.auth,
|
|
72
|
+
**self.kwargs,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
cursor = conn.cursor()
|
|
77
|
+
cursor.execute(args.sql)
|
|
78
|
+
results = cursor.fetchall()
|
|
79
|
+
|
|
80
|
+
# Create a pandas dataframe from the results
|
|
81
|
+
df = pd.DataFrame(results, columns=[desc[0] for desc in cursor.description])
|
|
82
|
+
|
|
83
|
+
cursor.close()
|
|
84
|
+
return df
|
|
85
|
+
|
|
86
|
+
finally:
|
|
87
|
+
conn.close()
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local integration.
|
|
3
|
+
|
|
4
|
+
This module provides built-in local implementations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .audit import LoggingAuditLogger
|
|
8
|
+
from .file_system import LocalFileSystem
|
|
9
|
+
from .storage import MemoryConversationStore
|
|
10
|
+
from .file_system_conversation_store import FileSystemConversationStore
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"MemoryConversationStore",
|
|
14
|
+
"FileSystemConversationStore",
|
|
15
|
+
"LocalFileSystem",
|
|
16
|
+
"LoggingAuditLogger",
|
|
17
|
+
]
|