vanna 0.7.8__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. vanna/__init__.py +167 -395
  2. vanna/agents/__init__.py +7 -0
  3. vanna/capabilities/__init__.py +17 -0
  4. vanna/capabilities/agent_memory/__init__.py +21 -0
  5. vanna/capabilities/agent_memory/base.py +103 -0
  6. vanna/capabilities/agent_memory/models.py +53 -0
  7. vanna/capabilities/file_system/__init__.py +14 -0
  8. vanna/capabilities/file_system/base.py +71 -0
  9. vanna/capabilities/file_system/models.py +25 -0
  10. vanna/capabilities/sql_runner/__init__.py +13 -0
  11. vanna/capabilities/sql_runner/base.py +37 -0
  12. vanna/capabilities/sql_runner/models.py +13 -0
  13. vanna/components/__init__.py +92 -0
  14. vanna/components/base.py +11 -0
  15. vanna/components/rich/__init__.py +83 -0
  16. vanna/components/rich/containers/__init__.py +7 -0
  17. vanna/components/rich/containers/card.py +20 -0
  18. vanna/components/rich/data/__init__.py +9 -0
  19. vanna/components/rich/data/chart.py +17 -0
  20. vanna/components/rich/data/dataframe.py +93 -0
  21. vanna/components/rich/feedback/__init__.py +21 -0
  22. vanna/components/rich/feedback/badge.py +16 -0
  23. vanna/components/rich/feedback/icon_text.py +14 -0
  24. vanna/components/rich/feedback/log_viewer.py +41 -0
  25. vanna/components/rich/feedback/notification.py +19 -0
  26. vanna/components/rich/feedback/progress.py +37 -0
  27. vanna/components/rich/feedback/status_card.py +28 -0
  28. vanna/components/rich/feedback/status_indicator.py +14 -0
  29. vanna/components/rich/interactive/__init__.py +21 -0
  30. vanna/components/rich/interactive/button.py +95 -0
  31. vanna/components/rich/interactive/task_list.py +58 -0
  32. vanna/components/rich/interactive/ui_state.py +93 -0
  33. vanna/components/rich/specialized/__init__.py +7 -0
  34. vanna/components/rich/specialized/artifact.py +20 -0
  35. vanna/components/rich/text.py +16 -0
  36. vanna/components/simple/__init__.py +15 -0
  37. vanna/components/simple/image.py +15 -0
  38. vanna/components/simple/link.py +15 -0
  39. vanna/components/simple/text.py +11 -0
  40. vanna/core/__init__.py +193 -0
  41. vanna/core/_compat.py +19 -0
  42. vanna/core/agent/__init__.py +10 -0
  43. vanna/core/agent/agent.py +1407 -0
  44. vanna/core/agent/config.py +123 -0
  45. vanna/core/audit/__init__.py +28 -0
  46. vanna/core/audit/base.py +299 -0
  47. vanna/core/audit/models.py +131 -0
  48. vanna/core/component_manager.py +329 -0
  49. vanna/core/components.py +53 -0
  50. vanna/core/enhancer/__init__.py +11 -0
  51. vanna/core/enhancer/base.py +94 -0
  52. vanna/core/enhancer/default.py +118 -0
  53. vanna/core/enricher/__init__.py +10 -0
  54. vanna/core/enricher/base.py +59 -0
  55. vanna/core/errors.py +47 -0
  56. vanna/core/evaluation/__init__.py +81 -0
  57. vanna/core/evaluation/base.py +186 -0
  58. vanna/core/evaluation/dataset.py +254 -0
  59. vanna/core/evaluation/evaluators.py +376 -0
  60. vanna/core/evaluation/report.py +289 -0
  61. vanna/core/evaluation/runner.py +313 -0
  62. vanna/core/filter/__init__.py +10 -0
  63. vanna/core/filter/base.py +67 -0
  64. vanna/core/lifecycle/__init__.py +10 -0
  65. vanna/core/lifecycle/base.py +83 -0
  66. vanna/core/llm/__init__.py +16 -0
  67. vanna/core/llm/base.py +40 -0
  68. vanna/core/llm/models.py +61 -0
  69. vanna/core/middleware/__init__.py +10 -0
  70. vanna/core/middleware/base.py +69 -0
  71. vanna/core/observability/__init__.py +11 -0
  72. vanna/core/observability/base.py +88 -0
  73. vanna/core/observability/models.py +47 -0
  74. vanna/core/recovery/__init__.py +11 -0
  75. vanna/core/recovery/base.py +84 -0
  76. vanna/core/recovery/models.py +32 -0
  77. vanna/core/registry.py +278 -0
  78. vanna/core/rich_component.py +156 -0
  79. vanna/core/simple_component.py +27 -0
  80. vanna/core/storage/__init__.py +14 -0
  81. vanna/core/storage/base.py +46 -0
  82. vanna/core/storage/models.py +46 -0
  83. vanna/core/system_prompt/__init__.py +13 -0
  84. vanna/core/system_prompt/base.py +36 -0
  85. vanna/core/system_prompt/default.py +157 -0
  86. vanna/core/tool/__init__.py +18 -0
  87. vanna/core/tool/base.py +70 -0
  88. vanna/core/tool/models.py +84 -0
  89. vanna/core/user/__init__.py +17 -0
  90. vanna/core/user/base.py +29 -0
  91. vanna/core/user/models.py +25 -0
  92. vanna/core/user/request_context.py +70 -0
  93. vanna/core/user/resolver.py +42 -0
  94. vanna/core/validation.py +164 -0
  95. vanna/core/workflow/__init__.py +12 -0
  96. vanna/core/workflow/base.py +254 -0
  97. vanna/core/workflow/default.py +789 -0
  98. vanna/examples/__init__.py +1 -0
  99. vanna/examples/__main__.py +44 -0
  100. vanna/examples/anthropic_quickstart.py +80 -0
  101. vanna/examples/artifact_example.py +293 -0
  102. vanna/examples/claude_sqlite_example.py +236 -0
  103. vanna/examples/coding_agent_example.py +300 -0
  104. vanna/examples/custom_system_prompt_example.py +174 -0
  105. vanna/examples/default_workflow_handler_example.py +208 -0
  106. vanna/examples/email_auth_example.py +340 -0
  107. vanna/examples/evaluation_example.py +269 -0
  108. vanna/examples/extensibility_example.py +262 -0
  109. vanna/examples/minimal_example.py +67 -0
  110. vanna/examples/mock_auth_example.py +227 -0
  111. vanna/examples/mock_custom_tool.py +311 -0
  112. vanna/examples/mock_quickstart.py +79 -0
  113. vanna/examples/mock_quota_example.py +145 -0
  114. vanna/examples/mock_rich_components_demo.py +396 -0
  115. vanna/examples/mock_sqlite_example.py +223 -0
  116. vanna/examples/openai_quickstart.py +83 -0
  117. vanna/examples/primitive_components_demo.py +305 -0
  118. vanna/examples/quota_lifecycle_example.py +139 -0
  119. vanna/examples/visualization_example.py +251 -0
  120. vanna/integrations/__init__.py +17 -0
  121. vanna/integrations/anthropic/__init__.py +9 -0
  122. vanna/integrations/anthropic/llm.py +270 -0
  123. vanna/integrations/azureopenai/__init__.py +9 -0
  124. vanna/integrations/azureopenai/llm.py +329 -0
  125. vanna/integrations/azuresearch/__init__.py +7 -0
  126. vanna/integrations/azuresearch/agent_memory.py +413 -0
  127. vanna/integrations/bigquery/__init__.py +5 -0
  128. vanna/integrations/bigquery/sql_runner.py +81 -0
  129. vanna/integrations/chromadb/__init__.py +104 -0
  130. vanna/integrations/chromadb/agent_memory.py +416 -0
  131. vanna/integrations/clickhouse/__init__.py +5 -0
  132. vanna/integrations/clickhouse/sql_runner.py +82 -0
  133. vanna/integrations/duckdb/__init__.py +5 -0
  134. vanna/integrations/duckdb/sql_runner.py +65 -0
  135. vanna/integrations/faiss/__init__.py +7 -0
  136. vanna/integrations/faiss/agent_memory.py +431 -0
  137. vanna/integrations/google/__init__.py +9 -0
  138. vanna/integrations/google/gemini.py +370 -0
  139. vanna/integrations/hive/__init__.py +5 -0
  140. vanna/integrations/hive/sql_runner.py +87 -0
  141. vanna/integrations/local/__init__.py +17 -0
  142. vanna/integrations/local/agent_memory/__init__.py +7 -0
  143. vanna/integrations/local/agent_memory/in_memory.py +285 -0
  144. vanna/integrations/local/audit.py +59 -0
  145. vanna/integrations/local/file_system.py +242 -0
  146. vanna/integrations/local/file_system_conversation_store.py +255 -0
  147. vanna/integrations/local/storage.py +62 -0
  148. vanna/integrations/marqo/__init__.py +7 -0
  149. vanna/integrations/marqo/agent_memory.py +354 -0
  150. vanna/integrations/milvus/__init__.py +7 -0
  151. vanna/integrations/milvus/agent_memory.py +458 -0
  152. vanna/integrations/mock/__init__.py +9 -0
  153. vanna/integrations/mock/llm.py +65 -0
  154. vanna/integrations/mssql/__init__.py +5 -0
  155. vanna/integrations/mssql/sql_runner.py +66 -0
  156. vanna/integrations/mysql/__init__.py +5 -0
  157. vanna/integrations/mysql/sql_runner.py +92 -0
  158. vanna/integrations/ollama/__init__.py +7 -0
  159. vanna/integrations/ollama/llm.py +252 -0
  160. vanna/integrations/openai/__init__.py +10 -0
  161. vanna/integrations/openai/llm.py +267 -0
  162. vanna/integrations/openai/responses.py +163 -0
  163. vanna/integrations/opensearch/__init__.py +7 -0
  164. vanna/integrations/opensearch/agent_memory.py +411 -0
  165. vanna/integrations/oracle/__init__.py +5 -0
  166. vanna/integrations/oracle/sql_runner.py +75 -0
  167. vanna/integrations/pinecone/__init__.py +7 -0
  168. vanna/integrations/pinecone/agent_memory.py +329 -0
  169. vanna/integrations/plotly/__init__.py +5 -0
  170. vanna/integrations/plotly/chart_generator.py +313 -0
  171. vanna/integrations/postgres/__init__.py +9 -0
  172. vanna/integrations/postgres/sql_runner.py +112 -0
  173. vanna/integrations/premium/agent_memory/__init__.py +7 -0
  174. vanna/integrations/premium/agent_memory/premium.py +186 -0
  175. vanna/integrations/presto/__init__.py +5 -0
  176. vanna/integrations/presto/sql_runner.py +107 -0
  177. vanna/integrations/qdrant/__init__.py +7 -0
  178. vanna/integrations/qdrant/agent_memory.py +461 -0
  179. vanna/integrations/snowflake/__init__.py +5 -0
  180. vanna/integrations/snowflake/sql_runner.py +147 -0
  181. vanna/integrations/sqlite/__init__.py +9 -0
  182. vanna/integrations/sqlite/sql_runner.py +65 -0
  183. vanna/integrations/weaviate/__init__.py +7 -0
  184. vanna/integrations/weaviate/agent_memory.py +428 -0
  185. vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
  186. vanna/legacy/__init__.py +403 -0
  187. vanna/legacy/adapter.py +463 -0
  188. vanna/{advanced → legacy/advanced}/__init__.py +3 -1
  189. vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
  190. vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
  191. vanna/{base → legacy/base}/base.py +247 -223
  192. vanna/legacy/bedrock/__init__.py +1 -0
  193. vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
  194. vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
  195. vanna/legacy/cohere/__init__.py +2 -0
  196. vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
  197. vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
  198. vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
  199. vanna/legacy/faiss/__init__.py +1 -0
  200. vanna/{faiss → legacy/faiss}/faiss.py +113 -59
  201. vanna/{flask → legacy/flask}/__init__.py +84 -43
  202. vanna/{flask → legacy/flask}/assets.py +5 -5
  203. vanna/{flask → legacy/flask}/auth.py +5 -4
  204. vanna/{google → legacy/google}/bigquery_vector.py +75 -42
  205. vanna/{google → legacy/google}/gemini_chat.py +7 -3
  206. vanna/{hf → legacy/hf}/hf.py +0 -1
  207. vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
  208. vanna/{mock → legacy/mock}/llm.py +0 -1
  209. vanna/legacy/mock/vectordb.py +67 -0
  210. vanna/legacy/ollama/ollama.py +110 -0
  211. vanna/{openai → legacy/openai}/openai_chat.py +2 -6
  212. vanna/legacy/opensearch/opensearch_vector.py +369 -0
  213. vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
  214. vanna/legacy/oracle/oracle_vector.py +584 -0
  215. vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
  216. vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
  217. vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
  218. vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
  219. vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
  220. vanna/{remote.py → legacy/remote.py} +28 -26
  221. vanna/{utils.py → legacy/utils.py} +6 -11
  222. vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
  223. vanna/{vllm → legacy/vllm}/vllm.py +5 -6
  224. vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
  225. vanna/{xinference → legacy/xinference}/xinference.py +6 -6
  226. vanna/py.typed +0 -0
  227. vanna/servers/__init__.py +16 -0
  228. vanna/servers/__main__.py +8 -0
  229. vanna/servers/base/__init__.py +18 -0
  230. vanna/servers/base/chat_handler.py +65 -0
  231. vanna/servers/base/models.py +111 -0
  232. vanna/servers/base/rich_chat_handler.py +141 -0
  233. vanna/servers/base/templates.py +331 -0
  234. vanna/servers/cli/__init__.py +7 -0
  235. vanna/servers/cli/server_runner.py +204 -0
  236. vanna/servers/fastapi/__init__.py +7 -0
  237. vanna/servers/fastapi/app.py +163 -0
  238. vanna/servers/fastapi/routes.py +183 -0
  239. vanna/servers/flask/__init__.py +7 -0
  240. vanna/servers/flask/app.py +132 -0
  241. vanna/servers/flask/routes.py +137 -0
  242. vanna/tools/__init__.py +41 -0
  243. vanna/tools/agent_memory.py +322 -0
  244. vanna/tools/file_system.py +879 -0
  245. vanna/tools/python.py +222 -0
  246. vanna/tools/run_sql.py +165 -0
  247. vanna/tools/visualize_data.py +195 -0
  248. vanna/utils/__init__.py +0 -0
  249. vanna/web_components/__init__.py +44 -0
  250. vanna-2.0.0.dist-info/METADATA +485 -0
  251. vanna-2.0.0.dist-info/RECORD +289 -0
  252. vanna-2.0.0.dist-info/entry_points.txt +3 -0
  253. vanna/bedrock/__init__.py +0 -1
  254. vanna/cohere/__init__.py +0 -2
  255. vanna/faiss/__init__.py +0 -1
  256. vanna/mock/vectordb.py +0 -55
  257. vanna/ollama/ollama.py +0 -103
  258. vanna/opensearch/opensearch_vector.py +0 -392
  259. vanna/opensearch/opensearch_vector_semantic.py +0 -175
  260. vanna/oracle/oracle_vector.py +0 -585
  261. vanna/qianfan/Qianfan_Chat.py +0 -165
  262. vanna/qianfan/Qianfan_embeddings.py +0 -36
  263. vanna/qianwen/QianwenAI_chat.py +0 -133
  264. vanna-0.7.8.dist-info/METADATA +0 -408
  265. vanna-0.7.8.dist-info/RECORD +0 -79
  266. /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
  267. /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
  268. /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
  269. /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
  270. /vanna/{base → legacy/base}/__init__.py +0 -0
  271. /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
  272. /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
  273. /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
  274. /vanna/{google → legacy/google}/__init__.py +0 -0
  275. /vanna/{hf → legacy/hf}/__init__.py +0 -0
  276. /vanna/{local.py → legacy/local.py} +0 -0
  277. /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
  278. /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
  279. /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
  280. /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
  281. /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
  282. /vanna/{mock → legacy/mock}/__init__.py +0 -0
  283. /vanna/{mock → legacy/mock}/embedding.py +0 -0
  284. /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
  285. /vanna/{openai → legacy/openai}/__init__.py +0 -0
  286. /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
  287. /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
  288. /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
  289. /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
  290. /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
  291. /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
  292. /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
  293. /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
  294. /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
  295. /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
  296. /vanna/{types → legacy/types}/__init__.py +0 -0
  297. /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
  298. /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
  299. /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
  300. /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
  301. {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
  302. {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,7 @@
1
+ """
2
+ Ollama integration for Vanna Agents.
3
+ """
4
+
5
+ from .llm import OllamaLlmService
6
+
7
+ __all__ = ["OllamaLlmService"]
@@ -0,0 +1,252 @@
1
+ """
2
+ Ollama LLM service implementation.
3
+
4
+ This module provides an implementation of the LlmService interface backed by
5
+ Ollama's local LLM API. It supports non-streaming responses and streaming
6
+ of text content. Tool calling support depends on the Ollama model being used.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import os
13
+ from typing import Any, AsyncGenerator, Dict, List, Optional
14
+
15
+ from vanna.core.llm import (
16
+ LlmService,
17
+ LlmRequest,
18
+ LlmResponse,
19
+ LlmStreamChunk,
20
+ )
21
+ from vanna.core.tool import ToolCall, ToolSchema
22
+
23
+
24
+ class OllamaLlmService(LlmService):
25
+ """Ollama-backed LLM service for local model inference.
26
+
27
+ Args:
28
+ model: Ollama model name (e.g., "gpt-oss:20b").
29
+ host: Ollama server URL; defaults to "http://localhost:11434" or env `OLLAMA_HOST`.
30
+ timeout: Request timeout in seconds; defaults to 240.
31
+ num_ctx: Context window size; defaults to 8192.
32
+ temperature: Sampling temperature; defaults to 0.7.
33
+ extra_options: Additional options passed to Ollama (e.g., num_predict, top_k, top_p).
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model: str,
39
+ host: Optional[str] = None,
40
+ timeout: float = 240.0,
41
+ num_ctx: int = 8192,
42
+ temperature: float = 0.7,
43
+ **extra_options: Any,
44
+ ) -> None:
45
+ try:
46
+ import ollama
47
+ except ImportError as e:
48
+ raise ImportError(
49
+ "ollama package is required. Install with: pip install 'vanna[ollama]' or pip install ollama"
50
+ ) from e
51
+
52
+ if not model:
53
+ raise ValueError("model parameter is required for Ollama")
54
+
55
+ self.model = model
56
+ self.host = host or os.getenv("OLLAMA_HOST", "http://localhost:11434")
57
+ self.timeout = timeout
58
+ self.num_ctx = num_ctx
59
+ self.temperature = temperature
60
+ self.extra_options = extra_options
61
+
62
+ # Create Ollama client
63
+ self._client = ollama.Client(host=self.host, timeout=timeout)
64
+
65
+ async def send_request(self, request: LlmRequest) -> LlmResponse:
66
+ """Send a non-streaming request to Ollama and return the response."""
67
+ payload = self._build_payload(request)
68
+
69
+ # Call the Ollama API
70
+ try:
71
+ resp = self._client.chat(**payload)
72
+ except Exception as e:
73
+ raise RuntimeError(f"Ollama request failed: {str(e)}") from e
74
+
75
+ # Extract message from response
76
+ message = resp.get("message", {})
77
+ content = message.get("content")
78
+ tool_calls = self._extract_tool_calls_from_message(message)
79
+
80
+ # Extract usage information if available
81
+ usage: Dict[str, int] = {}
82
+ if "prompt_eval_count" in resp or "eval_count" in resp:
83
+ usage = {
84
+ "prompt_tokens": resp.get("prompt_eval_count", 0),
85
+ "completion_tokens": resp.get("eval_count", 0),
86
+ "total_tokens": resp.get("prompt_eval_count", 0)
87
+ + resp.get("eval_count", 0),
88
+ }
89
+
90
+ return LlmResponse(
91
+ content=content,
92
+ tool_calls=tool_calls or None,
93
+ finish_reason=resp.get("done_reason")
94
+ or ("stop" if resp.get("done") else None),
95
+ usage=usage or None,
96
+ )
97
+
98
+ async def stream_request(
99
+ self, request: LlmRequest
100
+ ) -> AsyncGenerator[LlmStreamChunk, None]:
101
+ """Stream a request to Ollama.
102
+
103
+ Emits `LlmStreamChunk` for textual deltas as they arrive. Tool calls are
104
+ accumulated and emitted in a final chunk when the stream ends.
105
+ """
106
+ payload = self._build_payload(request)
107
+
108
+ # Ollama streaming
109
+ try:
110
+ stream = self._client.chat(**payload, stream=True)
111
+ except Exception as e:
112
+ raise RuntimeError(f"Ollama streaming request failed: {str(e)}") from e
113
+
114
+ # Accumulate tool calls if present
115
+ accumulated_tool_calls: List[ToolCall] = []
116
+ last_finish: Optional[str] = None
117
+
118
+ for chunk in stream:
119
+ message = chunk.get("message", {})
120
+
121
+ # Yield text content
122
+ content = message.get("content")
123
+ if content:
124
+ yield LlmStreamChunk(content=content)
125
+
126
+ # Accumulate tool calls
127
+ tool_calls = self._extract_tool_calls_from_message(message)
128
+ if tool_calls:
129
+ accumulated_tool_calls.extend(tool_calls)
130
+
131
+ # Track finish reason
132
+ if chunk.get("done"):
133
+ last_finish = chunk.get("done_reason", "stop")
134
+
135
+ # Emit final chunk with tool calls if any
136
+ if accumulated_tool_calls:
137
+ yield LlmStreamChunk(
138
+ tool_calls=accumulated_tool_calls, finish_reason=last_finish or "stop"
139
+ )
140
+ else:
141
+ # Emit terminal chunk to signal completion
142
+ yield LlmStreamChunk(finish_reason=last_finish or "stop")
143
+
144
+ async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
145
+ """Validate tool schemas. Returns a list of error messages."""
146
+ errors: List[str] = []
147
+ # Basic validation; Ollama model support for tools varies
148
+ for t in tools:
149
+ if not t.name:
150
+ errors.append(f"Tool must have a name")
151
+ if not t.description:
152
+ errors.append(f"Tool '{t.name}' should have a description")
153
+ return errors
154
+
155
+ # Internal helpers
156
+ def _build_payload(self, request: LlmRequest) -> Dict[str, Any]:
157
+ """Build the Ollama chat payload from LlmRequest."""
158
+ messages: List[Dict[str, Any]] = []
159
+
160
+ # Add system prompt as first message if provided
161
+ if request.system_prompt:
162
+ messages.append({"role": "system", "content": request.system_prompt})
163
+
164
+ # Convert messages to Ollama format
165
+ for m in request.messages:
166
+ msg: Dict[str, Any] = {"role": m.role, "content": m.content or ""}
167
+
168
+ # Handle tool calls in assistant messages
169
+ if m.role == "assistant" and m.tool_calls:
170
+ # Some Ollama models support tool_calls in message
171
+ tool_calls_payload = []
172
+ for tc in m.tool_calls:
173
+ tool_calls_payload.append(
174
+ {"function": {"name": tc.name, "arguments": tc.arguments}}
175
+ )
176
+ msg["tool_calls"] = tool_calls_payload
177
+
178
+ messages.append(msg)
179
+
180
+ # Build tools array if tools are provided
181
+ tools_payload: Optional[List[Dict[str, Any]]] = None
182
+ if request.tools:
183
+ tools_payload = []
184
+ for t in request.tools:
185
+ tools_payload.append(
186
+ {
187
+ "type": "function",
188
+ "function": {
189
+ "name": t.name,
190
+ "description": t.description,
191
+ "parameters": t.parameters,
192
+ },
193
+ }
194
+ )
195
+
196
+ # Build options
197
+ options: Dict[str, Any] = {
198
+ "num_ctx": self.num_ctx,
199
+ "temperature": self.temperature,
200
+ **self.extra_options,
201
+ }
202
+
203
+ # Build final payload
204
+ payload: Dict[str, Any] = {
205
+ "model": self.model,
206
+ "messages": messages,
207
+ "options": options,
208
+ }
209
+
210
+ # Add tools if provided (note: not all Ollama models support tools)
211
+ if tools_payload:
212
+ payload["tools"] = tools_payload
213
+
214
+ return payload
215
+
216
+ def _extract_tool_calls_from_message(
217
+ self, message: Dict[str, Any]
218
+ ) -> List[ToolCall]:
219
+ """Extract tool calls from Ollama message."""
220
+ tool_calls: List[ToolCall] = []
221
+
222
+ # Check for tool_calls in message
223
+ raw_tool_calls = message.get("tool_calls", [])
224
+ if not raw_tool_calls:
225
+ return tool_calls
226
+
227
+ for idx, tc in enumerate(raw_tool_calls):
228
+ fn = tc.get("function", {})
229
+ name = fn.get("name")
230
+ if not name:
231
+ continue
232
+
233
+ # Parse arguments
234
+ arguments = fn.get("arguments", {})
235
+ if isinstance(arguments, str):
236
+ try:
237
+ arguments = json.loads(arguments)
238
+ except Exception:
239
+ arguments = {"_raw": arguments}
240
+
241
+ if not isinstance(arguments, dict):
242
+ arguments = {"args": arguments}
243
+
244
+ tool_calls.append(
245
+ ToolCall(
246
+ id=tc.get("id", f"tool_call_{idx}"),
247
+ name=name,
248
+ arguments=arguments,
249
+ )
250
+ )
251
+
252
+ return tool_calls
@@ -0,0 +1,10 @@
1
+ """
2
+ OpenAI integration.
3
+
4
+ This module provides OpenAI LLM service implementations.
5
+ """
6
+
7
+ from .llm import OpenAILlmService
8
+ from .responses import OpenAIResponsesService
9
+
10
+ __all__ = ["OpenAILlmService", "OpenAIResponsesService"]
@@ -0,0 +1,267 @@
1
+ """
2
+ OpenAI LLM service implementation.
3
+
4
+ This module provides an implementation of the LlmService interface backed by
5
+ OpenAI's Chat Completions API (openai>=1.0.0). It supports non-streaming
6
+ responses and best-effort streaming of text content. Tool/function calling is
7
+ passed through when tools are provided, but full tool-call conversation
8
+ round-tripping may require adding assistant tool-call messages to the
9
+ conversation upstream.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import os
16
+ from typing import Any, AsyncGenerator, Dict, List, Optional, cast
17
+
18
+ from vanna.core.llm import (
19
+ LlmService,
20
+ LlmRequest,
21
+ LlmResponse,
22
+ LlmStreamChunk,
23
+ )
24
+ from vanna.core.tool import ToolCall, ToolSchema
25
+
26
+
27
+ class OpenAILlmService(LlmService):
28
+ """OpenAI Chat Completions-backed LLM service.
29
+
30
+ Args:
31
+ model: OpenAI model name (e.g., "gpt-5").
32
+ api_key: API key; falls back to env `OPENAI_API_KEY`.
33
+ organization: Optional org; env `OPENAI_ORG` if unset.
34
+ base_url: Optional custom base URL; env `OPENAI_BASE_URL` if unset.
35
+ extra_client_kwargs: Extra kwargs forwarded to `openai.OpenAI()`.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ model: Optional[str] = None,
41
+ api_key: Optional[str] = None,
42
+ organization: Optional[str] = None,
43
+ base_url: Optional[str] = None,
44
+ **extra_client_kwargs: Any,
45
+ ) -> None:
46
+ try:
47
+ from openai import OpenAI
48
+ except Exception as e: # pragma: no cover - import-time error surface
49
+ raise ImportError(
50
+ "openai package is required. Install with: pip install 'vanna[openai]'"
51
+ ) from e
52
+
53
+ self.model = model or os.getenv("OPENAI_MODEL", "gpt-5")
54
+ api_key = api_key or os.getenv("OPENAI_API_KEY")
55
+ organization = organization or os.getenv("OPENAI_ORG")
56
+ base_url = base_url or os.getenv("OPENAI_BASE_URL")
57
+
58
+ client_kwargs: Dict[str, Any] = {**extra_client_kwargs}
59
+ if api_key:
60
+ client_kwargs["api_key"] = api_key
61
+ if organization:
62
+ client_kwargs["organization"] = organization
63
+ if base_url:
64
+ client_kwargs["base_url"] = base_url
65
+
66
+ self._client = OpenAI(**client_kwargs)
67
+
68
+ async def send_request(self, request: LlmRequest) -> LlmResponse:
69
+ """Send a non-streaming request to OpenAI and return the response."""
70
+ payload = self._build_payload(request)
71
+
72
+ # Call the API synchronously; this function is async but we can block here.
73
+ resp = self._client.chat.completions.create(**payload, stream=False)
74
+
75
+ if not resp.choices:
76
+ return LlmResponse(content=None, tool_calls=None, finish_reason=None)
77
+
78
+ choice = resp.choices[0]
79
+ content: Optional[str] = getattr(choice.message, "content", None)
80
+ tool_calls = self._extract_tool_calls_from_message(choice.message)
81
+
82
+ usage: Dict[str, int] = {}
83
+ if getattr(resp, "usage", None):
84
+ usage = {
85
+ k: int(v)
86
+ for k, v in {
87
+ "prompt_tokens": getattr(resp.usage, "prompt_tokens", 0),
88
+ "completion_tokens": getattr(resp.usage, "completion_tokens", 0),
89
+ "total_tokens": getattr(resp.usage, "total_tokens", 0),
90
+ }.items()
91
+ }
92
+
93
+ return LlmResponse(
94
+ content=content,
95
+ tool_calls=tool_calls or None,
96
+ finish_reason=getattr(choice, "finish_reason", None),
97
+ usage=usage or None,
98
+ )
99
+
100
+ async def stream_request(
101
+ self, request: LlmRequest
102
+ ) -> AsyncGenerator[LlmStreamChunk, None]:
103
+ """Stream a request to OpenAI.
104
+
105
+ Emits `LlmStreamChunk` for textual deltas as they arrive. Tool-calls are
106
+ accumulated and emitted in a final chunk when the stream ends.
107
+ """
108
+ payload = self._build_payload(request)
109
+
110
+ # Synchronous streaming iterator; iterate within async context.
111
+ stream = self._client.chat.completions.create(**payload, stream=True)
112
+
113
+ # Builders for streamed tool-calls (index -> partial)
114
+ tc_builders: Dict[int, Dict[str, Optional[str]]] = {}
115
+ last_finish: Optional[str] = None
116
+
117
+ for event in stream:
118
+ if not getattr(event, "choices", None):
119
+ continue
120
+
121
+ choice = event.choices[0]
122
+ delta = getattr(choice, "delta", None)
123
+ if delta is None:
124
+ # Some SDK versions use `event.choices[0].message` on the final packet
125
+ last_finish = getattr(choice, "finish_reason", last_finish)
126
+ continue
127
+
128
+ # Text content
129
+ content_piece: Optional[str] = getattr(delta, "content", None)
130
+ if content_piece:
131
+ yield LlmStreamChunk(content=content_piece)
132
+
133
+ # Tool calls (streamed)
134
+ streamed_tool_calls = getattr(delta, "tool_calls", None)
135
+ if streamed_tool_calls:
136
+ for tc in streamed_tool_calls:
137
+ idx = getattr(tc, "index", 0) or 0
138
+ b = tc_builders.setdefault(
139
+ idx, {"id": None, "name": None, "arguments": ""}
140
+ )
141
+ if getattr(tc, "id", None):
142
+ b["id"] = tc.id
143
+ fn = getattr(tc, "function", None)
144
+ if fn is not None:
145
+ if getattr(fn, "name", None):
146
+ b["name"] = fn.name
147
+ if getattr(fn, "arguments", None):
148
+ b["arguments"] = (b["arguments"] or "") + fn.arguments
149
+
150
+ last_finish = getattr(choice, "finish_reason", last_finish)
151
+
152
+ # Emit final tool-calls chunk if any
153
+ final_tool_calls: List[ToolCall] = []
154
+ for b in tc_builders.values():
155
+ if not b.get("name"):
156
+ continue
157
+ args_raw = b.get("arguments") or "{}"
158
+ try:
159
+ loaded = json.loads(args_raw)
160
+ if isinstance(loaded, dict):
161
+ args_dict: Dict[str, Any] = loaded
162
+ else:
163
+ args_dict = {"args": loaded}
164
+ except Exception:
165
+ args_dict = {"_raw": args_raw}
166
+ final_tool_calls.append(
167
+ ToolCall(
168
+ id=b.get("id") or "tool_call",
169
+ name=b["name"] or "tool",
170
+ arguments=args_dict,
171
+ )
172
+ )
173
+
174
+ if final_tool_calls:
175
+ yield LlmStreamChunk(tool_calls=final_tool_calls, finish_reason=last_finish)
176
+ else:
177
+ # Still emit a terminal chunk to signal completion
178
+ yield LlmStreamChunk(finish_reason=last_finish or "stop")
179
+
180
+ async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
181
+ """Validate tool schemas. Returns a list of error messages."""
182
+ errors: List[str] = []
183
+ # Basic checks; OpenAI will enforce further validation server-side.
184
+ for t in tools:
185
+ if not t.name or len(t.name) > 64:
186
+ errors.append(f"Invalid tool name: {t.name!r}")
187
+ return errors
188
+
189
+ # Internal helpers
190
+ def _build_payload(self, request: LlmRequest) -> Dict[str, Any]:
191
+ messages: List[Dict[str, Any]] = []
192
+
193
+ # Add system prompt as first message if provided
194
+ if request.system_prompt:
195
+ messages.append({"role": "system", "content": request.system_prompt})
196
+
197
+ for m in request.messages:
198
+ msg: Dict[str, Any] = {"role": m.role, "content": m.content}
199
+ if m.role == "tool" and m.tool_call_id:
200
+ msg["tool_call_id"] = m.tool_call_id
201
+ elif m.role == "assistant" and m.tool_calls:
202
+ # Convert tool calls to OpenAI format
203
+ tool_calls_payload = []
204
+ for tc in m.tool_calls:
205
+ tool_calls_payload.append(
206
+ {
207
+ "id": tc.id,
208
+ "type": "function",
209
+ "function": {
210
+ "name": tc.name,
211
+ "arguments": json.dumps(tc.arguments),
212
+ },
213
+ }
214
+ )
215
+ msg["tool_calls"] = tool_calls_payload
216
+ messages.append(msg)
217
+
218
+ tools_payload: Optional[List[Dict[str, Any]]] = None
219
+ if request.tools:
220
+ tools_payload = [
221
+ {
222
+ "type": "function",
223
+ "function": {
224
+ "name": t.name,
225
+ "description": t.description,
226
+ "parameters": t.parameters,
227
+ },
228
+ }
229
+ for t in request.tools
230
+ ]
231
+
232
+ payload: Dict[str, Any] = {
233
+ "model": self.model,
234
+ "messages": messages,
235
+ }
236
+ if request.max_tokens is not None:
237
+ payload["max_tokens"] = request.max_tokens
238
+ if tools_payload:
239
+ payload["tools"] = tools_payload
240
+ payload["tool_choice"] = "auto"
241
+
242
+ return payload
243
+
244
+ def _extract_tool_calls_from_message(self, message: Any) -> List[ToolCall]:
245
+ tool_calls: List[ToolCall] = []
246
+ raw_tool_calls = getattr(message, "tool_calls", None) or []
247
+ for tc in raw_tool_calls:
248
+ fn = getattr(tc, "function", None)
249
+ if not fn:
250
+ continue
251
+ args_raw = getattr(fn, "arguments", "{}")
252
+ try:
253
+ loaded = json.loads(args_raw)
254
+ if isinstance(loaded, dict):
255
+ args_dict: Dict[str, Any] = loaded
256
+ else:
257
+ args_dict = {"args": loaded}
258
+ except Exception:
259
+ args_dict = {"_raw": args_raw}
260
+ tool_calls.append(
261
+ ToolCall(
262
+ id=getattr(tc, "id", "tool_call"),
263
+ name=getattr(fn, "name", "tool"),
264
+ arguments=args_dict,
265
+ )
266
+ )
267
+ return tool_calls