vanna 0.7.9__py3-none-any.whl → 2.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +439 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +224 -217
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0rc1.dist-info/METADATA +868 -0
- vanna-2.0.0rc1.dist-info/RECORD +289 -0
- vanna-2.0.0rc1.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.9.dist-info/METADATA +0 -408
- vanna-0.7.9.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/WHEEL +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Integrations module.
|
|
3
|
+
|
|
4
|
+
This package contains concrete implementations of core abstractions and capabilities.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .local import MemoryConversationStore
|
|
8
|
+
from .mock import MockLlmService
|
|
9
|
+
from .plotly import PlotlyChartGenerator
|
|
10
|
+
from .sqlite import SqliteRunner
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"MockLlmService",
|
|
14
|
+
"MemoryConversationStore",
|
|
15
|
+
"SqliteRunner",
|
|
16
|
+
"PlotlyChartGenerator",
|
|
17
|
+
]
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Anthropic LLM service implementation.
|
|
3
|
+
|
|
4
|
+
Implements the LlmService interface using Anthropic's Messages API
|
|
5
|
+
(anthropic>=0.8.0). Supports non-streaming and streaming text output.
|
|
6
|
+
Tool-calls (tool_use blocks) are surfaced at the end of a stream or after a
|
|
7
|
+
non-streaming call as ToolCall entries.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
from vanna.core.llm import (
|
|
19
|
+
LlmService,
|
|
20
|
+
LlmRequest,
|
|
21
|
+
LlmResponse,
|
|
22
|
+
LlmStreamChunk,
|
|
23
|
+
)
|
|
24
|
+
from vanna.core.tool import ToolCall, ToolSchema
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AnthropicLlmService(LlmService):
|
|
28
|
+
"""Anthropic Messages-backed LLM service.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model: Anthropic model name (e.g., "claude-sonnet-4-5", "claude-opus-4").
|
|
32
|
+
Defaults to "claude-sonnet-4-5". Can also be set via ANTHROPIC_MODEL env var.
|
|
33
|
+
api_key: API key; falls back to env `ANTHROPIC_API_KEY`.
|
|
34
|
+
base_url: Optional custom base URL; env `ANTHROPIC_BASE_URL` if unset.
|
|
35
|
+
extra_client_kwargs: Extra kwargs forwarded to `anthropic.Anthropic()`.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
model: Optional[str] = None,
|
|
41
|
+
api_key: Optional[str] = None,
|
|
42
|
+
base_url: Optional[str] = None,
|
|
43
|
+
**extra_client_kwargs: Any,
|
|
44
|
+
) -> None:
|
|
45
|
+
try:
|
|
46
|
+
import anthropic
|
|
47
|
+
except Exception as e: # pragma: no cover
|
|
48
|
+
raise ImportError(
|
|
49
|
+
"anthropic package is required. Install with: pip install 'vanna[anthropic]'"
|
|
50
|
+
) from e
|
|
51
|
+
|
|
52
|
+
# Model selection - use environment variable or default
|
|
53
|
+
self.model = model or os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-5")
|
|
54
|
+
api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
|
55
|
+
base_url = base_url or os.getenv("ANTHROPIC_BASE_URL")
|
|
56
|
+
|
|
57
|
+
client_kwargs: Dict[str, Any] = {**extra_client_kwargs}
|
|
58
|
+
if api_key:
|
|
59
|
+
client_kwargs["api_key"] = api_key
|
|
60
|
+
if base_url:
|
|
61
|
+
client_kwargs["base_url"] = base_url
|
|
62
|
+
|
|
63
|
+
self._client = anthropic.Anthropic(**client_kwargs)
|
|
64
|
+
|
|
65
|
+
async def send_request(self, request: LlmRequest) -> LlmResponse:
|
|
66
|
+
"""Send a non-streaming request to Anthropic and return the response."""
|
|
67
|
+
payload = self._build_payload(request)
|
|
68
|
+
|
|
69
|
+
resp = self._client.messages.create(**payload)
|
|
70
|
+
|
|
71
|
+
logger.info(f"Anthropic response: {resp}")
|
|
72
|
+
|
|
73
|
+
text_content, tool_calls = self._parse_message_content(resp)
|
|
74
|
+
|
|
75
|
+
usage: Dict[str, int] = {}
|
|
76
|
+
if getattr(resp, "usage", None):
|
|
77
|
+
try:
|
|
78
|
+
usage = {
|
|
79
|
+
"input_tokens": int(resp.usage.input_tokens),
|
|
80
|
+
"output_tokens": int(resp.usage.output_tokens),
|
|
81
|
+
}
|
|
82
|
+
except Exception:
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
return LlmResponse(
|
|
86
|
+
content=text_content or None,
|
|
87
|
+
tool_calls=tool_calls or None,
|
|
88
|
+
finish_reason=getattr(resp, "stop_reason", None),
|
|
89
|
+
usage=usage or None,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
async def stream_request(
|
|
93
|
+
self, request: LlmRequest
|
|
94
|
+
) -> AsyncGenerator[LlmStreamChunk, None]:
|
|
95
|
+
"""Stream a request to Anthropic.
|
|
96
|
+
|
|
97
|
+
Yields text chunks as they arrive. Emits tool-calls at the end by
|
|
98
|
+
inspecting the final message.
|
|
99
|
+
"""
|
|
100
|
+
payload = self._build_payload(request)
|
|
101
|
+
|
|
102
|
+
logger.info(f"Anthropic streaming payload: {payload}")
|
|
103
|
+
|
|
104
|
+
# SDK provides a streaming context manager with a text_stream iterator.
|
|
105
|
+
with self._client.messages.stream(**payload) as stream:
|
|
106
|
+
for text in stream.text_stream:
|
|
107
|
+
if text:
|
|
108
|
+
yield LlmStreamChunk(content=text)
|
|
109
|
+
|
|
110
|
+
final = stream.get_final_message()
|
|
111
|
+
logger.info(f"Anthropic stream response: {final}")
|
|
112
|
+
_, tool_calls = self._parse_message_content(final)
|
|
113
|
+
if tool_calls:
|
|
114
|
+
yield LlmStreamChunk(
|
|
115
|
+
tool_calls=tool_calls,
|
|
116
|
+
finish_reason=getattr(final, "stop_reason", None),
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
yield LlmStreamChunk(
|
|
120
|
+
finish_reason=getattr(final, "stop_reason", None) or "stop"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
|
|
124
|
+
"""Basic validation of tool schemas for Anthropic."""
|
|
125
|
+
errors: List[str] = []
|
|
126
|
+
for t in tools:
|
|
127
|
+
if not t.name:
|
|
128
|
+
errors.append("Tool name is required")
|
|
129
|
+
return errors
|
|
130
|
+
|
|
131
|
+
# Internal helpers
|
|
132
|
+
def _build_payload(self, request: LlmRequest) -> Dict[str, Any]:
|
|
133
|
+
# Anthropic requires messages content as list of content blocks per message
|
|
134
|
+
# We need to group consecutive tool messages into single user messages
|
|
135
|
+
messages: List[Dict[str, Any]] = []
|
|
136
|
+
i = 0
|
|
137
|
+
|
|
138
|
+
while i < len(request.messages):
|
|
139
|
+
m = request.messages[i]
|
|
140
|
+
|
|
141
|
+
if m.role == "tool":
|
|
142
|
+
# Group consecutive tool messages into one user message
|
|
143
|
+
tool_content_blocks = []
|
|
144
|
+
while i < len(request.messages) and request.messages[i].role == "tool":
|
|
145
|
+
tool_msg = request.messages[i]
|
|
146
|
+
if tool_msg.tool_call_id:
|
|
147
|
+
tool_content_blocks.append(
|
|
148
|
+
{
|
|
149
|
+
"type": "tool_result",
|
|
150
|
+
"tool_use_id": tool_msg.tool_call_id,
|
|
151
|
+
"content": tool_msg.content,
|
|
152
|
+
}
|
|
153
|
+
)
|
|
154
|
+
i += 1
|
|
155
|
+
|
|
156
|
+
if tool_content_blocks:
|
|
157
|
+
messages.append(
|
|
158
|
+
{
|
|
159
|
+
"role": "user",
|
|
160
|
+
"content": tool_content_blocks,
|
|
161
|
+
}
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
# Handle non-tool messages normally
|
|
165
|
+
content_blocks = []
|
|
166
|
+
|
|
167
|
+
# Handle text content - only add if not empty
|
|
168
|
+
if m.content and m.content.strip():
|
|
169
|
+
content_blocks.append({"type": "text", "text": m.content})
|
|
170
|
+
|
|
171
|
+
# Handle tool_calls for assistant messages (convert to tool_use blocks)
|
|
172
|
+
if m.role == "assistant" and m.tool_calls:
|
|
173
|
+
for tc in m.tool_calls:
|
|
174
|
+
content_blocks.append(
|
|
175
|
+
{
|
|
176
|
+
"type": "tool_use",
|
|
177
|
+
"id": tc.id,
|
|
178
|
+
"name": tc.name,
|
|
179
|
+
"input": tc.arguments, # type: ignore[dict-item]
|
|
180
|
+
}
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Ensure we have at least one content block for text messages
|
|
184
|
+
if not content_blocks and m.role in {"user", "assistant"}:
|
|
185
|
+
content_blocks.append({"type": "text", "text": m.content or ""})
|
|
186
|
+
|
|
187
|
+
if content_blocks:
|
|
188
|
+
role = m.role if m.role in {"user", "assistant"} else "user"
|
|
189
|
+
messages.append(
|
|
190
|
+
{
|
|
191
|
+
"role": role,
|
|
192
|
+
"content": content_blocks,
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
i += 1
|
|
197
|
+
|
|
198
|
+
tools_payload: Optional[List[Dict[str, Any]]] = None
|
|
199
|
+
if request.tools:
|
|
200
|
+
tools_payload = [
|
|
201
|
+
{
|
|
202
|
+
"name": t.name,
|
|
203
|
+
"description": t.description,
|
|
204
|
+
"input_schema": t.parameters,
|
|
205
|
+
}
|
|
206
|
+
for t in request.tools
|
|
207
|
+
]
|
|
208
|
+
|
|
209
|
+
payload: Dict[str, Any] = {
|
|
210
|
+
"model": self.model,
|
|
211
|
+
"messages": messages,
|
|
212
|
+
# Anthropic requires max_tokens; default if not provided
|
|
213
|
+
"max_tokens": request.max_tokens if request.max_tokens is not None else 512,
|
|
214
|
+
"temperature": request.temperature,
|
|
215
|
+
}
|
|
216
|
+
if tools_payload:
|
|
217
|
+
payload["tools"] = tools_payload
|
|
218
|
+
payload["tool_choice"] = {"type": "auto"}
|
|
219
|
+
|
|
220
|
+
# Add system prompt if provided
|
|
221
|
+
if request.system_prompt:
|
|
222
|
+
payload["system"] = request.system_prompt
|
|
223
|
+
|
|
224
|
+
return payload
|
|
225
|
+
|
|
226
|
+
def _parse_message_content(self, msg: Any) -> Tuple[str, List[ToolCall]]:
|
|
227
|
+
text_parts: List[str] = []
|
|
228
|
+
tool_calls: List[ToolCall] = []
|
|
229
|
+
|
|
230
|
+
content_list = getattr(msg, "content", []) or []
|
|
231
|
+
for block in content_list:
|
|
232
|
+
btype = getattr(block, "type", None) or (
|
|
233
|
+
block.get("type") if isinstance(block, dict) else None
|
|
234
|
+
)
|
|
235
|
+
if btype == "text":
|
|
236
|
+
# SDK returns block.text for typed object; dict uses {"text": ...}
|
|
237
|
+
text = getattr(block, "text", None)
|
|
238
|
+
if text is None and isinstance(block, dict):
|
|
239
|
+
text = block.get("text")
|
|
240
|
+
if text:
|
|
241
|
+
text_parts.append(str(text))
|
|
242
|
+
elif btype == "tool_use":
|
|
243
|
+
# Tool call with name and input
|
|
244
|
+
name = getattr(block, "name", None) or (
|
|
245
|
+
block.get("name") if isinstance(block, dict) else None
|
|
246
|
+
)
|
|
247
|
+
tc_id = getattr(block, "id", None) or (
|
|
248
|
+
block.get("id") if isinstance(block, dict) else None
|
|
249
|
+
)
|
|
250
|
+
input_data = getattr(block, "input", None) or (
|
|
251
|
+
block.get("input") if isinstance(block, dict) else None
|
|
252
|
+
)
|
|
253
|
+
if name:
|
|
254
|
+
try:
|
|
255
|
+
# input_data should be a dict already
|
|
256
|
+
args = (
|
|
257
|
+
input_data
|
|
258
|
+
if isinstance(input_data, dict)
|
|
259
|
+
else {"_raw": input_data}
|
|
260
|
+
)
|
|
261
|
+
except Exception:
|
|
262
|
+
args = {"_raw": str(input_data)}
|
|
263
|
+
tool_calls.append(
|
|
264
|
+
ToolCall(
|
|
265
|
+
id=str(tc_id or "tool_call"), name=str(name), arguments=args
|
|
266
|
+
)
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
text_content = "".join(text_parts)
|
|
270
|
+
return text_content, tool_calls
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Azure OpenAI LLM service implementation.
|
|
3
|
+
|
|
4
|
+
Provides an `LlmService` backed by Azure OpenAI Chat Completions (openai>=1.0.0)
|
|
5
|
+
with support for streaming, deployment-scoped models, and Azure-specific
|
|
6
|
+
authentication flows.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Set
|
|
14
|
+
|
|
15
|
+
from vanna.core.llm import (
|
|
16
|
+
LlmService,
|
|
17
|
+
LlmRequest,
|
|
18
|
+
LlmResponse,
|
|
19
|
+
LlmStreamChunk,
|
|
20
|
+
)
|
|
21
|
+
from vanna.core.tool import ToolCall, ToolSchema
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Models that don't support temperature and other sampling parameters
|
|
25
|
+
REASONING_MODELS: Set[str] = {
|
|
26
|
+
"o1",
|
|
27
|
+
"o1-mini",
|
|
28
|
+
"o1-preview",
|
|
29
|
+
"o3-mini",
|
|
30
|
+
"gpt-5",
|
|
31
|
+
"gpt-5-mini",
|
|
32
|
+
"gpt-5-nano",
|
|
33
|
+
"gpt-5-pro",
|
|
34
|
+
"gpt-5-codex",
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _is_reasoning_model(model: str) -> bool:
|
|
39
|
+
"""Return True when the deployment targets a reasoning-only model."""
|
|
40
|
+
model_lower = model.lower()
|
|
41
|
+
return any(reasoning_model in model_lower for reasoning_model in REASONING_MODELS)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class AzureOpenAILlmService(LlmService):
|
|
45
|
+
"""Azure OpenAI Chat Completions-backed LLM service.
|
|
46
|
+
|
|
47
|
+
Wraps `openai.AzureOpenAI` so Vanna can talk to deployment-scoped models
|
|
48
|
+
and either API key or Microsoft Entra ID authentication.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
model: Deployment name in Azure OpenAI (required).
|
|
52
|
+
api_key: API key; falls back to `AZURE_OPENAI_API_KEY`.
|
|
53
|
+
azure_endpoint: Azure OpenAI endpoint URL; falls back to
|
|
54
|
+
`AZURE_OPENAI_ENDPOINT`.
|
|
55
|
+
api_version: API version; defaults to "2024-10-21" or
|
|
56
|
+
`AZURE_OPENAI_API_VERSION`.
|
|
57
|
+
azure_ad_token_provider: Optional bearer token provider for Entra ID.
|
|
58
|
+
**extra_client_kwargs: Additional keyword arguments forwarded to the
|
|
59
|
+
underlying client.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
model: Optional[str] = None,
|
|
65
|
+
api_key: Optional[str] = None,
|
|
66
|
+
azure_endpoint: Optional[str] = None,
|
|
67
|
+
api_version: Optional[str] = None,
|
|
68
|
+
azure_ad_token_provider: Optional[Any] = None,
|
|
69
|
+
**extra_client_kwargs: Any,
|
|
70
|
+
) -> None:
|
|
71
|
+
try:
|
|
72
|
+
from openai import AzureOpenAI
|
|
73
|
+
except Exception as e: # pragma: no cover
|
|
74
|
+
raise ImportError(
|
|
75
|
+
"openai package is required. Install with: pip install 'vanna[azureopenai]' "
|
|
76
|
+
"or 'pip install openai'"
|
|
77
|
+
) from e
|
|
78
|
+
|
|
79
|
+
# Model/deployment name is required for Azure OpenAI
|
|
80
|
+
self.model = model or os.getenv("AZURE_OPENAI_MODEL")
|
|
81
|
+
if not self.model:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
"model parameter (deployment name) is required for Azure OpenAI. "
|
|
84
|
+
"Provide it as argument or set AZURE_OPENAI_MODEL environment variable."
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Azure endpoint is required
|
|
88
|
+
azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
|
|
89
|
+
if not azure_endpoint:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
"azure_endpoint is required for Azure OpenAI. "
|
|
92
|
+
"Provide it as argument or set AZURE_OPENAI_ENDPOINT environment variable."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# API version - use latest stable GA version by default
|
|
96
|
+
api_version = api_version or os.getenv("AZURE_OPENAI_API_VERSION", "2024-10-21")
|
|
97
|
+
|
|
98
|
+
# Build client kwargs
|
|
99
|
+
client_kwargs: Dict[str, Any] = {
|
|
100
|
+
"azure_endpoint": azure_endpoint,
|
|
101
|
+
"api_version": api_version,
|
|
102
|
+
**extra_client_kwargs,
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
# Authentication: prefer Azure AD token provider, fallback to API key
|
|
106
|
+
if azure_ad_token_provider is not None:
|
|
107
|
+
client_kwargs["azure_ad_token_provider"] = azure_ad_token_provider
|
|
108
|
+
else:
|
|
109
|
+
api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
|
|
110
|
+
if not api_key:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
"Authentication required: provide either api_key or azure_ad_token_provider. "
|
|
113
|
+
"API key can also be set via AZURE_OPENAI_API_KEY environment variable."
|
|
114
|
+
)
|
|
115
|
+
client_kwargs["api_key"] = api_key
|
|
116
|
+
|
|
117
|
+
self._client = AzureOpenAI(**client_kwargs)
|
|
118
|
+
self._is_reasoning_model = _is_reasoning_model(self.model)
|
|
119
|
+
|
|
120
|
+
async def send_request(self, request: LlmRequest) -> LlmResponse:
|
|
121
|
+
"""Send a non-streaming request to Azure OpenAI and return the response."""
|
|
122
|
+
payload = self._build_payload(request)
|
|
123
|
+
|
|
124
|
+
# Call the API synchronously; this function is async but we can block here.
|
|
125
|
+
resp = self._client.chat.completions.create(**payload, stream=False)
|
|
126
|
+
|
|
127
|
+
if not resp.choices:
|
|
128
|
+
return LlmResponse(content=None, tool_calls=None, finish_reason=None)
|
|
129
|
+
|
|
130
|
+
choice = resp.choices[0]
|
|
131
|
+
content: Optional[str] = getattr(choice.message, "content", None)
|
|
132
|
+
tool_calls = self._extract_tool_calls_from_message(choice.message)
|
|
133
|
+
|
|
134
|
+
usage: Dict[str, int] = {}
|
|
135
|
+
if getattr(resp, "usage", None):
|
|
136
|
+
usage = {
|
|
137
|
+
k: int(v)
|
|
138
|
+
for k, v in {
|
|
139
|
+
"prompt_tokens": getattr(resp.usage, "prompt_tokens", 0),
|
|
140
|
+
"completion_tokens": getattr(resp.usage, "completion_tokens", 0),
|
|
141
|
+
"total_tokens": getattr(resp.usage, "total_tokens", 0),
|
|
142
|
+
}.items()
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
return LlmResponse(
|
|
146
|
+
content=content,
|
|
147
|
+
tool_calls=tool_calls or None,
|
|
148
|
+
finish_reason=getattr(choice, "finish_reason", None),
|
|
149
|
+
usage=usage or None,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
async def stream_request(
|
|
153
|
+
self, request: LlmRequest
|
|
154
|
+
) -> AsyncGenerator[LlmStreamChunk, None]:
|
|
155
|
+
"""
|
|
156
|
+
Stream a request to Azure OpenAI.
|
|
157
|
+
|
|
158
|
+
Emits `LlmStreamChunk` for textual deltas as they arrive. Tool-calls are
|
|
159
|
+
accumulated and emitted in a final chunk when the stream ends.
|
|
160
|
+
"""
|
|
161
|
+
payload = self._build_payload(request)
|
|
162
|
+
|
|
163
|
+
# Synchronous streaming iterator; iterate within async context.
|
|
164
|
+
stream = self._client.chat.completions.create(**payload, stream=True)
|
|
165
|
+
|
|
166
|
+
# Builders for streamed tool-calls (index -> partial)
|
|
167
|
+
tc_builders: Dict[int, Dict[str, Optional[str]]] = {}
|
|
168
|
+
last_finish: Optional[str] = None
|
|
169
|
+
|
|
170
|
+
for event in stream:
|
|
171
|
+
if not getattr(event, "choices", None):
|
|
172
|
+
continue
|
|
173
|
+
|
|
174
|
+
choice = event.choices[0]
|
|
175
|
+
delta = getattr(choice, "delta", None)
|
|
176
|
+
if delta is None:
|
|
177
|
+
# Some SDK versions use `event.choices[0].message` on the final packet
|
|
178
|
+
last_finish = getattr(choice, "finish_reason", last_finish)
|
|
179
|
+
continue
|
|
180
|
+
|
|
181
|
+
# Text content
|
|
182
|
+
content_piece: Optional[str] = getattr(delta, "content", None)
|
|
183
|
+
if content_piece:
|
|
184
|
+
yield LlmStreamChunk(content=content_piece)
|
|
185
|
+
|
|
186
|
+
# Tool calls (streamed)
|
|
187
|
+
streamed_tool_calls = getattr(delta, "tool_calls", None)
|
|
188
|
+
if streamed_tool_calls:
|
|
189
|
+
for tc in streamed_tool_calls:
|
|
190
|
+
idx = getattr(tc, "index", 0) or 0
|
|
191
|
+
b = tc_builders.setdefault(
|
|
192
|
+
idx, {"id": None, "name": None, "arguments": ""}
|
|
193
|
+
)
|
|
194
|
+
if getattr(tc, "id", None):
|
|
195
|
+
b["id"] = tc.id
|
|
196
|
+
fn = getattr(tc, "function", None)
|
|
197
|
+
if fn is not None:
|
|
198
|
+
if getattr(fn, "name", None):
|
|
199
|
+
b["name"] = fn.name
|
|
200
|
+
if getattr(fn, "arguments", None):
|
|
201
|
+
b["arguments"] = (b["arguments"] or "") + fn.arguments
|
|
202
|
+
|
|
203
|
+
last_finish = getattr(choice, "finish_reason", last_finish)
|
|
204
|
+
|
|
205
|
+
# Emit final tool-calls chunk if any
|
|
206
|
+
final_tool_calls: List[ToolCall] = []
|
|
207
|
+
for b in tc_builders.values():
|
|
208
|
+
if not b.get("name"):
|
|
209
|
+
continue
|
|
210
|
+
args_raw = b.get("arguments") or "{}"
|
|
211
|
+
try:
|
|
212
|
+
loaded = json.loads(args_raw)
|
|
213
|
+
if isinstance(loaded, dict):
|
|
214
|
+
args_dict: Dict[str, Any] = loaded
|
|
215
|
+
else:
|
|
216
|
+
args_dict = {"args": loaded}
|
|
217
|
+
except Exception:
|
|
218
|
+
args_dict = {"_raw": args_raw}
|
|
219
|
+
final_tool_calls.append(
|
|
220
|
+
ToolCall(
|
|
221
|
+
id=b.get("id") or "tool_call",
|
|
222
|
+
name=b["name"] or "tool",
|
|
223
|
+
arguments=args_dict,
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
if final_tool_calls:
|
|
228
|
+
yield LlmStreamChunk(tool_calls=final_tool_calls, finish_reason=last_finish)
|
|
229
|
+
else:
|
|
230
|
+
# Still emit a terminal chunk to signal completion
|
|
231
|
+
yield LlmStreamChunk(finish_reason=last_finish or "stop")
|
|
232
|
+
|
|
233
|
+
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
|
|
234
|
+
"""Validate tool schemas. Returns a list of error messages."""
|
|
235
|
+
errors: List[str] = []
|
|
236
|
+
# Basic checks; Azure OpenAI will enforce further validation server-side.
|
|
237
|
+
for t in tools:
|
|
238
|
+
if not t.name or len(t.name) > 64:
|
|
239
|
+
errors.append(f"Invalid tool name: {t.name!r}")
|
|
240
|
+
return errors
|
|
241
|
+
|
|
242
|
+
# Internal helpers
|
|
243
|
+
def _build_payload(self, request: LlmRequest) -> Dict[str, Any]:
|
|
244
|
+
"""Build the API payload from LlmRequest."""
|
|
245
|
+
messages: List[Dict[str, Any]] = []
|
|
246
|
+
|
|
247
|
+
# Add system prompt as first message if provided
|
|
248
|
+
if request.system_prompt:
|
|
249
|
+
messages.append({"role": "system", "content": request.system_prompt})
|
|
250
|
+
|
|
251
|
+
for m in request.messages:
|
|
252
|
+
msg: Dict[str, Any] = {"role": m.role, "content": m.content}
|
|
253
|
+
if m.role == "tool" and m.tool_call_id:
|
|
254
|
+
msg["tool_call_id"] = m.tool_call_id
|
|
255
|
+
elif m.role == "assistant" and m.tool_calls:
|
|
256
|
+
# Convert tool calls to OpenAI format
|
|
257
|
+
tool_calls_payload = []
|
|
258
|
+
for tc in m.tool_calls:
|
|
259
|
+
tool_calls_payload.append(
|
|
260
|
+
{
|
|
261
|
+
"id": tc.id,
|
|
262
|
+
"type": "function",
|
|
263
|
+
"function": {
|
|
264
|
+
"name": tc.name,
|
|
265
|
+
"arguments": json.dumps(tc.arguments),
|
|
266
|
+
},
|
|
267
|
+
}
|
|
268
|
+
)
|
|
269
|
+
msg["tool_calls"] = tool_calls_payload
|
|
270
|
+
messages.append(msg)
|
|
271
|
+
|
|
272
|
+
tools_payload: Optional[List[Dict[str, Any]]] = None
|
|
273
|
+
if request.tools:
|
|
274
|
+
tools_payload = [
|
|
275
|
+
{
|
|
276
|
+
"type": "function",
|
|
277
|
+
"function": {
|
|
278
|
+
"name": t.name,
|
|
279
|
+
"description": t.description,
|
|
280
|
+
"parameters": t.parameters,
|
|
281
|
+
},
|
|
282
|
+
}
|
|
283
|
+
for t in request.tools
|
|
284
|
+
]
|
|
285
|
+
|
|
286
|
+
payload: Dict[str, Any] = {
|
|
287
|
+
"model": self.model,
|
|
288
|
+
"messages": messages,
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
# Add temperature only for non-reasoning models
|
|
292
|
+
# Reasoning models (GPT-5, o1, o3-mini) don't support temperature parameter
|
|
293
|
+
if not self._is_reasoning_model:
|
|
294
|
+
payload["temperature"] = request.temperature
|
|
295
|
+
|
|
296
|
+
if request.max_tokens is not None:
|
|
297
|
+
payload["max_tokens"] = request.max_tokens
|
|
298
|
+
|
|
299
|
+
if tools_payload:
|
|
300
|
+
payload["tools"] = tools_payload
|
|
301
|
+
payload["tool_choice"] = "auto"
|
|
302
|
+
|
|
303
|
+
return payload
|
|
304
|
+
|
|
305
|
+
def _extract_tool_calls_from_message(self, message: Any) -> List[ToolCall]:
|
|
306
|
+
"""Extract tool calls from OpenAI message object."""
|
|
307
|
+
tool_calls: List[ToolCall] = []
|
|
308
|
+
raw_tool_calls = getattr(message, "tool_calls", None) or []
|
|
309
|
+
for tc in raw_tool_calls:
|
|
310
|
+
fn = getattr(tc, "function", None)
|
|
311
|
+
if not fn:
|
|
312
|
+
continue
|
|
313
|
+
args_raw = getattr(fn, "arguments", "{}")
|
|
314
|
+
try:
|
|
315
|
+
loaded = json.loads(args_raw)
|
|
316
|
+
if isinstance(loaded, dict):
|
|
317
|
+
args_dict: Dict[str, Any] = loaded
|
|
318
|
+
else:
|
|
319
|
+
args_dict = {"args": loaded}
|
|
320
|
+
except Exception:
|
|
321
|
+
args_dict = {"_raw": args_raw}
|
|
322
|
+
tool_calls.append(
|
|
323
|
+
ToolCall(
|
|
324
|
+
id=getattr(tc, "id", "tool_call"),
|
|
325
|
+
name=getattr(fn, "name", "tool"),
|
|
326
|
+
arguments=args_dict,
|
|
327
|
+
)
|
|
328
|
+
)
|
|
329
|
+
return tool_calls
|