agno 2.1.2__py3-none-any.whl → 2.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +5540 -2273
- agno/api/api.py +2 -0
- agno/api/os.py +1 -1
- agno/compression/__init__.py +3 -0
- agno/compression/manager.py +247 -0
- agno/culture/__init__.py +3 -0
- agno/culture/manager.py +956 -0
- agno/db/async_postgres/__init__.py +3 -0
- agno/db/base.py +689 -6
- agno/db/dynamo/dynamo.py +933 -37
- agno/db/dynamo/schemas.py +174 -10
- agno/db/dynamo/utils.py +63 -4
- agno/db/firestore/firestore.py +831 -9
- agno/db/firestore/schemas.py +51 -0
- agno/db/firestore/utils.py +102 -4
- agno/db/gcs_json/gcs_json_db.py +660 -12
- agno/db/gcs_json/utils.py +60 -26
- agno/db/in_memory/in_memory_db.py +287 -14
- agno/db/in_memory/utils.py +60 -2
- agno/db/json/json_db.py +590 -14
- agno/db/json/utils.py +60 -26
- agno/db/migrations/manager.py +199 -0
- agno/db/migrations/v1_to_v2.py +43 -13
- agno/db/migrations/versions/__init__.py +0 -0
- agno/db/migrations/versions/v2_3_0.py +938 -0
- agno/db/mongo/__init__.py +15 -1
- agno/db/mongo/async_mongo.py +2760 -0
- agno/db/mongo/mongo.py +879 -11
- agno/db/mongo/schemas.py +42 -0
- agno/db/mongo/utils.py +80 -8
- agno/db/mysql/__init__.py +2 -1
- agno/db/mysql/async_mysql.py +2912 -0
- agno/db/mysql/mysql.py +946 -68
- agno/db/mysql/schemas.py +72 -10
- agno/db/mysql/utils.py +198 -7
- agno/db/postgres/__init__.py +2 -1
- agno/db/postgres/async_postgres.py +2579 -0
- agno/db/postgres/postgres.py +942 -57
- agno/db/postgres/schemas.py +81 -18
- agno/db/postgres/utils.py +164 -2
- agno/db/redis/redis.py +671 -7
- agno/db/redis/schemas.py +50 -0
- agno/db/redis/utils.py +65 -7
- agno/db/schemas/__init__.py +2 -1
- agno/db/schemas/culture.py +120 -0
- agno/db/schemas/evals.py +1 -0
- agno/db/schemas/memory.py +17 -2
- agno/db/singlestore/schemas.py +63 -0
- agno/db/singlestore/singlestore.py +949 -83
- agno/db/singlestore/utils.py +60 -2
- agno/db/sqlite/__init__.py +2 -1
- agno/db/sqlite/async_sqlite.py +2911 -0
- agno/db/sqlite/schemas.py +62 -0
- agno/db/sqlite/sqlite.py +965 -46
- agno/db/sqlite/utils.py +169 -8
- agno/db/surrealdb/__init__.py +3 -0
- agno/db/surrealdb/metrics.py +292 -0
- agno/db/surrealdb/models.py +334 -0
- agno/db/surrealdb/queries.py +71 -0
- agno/db/surrealdb/surrealdb.py +1908 -0
- agno/db/surrealdb/utils.py +147 -0
- agno/db/utils.py +2 -0
- agno/eval/__init__.py +10 -0
- agno/eval/accuracy.py +75 -55
- agno/eval/agent_as_judge.py +861 -0
- agno/eval/base.py +29 -0
- agno/eval/performance.py +16 -7
- agno/eval/reliability.py +28 -16
- agno/eval/utils.py +35 -17
- agno/exceptions.py +27 -2
- agno/filters.py +354 -0
- agno/guardrails/prompt_injection.py +1 -0
- agno/hooks/__init__.py +3 -0
- agno/hooks/decorator.py +164 -0
- agno/integrations/discord/client.py +1 -1
- agno/knowledge/chunking/agentic.py +13 -10
- agno/knowledge/chunking/fixed.py +4 -1
- agno/knowledge/chunking/semantic.py +9 -4
- agno/knowledge/chunking/strategy.py +59 -15
- agno/knowledge/embedder/fastembed.py +1 -1
- agno/knowledge/embedder/nebius.py +1 -1
- agno/knowledge/embedder/ollama.py +8 -0
- agno/knowledge/embedder/openai.py +8 -8
- agno/knowledge/embedder/sentence_transformer.py +6 -2
- agno/knowledge/embedder/vllm.py +262 -0
- agno/knowledge/knowledge.py +1618 -318
- agno/knowledge/reader/base.py +6 -2
- agno/knowledge/reader/csv_reader.py +8 -10
- agno/knowledge/reader/docx_reader.py +5 -6
- agno/knowledge/reader/field_labeled_csv_reader.py +16 -20
- agno/knowledge/reader/json_reader.py +5 -4
- agno/knowledge/reader/markdown_reader.py +8 -8
- agno/knowledge/reader/pdf_reader.py +17 -19
- agno/knowledge/reader/pptx_reader.py +101 -0
- agno/knowledge/reader/reader_factory.py +32 -3
- agno/knowledge/reader/s3_reader.py +3 -3
- agno/knowledge/reader/tavily_reader.py +193 -0
- agno/knowledge/reader/text_reader.py +22 -10
- agno/knowledge/reader/web_search_reader.py +1 -48
- agno/knowledge/reader/website_reader.py +10 -10
- agno/knowledge/reader/wikipedia_reader.py +33 -1
- agno/knowledge/types.py +1 -0
- agno/knowledge/utils.py +72 -7
- agno/media.py +22 -6
- agno/memory/__init__.py +14 -1
- agno/memory/manager.py +544 -83
- agno/memory/strategies/__init__.py +15 -0
- agno/memory/strategies/base.py +66 -0
- agno/memory/strategies/summarize.py +196 -0
- agno/memory/strategies/types.py +37 -0
- agno/models/aimlapi/aimlapi.py +17 -0
- agno/models/anthropic/claude.py +515 -40
- agno/models/aws/bedrock.py +102 -21
- agno/models/aws/claude.py +131 -274
- agno/models/azure/ai_foundry.py +41 -19
- agno/models/azure/openai_chat.py +39 -8
- agno/models/base.py +1249 -525
- agno/models/cerebras/cerebras.py +91 -21
- agno/models/cerebras/cerebras_openai.py +21 -2
- agno/models/cohere/chat.py +40 -6
- agno/models/cometapi/cometapi.py +18 -1
- agno/models/dashscope/dashscope.py +2 -3
- agno/models/deepinfra/deepinfra.py +18 -1
- agno/models/deepseek/deepseek.py +69 -3
- agno/models/fireworks/fireworks.py +18 -1
- agno/models/google/gemini.py +877 -80
- agno/models/google/utils.py +22 -0
- agno/models/groq/groq.py +51 -18
- agno/models/huggingface/huggingface.py +17 -6
- agno/models/ibm/watsonx.py +16 -6
- agno/models/internlm/internlm.py +18 -1
- agno/models/langdb/langdb.py +13 -1
- agno/models/litellm/chat.py +44 -9
- agno/models/litellm/litellm_openai.py +18 -1
- agno/models/message.py +28 -5
- agno/models/meta/llama.py +47 -14
- agno/models/meta/llama_openai.py +22 -17
- agno/models/mistral/mistral.py +8 -4
- agno/models/nebius/nebius.py +6 -7
- agno/models/nvidia/nvidia.py +20 -3
- agno/models/ollama/chat.py +24 -8
- agno/models/openai/chat.py +104 -29
- agno/models/openai/responses.py +101 -81
- agno/models/openrouter/openrouter.py +60 -3
- agno/models/perplexity/perplexity.py +17 -1
- agno/models/portkey/portkey.py +7 -6
- agno/models/requesty/requesty.py +24 -4
- agno/models/response.py +73 -2
- agno/models/sambanova/sambanova.py +20 -3
- agno/models/siliconflow/siliconflow.py +19 -2
- agno/models/together/together.py +20 -3
- agno/models/utils.py +254 -8
- agno/models/vercel/v0.py +20 -3
- agno/models/vertexai/__init__.py +0 -0
- agno/models/vertexai/claude.py +190 -0
- agno/models/vllm/vllm.py +19 -14
- agno/models/xai/xai.py +19 -2
- agno/os/app.py +549 -152
- agno/os/auth.py +190 -3
- agno/os/config.py +23 -0
- agno/os/interfaces/a2a/router.py +8 -11
- agno/os/interfaces/a2a/utils.py +1 -1
- agno/os/interfaces/agui/router.py +18 -3
- agno/os/interfaces/agui/utils.py +152 -39
- agno/os/interfaces/slack/router.py +55 -37
- agno/os/interfaces/slack/slack.py +9 -1
- agno/os/interfaces/whatsapp/router.py +0 -1
- agno/os/interfaces/whatsapp/security.py +3 -1
- agno/os/mcp.py +110 -52
- agno/os/middleware/__init__.py +2 -0
- agno/os/middleware/jwt.py +676 -112
- agno/os/router.py +40 -1478
- agno/os/routers/agents/__init__.py +3 -0
- agno/os/routers/agents/router.py +599 -0
- agno/os/routers/agents/schema.py +261 -0
- agno/os/routers/evals/evals.py +96 -39
- agno/os/routers/evals/schemas.py +65 -33
- agno/os/routers/evals/utils.py +80 -10
- agno/os/routers/health.py +10 -4
- agno/os/routers/knowledge/knowledge.py +196 -38
- agno/os/routers/knowledge/schemas.py +82 -22
- agno/os/routers/memory/memory.py +279 -52
- agno/os/routers/memory/schemas.py +46 -17
- agno/os/routers/metrics/metrics.py +20 -8
- agno/os/routers/metrics/schemas.py +16 -16
- agno/os/routers/session/session.py +462 -34
- agno/os/routers/teams/__init__.py +3 -0
- agno/os/routers/teams/router.py +512 -0
- agno/os/routers/teams/schema.py +257 -0
- agno/os/routers/traces/__init__.py +3 -0
- agno/os/routers/traces/schemas.py +414 -0
- agno/os/routers/traces/traces.py +499 -0
- agno/os/routers/workflows/__init__.py +3 -0
- agno/os/routers/workflows/router.py +624 -0
- agno/os/routers/workflows/schema.py +75 -0
- agno/os/schema.py +256 -693
- agno/os/scopes.py +469 -0
- agno/os/utils.py +514 -36
- agno/reasoning/anthropic.py +80 -0
- agno/reasoning/gemini.py +73 -0
- agno/reasoning/openai.py +5 -0
- agno/reasoning/vertexai.py +76 -0
- agno/run/__init__.py +6 -0
- agno/run/agent.py +155 -32
- agno/run/base.py +55 -3
- agno/run/requirement.py +181 -0
- agno/run/team.py +125 -38
- agno/run/workflow.py +72 -18
- agno/session/agent.py +102 -89
- agno/session/summary.py +56 -15
- agno/session/team.py +164 -90
- agno/session/workflow.py +405 -40
- agno/table.py +10 -0
- agno/team/team.py +3974 -1903
- agno/tools/dalle.py +2 -4
- agno/tools/eleven_labs.py +23 -25
- agno/tools/exa.py +21 -16
- agno/tools/file.py +153 -23
- agno/tools/file_generation.py +16 -10
- agno/tools/firecrawl.py +15 -7
- agno/tools/function.py +193 -38
- agno/tools/gmail.py +238 -14
- agno/tools/google_drive.py +271 -0
- agno/tools/googlecalendar.py +36 -8
- agno/tools/googlesheets.py +20 -5
- agno/tools/jira.py +20 -0
- agno/tools/mcp/__init__.py +10 -0
- agno/tools/mcp/mcp.py +331 -0
- agno/tools/mcp/multi_mcp.py +347 -0
- agno/tools/mcp/params.py +24 -0
- agno/tools/mcp_toolbox.py +3 -3
- agno/tools/models/nebius.py +5 -5
- agno/tools/models_labs.py +20 -10
- agno/tools/nano_banana.py +151 -0
- agno/tools/notion.py +204 -0
- agno/tools/parallel.py +314 -0
- agno/tools/postgres.py +76 -36
- agno/tools/redshift.py +406 -0
- agno/tools/scrapegraph.py +1 -1
- agno/tools/shopify.py +1519 -0
- agno/tools/slack.py +18 -3
- agno/tools/spotify.py +919 -0
- agno/tools/tavily.py +146 -0
- agno/tools/toolkit.py +25 -0
- agno/tools/workflow.py +8 -1
- agno/tools/yfinance.py +12 -11
- agno/tracing/__init__.py +12 -0
- agno/tracing/exporter.py +157 -0
- agno/tracing/schemas.py +276 -0
- agno/tracing/setup.py +111 -0
- agno/utils/agent.py +938 -0
- agno/utils/cryptography.py +22 -0
- agno/utils/dttm.py +33 -0
- agno/utils/events.py +151 -3
- agno/utils/gemini.py +15 -5
- agno/utils/hooks.py +118 -4
- agno/utils/http.py +113 -2
- agno/utils/knowledge.py +12 -5
- agno/utils/log.py +1 -0
- agno/utils/mcp.py +92 -2
- agno/utils/media.py +187 -1
- agno/utils/merge_dict.py +3 -3
- agno/utils/message.py +60 -0
- agno/utils/models/ai_foundry.py +9 -2
- agno/utils/models/claude.py +49 -14
- agno/utils/models/cohere.py +9 -2
- agno/utils/models/llama.py +9 -2
- agno/utils/models/mistral.py +4 -2
- agno/utils/print_response/agent.py +109 -16
- agno/utils/print_response/team.py +223 -30
- agno/utils/print_response/workflow.py +251 -34
- agno/utils/streamlit.py +1 -1
- agno/utils/team.py +98 -9
- agno/utils/tokens.py +657 -0
- agno/vectordb/base.py +39 -7
- agno/vectordb/cassandra/cassandra.py +21 -5
- agno/vectordb/chroma/chromadb.py +43 -12
- agno/vectordb/clickhouse/clickhousedb.py +21 -5
- agno/vectordb/couchbase/couchbase.py +29 -5
- agno/vectordb/lancedb/lance_db.py +92 -181
- agno/vectordb/langchaindb/langchaindb.py +24 -4
- agno/vectordb/lightrag/lightrag.py +17 -3
- agno/vectordb/llamaindex/llamaindexdb.py +25 -5
- agno/vectordb/milvus/milvus.py +50 -37
- agno/vectordb/mongodb/__init__.py +7 -1
- agno/vectordb/mongodb/mongodb.py +36 -30
- agno/vectordb/pgvector/pgvector.py +201 -77
- agno/vectordb/pineconedb/pineconedb.py +41 -23
- agno/vectordb/qdrant/qdrant.py +67 -54
- agno/vectordb/redis/__init__.py +9 -0
- agno/vectordb/redis/redisdb.py +682 -0
- agno/vectordb/singlestore/singlestore.py +50 -29
- agno/vectordb/surrealdb/surrealdb.py +31 -41
- agno/vectordb/upstashdb/upstashdb.py +34 -6
- agno/vectordb/weaviate/weaviate.py +53 -14
- agno/workflow/__init__.py +2 -0
- agno/workflow/agent.py +299 -0
- agno/workflow/condition.py +120 -18
- agno/workflow/loop.py +77 -10
- agno/workflow/parallel.py +231 -143
- agno/workflow/router.py +118 -17
- agno/workflow/step.py +609 -170
- agno/workflow/steps.py +73 -6
- agno/workflow/types.py +96 -21
- agno/workflow/workflow.py +2039 -262
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/METADATA +201 -66
- agno-2.3.13.dist-info/RECORD +613 -0
- agno/tools/googlesearch.py +0 -98
- agno/tools/mcp.py +0 -679
- agno/tools/memori.py +0 -339
- agno-2.1.2.dist-info/RECORD +0 -543
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/WHEEL +0 -0
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/licenses/LICENSE +0 -0
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/top_level.txt +0 -0
agno/models/cerebras/cerebras.py
CHANGED
|
@@ -12,6 +12,7 @@ from agno.models.message import Message
|
|
|
12
12
|
from agno.models.metrics import Metrics
|
|
13
13
|
from agno.models.response import ModelResponse
|
|
14
14
|
from agno.run.agent import RunOutput
|
|
15
|
+
from agno.utils.http import get_default_async_client, get_default_sync_client
|
|
15
16
|
from agno.utils.log import log_debug, log_error, log_warning
|
|
16
17
|
|
|
17
18
|
try:
|
|
@@ -51,6 +52,7 @@ class Cerebras(Model):
|
|
|
51
52
|
temperature: Optional[float] = None
|
|
52
53
|
top_p: Optional[float] = None
|
|
53
54
|
top_k: Optional[int] = None
|
|
55
|
+
strict_output: bool = True # When True, guarantees schema adherence for structured outputs. When False, attempts to follow schema as a guide but may occasionally deviate
|
|
54
56
|
extra_headers: Optional[Any] = None
|
|
55
57
|
extra_query: Optional[Any] = None
|
|
56
58
|
extra_body: Optional[Any] = None
|
|
@@ -63,7 +65,7 @@ class Cerebras(Model):
|
|
|
63
65
|
max_retries: Optional[int] = None
|
|
64
66
|
default_headers: Optional[Any] = None
|
|
65
67
|
default_query: Optional[Any] = None
|
|
66
|
-
http_client: Optional[httpx.Client] = None
|
|
68
|
+
http_client: Optional[Union[httpx.Client, httpx.AsyncClient]] = None
|
|
67
69
|
client_params: Optional[Dict[str, Any]] = None
|
|
68
70
|
|
|
69
71
|
# Cerebras clients
|
|
@@ -102,12 +104,15 @@ class Cerebras(Model):
|
|
|
102
104
|
Returns:
|
|
103
105
|
CerebrasClient: An instance of the Cerebras client.
|
|
104
106
|
"""
|
|
105
|
-
if self.client:
|
|
107
|
+
if self.client and not self.client.is_closed():
|
|
106
108
|
return self.client
|
|
107
109
|
|
|
108
110
|
client_params: Dict[str, Any] = self._get_client_params()
|
|
109
111
|
if self.http_client is not None:
|
|
110
112
|
client_params["http_client"] = self.http_client
|
|
113
|
+
else:
|
|
114
|
+
# Use global sync client when no custom http_client is provided
|
|
115
|
+
client_params["http_client"] = get_default_sync_client()
|
|
111
116
|
self.client = CerebrasClient(**client_params)
|
|
112
117
|
return self.client
|
|
113
118
|
|
|
@@ -118,17 +123,15 @@ class Cerebras(Model):
|
|
|
118
123
|
Returns:
|
|
119
124
|
AsyncCerebras: An instance of the asynchronous Cerebras client.
|
|
120
125
|
"""
|
|
121
|
-
if self.async_client:
|
|
126
|
+
if self.async_client and not self.async_client.is_closed():
|
|
122
127
|
return self.async_client
|
|
123
128
|
|
|
124
129
|
client_params: Dict[str, Any] = self._get_client_params()
|
|
125
|
-
if self.http_client:
|
|
130
|
+
if self.http_client and isinstance(self.http_client, httpx.AsyncClient):
|
|
126
131
|
client_params["http_client"] = self.http_client
|
|
127
132
|
else:
|
|
128
|
-
#
|
|
129
|
-
client_params["http_client"] =
|
|
130
|
-
limits=httpx.Limits(max_connections=1000, max_keepalive_connections=100)
|
|
131
|
-
)
|
|
133
|
+
# Use global async client when no custom http_client is provided
|
|
134
|
+
client_params["http_client"] = get_default_async_client()
|
|
132
135
|
self.async_client = AsyncCerebrasClient(**client_params)
|
|
133
136
|
return self.async_client
|
|
134
137
|
|
|
@@ -186,10 +189,10 @@ class Cerebras(Model):
|
|
|
186
189
|
and response_format.get("type") == "json_schema"
|
|
187
190
|
and isinstance(response_format.get("json_schema"), dict)
|
|
188
191
|
):
|
|
189
|
-
# Ensure json_schema has strict
|
|
192
|
+
# Ensure json_schema has strict parameter set
|
|
190
193
|
schema = response_format["json_schema"]
|
|
191
194
|
if isinstance(schema.get("schema"), dict) and "strict" not in schema:
|
|
192
|
-
schema["strict"] =
|
|
195
|
+
schema["strict"] = self.strict_output
|
|
193
196
|
|
|
194
197
|
request_params["response_format"] = response_format
|
|
195
198
|
|
|
@@ -209,6 +212,7 @@ class Cerebras(Model):
|
|
|
209
212
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
210
213
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
211
214
|
run_response: Optional[RunOutput] = None,
|
|
215
|
+
compress_tool_results: bool = False,
|
|
212
216
|
) -> ModelResponse:
|
|
213
217
|
"""
|
|
214
218
|
Send a chat completion request to the Cerebras API.
|
|
@@ -225,7 +229,7 @@ class Cerebras(Model):
|
|
|
225
229
|
assistant_message.metrics.start_timer()
|
|
226
230
|
provider_response = self.get_client().chat.completions.create(
|
|
227
231
|
model=self.id,
|
|
228
|
-
messages=[self._format_message(m) for m in messages], # type: ignore
|
|
232
|
+
messages=[self._format_message(m, compress_tool_results) for m in messages], # type: ignore
|
|
229
233
|
**self.get_request_params(response_format=response_format, tools=tools),
|
|
230
234
|
)
|
|
231
235
|
assistant_message.metrics.stop_timer()
|
|
@@ -242,6 +246,7 @@ class Cerebras(Model):
|
|
|
242
246
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
243
247
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
244
248
|
run_response: Optional[RunOutput] = None,
|
|
249
|
+
compress_tool_results: bool = False,
|
|
245
250
|
) -> ModelResponse:
|
|
246
251
|
"""
|
|
247
252
|
Sends an asynchronous chat completion request to the Cerebras API.
|
|
@@ -258,7 +263,7 @@ class Cerebras(Model):
|
|
|
258
263
|
assistant_message.metrics.start_timer()
|
|
259
264
|
provider_response = await self.get_async_client().chat.completions.create(
|
|
260
265
|
model=self.id,
|
|
261
|
-
messages=[self._format_message(m) for m in messages], # type: ignore
|
|
266
|
+
messages=[self._format_message(m, compress_tool_results) for m in messages], # type: ignore
|
|
262
267
|
**self.get_request_params(response_format=response_format, tools=tools),
|
|
263
268
|
)
|
|
264
269
|
assistant_message.metrics.stop_timer()
|
|
@@ -275,6 +280,7 @@ class Cerebras(Model):
|
|
|
275
280
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
276
281
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
277
282
|
run_response: Optional[RunOutput] = None,
|
|
283
|
+
compress_tool_results: bool = False,
|
|
278
284
|
) -> Iterator[ModelResponse]:
|
|
279
285
|
"""
|
|
280
286
|
Send a streaming chat completion request to the Cerebras API.
|
|
@@ -292,7 +298,7 @@ class Cerebras(Model):
|
|
|
292
298
|
|
|
293
299
|
for chunk in self.get_client().chat.completions.create(
|
|
294
300
|
model=self.id,
|
|
295
|
-
messages=[self._format_message(m) for m in messages], # type: ignore
|
|
301
|
+
messages=[self._format_message(m, compress_tool_results) for m in messages], # type: ignore
|
|
296
302
|
stream=True,
|
|
297
303
|
**self.get_request_params(response_format=response_format, tools=tools),
|
|
298
304
|
):
|
|
@@ -308,6 +314,7 @@ class Cerebras(Model):
|
|
|
308
314
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
309
315
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
310
316
|
run_response: Optional[RunOutput] = None,
|
|
317
|
+
compress_tool_results: bool = False,
|
|
311
318
|
) -> AsyncIterator[ModelResponse]:
|
|
312
319
|
"""
|
|
313
320
|
Sends an asynchronous streaming chat completion request to the Cerebras API.
|
|
@@ -325,7 +332,7 @@ class Cerebras(Model):
|
|
|
325
332
|
|
|
326
333
|
async_stream = await self.get_async_client().chat.completions.create(
|
|
327
334
|
model=self.id,
|
|
328
|
-
messages=[self._format_message(m) for m in messages], # type: ignore
|
|
335
|
+
messages=[self._format_message(m, compress_tool_results) for m in messages], # type: ignore
|
|
329
336
|
stream=True,
|
|
330
337
|
**self.get_request_params(response_format=response_format, tools=tools),
|
|
331
338
|
)
|
|
@@ -335,20 +342,27 @@ class Cerebras(Model):
|
|
|
335
342
|
|
|
336
343
|
assistant_message.metrics.stop_timer()
|
|
337
344
|
|
|
338
|
-
def _format_message(self, message: Message) -> Dict[str, Any]:
|
|
345
|
+
def _format_message(self, message: Message, compress_tool_results: bool = False) -> Dict[str, Any]:
|
|
339
346
|
"""
|
|
340
347
|
Format a message into the format expected by the Cerebras API.
|
|
341
348
|
|
|
342
349
|
Args:
|
|
343
350
|
message (Message): The message to format.
|
|
351
|
+
compress_tool_results: Whether to compress tool results.
|
|
344
352
|
|
|
345
353
|
Returns:
|
|
346
354
|
Dict[str, Any]: The formatted message.
|
|
347
355
|
"""
|
|
356
|
+
# Use compressed content for tool messages if compression is active
|
|
357
|
+
if message.role == "tool":
|
|
358
|
+
content = message.get_content(use_compressed_content=compress_tool_results)
|
|
359
|
+
else:
|
|
360
|
+
content = message.content if message.content is not None else ""
|
|
361
|
+
|
|
348
362
|
# Basic message content
|
|
349
363
|
message_dict: Dict[str, Any] = {
|
|
350
364
|
"role": message.role,
|
|
351
|
-
"content":
|
|
365
|
+
"content": content,
|
|
352
366
|
}
|
|
353
367
|
|
|
354
368
|
# Add name if present
|
|
@@ -377,7 +391,7 @@ class Cerebras(Model):
|
|
|
377
391
|
message_dict = {
|
|
378
392
|
"role": "tool",
|
|
379
393
|
"tool_call_id": message.tool_call_id,
|
|
380
|
-
"content":
|
|
394
|
+
"content": content,
|
|
381
395
|
}
|
|
382
396
|
|
|
383
397
|
# Ensure no None values in the message
|
|
@@ -456,18 +470,19 @@ class Cerebras(Model):
|
|
|
456
470
|
if choice_delta.content:
|
|
457
471
|
model_response.content = choice_delta.content
|
|
458
472
|
|
|
459
|
-
# Add tool calls
|
|
473
|
+
# Add tool calls - preserve index for proper aggregation in parse_tool_calls
|
|
460
474
|
if choice_delta.tool_calls:
|
|
461
475
|
model_response.tool_calls = [
|
|
462
476
|
{
|
|
477
|
+
"index": tool_call.index if hasattr(tool_call, "index") else idx,
|
|
463
478
|
"id": tool_call.id,
|
|
464
479
|
"type": tool_call.type,
|
|
465
480
|
"function": {
|
|
466
|
-
"name": tool_call.function.name,
|
|
467
|
-
"arguments": tool_call.function.arguments,
|
|
481
|
+
"name": tool_call.function.name if tool_call.function else None,
|
|
482
|
+
"arguments": tool_call.function.arguments if tool_call.function else None,
|
|
468
483
|
},
|
|
469
484
|
}
|
|
470
|
-
for tool_call in choice_delta.tool_calls
|
|
485
|
+
for idx, tool_call in enumerate(choice_delta.tool_calls)
|
|
471
486
|
]
|
|
472
487
|
|
|
473
488
|
# Add usage metrics
|
|
@@ -476,6 +491,61 @@ class Cerebras(Model):
|
|
|
476
491
|
|
|
477
492
|
return model_response
|
|
478
493
|
|
|
494
|
+
def parse_tool_calls(self, tool_calls_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
495
|
+
"""
|
|
496
|
+
Build complete tool calls from streamed tool call delta data.
|
|
497
|
+
|
|
498
|
+
Cerebras streams tool calls incrementally with partial data in each chunk.
|
|
499
|
+
This method aggregates those chunks by index to produce complete tool calls.
|
|
500
|
+
|
|
501
|
+
Args:
|
|
502
|
+
tool_calls_data: List of tool call deltas from streaming chunks.
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
List[Dict[str, Any]]: List of fully-formed tool call dicts.
|
|
506
|
+
"""
|
|
507
|
+
tool_calls: List[Dict[str, Any]] = []
|
|
508
|
+
|
|
509
|
+
for tool_call_delta in tool_calls_data:
|
|
510
|
+
# Get the index for this tool call (default to 0 if not present)
|
|
511
|
+
index = tool_call_delta.get("index", 0)
|
|
512
|
+
|
|
513
|
+
# Extend the list if needed
|
|
514
|
+
while len(tool_calls) <= index:
|
|
515
|
+
tool_calls.append(
|
|
516
|
+
{
|
|
517
|
+
"id": None,
|
|
518
|
+
"type": None,
|
|
519
|
+
"function": {
|
|
520
|
+
"name": "",
|
|
521
|
+
"arguments": "",
|
|
522
|
+
},
|
|
523
|
+
}
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
tool_call_entry = tool_calls[index]
|
|
527
|
+
|
|
528
|
+
# Update id if present
|
|
529
|
+
if tool_call_delta.get("id"):
|
|
530
|
+
tool_call_entry["id"] = tool_call_delta["id"]
|
|
531
|
+
|
|
532
|
+
# Update type if present
|
|
533
|
+
if tool_call_delta.get("type"):
|
|
534
|
+
tool_call_entry["type"] = tool_call_delta["type"]
|
|
535
|
+
|
|
536
|
+
# Update function name and arguments (concatenate for streaming)
|
|
537
|
+
if tool_call_delta.get("function"):
|
|
538
|
+
func_delta = tool_call_delta["function"]
|
|
539
|
+
if func_delta.get("name"):
|
|
540
|
+
tool_call_entry["function"]["name"] += func_delta["name"]
|
|
541
|
+
if func_delta.get("arguments"):
|
|
542
|
+
tool_call_entry["function"]["arguments"] += func_delta["arguments"]
|
|
543
|
+
|
|
544
|
+
# Filter out any incomplete tool calls (missing id or function name)
|
|
545
|
+
complete_tool_calls = [tc for tc in tool_calls if tc.get("id") and tc.get("function", {}).get("name")]
|
|
546
|
+
|
|
547
|
+
return complete_tool_calls
|
|
548
|
+
|
|
479
549
|
def _get_metrics(self, response_usage: Union[ChatCompletionResponseUsage, ChatChunkResponseUsage]) -> Metrics:
|
|
480
550
|
"""
|
|
481
551
|
Parse the given Cerebras usage into an Agno Metrics object.
|
|
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Type, Union
|
|
|
5
5
|
|
|
6
6
|
from pydantic import BaseModel
|
|
7
7
|
|
|
8
|
+
from agno.exceptions import ModelAuthenticationError
|
|
8
9
|
from agno.models.message import Message
|
|
9
10
|
from agno.models.openai.like import OpenAILike
|
|
10
11
|
from agno.utils.log import log_debug
|
|
@@ -20,6 +21,22 @@ class CerebrasOpenAI(OpenAILike):
|
|
|
20
21
|
base_url: str = "https://api.cerebras.ai/v1"
|
|
21
22
|
api_key: Optional[str] = field(default_factory=lambda: getenv("CEREBRAS_API_KEY", None))
|
|
22
23
|
|
|
24
|
+
def _get_client_params(self) -> Dict[str, Any]:
|
|
25
|
+
"""
|
|
26
|
+
Returns client parameters for API requests, checking for CEREBRAS_API_KEY.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Dict[str, Any]: A dictionary of client parameters for API requests.
|
|
30
|
+
"""
|
|
31
|
+
if not self.api_key:
|
|
32
|
+
self.api_key = getenv("CEREBRAS_API_KEY")
|
|
33
|
+
if not self.api_key:
|
|
34
|
+
raise ModelAuthenticationError(
|
|
35
|
+
message="CEREBRAS_API_KEY not set. Please set the CEREBRAS_API_KEY environment variable.",
|
|
36
|
+
model_name=self.name,
|
|
37
|
+
)
|
|
38
|
+
return super()._get_client_params()
|
|
39
|
+
|
|
23
40
|
def get_request_params(
|
|
24
41
|
self,
|
|
25
42
|
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
@@ -61,7 +78,7 @@ class CerebrasOpenAI(OpenAILike):
|
|
|
61
78
|
log_debug(f"Calling {self.provider} with request parameters: {request_params}", log_level=2)
|
|
62
79
|
return request_params
|
|
63
80
|
|
|
64
|
-
def _format_message(self, message: Message) -> Dict[str, Any]:
|
|
81
|
+
def _format_message(self, message: Message, compress_tool_results: bool = False) -> Dict[str, Any]:
|
|
65
82
|
"""
|
|
66
83
|
Format a message into the format expected by the Cerebras API.
|
|
67
84
|
|
|
@@ -71,6 +88,7 @@ class CerebrasOpenAI(OpenAILike):
|
|
|
71
88
|
Returns:
|
|
72
89
|
Dict[str, Any]: The formatted message.
|
|
73
90
|
"""
|
|
91
|
+
|
|
74
92
|
# Basic message content
|
|
75
93
|
message_dict: Dict[str, Any] = {
|
|
76
94
|
"role": message.role,
|
|
@@ -100,10 +118,11 @@ class CerebrasOpenAI(OpenAILike):
|
|
|
100
118
|
|
|
101
119
|
# Handle tool responses
|
|
102
120
|
if message.role == "tool" and message.tool_call_id:
|
|
121
|
+
content = message.get_content(use_compressed_content=compress_tool_results)
|
|
103
122
|
message_dict = {
|
|
104
123
|
"role": "tool",
|
|
105
124
|
"tool_call_id": message.tool_call_id,
|
|
106
|
-
"content":
|
|
125
|
+
"content": content if message.content is not None else "",
|
|
107
126
|
}
|
|
108
127
|
|
|
109
128
|
# Ensure no None values in the message
|
agno/models/cohere/chat.py
CHANGED
|
@@ -2,6 +2,7 @@ from dataclasses import dataclass
|
|
|
2
2
|
from os import getenv
|
|
3
3
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, Type, Union
|
|
4
4
|
|
|
5
|
+
import httpx
|
|
5
6
|
from pydantic import BaseModel
|
|
6
7
|
|
|
7
8
|
from agno.exceptions import ModelProviderError
|
|
@@ -10,7 +11,8 @@ from agno.models.message import Message
|
|
|
10
11
|
from agno.models.metrics import Metrics
|
|
11
12
|
from agno.models.response import ModelResponse
|
|
12
13
|
from agno.run.agent import RunOutput
|
|
13
|
-
from agno.utils.
|
|
14
|
+
from agno.utils.http import get_default_async_client, get_default_sync_client
|
|
15
|
+
from agno.utils.log import log_debug, log_error, log_warning
|
|
14
16
|
from agno.utils.models.cohere import format_messages
|
|
15
17
|
|
|
16
18
|
try:
|
|
@@ -50,6 +52,7 @@ class Cohere(Model):
|
|
|
50
52
|
# -*- Client parameters
|
|
51
53
|
api_key: Optional[str] = None
|
|
52
54
|
client_params: Optional[Dict[str, Any]] = None
|
|
55
|
+
http_client: Optional[Union[httpx.Client, httpx.AsyncClient]] = None
|
|
53
56
|
# -*- Provide the Cohere client manually
|
|
54
57
|
client: Optional[CohereClient] = None
|
|
55
58
|
async_client: Optional[CohereAsyncClient] = None
|
|
@@ -66,6 +69,17 @@ class Cohere(Model):
|
|
|
66
69
|
|
|
67
70
|
_client_params["api_key"] = self.api_key
|
|
68
71
|
|
|
72
|
+
if self.http_client:
|
|
73
|
+
if isinstance(self.http_client, httpx.Client):
|
|
74
|
+
_client_params["httpx_client"] = self.http_client
|
|
75
|
+
else:
|
|
76
|
+
log_warning("http_client is not an instance of httpx.Client. Using default global httpx.Client.")
|
|
77
|
+
# Use global sync client when user http_client is invalid
|
|
78
|
+
_client_params["httpx_client"] = get_default_sync_client()
|
|
79
|
+
else:
|
|
80
|
+
# Use global sync client when no custom http_client is provided
|
|
81
|
+
_client_params["httpx_client"] = get_default_sync_client()
|
|
82
|
+
|
|
69
83
|
self.client = CohereClient(**_client_params)
|
|
70
84
|
return self.client # type: ignore
|
|
71
85
|
|
|
@@ -78,10 +92,26 @@ class Cohere(Model):
|
|
|
78
92
|
self.api_key = self.api_key or getenv("CO_API_KEY")
|
|
79
93
|
|
|
80
94
|
if not self.api_key:
|
|
81
|
-
|
|
95
|
+
raise ModelProviderError(
|
|
96
|
+
message="CO_API_KEY not set. Please set the CO_API_KEY environment variable.",
|
|
97
|
+
model_name=self.name,
|
|
98
|
+
model_id=self.id,
|
|
99
|
+
)
|
|
82
100
|
|
|
83
101
|
_client_params["api_key"] = self.api_key
|
|
84
102
|
|
|
103
|
+
if self.http_client:
|
|
104
|
+
if isinstance(self.http_client, httpx.AsyncClient):
|
|
105
|
+
_client_params["httpx_client"] = self.http_client
|
|
106
|
+
else:
|
|
107
|
+
log_warning(
|
|
108
|
+
"http_client is not an instance of httpx.AsyncClient. Using default global httpx.AsyncClient."
|
|
109
|
+
)
|
|
110
|
+
# Use global async client when user http_client is invalid
|
|
111
|
+
_client_params["httpx_client"] = get_default_async_client()
|
|
112
|
+
else:
|
|
113
|
+
# Use global async client when no custom http_client is provided
|
|
114
|
+
_client_params["httpx_client"] = get_default_async_client()
|
|
85
115
|
self.async_client = CohereAsyncClient(**_client_params)
|
|
86
116
|
return self.async_client # type: ignore
|
|
87
117
|
|
|
@@ -155,6 +185,7 @@ class Cohere(Model):
|
|
|
155
185
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
156
186
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
157
187
|
run_response: Optional[RunOutput] = None,
|
|
188
|
+
compress_tool_results: bool = False,
|
|
158
189
|
) -> ModelResponse:
|
|
159
190
|
"""
|
|
160
191
|
Invoke a non-streamed chat response from the Cohere API.
|
|
@@ -168,7 +199,7 @@ class Cohere(Model):
|
|
|
168
199
|
assistant_message.metrics.start_timer()
|
|
169
200
|
provider_response = self.get_client().chat(
|
|
170
201
|
model=self.id,
|
|
171
|
-
messages=format_messages(messages), # type: ignore
|
|
202
|
+
messages=format_messages(messages, compress_tool_results), # type: ignore
|
|
172
203
|
**request_kwargs,
|
|
173
204
|
) # type: ignore
|
|
174
205
|
assistant_message.metrics.stop_timer()
|
|
@@ -189,6 +220,7 @@ class Cohere(Model):
|
|
|
189
220
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
190
221
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
191
222
|
run_response: Optional[RunOutput] = None,
|
|
223
|
+
compress_tool_results: bool = False,
|
|
192
224
|
) -> Iterator[ModelResponse]:
|
|
193
225
|
"""
|
|
194
226
|
Invoke a streamed chat response from the Cohere API.
|
|
@@ -205,7 +237,7 @@ class Cohere(Model):
|
|
|
205
237
|
|
|
206
238
|
for response in self.get_client().chat_stream(
|
|
207
239
|
model=self.id,
|
|
208
|
-
messages=format_messages(messages), # type: ignore
|
|
240
|
+
messages=format_messages(messages, compress_tool_results), # type: ignore
|
|
209
241
|
**request_kwargs,
|
|
210
242
|
):
|
|
211
243
|
model_response, tool_use = self._parse_provider_response_delta(response, tool_use=tool_use)
|
|
@@ -225,6 +257,7 @@ class Cohere(Model):
|
|
|
225
257
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
226
258
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
227
259
|
run_response: Optional[RunOutput] = None,
|
|
260
|
+
compress_tool_results: bool = False,
|
|
228
261
|
) -> ModelResponse:
|
|
229
262
|
"""
|
|
230
263
|
Asynchronously invoke a non-streamed chat response from the Cohere API.
|
|
@@ -238,7 +271,7 @@ class Cohere(Model):
|
|
|
238
271
|
assistant_message.metrics.start_timer()
|
|
239
272
|
provider_response = await self.get_async_client().chat(
|
|
240
273
|
model=self.id,
|
|
241
|
-
messages=format_messages(messages), # type: ignore
|
|
274
|
+
messages=format_messages(messages, compress_tool_results), # type: ignore
|
|
242
275
|
**request_kwargs,
|
|
243
276
|
)
|
|
244
277
|
assistant_message.metrics.stop_timer()
|
|
@@ -259,6 +292,7 @@ class Cohere(Model):
|
|
|
259
292
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
260
293
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
261
294
|
run_response: Optional[RunOutput] = None,
|
|
295
|
+
compress_tool_results: bool = False,
|
|
262
296
|
) -> AsyncIterator[ModelResponse]:
|
|
263
297
|
"""
|
|
264
298
|
Asynchronously invoke a streamed chat response from the Cohere API.
|
|
@@ -275,7 +309,7 @@ class Cohere(Model):
|
|
|
275
309
|
|
|
276
310
|
async for response in self.get_async_client().chat_stream(
|
|
277
311
|
model=self.id,
|
|
278
|
-
messages=format_messages(messages), # type: ignore
|
|
312
|
+
messages=format_messages(messages, compress_tool_results), # type: ignore
|
|
279
313
|
**request_kwargs,
|
|
280
314
|
):
|
|
281
315
|
model_response, tool_use = self._parse_provider_response_delta(response, tool_use=tool_use)
|
agno/models/cometapi/cometapi.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
2
|
from os import getenv
|
|
3
|
-
from typing import List, Optional
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
4
|
|
|
5
5
|
import httpx
|
|
6
6
|
|
|
7
|
+
from agno.exceptions import ModelAuthenticationError
|
|
7
8
|
from agno.models.openai.like import OpenAILike
|
|
8
9
|
from agno.utils.log import log_debug
|
|
9
10
|
|
|
@@ -26,6 +27,22 @@ class CometAPI(OpenAILike):
|
|
|
26
27
|
api_key: Optional[str] = field(default_factory=lambda: getenv("COMETAPI_KEY"))
|
|
27
28
|
base_url: str = "https://api.cometapi.com/v1"
|
|
28
29
|
|
|
30
|
+
def _get_client_params(self) -> Dict[str, Any]:
|
|
31
|
+
"""
|
|
32
|
+
Returns client parameters for API requests, checking for COMETAPI_KEY.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Dict[str, Any]: A dictionary of client parameters for API requests.
|
|
36
|
+
"""
|
|
37
|
+
if not self.api_key:
|
|
38
|
+
self.api_key = getenv("COMETAPI_KEY")
|
|
39
|
+
if not self.api_key:
|
|
40
|
+
raise ModelAuthenticationError(
|
|
41
|
+
message="COMETAPI_KEY not set. Please set the COMETAPI_KEY environment variable.",
|
|
42
|
+
model_name=self.name,
|
|
43
|
+
)
|
|
44
|
+
return super()._get_client_params()
|
|
45
|
+
|
|
29
46
|
def get_available_models(self) -> List[str]:
|
|
30
47
|
"""
|
|
31
48
|
Fetch available chat models from CometAPI, filtering out non-chat models.
|
|
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Type, Union
|
|
|
4
4
|
|
|
5
5
|
from pydantic import BaseModel
|
|
6
6
|
|
|
7
|
-
from agno.exceptions import
|
|
7
|
+
from agno.exceptions import ModelAuthenticationError
|
|
8
8
|
from agno.models.openai.like import OpenAILike
|
|
9
9
|
|
|
10
10
|
|
|
@@ -43,10 +43,9 @@ class DashScope(OpenAILike):
|
|
|
43
43
|
if not self.api_key:
|
|
44
44
|
self.api_key = getenv("DASHSCOPE_API_KEY")
|
|
45
45
|
if not self.api_key:
|
|
46
|
-
raise
|
|
46
|
+
raise ModelAuthenticationError(
|
|
47
47
|
message="DASHSCOPE_API_KEY not set. Please set the DASHSCOPE_API_KEY environment variable.",
|
|
48
48
|
model_name=self.name,
|
|
49
|
-
model_id=self.id,
|
|
50
49
|
)
|
|
51
50
|
|
|
52
51
|
# Define base client params
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
2
|
from os import getenv
|
|
3
|
-
from typing import Optional
|
|
3
|
+
from typing import Any, Dict, Optional
|
|
4
4
|
|
|
5
|
+
from agno.exceptions import ModelAuthenticationError
|
|
5
6
|
from agno.models.openai.like import OpenAILike
|
|
6
7
|
|
|
7
8
|
|
|
@@ -26,3 +27,19 @@ class DeepInfra(OpenAILike):
|
|
|
26
27
|
base_url: str = "https://api.deepinfra.com/v1/openai"
|
|
27
28
|
|
|
28
29
|
supports_native_structured_outputs: bool = False
|
|
30
|
+
|
|
31
|
+
def _get_client_params(self) -> Dict[str, Any]:
|
|
32
|
+
"""
|
|
33
|
+
Returns client parameters for API requests, checking for DEEPINFRA_API_KEY.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Dict[str, Any]: A dictionary of client parameters for API requests.
|
|
37
|
+
"""
|
|
38
|
+
if not self.api_key:
|
|
39
|
+
self.api_key = getenv("DEEPINFRA_API_KEY")
|
|
40
|
+
if not self.api_key:
|
|
41
|
+
raise ModelAuthenticationError(
|
|
42
|
+
message="DEEPINFRA_API_KEY not set. Please set the DEEPINFRA_API_KEY environment variable.",
|
|
43
|
+
model_name=self.name,
|
|
44
|
+
)
|
|
45
|
+
return super()._get_client_params()
|
agno/models/deepseek/deepseek.py
CHANGED
|
@@ -2,8 +2,11 @@ from dataclasses import dataclass, field
|
|
|
2
2
|
from os import getenv
|
|
3
3
|
from typing import Any, Dict, Optional
|
|
4
4
|
|
|
5
|
-
from agno.exceptions import
|
|
5
|
+
from agno.exceptions import ModelAuthenticationError
|
|
6
|
+
from agno.models.message import Message
|
|
6
7
|
from agno.models.openai.like import OpenAILike
|
|
8
|
+
from agno.utils.log import log_warning
|
|
9
|
+
from agno.utils.openai import _format_file_for_message, audio_to_message, images_to_message
|
|
7
10
|
|
|
8
11
|
|
|
9
12
|
@dataclass
|
|
@@ -35,10 +38,9 @@ class DeepSeek(OpenAILike):
|
|
|
35
38
|
self.api_key = getenv("DEEPSEEK_API_KEY")
|
|
36
39
|
if not self.api_key:
|
|
37
40
|
# Raise error immediately if key is missing
|
|
38
|
-
raise
|
|
41
|
+
raise ModelAuthenticationError(
|
|
39
42
|
message="DEEPSEEK_API_KEY not set. Please set the DEEPSEEK_API_KEY environment variable.",
|
|
40
43
|
model_name=self.name,
|
|
41
|
-
model_id=self.id,
|
|
42
44
|
)
|
|
43
45
|
|
|
44
46
|
# Define base client params
|
|
@@ -59,3 +61,67 @@ class DeepSeek(OpenAILike):
|
|
|
59
61
|
if self.client_params:
|
|
60
62
|
client_params.update(self.client_params)
|
|
61
63
|
return client_params
|
|
64
|
+
|
|
65
|
+
def _format_message(self, message: Message, compress_tool_results: bool = False) -> Dict[str, Any]:
|
|
66
|
+
"""
|
|
67
|
+
Format a message into the format expected by OpenAI.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
message (Message): The message to format.
|
|
71
|
+
compress_tool_results: Whether to compress tool results.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Dict[str, Any]: The formatted message.
|
|
75
|
+
"""
|
|
76
|
+
tool_result = message.get_content(use_compressed_content=compress_tool_results)
|
|
77
|
+
|
|
78
|
+
message_dict: Dict[str, Any] = {
|
|
79
|
+
"role": self.role_map[message.role] if self.role_map else self.default_role_map[message.role],
|
|
80
|
+
"content": tool_result,
|
|
81
|
+
"name": message.name,
|
|
82
|
+
"tool_call_id": message.tool_call_id,
|
|
83
|
+
"tool_calls": message.tool_calls,
|
|
84
|
+
"reasoning_content": message.reasoning_content,
|
|
85
|
+
}
|
|
86
|
+
message_dict = {k: v for k, v in message_dict.items() if v is not None}
|
|
87
|
+
|
|
88
|
+
# Ignore non-string message content
|
|
89
|
+
# because we assume that the images/audio are already added to the message
|
|
90
|
+
if (message.images is not None and len(message.images) > 0) or (
|
|
91
|
+
message.audio is not None and len(message.audio) > 0
|
|
92
|
+
):
|
|
93
|
+
# Ignore non-string message content
|
|
94
|
+
# because we assume that the images/audio are already added to the message
|
|
95
|
+
if isinstance(message.content, str):
|
|
96
|
+
message_dict["content"] = [{"type": "text", "text": message.content}]
|
|
97
|
+
if message.images is not None:
|
|
98
|
+
message_dict["content"].extend(images_to_message(images=message.images))
|
|
99
|
+
|
|
100
|
+
if message.audio is not None:
|
|
101
|
+
message_dict["content"].extend(audio_to_message(audio=message.audio))
|
|
102
|
+
|
|
103
|
+
if message.audio_output is not None:
|
|
104
|
+
message_dict["content"] = ""
|
|
105
|
+
message_dict["audio"] = {"id": message.audio_output.id}
|
|
106
|
+
|
|
107
|
+
if message.videos is not None and len(message.videos) > 0:
|
|
108
|
+
log_warning("Video input is currently unsupported.")
|
|
109
|
+
|
|
110
|
+
if message.files is not None:
|
|
111
|
+
# Ensure content is a list of parts
|
|
112
|
+
content = message_dict.get("content")
|
|
113
|
+
if isinstance(content, str): # wrap existing text
|
|
114
|
+
text = content
|
|
115
|
+
message_dict["content"] = [{"type": "text", "text": text}]
|
|
116
|
+
elif content is None:
|
|
117
|
+
message_dict["content"] = []
|
|
118
|
+
# Insert each file part before text parts
|
|
119
|
+
for file in message.files:
|
|
120
|
+
file_part = _format_file_for_message(file)
|
|
121
|
+
if file_part:
|
|
122
|
+
message_dict["content"].insert(0, file_part)
|
|
123
|
+
|
|
124
|
+
# Manually add the content field even if it is None
|
|
125
|
+
if message.content is None:
|
|
126
|
+
message_dict["content"] = ""
|
|
127
|
+
return message_dict
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
2
|
from os import getenv
|
|
3
|
-
from typing import Optional
|
|
3
|
+
from typing import Any, Dict, Optional
|
|
4
4
|
|
|
5
|
+
from agno.exceptions import ModelAuthenticationError
|
|
5
6
|
from agno.models.openai import OpenAILike
|
|
6
7
|
|
|
7
8
|
|
|
@@ -24,3 +25,19 @@ class Fireworks(OpenAILike):
|
|
|
24
25
|
|
|
25
26
|
api_key: Optional[str] = field(default_factory=lambda: getenv("FIREWORKS_API_KEY"))
|
|
26
27
|
base_url: str = "https://api.fireworks.ai/inference/v1"
|
|
28
|
+
|
|
29
|
+
def _get_client_params(self) -> Dict[str, Any]:
|
|
30
|
+
"""
|
|
31
|
+
Returns client parameters for API requests, checking for FIREWORKS_API_KEY.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Dict[str, Any]: A dictionary of client parameters for API requests.
|
|
35
|
+
"""
|
|
36
|
+
if not self.api_key:
|
|
37
|
+
self.api_key = getenv("FIREWORKS_API_KEY")
|
|
38
|
+
if not self.api_key:
|
|
39
|
+
raise ModelAuthenticationError(
|
|
40
|
+
message="FIREWORKS_API_KEY not set. Please set the FIREWORKS_API_KEY environment variable.",
|
|
41
|
+
model_name=self.name,
|
|
42
|
+
)
|
|
43
|
+
return super()._get_client_params()
|