vanna 0.7.9__py3-none-any.whl → 2.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +439 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +224 -217
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0rc1.dist-info/METADATA +868 -0
- vanna-2.0.0rc1.dist-info/RECORD +289 -0
- vanna-2.0.0rc1.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.9.dist-info/METADATA +0 -408
- vanna-0.7.9.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/WHEEL +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Ollama LLM service implementation.
|
|
3
|
+
|
|
4
|
+
This module provides an implementation of the LlmService interface backed by
|
|
5
|
+
Ollama's local LLM API. It supports non-streaming responses and streaming
|
|
6
|
+
of text content. Tool calling support depends on the Ollama model being used.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional
|
|
14
|
+
|
|
15
|
+
from vanna.core.llm import (
|
|
16
|
+
LlmService,
|
|
17
|
+
LlmRequest,
|
|
18
|
+
LlmResponse,
|
|
19
|
+
LlmStreamChunk,
|
|
20
|
+
)
|
|
21
|
+
from vanna.core.tool import ToolCall, ToolSchema
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OllamaLlmService(LlmService):
|
|
25
|
+
"""Ollama-backed LLM service for local model inference.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
model: Ollama model name (e.g., "gpt-oss:20b").
|
|
29
|
+
host: Ollama server URL; defaults to "http://localhost:11434" or env `OLLAMA_HOST`.
|
|
30
|
+
timeout: Request timeout in seconds; defaults to 240.
|
|
31
|
+
num_ctx: Context window size; defaults to 8192.
|
|
32
|
+
temperature: Sampling temperature; defaults to 0.7.
|
|
33
|
+
extra_options: Additional options passed to Ollama (e.g., num_predict, top_k, top_p).
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
model: str,
|
|
39
|
+
host: Optional[str] = None,
|
|
40
|
+
timeout: float = 240.0,
|
|
41
|
+
num_ctx: int = 8192,
|
|
42
|
+
temperature: float = 0.7,
|
|
43
|
+
**extra_options: Any,
|
|
44
|
+
) -> None:
|
|
45
|
+
try:
|
|
46
|
+
import ollama
|
|
47
|
+
except ImportError as e:
|
|
48
|
+
raise ImportError(
|
|
49
|
+
"ollama package is required. Install with: pip install 'vanna[ollama]' or pip install ollama"
|
|
50
|
+
) from e
|
|
51
|
+
|
|
52
|
+
if not model:
|
|
53
|
+
raise ValueError("model parameter is required for Ollama")
|
|
54
|
+
|
|
55
|
+
self.model = model
|
|
56
|
+
self.host = host or os.getenv("OLLAMA_HOST", "http://localhost:11434")
|
|
57
|
+
self.timeout = timeout
|
|
58
|
+
self.num_ctx = num_ctx
|
|
59
|
+
self.temperature = temperature
|
|
60
|
+
self.extra_options = extra_options
|
|
61
|
+
|
|
62
|
+
# Create Ollama client
|
|
63
|
+
self._client = ollama.Client(host=self.host, timeout=timeout)
|
|
64
|
+
|
|
65
|
+
async def send_request(self, request: LlmRequest) -> LlmResponse:
|
|
66
|
+
"""Send a non-streaming request to Ollama and return the response."""
|
|
67
|
+
payload = self._build_payload(request)
|
|
68
|
+
|
|
69
|
+
# Call the Ollama API
|
|
70
|
+
try:
|
|
71
|
+
resp = self._client.chat(**payload)
|
|
72
|
+
except Exception as e:
|
|
73
|
+
raise RuntimeError(f"Ollama request failed: {str(e)}") from e
|
|
74
|
+
|
|
75
|
+
# Extract message from response
|
|
76
|
+
message = resp.get("message", {})
|
|
77
|
+
content = message.get("content")
|
|
78
|
+
tool_calls = self._extract_tool_calls_from_message(message)
|
|
79
|
+
|
|
80
|
+
# Extract usage information if available
|
|
81
|
+
usage: Dict[str, int] = {}
|
|
82
|
+
if "prompt_eval_count" in resp or "eval_count" in resp:
|
|
83
|
+
usage = {
|
|
84
|
+
"prompt_tokens": resp.get("prompt_eval_count", 0),
|
|
85
|
+
"completion_tokens": resp.get("eval_count", 0),
|
|
86
|
+
"total_tokens": resp.get("prompt_eval_count", 0)
|
|
87
|
+
+ resp.get("eval_count", 0),
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
return LlmResponse(
|
|
91
|
+
content=content,
|
|
92
|
+
tool_calls=tool_calls or None,
|
|
93
|
+
finish_reason=resp.get("done_reason")
|
|
94
|
+
or ("stop" if resp.get("done") else None),
|
|
95
|
+
usage=usage or None,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
async def stream_request(
|
|
99
|
+
self, request: LlmRequest
|
|
100
|
+
) -> AsyncGenerator[LlmStreamChunk, None]:
|
|
101
|
+
"""Stream a request to Ollama.
|
|
102
|
+
|
|
103
|
+
Emits `LlmStreamChunk` for textual deltas as they arrive. Tool calls are
|
|
104
|
+
accumulated and emitted in a final chunk when the stream ends.
|
|
105
|
+
"""
|
|
106
|
+
payload = self._build_payload(request)
|
|
107
|
+
|
|
108
|
+
# Ollama streaming
|
|
109
|
+
try:
|
|
110
|
+
stream = self._client.chat(**payload, stream=True)
|
|
111
|
+
except Exception as e:
|
|
112
|
+
raise RuntimeError(f"Ollama streaming request failed: {str(e)}") from e
|
|
113
|
+
|
|
114
|
+
# Accumulate tool calls if present
|
|
115
|
+
accumulated_tool_calls: List[ToolCall] = []
|
|
116
|
+
last_finish: Optional[str] = None
|
|
117
|
+
|
|
118
|
+
for chunk in stream:
|
|
119
|
+
message = chunk.get("message", {})
|
|
120
|
+
|
|
121
|
+
# Yield text content
|
|
122
|
+
content = message.get("content")
|
|
123
|
+
if content:
|
|
124
|
+
yield LlmStreamChunk(content=content)
|
|
125
|
+
|
|
126
|
+
# Accumulate tool calls
|
|
127
|
+
tool_calls = self._extract_tool_calls_from_message(message)
|
|
128
|
+
if tool_calls:
|
|
129
|
+
accumulated_tool_calls.extend(tool_calls)
|
|
130
|
+
|
|
131
|
+
# Track finish reason
|
|
132
|
+
if chunk.get("done"):
|
|
133
|
+
last_finish = chunk.get("done_reason", "stop")
|
|
134
|
+
|
|
135
|
+
# Emit final chunk with tool calls if any
|
|
136
|
+
if accumulated_tool_calls:
|
|
137
|
+
yield LlmStreamChunk(
|
|
138
|
+
tool_calls=accumulated_tool_calls, finish_reason=last_finish or "stop"
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
# Emit terminal chunk to signal completion
|
|
142
|
+
yield LlmStreamChunk(finish_reason=last_finish or "stop")
|
|
143
|
+
|
|
144
|
+
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
|
|
145
|
+
"""Validate tool schemas. Returns a list of error messages."""
|
|
146
|
+
errors: List[str] = []
|
|
147
|
+
# Basic validation; Ollama model support for tools varies
|
|
148
|
+
for t in tools:
|
|
149
|
+
if not t.name:
|
|
150
|
+
errors.append(f"Tool must have a name")
|
|
151
|
+
if not t.description:
|
|
152
|
+
errors.append(f"Tool '{t.name}' should have a description")
|
|
153
|
+
return errors
|
|
154
|
+
|
|
155
|
+
# Internal helpers
|
|
156
|
+
def _build_payload(self, request: LlmRequest) -> Dict[str, Any]:
|
|
157
|
+
"""Build the Ollama chat payload from LlmRequest."""
|
|
158
|
+
messages: List[Dict[str, Any]] = []
|
|
159
|
+
|
|
160
|
+
# Add system prompt as first message if provided
|
|
161
|
+
if request.system_prompt:
|
|
162
|
+
messages.append({"role": "system", "content": request.system_prompt})
|
|
163
|
+
|
|
164
|
+
# Convert messages to Ollama format
|
|
165
|
+
for m in request.messages:
|
|
166
|
+
msg: Dict[str, Any] = {"role": m.role, "content": m.content or ""}
|
|
167
|
+
|
|
168
|
+
# Handle tool calls in assistant messages
|
|
169
|
+
if m.role == "assistant" and m.tool_calls:
|
|
170
|
+
# Some Ollama models support tool_calls in message
|
|
171
|
+
tool_calls_payload = []
|
|
172
|
+
for tc in m.tool_calls:
|
|
173
|
+
tool_calls_payload.append(
|
|
174
|
+
{"function": {"name": tc.name, "arguments": tc.arguments}}
|
|
175
|
+
)
|
|
176
|
+
msg["tool_calls"] = tool_calls_payload
|
|
177
|
+
|
|
178
|
+
messages.append(msg)
|
|
179
|
+
|
|
180
|
+
# Build tools array if tools are provided
|
|
181
|
+
tools_payload: Optional[List[Dict[str, Any]]] = None
|
|
182
|
+
if request.tools:
|
|
183
|
+
tools_payload = []
|
|
184
|
+
for t in request.tools:
|
|
185
|
+
tools_payload.append(
|
|
186
|
+
{
|
|
187
|
+
"type": "function",
|
|
188
|
+
"function": {
|
|
189
|
+
"name": t.name,
|
|
190
|
+
"description": t.description,
|
|
191
|
+
"parameters": t.parameters,
|
|
192
|
+
},
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Build options
|
|
197
|
+
options: Dict[str, Any] = {
|
|
198
|
+
"num_ctx": self.num_ctx,
|
|
199
|
+
"temperature": self.temperature,
|
|
200
|
+
**self.extra_options,
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
# Build final payload
|
|
204
|
+
payload: Dict[str, Any] = {
|
|
205
|
+
"model": self.model,
|
|
206
|
+
"messages": messages,
|
|
207
|
+
"options": options,
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
# Add tools if provided (note: not all Ollama models support tools)
|
|
211
|
+
if tools_payload:
|
|
212
|
+
payload["tools"] = tools_payload
|
|
213
|
+
|
|
214
|
+
return payload
|
|
215
|
+
|
|
216
|
+
def _extract_tool_calls_from_message(
|
|
217
|
+
self, message: Dict[str, Any]
|
|
218
|
+
) -> List[ToolCall]:
|
|
219
|
+
"""Extract tool calls from Ollama message."""
|
|
220
|
+
tool_calls: List[ToolCall] = []
|
|
221
|
+
|
|
222
|
+
# Check for tool_calls in message
|
|
223
|
+
raw_tool_calls = message.get("tool_calls", [])
|
|
224
|
+
if not raw_tool_calls:
|
|
225
|
+
return tool_calls
|
|
226
|
+
|
|
227
|
+
for idx, tc in enumerate(raw_tool_calls):
|
|
228
|
+
fn = tc.get("function", {})
|
|
229
|
+
name = fn.get("name")
|
|
230
|
+
if not name:
|
|
231
|
+
continue
|
|
232
|
+
|
|
233
|
+
# Parse arguments
|
|
234
|
+
arguments = fn.get("arguments", {})
|
|
235
|
+
if isinstance(arguments, str):
|
|
236
|
+
try:
|
|
237
|
+
arguments = json.loads(arguments)
|
|
238
|
+
except Exception:
|
|
239
|
+
arguments = {"_raw": arguments}
|
|
240
|
+
|
|
241
|
+
if not isinstance(arguments, dict):
|
|
242
|
+
arguments = {"args": arguments}
|
|
243
|
+
|
|
244
|
+
tool_calls.append(
|
|
245
|
+
ToolCall(
|
|
246
|
+
id=tc.get("id", f"tool_call_{idx}"),
|
|
247
|
+
name=name,
|
|
248
|
+
arguments=arguments,
|
|
249
|
+
)
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
return tool_calls
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OpenAI LLM service implementation.
|
|
3
|
+
|
|
4
|
+
This module provides an implementation of the LlmService interface backed by
|
|
5
|
+
OpenAI's Chat Completions API (openai>=1.0.0). It supports non-streaming
|
|
6
|
+
responses and best-effort streaming of text content. Tool/function calling is
|
|
7
|
+
passed through when tools are provided, but full tool-call conversation
|
|
8
|
+
round-tripping may require adding assistant tool-call messages to the
|
|
9
|
+
conversation upstream.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import json
|
|
15
|
+
import os
|
|
16
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, cast
|
|
17
|
+
|
|
18
|
+
from vanna.core.llm import (
|
|
19
|
+
LlmService,
|
|
20
|
+
LlmRequest,
|
|
21
|
+
LlmResponse,
|
|
22
|
+
LlmStreamChunk,
|
|
23
|
+
)
|
|
24
|
+
from vanna.core.tool import ToolCall, ToolSchema
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OpenAILlmService(LlmService):
|
|
28
|
+
"""OpenAI Chat Completions-backed LLM service.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model: OpenAI model name (e.g., "gpt-5").
|
|
32
|
+
api_key: API key; falls back to env `OPENAI_API_KEY`.
|
|
33
|
+
organization: Optional org; env `OPENAI_ORG` if unset.
|
|
34
|
+
base_url: Optional custom base URL; env `OPENAI_BASE_URL` if unset.
|
|
35
|
+
extra_client_kwargs: Extra kwargs forwarded to `openai.OpenAI()`.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
model: Optional[str] = None,
|
|
41
|
+
api_key: Optional[str] = None,
|
|
42
|
+
organization: Optional[str] = None,
|
|
43
|
+
base_url: Optional[str] = None,
|
|
44
|
+
**extra_client_kwargs: Any,
|
|
45
|
+
) -> None:
|
|
46
|
+
try:
|
|
47
|
+
from openai import OpenAI
|
|
48
|
+
except Exception as e: # pragma: no cover - import-time error surface
|
|
49
|
+
raise ImportError(
|
|
50
|
+
"openai package is required. Install with: pip install 'vanna[openai]'"
|
|
51
|
+
) from e
|
|
52
|
+
|
|
53
|
+
self.model = model or os.getenv("OPENAI_MODEL", "gpt-5")
|
|
54
|
+
api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
55
|
+
organization = organization or os.getenv("OPENAI_ORG")
|
|
56
|
+
base_url = base_url or os.getenv("OPENAI_BASE_URL")
|
|
57
|
+
|
|
58
|
+
client_kwargs: Dict[str, Any] = {**extra_client_kwargs}
|
|
59
|
+
if api_key:
|
|
60
|
+
client_kwargs["api_key"] = api_key
|
|
61
|
+
if organization:
|
|
62
|
+
client_kwargs["organization"] = organization
|
|
63
|
+
if base_url:
|
|
64
|
+
client_kwargs["base_url"] = base_url
|
|
65
|
+
|
|
66
|
+
self._client = OpenAI(**client_kwargs)
|
|
67
|
+
|
|
68
|
+
async def send_request(self, request: LlmRequest) -> LlmResponse:
|
|
69
|
+
"""Send a non-streaming request to OpenAI and return the response."""
|
|
70
|
+
payload = self._build_payload(request)
|
|
71
|
+
|
|
72
|
+
# Call the API synchronously; this function is async but we can block here.
|
|
73
|
+
resp = self._client.chat.completions.create(**payload, stream=False)
|
|
74
|
+
|
|
75
|
+
if not resp.choices:
|
|
76
|
+
return LlmResponse(content=None, tool_calls=None, finish_reason=None)
|
|
77
|
+
|
|
78
|
+
choice = resp.choices[0]
|
|
79
|
+
content: Optional[str] = getattr(choice.message, "content", None)
|
|
80
|
+
tool_calls = self._extract_tool_calls_from_message(choice.message)
|
|
81
|
+
|
|
82
|
+
usage: Dict[str, int] = {}
|
|
83
|
+
if getattr(resp, "usage", None):
|
|
84
|
+
usage = {
|
|
85
|
+
k: int(v)
|
|
86
|
+
for k, v in {
|
|
87
|
+
"prompt_tokens": getattr(resp.usage, "prompt_tokens", 0),
|
|
88
|
+
"completion_tokens": getattr(resp.usage, "completion_tokens", 0),
|
|
89
|
+
"total_tokens": getattr(resp.usage, "total_tokens", 0),
|
|
90
|
+
}.items()
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
return LlmResponse(
|
|
94
|
+
content=content,
|
|
95
|
+
tool_calls=tool_calls or None,
|
|
96
|
+
finish_reason=getattr(choice, "finish_reason", None),
|
|
97
|
+
usage=usage or None,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
async def stream_request(
|
|
101
|
+
self, request: LlmRequest
|
|
102
|
+
) -> AsyncGenerator[LlmStreamChunk, None]:
|
|
103
|
+
"""Stream a request to OpenAI.
|
|
104
|
+
|
|
105
|
+
Emits `LlmStreamChunk` for textual deltas as they arrive. Tool-calls are
|
|
106
|
+
accumulated and emitted in a final chunk when the stream ends.
|
|
107
|
+
"""
|
|
108
|
+
payload = self._build_payload(request)
|
|
109
|
+
|
|
110
|
+
# Synchronous streaming iterator; iterate within async context.
|
|
111
|
+
stream = self._client.chat.completions.create(**payload, stream=True)
|
|
112
|
+
|
|
113
|
+
# Builders for streamed tool-calls (index -> partial)
|
|
114
|
+
tc_builders: Dict[int, Dict[str, Optional[str]]] = {}
|
|
115
|
+
last_finish: Optional[str] = None
|
|
116
|
+
|
|
117
|
+
for event in stream:
|
|
118
|
+
if not getattr(event, "choices", None):
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
choice = event.choices[0]
|
|
122
|
+
delta = getattr(choice, "delta", None)
|
|
123
|
+
if delta is None:
|
|
124
|
+
# Some SDK versions use `event.choices[0].message` on the final packet
|
|
125
|
+
last_finish = getattr(choice, "finish_reason", last_finish)
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
# Text content
|
|
129
|
+
content_piece: Optional[str] = getattr(delta, "content", None)
|
|
130
|
+
if content_piece:
|
|
131
|
+
yield LlmStreamChunk(content=content_piece)
|
|
132
|
+
|
|
133
|
+
# Tool calls (streamed)
|
|
134
|
+
streamed_tool_calls = getattr(delta, "tool_calls", None)
|
|
135
|
+
if streamed_tool_calls:
|
|
136
|
+
for tc in streamed_tool_calls:
|
|
137
|
+
idx = getattr(tc, "index", 0) or 0
|
|
138
|
+
b = tc_builders.setdefault(
|
|
139
|
+
idx, {"id": None, "name": None, "arguments": ""}
|
|
140
|
+
)
|
|
141
|
+
if getattr(tc, "id", None):
|
|
142
|
+
b["id"] = tc.id
|
|
143
|
+
fn = getattr(tc, "function", None)
|
|
144
|
+
if fn is not None:
|
|
145
|
+
if getattr(fn, "name", None):
|
|
146
|
+
b["name"] = fn.name
|
|
147
|
+
if getattr(fn, "arguments", None):
|
|
148
|
+
b["arguments"] = (b["arguments"] or "") + fn.arguments
|
|
149
|
+
|
|
150
|
+
last_finish = getattr(choice, "finish_reason", last_finish)
|
|
151
|
+
|
|
152
|
+
# Emit final tool-calls chunk if any
|
|
153
|
+
final_tool_calls: List[ToolCall] = []
|
|
154
|
+
for b in tc_builders.values():
|
|
155
|
+
if not b.get("name"):
|
|
156
|
+
continue
|
|
157
|
+
args_raw = b.get("arguments") or "{}"
|
|
158
|
+
try:
|
|
159
|
+
loaded = json.loads(args_raw)
|
|
160
|
+
if isinstance(loaded, dict):
|
|
161
|
+
args_dict: Dict[str, Any] = loaded
|
|
162
|
+
else:
|
|
163
|
+
args_dict = {"args": loaded}
|
|
164
|
+
except Exception:
|
|
165
|
+
args_dict = {"_raw": args_raw}
|
|
166
|
+
final_tool_calls.append(
|
|
167
|
+
ToolCall(
|
|
168
|
+
id=b.get("id") or "tool_call",
|
|
169
|
+
name=b["name"] or "tool",
|
|
170
|
+
arguments=args_dict,
|
|
171
|
+
)
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if final_tool_calls:
|
|
175
|
+
yield LlmStreamChunk(tool_calls=final_tool_calls, finish_reason=last_finish)
|
|
176
|
+
else:
|
|
177
|
+
# Still emit a terminal chunk to signal completion
|
|
178
|
+
yield LlmStreamChunk(finish_reason=last_finish or "stop")
|
|
179
|
+
|
|
180
|
+
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
|
|
181
|
+
"""Validate tool schemas. Returns a list of error messages."""
|
|
182
|
+
errors: List[str] = []
|
|
183
|
+
# Basic checks; OpenAI will enforce further validation server-side.
|
|
184
|
+
for t in tools:
|
|
185
|
+
if not t.name or len(t.name) > 64:
|
|
186
|
+
errors.append(f"Invalid tool name: {t.name!r}")
|
|
187
|
+
return errors
|
|
188
|
+
|
|
189
|
+
# Internal helpers
|
|
190
|
+
def _build_payload(self, request: LlmRequest) -> Dict[str, Any]:
|
|
191
|
+
messages: List[Dict[str, Any]] = []
|
|
192
|
+
|
|
193
|
+
# Add system prompt as first message if provided
|
|
194
|
+
if request.system_prompt:
|
|
195
|
+
messages.append({"role": "system", "content": request.system_prompt})
|
|
196
|
+
|
|
197
|
+
for m in request.messages:
|
|
198
|
+
msg: Dict[str, Any] = {"role": m.role, "content": m.content}
|
|
199
|
+
if m.role == "tool" and m.tool_call_id:
|
|
200
|
+
msg["tool_call_id"] = m.tool_call_id
|
|
201
|
+
elif m.role == "assistant" and m.tool_calls:
|
|
202
|
+
# Convert tool calls to OpenAI format
|
|
203
|
+
tool_calls_payload = []
|
|
204
|
+
for tc in m.tool_calls:
|
|
205
|
+
tool_calls_payload.append(
|
|
206
|
+
{
|
|
207
|
+
"id": tc.id,
|
|
208
|
+
"type": "function",
|
|
209
|
+
"function": {
|
|
210
|
+
"name": tc.name,
|
|
211
|
+
"arguments": json.dumps(tc.arguments),
|
|
212
|
+
},
|
|
213
|
+
}
|
|
214
|
+
)
|
|
215
|
+
msg["tool_calls"] = tool_calls_payload
|
|
216
|
+
messages.append(msg)
|
|
217
|
+
|
|
218
|
+
tools_payload: Optional[List[Dict[str, Any]]] = None
|
|
219
|
+
if request.tools:
|
|
220
|
+
tools_payload = [
|
|
221
|
+
{
|
|
222
|
+
"type": "function",
|
|
223
|
+
"function": {
|
|
224
|
+
"name": t.name,
|
|
225
|
+
"description": t.description,
|
|
226
|
+
"parameters": t.parameters,
|
|
227
|
+
},
|
|
228
|
+
}
|
|
229
|
+
for t in request.tools
|
|
230
|
+
]
|
|
231
|
+
|
|
232
|
+
payload: Dict[str, Any] = {
|
|
233
|
+
"model": self.model,
|
|
234
|
+
"messages": messages,
|
|
235
|
+
}
|
|
236
|
+
if request.max_tokens is not None:
|
|
237
|
+
payload["max_tokens"] = request.max_tokens
|
|
238
|
+
if tools_payload:
|
|
239
|
+
payload["tools"] = tools_payload
|
|
240
|
+
payload["tool_choice"] = "auto"
|
|
241
|
+
|
|
242
|
+
return payload
|
|
243
|
+
|
|
244
|
+
def _extract_tool_calls_from_message(self, message: Any) -> List[ToolCall]:
|
|
245
|
+
tool_calls: List[ToolCall] = []
|
|
246
|
+
raw_tool_calls = getattr(message, "tool_calls", None) or []
|
|
247
|
+
for tc in raw_tool_calls:
|
|
248
|
+
fn = getattr(tc, "function", None)
|
|
249
|
+
if not fn:
|
|
250
|
+
continue
|
|
251
|
+
args_raw = getattr(fn, "arguments", "{}")
|
|
252
|
+
try:
|
|
253
|
+
loaded = json.loads(args_raw)
|
|
254
|
+
if isinstance(loaded, dict):
|
|
255
|
+
args_dict: Dict[str, Any] = loaded
|
|
256
|
+
else:
|
|
257
|
+
args_dict = {"args": loaded}
|
|
258
|
+
except Exception:
|
|
259
|
+
args_dict = {"_raw": args_raw}
|
|
260
|
+
tool_calls.append(
|
|
261
|
+
ToolCall(
|
|
262
|
+
id=getattr(tc, "id", "tool_call"),
|
|
263
|
+
name=getattr(fn, "name", "tool"),
|
|
264
|
+
arguments=args_dict,
|
|
265
|
+
)
|
|
266
|
+
)
|
|
267
|
+
return tool_calls
|