agno 2.0.1__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +6015 -2823
- agno/api/api.py +2 -0
- agno/api/os.py +1 -1
- agno/culture/__init__.py +3 -0
- agno/culture/manager.py +956 -0
- agno/db/async_postgres/__init__.py +3 -0
- agno/db/base.py +385 -6
- agno/db/dynamo/dynamo.py +388 -81
- agno/db/dynamo/schemas.py +47 -10
- agno/db/dynamo/utils.py +63 -4
- agno/db/firestore/firestore.py +435 -64
- agno/db/firestore/schemas.py +11 -0
- agno/db/firestore/utils.py +102 -4
- agno/db/gcs_json/gcs_json_db.py +384 -42
- agno/db/gcs_json/utils.py +60 -26
- agno/db/in_memory/in_memory_db.py +351 -66
- agno/db/in_memory/utils.py +60 -2
- agno/db/json/json_db.py +339 -48
- agno/db/json/utils.py +60 -26
- agno/db/migrations/manager.py +199 -0
- agno/db/migrations/v1_to_v2.py +510 -37
- agno/db/migrations/versions/__init__.py +0 -0
- agno/db/migrations/versions/v2_3_0.py +938 -0
- agno/db/mongo/__init__.py +15 -1
- agno/db/mongo/async_mongo.py +2036 -0
- agno/db/mongo/mongo.py +653 -76
- agno/db/mongo/schemas.py +13 -0
- agno/db/mongo/utils.py +80 -8
- agno/db/mysql/mysql.py +687 -25
- agno/db/mysql/schemas.py +61 -37
- agno/db/mysql/utils.py +60 -2
- agno/db/postgres/__init__.py +2 -1
- agno/db/postgres/async_postgres.py +2001 -0
- agno/db/postgres/postgres.py +676 -57
- agno/db/postgres/schemas.py +43 -18
- agno/db/postgres/utils.py +164 -2
- agno/db/redis/redis.py +344 -38
- agno/db/redis/schemas.py +18 -0
- agno/db/redis/utils.py +60 -2
- agno/db/schemas/__init__.py +2 -1
- agno/db/schemas/culture.py +120 -0
- agno/db/schemas/memory.py +13 -0
- agno/db/singlestore/schemas.py +26 -1
- agno/db/singlestore/singlestore.py +687 -53
- agno/db/singlestore/utils.py +60 -2
- agno/db/sqlite/__init__.py +2 -1
- agno/db/sqlite/async_sqlite.py +2371 -0
- agno/db/sqlite/schemas.py +24 -0
- agno/db/sqlite/sqlite.py +774 -85
- agno/db/sqlite/utils.py +168 -5
- agno/db/surrealdb/__init__.py +3 -0
- agno/db/surrealdb/metrics.py +292 -0
- agno/db/surrealdb/models.py +309 -0
- agno/db/surrealdb/queries.py +71 -0
- agno/db/surrealdb/surrealdb.py +1361 -0
- agno/db/surrealdb/utils.py +147 -0
- agno/db/utils.py +50 -22
- agno/eval/accuracy.py +50 -43
- agno/eval/performance.py +6 -3
- agno/eval/reliability.py +6 -3
- agno/eval/utils.py +33 -16
- agno/exceptions.py +68 -1
- agno/filters.py +354 -0
- agno/guardrails/__init__.py +6 -0
- agno/guardrails/base.py +19 -0
- agno/guardrails/openai.py +144 -0
- agno/guardrails/pii.py +94 -0
- agno/guardrails/prompt_injection.py +52 -0
- agno/integrations/discord/client.py +1 -0
- agno/knowledge/chunking/agentic.py +13 -10
- agno/knowledge/chunking/fixed.py +1 -1
- agno/knowledge/chunking/semantic.py +40 -8
- agno/knowledge/chunking/strategy.py +59 -15
- agno/knowledge/embedder/aws_bedrock.py +9 -4
- agno/knowledge/embedder/azure_openai.py +54 -0
- agno/knowledge/embedder/base.py +2 -0
- agno/knowledge/embedder/cohere.py +184 -5
- agno/knowledge/embedder/fastembed.py +1 -1
- agno/knowledge/embedder/google.py +79 -1
- agno/knowledge/embedder/huggingface.py +9 -4
- agno/knowledge/embedder/jina.py +63 -0
- agno/knowledge/embedder/mistral.py +78 -11
- agno/knowledge/embedder/nebius.py +1 -1
- agno/knowledge/embedder/ollama.py +13 -0
- agno/knowledge/embedder/openai.py +37 -65
- agno/knowledge/embedder/sentence_transformer.py +8 -4
- agno/knowledge/embedder/vllm.py +262 -0
- agno/knowledge/embedder/voyageai.py +69 -16
- agno/knowledge/knowledge.py +594 -186
- agno/knowledge/reader/base.py +9 -2
- agno/knowledge/reader/csv_reader.py +8 -10
- agno/knowledge/reader/docx_reader.py +5 -6
- agno/knowledge/reader/field_labeled_csv_reader.py +290 -0
- agno/knowledge/reader/json_reader.py +6 -5
- agno/knowledge/reader/markdown_reader.py +13 -13
- agno/knowledge/reader/pdf_reader.py +43 -68
- agno/knowledge/reader/pptx_reader.py +101 -0
- agno/knowledge/reader/reader_factory.py +51 -6
- agno/knowledge/reader/s3_reader.py +3 -15
- agno/knowledge/reader/tavily_reader.py +194 -0
- agno/knowledge/reader/text_reader.py +13 -13
- agno/knowledge/reader/web_search_reader.py +2 -43
- agno/knowledge/reader/website_reader.py +43 -25
- agno/knowledge/reranker/__init__.py +2 -8
- agno/knowledge/types.py +9 -0
- agno/knowledge/utils.py +20 -0
- agno/media.py +72 -0
- agno/memory/manager.py +336 -82
- agno/models/aimlapi/aimlapi.py +2 -2
- agno/models/anthropic/claude.py +183 -37
- agno/models/aws/bedrock.py +52 -112
- agno/models/aws/claude.py +33 -1
- agno/models/azure/ai_foundry.py +33 -15
- agno/models/azure/openai_chat.py +25 -8
- agno/models/base.py +999 -519
- agno/models/cerebras/cerebras.py +19 -13
- agno/models/cerebras/cerebras_openai.py +8 -5
- agno/models/cohere/chat.py +27 -1
- agno/models/cometapi/__init__.py +5 -0
- agno/models/cometapi/cometapi.py +57 -0
- agno/models/dashscope/dashscope.py +1 -0
- agno/models/deepinfra/deepinfra.py +2 -2
- agno/models/deepseek/deepseek.py +2 -2
- agno/models/fireworks/fireworks.py +2 -2
- agno/models/google/gemini.py +103 -31
- agno/models/groq/groq.py +28 -11
- agno/models/huggingface/huggingface.py +2 -1
- agno/models/internlm/internlm.py +2 -2
- agno/models/langdb/langdb.py +4 -4
- agno/models/litellm/chat.py +18 -1
- agno/models/litellm/litellm_openai.py +2 -2
- agno/models/llama_cpp/__init__.py +5 -0
- agno/models/llama_cpp/llama_cpp.py +22 -0
- agno/models/message.py +139 -0
- agno/models/meta/llama.py +27 -10
- agno/models/meta/llama_openai.py +5 -17
- agno/models/nebius/nebius.py +6 -6
- agno/models/nexus/__init__.py +3 -0
- agno/models/nexus/nexus.py +22 -0
- agno/models/nvidia/nvidia.py +2 -2
- agno/models/ollama/chat.py +59 -5
- agno/models/openai/chat.py +69 -29
- agno/models/openai/responses.py +103 -106
- agno/models/openrouter/openrouter.py +41 -3
- agno/models/perplexity/perplexity.py +4 -5
- agno/models/portkey/portkey.py +3 -3
- agno/models/requesty/__init__.py +5 -0
- agno/models/requesty/requesty.py +52 -0
- agno/models/response.py +77 -1
- agno/models/sambanova/sambanova.py +2 -2
- agno/models/siliconflow/__init__.py +5 -0
- agno/models/siliconflow/siliconflow.py +25 -0
- agno/models/together/together.py +2 -2
- agno/models/utils.py +254 -8
- agno/models/vercel/v0.py +2 -2
- agno/models/vertexai/__init__.py +0 -0
- agno/models/vertexai/claude.py +96 -0
- agno/models/vllm/vllm.py +1 -0
- agno/models/xai/xai.py +3 -2
- agno/os/app.py +543 -178
- agno/os/auth.py +24 -14
- agno/os/config.py +1 -0
- agno/os/interfaces/__init__.py +1 -0
- agno/os/interfaces/a2a/__init__.py +3 -0
- agno/os/interfaces/a2a/a2a.py +42 -0
- agno/os/interfaces/a2a/router.py +250 -0
- agno/os/interfaces/a2a/utils.py +924 -0
- agno/os/interfaces/agui/agui.py +23 -7
- agno/os/interfaces/agui/router.py +27 -3
- agno/os/interfaces/agui/utils.py +242 -142
- agno/os/interfaces/base.py +6 -2
- agno/os/interfaces/slack/router.py +81 -23
- agno/os/interfaces/slack/slack.py +29 -14
- agno/os/interfaces/whatsapp/router.py +11 -4
- agno/os/interfaces/whatsapp/whatsapp.py +14 -7
- agno/os/mcp.py +111 -54
- agno/os/middleware/__init__.py +7 -0
- agno/os/middleware/jwt.py +233 -0
- agno/os/router.py +556 -139
- agno/os/routers/evals/evals.py +71 -34
- agno/os/routers/evals/schemas.py +31 -31
- agno/os/routers/evals/utils.py +6 -5
- agno/os/routers/health.py +31 -0
- agno/os/routers/home.py +52 -0
- agno/os/routers/knowledge/knowledge.py +185 -38
- agno/os/routers/knowledge/schemas.py +82 -22
- agno/os/routers/memory/memory.py +158 -53
- agno/os/routers/memory/schemas.py +20 -16
- agno/os/routers/metrics/metrics.py +20 -8
- agno/os/routers/metrics/schemas.py +16 -16
- agno/os/routers/session/session.py +499 -38
- agno/os/schema.py +308 -198
- agno/os/utils.py +401 -41
- agno/reasoning/anthropic.py +80 -0
- agno/reasoning/azure_ai_foundry.py +2 -2
- agno/reasoning/deepseek.py +2 -2
- agno/reasoning/default.py +3 -1
- agno/reasoning/gemini.py +73 -0
- agno/reasoning/groq.py +2 -2
- agno/reasoning/ollama.py +2 -2
- agno/reasoning/openai.py +7 -2
- agno/reasoning/vertexai.py +76 -0
- agno/run/__init__.py +6 -0
- agno/run/agent.py +248 -94
- agno/run/base.py +44 -5
- agno/run/team.py +238 -97
- agno/run/workflow.py +144 -33
- agno/session/agent.py +105 -89
- agno/session/summary.py +65 -25
- agno/session/team.py +176 -96
- agno/session/workflow.py +406 -40
- agno/team/team.py +3854 -1610
- agno/tools/dalle.py +2 -4
- agno/tools/decorator.py +4 -2
- agno/tools/duckduckgo.py +15 -11
- agno/tools/e2b.py +14 -7
- agno/tools/eleven_labs.py +23 -25
- agno/tools/exa.py +21 -16
- agno/tools/file.py +153 -23
- agno/tools/file_generation.py +350 -0
- agno/tools/firecrawl.py +4 -4
- agno/tools/function.py +250 -30
- agno/tools/gmail.py +238 -14
- agno/tools/google_drive.py +270 -0
- agno/tools/googlecalendar.py +36 -8
- agno/tools/googlesheets.py +20 -5
- agno/tools/jira.py +20 -0
- agno/tools/knowledge.py +3 -3
- agno/tools/mcp/__init__.py +10 -0
- agno/tools/mcp/mcp.py +331 -0
- agno/tools/mcp/multi_mcp.py +347 -0
- agno/tools/mcp/params.py +24 -0
- agno/tools/mcp_toolbox.py +284 -0
- agno/tools/mem0.py +11 -17
- agno/tools/memori.py +1 -53
- agno/tools/memory.py +419 -0
- agno/tools/models/nebius.py +5 -5
- agno/tools/models_labs.py +20 -10
- agno/tools/notion.py +204 -0
- agno/tools/parallel.py +314 -0
- agno/tools/scrapegraph.py +58 -31
- agno/tools/searxng.py +2 -2
- agno/tools/serper.py +2 -2
- agno/tools/slack.py +18 -3
- agno/tools/spider.py +2 -2
- agno/tools/tavily.py +146 -0
- agno/tools/whatsapp.py +1 -1
- agno/tools/workflow.py +278 -0
- agno/tools/yfinance.py +12 -11
- agno/utils/agent.py +820 -0
- agno/utils/audio.py +27 -0
- agno/utils/common.py +90 -1
- agno/utils/events.py +217 -2
- agno/utils/gemini.py +180 -22
- agno/utils/hooks.py +57 -0
- agno/utils/http.py +111 -0
- agno/utils/knowledge.py +12 -5
- agno/utils/log.py +1 -0
- agno/utils/mcp.py +92 -2
- agno/utils/media.py +188 -10
- agno/utils/merge_dict.py +22 -1
- agno/utils/message.py +60 -0
- agno/utils/models/claude.py +40 -11
- agno/utils/print_response/agent.py +105 -21
- agno/utils/print_response/team.py +103 -38
- agno/utils/print_response/workflow.py +251 -34
- agno/utils/reasoning.py +22 -1
- agno/utils/serialize.py +32 -0
- agno/utils/streamlit.py +16 -10
- agno/utils/string.py +41 -0
- agno/utils/team.py +98 -9
- agno/utils/tools.py +1 -1
- agno/vectordb/base.py +23 -4
- agno/vectordb/cassandra/cassandra.py +65 -9
- agno/vectordb/chroma/chromadb.py +182 -38
- agno/vectordb/clickhouse/clickhousedb.py +64 -11
- agno/vectordb/couchbase/couchbase.py +105 -10
- agno/vectordb/lancedb/lance_db.py +124 -133
- agno/vectordb/langchaindb/langchaindb.py +25 -7
- agno/vectordb/lightrag/lightrag.py +17 -3
- agno/vectordb/llamaindex/__init__.py +3 -0
- agno/vectordb/llamaindex/llamaindexdb.py +46 -7
- agno/vectordb/milvus/milvus.py +126 -9
- agno/vectordb/mongodb/__init__.py +7 -1
- agno/vectordb/mongodb/mongodb.py +112 -7
- agno/vectordb/pgvector/pgvector.py +142 -21
- agno/vectordb/pineconedb/pineconedb.py +80 -8
- agno/vectordb/qdrant/qdrant.py +125 -39
- agno/vectordb/redis/__init__.py +9 -0
- agno/vectordb/redis/redisdb.py +694 -0
- agno/vectordb/singlestore/singlestore.py +111 -25
- agno/vectordb/surrealdb/surrealdb.py +31 -5
- agno/vectordb/upstashdb/upstashdb.py +76 -8
- agno/vectordb/weaviate/weaviate.py +86 -15
- agno/workflow/__init__.py +2 -0
- agno/workflow/agent.py +299 -0
- agno/workflow/condition.py +112 -18
- agno/workflow/loop.py +69 -10
- agno/workflow/parallel.py +266 -118
- agno/workflow/router.py +110 -17
- agno/workflow/step.py +638 -129
- agno/workflow/steps.py +65 -6
- agno/workflow/types.py +61 -23
- agno/workflow/workflow.py +2085 -272
- {agno-2.0.1.dist-info → agno-2.3.0.dist-info}/METADATA +182 -58
- agno-2.3.0.dist-info/RECORD +577 -0
- agno/knowledge/reader/url_reader.py +0 -128
- agno/tools/googlesearch.py +0 -98
- agno/tools/mcp.py +0 -610
- agno/utils/models/aws_claude.py +0 -170
- agno-2.0.1.dist-info/RECORD +0 -515
- {agno-2.0.1.dist-info → agno-2.3.0.dist-info}/WHEEL +0 -0
- {agno-2.0.1.dist-info → agno-2.3.0.dist-info}/licenses/LICENSE +0 -0
- {agno-2.0.1.dist-info → agno-2.3.0.dist-info}/top_level.txt +0 -0
agno/models/ollama/chat.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from dataclasses import dataclass
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from os import getenv
|
|
3
4
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Type, Union
|
|
4
5
|
|
|
5
6
|
from pydantic import BaseModel
|
|
@@ -10,6 +11,7 @@ from agno.models.message import Message
|
|
|
10
11
|
from agno.models.metrics import Metrics
|
|
11
12
|
from agno.models.response import ModelResponse
|
|
12
13
|
from agno.utils.log import log_debug, log_warning
|
|
14
|
+
from agno.utils.reasoning import extract_thinking_content
|
|
13
15
|
|
|
14
16
|
try:
|
|
15
17
|
from ollama import AsyncClient as AsyncOllamaClient
|
|
@@ -43,6 +45,7 @@ class Ollama(Model):
|
|
|
43
45
|
# Client parameters
|
|
44
46
|
host: Optional[str] = None
|
|
45
47
|
timeout: Optional[Any] = None
|
|
48
|
+
api_key: Optional[str] = field(default_factory=lambda: getenv("OLLAMA_API_KEY"))
|
|
46
49
|
client_params: Optional[Dict[str, Any]] = None
|
|
47
50
|
|
|
48
51
|
# Ollama clients
|
|
@@ -50,10 +53,23 @@ class Ollama(Model):
|
|
|
50
53
|
async_client: Optional[AsyncOllamaClient] = None
|
|
51
54
|
|
|
52
55
|
def _get_client_params(self) -> Dict[str, Any]:
|
|
56
|
+
host = self.host
|
|
57
|
+
headers = {}
|
|
58
|
+
|
|
59
|
+
if self.api_key:
|
|
60
|
+
if not host:
|
|
61
|
+
host = "https://ollama.com"
|
|
62
|
+
headers["authorization"] = f"Bearer {self.api_key}"
|
|
63
|
+
log_debug(f"Using Ollama cloud endpoint: {host}")
|
|
64
|
+
|
|
53
65
|
base_params = {
|
|
54
|
-
"host":
|
|
66
|
+
"host": host,
|
|
55
67
|
"timeout": self.timeout,
|
|
56
68
|
}
|
|
69
|
+
|
|
70
|
+
if headers:
|
|
71
|
+
base_params["headers"] = headers
|
|
72
|
+
|
|
57
73
|
# Create client_params dict with non-None values
|
|
58
74
|
client_params = {k: v for k, v in base_params.items() if v is not None}
|
|
59
75
|
# Add additional client params if provided
|
|
@@ -84,7 +100,8 @@ class Ollama(Model):
|
|
|
84
100
|
if self.async_client is not None:
|
|
85
101
|
return self.async_client
|
|
86
102
|
|
|
87
|
-
|
|
103
|
+
self.async_client = AsyncOllamaClient(**self._get_client_params())
|
|
104
|
+
return self.async_client
|
|
88
105
|
|
|
89
106
|
def get_request_params(
|
|
90
107
|
self,
|
|
@@ -144,6 +161,28 @@ class Ollama(Model):
|
|
|
144
161
|
"role": message.role,
|
|
145
162
|
"content": message.content,
|
|
146
163
|
}
|
|
164
|
+
|
|
165
|
+
if message.role == "assistant" and message.tool_calls is not None:
|
|
166
|
+
# Format tool calls for assistant messages
|
|
167
|
+
formatted_tool_calls = []
|
|
168
|
+
for tool_call in message.tool_calls:
|
|
169
|
+
if "function" in tool_call:
|
|
170
|
+
function_data = tool_call["function"]
|
|
171
|
+
formatted_tool_call = {
|
|
172
|
+
"id": tool_call.get("id"),
|
|
173
|
+
"type": "function",
|
|
174
|
+
"function": {
|
|
175
|
+
"name": function_data["name"],
|
|
176
|
+
"arguments": json.loads(function_data["arguments"])
|
|
177
|
+
if isinstance(function_data["arguments"], str)
|
|
178
|
+
else function_data["arguments"],
|
|
179
|
+
},
|
|
180
|
+
}
|
|
181
|
+
formatted_tool_calls.append(formatted_tool_call)
|
|
182
|
+
|
|
183
|
+
if formatted_tool_calls:
|
|
184
|
+
_message["tool_calls"] = formatted_tool_calls
|
|
185
|
+
|
|
147
186
|
if message.role == "user":
|
|
148
187
|
if message.images is not None:
|
|
149
188
|
message_images = []
|
|
@@ -309,6 +348,16 @@ class Ollama(Model):
|
|
|
309
348
|
if response_message.get("content") is not None:
|
|
310
349
|
model_response.content = response_message.get("content")
|
|
311
350
|
|
|
351
|
+
# Extract thinking content between <think> tags if present
|
|
352
|
+
if model_response.content and model_response.content.find("<think>") != -1:
|
|
353
|
+
reasoning_content, clean_content = extract_thinking_content(model_response.content)
|
|
354
|
+
|
|
355
|
+
if reasoning_content:
|
|
356
|
+
# Store extracted thinking content separately
|
|
357
|
+
model_response.reasoning_content = reasoning_content
|
|
358
|
+
# Update main content with clean version
|
|
359
|
+
model_response.content = clean_content
|
|
360
|
+
|
|
312
361
|
if response_message.get("tool_calls") is not None:
|
|
313
362
|
if model_response.tool_calls is None:
|
|
314
363
|
model_response.tool_calls = []
|
|
@@ -380,8 +429,13 @@ class Ollama(Model):
|
|
|
380
429
|
"""
|
|
381
430
|
metrics = Metrics()
|
|
382
431
|
|
|
383
|
-
|
|
384
|
-
|
|
432
|
+
# Safely handle None values from Ollama Cloud responses
|
|
433
|
+
input_tokens = response.get("prompt_eval_count")
|
|
434
|
+
output_tokens = response.get("eval_count")
|
|
435
|
+
|
|
436
|
+
# Default to 0 if None
|
|
437
|
+
metrics.input_tokens = input_tokens if input_tokens is not None else 0
|
|
438
|
+
metrics.output_tokens = output_tokens if output_tokens is not None else 0
|
|
385
439
|
metrics.total_tokens = metrics.input_tokens + metrics.output_tokens
|
|
386
440
|
|
|
387
441
|
return metrics
|
agno/models/openai/chat.py
CHANGED
|
@@ -14,21 +14,19 @@ from agno.models.message import Message
|
|
|
14
14
|
from agno.models.metrics import Metrics
|
|
15
15
|
from agno.models.response import ModelResponse
|
|
16
16
|
from agno.run.agent import RunOutput
|
|
17
|
+
from agno.run.team import TeamRunOutput
|
|
18
|
+
from agno.utils.http import get_default_async_client, get_default_sync_client
|
|
17
19
|
from agno.utils.log import log_debug, log_error, log_warning
|
|
18
20
|
from agno.utils.openai import _format_file_for_message, audio_to_message, images_to_message
|
|
21
|
+
from agno.utils.reasoning import extract_thinking_content
|
|
19
22
|
|
|
20
23
|
try:
|
|
21
24
|
from openai import APIConnectionError, APIStatusError, RateLimitError
|
|
22
25
|
from openai import AsyncOpenAI as AsyncOpenAIClient
|
|
23
26
|
from openai import OpenAI as OpenAIClient
|
|
24
27
|
from openai.types import CompletionUsage
|
|
25
|
-
from openai.types.chat import ChatCompletionAudio
|
|
26
|
-
from openai.types.chat.
|
|
27
|
-
from openai.types.chat.chat_completion_chunk import (
|
|
28
|
-
ChatCompletionChunk,
|
|
29
|
-
ChoiceDelta,
|
|
30
|
-
ChoiceDeltaToolCall,
|
|
31
|
-
)
|
|
28
|
+
from openai.types.chat import ChatCompletion, ChatCompletionAudio, ChatCompletionChunk
|
|
29
|
+
from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall
|
|
32
30
|
except (ImportError, ModuleNotFoundError):
|
|
33
31
|
raise ImportError("`openai` not installed. Please install using `pip install openai`")
|
|
34
32
|
|
|
@@ -68,8 +66,10 @@ class OpenAIChat(Model):
|
|
|
68
66
|
user: Optional[str] = None
|
|
69
67
|
top_p: Optional[float] = None
|
|
70
68
|
service_tier: Optional[str] = None # "auto" | "default" | "flex" | "priority", defaults to "auto" when not set
|
|
69
|
+
strict_output: bool = True # When True, guarantees schema adherence for structured outputs. When False, attempts to follow schema as a guide but may occasionally deviate
|
|
71
70
|
extra_headers: Optional[Any] = None
|
|
72
71
|
extra_query: Optional[Any] = None
|
|
72
|
+
extra_body: Optional[Any] = None
|
|
73
73
|
request_params: Optional[Dict[str, Any]] = None
|
|
74
74
|
role_map: Optional[Dict[str, str]] = None
|
|
75
75
|
|
|
@@ -84,6 +84,10 @@ class OpenAIChat(Model):
|
|
|
84
84
|
http_client: Optional[Union[httpx.Client, httpx.AsyncClient]] = None
|
|
85
85
|
client_params: Optional[Dict[str, Any]] = None
|
|
86
86
|
|
|
87
|
+
# Cached clients to avoid recreating them on every request
|
|
88
|
+
client: Optional[OpenAIClient] = None
|
|
89
|
+
async_client: Optional[AsyncOpenAIClient] = None
|
|
90
|
+
|
|
87
91
|
# The role to map the message role to.
|
|
88
92
|
default_role_map = {
|
|
89
93
|
"system": "developer",
|
|
@@ -121,48 +125,68 @@ class OpenAIChat(Model):
|
|
|
121
125
|
|
|
122
126
|
def get_client(self) -> OpenAIClient:
|
|
123
127
|
"""
|
|
124
|
-
Returns an OpenAI client.
|
|
128
|
+
Returns an OpenAI client. Caches the client to avoid recreating it on every request.
|
|
125
129
|
|
|
126
130
|
Returns:
|
|
127
131
|
OpenAIClient: An instance of the OpenAI client.
|
|
128
132
|
"""
|
|
133
|
+
# Return cached client if it exists and is not closed
|
|
134
|
+
if self.client is not None and not self.client.is_closed():
|
|
135
|
+
return self.client
|
|
136
|
+
|
|
137
|
+
log_debug(f"Creating new sync OpenAI client for model {self.id}")
|
|
129
138
|
client_params: Dict[str, Any] = self._get_client_params()
|
|
130
139
|
if self.http_client:
|
|
131
140
|
if isinstance(self.http_client, httpx.Client):
|
|
132
141
|
client_params["http_client"] = self.http_client
|
|
133
142
|
else:
|
|
134
|
-
log_warning("http_client is not an instance of httpx.Client.")
|
|
135
|
-
|
|
143
|
+
log_warning("http_client is not an instance of httpx.Client. Using default global httpx.Client.")
|
|
144
|
+
# Use global sync client when user http_client is invalid
|
|
145
|
+
client_params["http_client"] = get_default_sync_client()
|
|
146
|
+
else:
|
|
147
|
+
# Use global sync client when no custom http_client is provided
|
|
148
|
+
client_params["http_client"] = get_default_sync_client()
|
|
149
|
+
|
|
150
|
+
# Create and cache the client
|
|
151
|
+
self.client = OpenAIClient(**client_params)
|
|
152
|
+
return self.client
|
|
136
153
|
|
|
137
154
|
def get_async_client(self) -> AsyncOpenAIClient:
|
|
138
155
|
"""
|
|
139
|
-
Returns an asynchronous OpenAI client.
|
|
156
|
+
Returns an asynchronous OpenAI client. Caches the client to avoid recreating it on every request.
|
|
140
157
|
|
|
141
158
|
Returns:
|
|
142
159
|
AsyncOpenAIClient: An instance of the asynchronous OpenAI client.
|
|
143
160
|
"""
|
|
161
|
+
# Return cached client if it exists and is not closed
|
|
162
|
+
if self.async_client is not None and not self.async_client.is_closed():
|
|
163
|
+
return self.async_client
|
|
164
|
+
|
|
165
|
+
log_debug(f"Creating new async OpenAI client for model {self.id}")
|
|
144
166
|
client_params: Dict[str, Any] = self._get_client_params()
|
|
145
167
|
if self.http_client:
|
|
146
168
|
if isinstance(self.http_client, httpx.AsyncClient):
|
|
147
169
|
client_params["http_client"] = self.http_client
|
|
148
170
|
else:
|
|
149
|
-
log_warning(
|
|
150
|
-
|
|
151
|
-
client_params["http_client"] = httpx.AsyncClient(
|
|
152
|
-
limits=httpx.Limits(max_connections=1000, max_keepalive_connections=100)
|
|
171
|
+
log_warning(
|
|
172
|
+
"http_client is not an instance of httpx.AsyncClient. Using default global httpx.AsyncClient."
|
|
153
173
|
)
|
|
174
|
+
# Use global async client when user http_client is invalid
|
|
175
|
+
client_params["http_client"] = get_default_async_client()
|
|
154
176
|
else:
|
|
155
|
-
#
|
|
156
|
-
client_params["http_client"] =
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
177
|
+
# Use global async client when no custom http_client is provided
|
|
178
|
+
client_params["http_client"] = get_default_async_client()
|
|
179
|
+
|
|
180
|
+
# Create and cache the client
|
|
181
|
+
self.async_client = AsyncOpenAIClient(**client_params)
|
|
182
|
+
return self.async_client
|
|
160
183
|
|
|
161
184
|
def get_request_params(
|
|
162
185
|
self,
|
|
163
186
|
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
164
187
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
165
188
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
189
|
+
run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
|
|
166
190
|
) -> Dict[str, Any]:
|
|
167
191
|
"""
|
|
168
192
|
Returns keyword arguments for API requests.
|
|
@@ -191,6 +215,7 @@ class OpenAIChat(Model):
|
|
|
191
215
|
"top_p": self.top_p,
|
|
192
216
|
"extra_headers": self.extra_headers,
|
|
193
217
|
"extra_query": self.extra_query,
|
|
218
|
+
"extra_body": self.extra_body,
|
|
194
219
|
"metadata": self.metadata,
|
|
195
220
|
"service_tier": self.service_tier,
|
|
196
221
|
}
|
|
@@ -207,7 +232,7 @@ class OpenAIChat(Model):
|
|
|
207
232
|
"json_schema": {
|
|
208
233
|
"name": response_format.__name__,
|
|
209
234
|
"schema": schema,
|
|
210
|
-
"strict":
|
|
235
|
+
"strict": self.strict_output,
|
|
211
236
|
},
|
|
212
237
|
}
|
|
213
238
|
else:
|
|
@@ -270,6 +295,7 @@ class OpenAIChat(Model):
|
|
|
270
295
|
"user": self.user,
|
|
271
296
|
"extra_headers": self.extra_headers,
|
|
272
297
|
"extra_query": self.extra_query,
|
|
298
|
+
"extra_body": self.extra_body,
|
|
273
299
|
"service_tier": self.service_tier,
|
|
274
300
|
}
|
|
275
301
|
)
|
|
@@ -347,7 +373,7 @@ class OpenAIChat(Model):
|
|
|
347
373
|
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
348
374
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
349
375
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
350
|
-
run_response: Optional[RunOutput] = None,
|
|
376
|
+
run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
|
|
351
377
|
) -> ModelResponse:
|
|
352
378
|
"""
|
|
353
379
|
Send a chat completion request to the OpenAI API and parse the response.
|
|
@@ -371,7 +397,9 @@ class OpenAIChat(Model):
|
|
|
371
397
|
provider_response = self.get_client().chat.completions.create(
|
|
372
398
|
model=self.id,
|
|
373
399
|
messages=[self._format_message(m) for m in messages], # type: ignore
|
|
374
|
-
**self.get_request_params(
|
|
400
|
+
**self.get_request_params(
|
|
401
|
+
response_format=response_format, tools=tools, tool_choice=tool_choice, run_response=run_response
|
|
402
|
+
),
|
|
375
403
|
)
|
|
376
404
|
assistant_message.metrics.stop_timer()
|
|
377
405
|
|
|
@@ -425,7 +453,7 @@ class OpenAIChat(Model):
|
|
|
425
453
|
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
426
454
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
427
455
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
428
|
-
run_response: Optional[RunOutput] = None,
|
|
456
|
+
run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
|
|
429
457
|
) -> ModelResponse:
|
|
430
458
|
"""
|
|
431
459
|
Sends an asynchronous chat completion request to the OpenAI API.
|
|
@@ -448,7 +476,9 @@ class OpenAIChat(Model):
|
|
|
448
476
|
response = await self.get_async_client().chat.completions.create(
|
|
449
477
|
model=self.id,
|
|
450
478
|
messages=[self._format_message(m) for m in messages], # type: ignore
|
|
451
|
-
**self.get_request_params(
|
|
479
|
+
**self.get_request_params(
|
|
480
|
+
response_format=response_format, tools=tools, tool_choice=tool_choice, run_response=run_response
|
|
481
|
+
),
|
|
452
482
|
)
|
|
453
483
|
assistant_message.metrics.stop_timer()
|
|
454
484
|
|
|
@@ -502,7 +532,7 @@ class OpenAIChat(Model):
|
|
|
502
532
|
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
503
533
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
504
534
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
505
|
-
run_response: Optional[RunOutput] = None,
|
|
535
|
+
run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
|
|
506
536
|
) -> Iterator[ModelResponse]:
|
|
507
537
|
"""
|
|
508
538
|
Send a streaming chat completion request to the OpenAI API.
|
|
@@ -525,7 +555,9 @@ class OpenAIChat(Model):
|
|
|
525
555
|
messages=[self._format_message(m) for m in messages], # type: ignore
|
|
526
556
|
stream=True,
|
|
527
557
|
stream_options={"include_usage": True},
|
|
528
|
-
**self.get_request_params(
|
|
558
|
+
**self.get_request_params(
|
|
559
|
+
response_format=response_format, tools=tools, tool_choice=tool_choice, run_response=run_response
|
|
560
|
+
),
|
|
529
561
|
):
|
|
530
562
|
yield self._parse_provider_response_delta(chunk)
|
|
531
563
|
|
|
@@ -576,7 +608,7 @@ class OpenAIChat(Model):
|
|
|
576
608
|
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
577
609
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
578
610
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
579
|
-
run_response: Optional[RunOutput] = None,
|
|
611
|
+
run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
|
|
580
612
|
) -> AsyncIterator[ModelResponse]:
|
|
581
613
|
"""
|
|
582
614
|
Sends an asynchronous streaming chat completion request to the OpenAI API.
|
|
@@ -599,7 +631,9 @@ class OpenAIChat(Model):
|
|
|
599
631
|
messages=[self._format_message(m) for m in messages], # type: ignore
|
|
600
632
|
stream=True,
|
|
601
633
|
stream_options={"include_usage": True},
|
|
602
|
-
**self.get_request_params(
|
|
634
|
+
**self.get_request_params(
|
|
635
|
+
response_format=response_format, tools=tools, tool_choice=tool_choice, run_response=run_response
|
|
636
|
+
),
|
|
603
637
|
)
|
|
604
638
|
|
|
605
639
|
async for chunk in async_stream:
|
|
@@ -713,6 +747,12 @@ class OpenAIChat(Model):
|
|
|
713
747
|
if response_message.content is not None:
|
|
714
748
|
model_response.content = response_message.content
|
|
715
749
|
|
|
750
|
+
# Extract thinking content before any structured parsing
|
|
751
|
+
if model_response.content:
|
|
752
|
+
reasoning_content, output_content = extract_thinking_content(model_response.content)
|
|
753
|
+
if reasoning_content:
|
|
754
|
+
model_response.reasoning_content = reasoning_content
|
|
755
|
+
model_response.content = output_content
|
|
716
756
|
# Add tool calls
|
|
717
757
|
if response_message.tool_calls is not None and len(response_message.tool_calls) > 0:
|
|
718
758
|
try:
|