agno 2.2.13__py3-none-any.whl → 2.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/__init__.py +6 -0
- agno/agent/agent.py +5252 -3145
- agno/agent/remote.py +525 -0
- agno/api/api.py +2 -0
- agno/client/__init__.py +3 -0
- agno/client/a2a/__init__.py +10 -0
- agno/client/a2a/client.py +554 -0
- agno/client/a2a/schemas.py +112 -0
- agno/client/a2a/utils.py +369 -0
- agno/client/os.py +2669 -0
- agno/compression/__init__.py +3 -0
- agno/compression/manager.py +247 -0
- agno/culture/manager.py +2 -2
- agno/db/base.py +927 -6
- agno/db/dynamo/dynamo.py +788 -2
- agno/db/dynamo/schemas.py +128 -0
- agno/db/dynamo/utils.py +26 -3
- agno/db/firestore/firestore.py +674 -50
- agno/db/firestore/schemas.py +41 -0
- agno/db/firestore/utils.py +25 -10
- agno/db/gcs_json/gcs_json_db.py +506 -3
- agno/db/gcs_json/utils.py +14 -2
- agno/db/in_memory/in_memory_db.py +203 -4
- agno/db/in_memory/utils.py +14 -2
- agno/db/json/json_db.py +498 -2
- agno/db/json/utils.py +14 -2
- agno/db/migrations/manager.py +199 -0
- agno/db/migrations/utils.py +19 -0
- agno/db/migrations/v1_to_v2.py +54 -16
- agno/db/migrations/versions/__init__.py +0 -0
- agno/db/migrations/versions/v2_3_0.py +977 -0
- agno/db/mongo/async_mongo.py +1013 -39
- agno/db/mongo/mongo.py +684 -4
- agno/db/mongo/schemas.py +48 -0
- agno/db/mongo/utils.py +17 -0
- agno/db/mysql/__init__.py +2 -1
- agno/db/mysql/async_mysql.py +2958 -0
- agno/db/mysql/mysql.py +722 -53
- agno/db/mysql/schemas.py +77 -11
- agno/db/mysql/utils.py +151 -8
- agno/db/postgres/async_postgres.py +1254 -137
- agno/db/postgres/postgres.py +2316 -93
- agno/db/postgres/schemas.py +153 -21
- agno/db/postgres/utils.py +22 -7
- agno/db/redis/redis.py +531 -3
- agno/db/redis/schemas.py +36 -0
- agno/db/redis/utils.py +31 -15
- agno/db/schemas/evals.py +1 -0
- agno/db/schemas/memory.py +20 -9
- agno/db/singlestore/schemas.py +70 -1
- agno/db/singlestore/singlestore.py +737 -74
- agno/db/singlestore/utils.py +13 -3
- agno/db/sqlite/async_sqlite.py +1069 -89
- agno/db/sqlite/schemas.py +133 -1
- agno/db/sqlite/sqlite.py +2203 -165
- agno/db/sqlite/utils.py +21 -11
- agno/db/surrealdb/models.py +25 -0
- agno/db/surrealdb/surrealdb.py +603 -1
- agno/db/utils.py +60 -0
- agno/eval/__init__.py +26 -3
- agno/eval/accuracy.py +25 -12
- agno/eval/agent_as_judge.py +871 -0
- agno/eval/base.py +29 -0
- agno/eval/performance.py +10 -4
- agno/eval/reliability.py +22 -13
- agno/eval/utils.py +2 -1
- agno/exceptions.py +42 -0
- agno/hooks/__init__.py +3 -0
- agno/hooks/decorator.py +164 -0
- agno/integrations/discord/client.py +13 -2
- agno/knowledge/__init__.py +4 -0
- agno/knowledge/chunking/code.py +90 -0
- agno/knowledge/chunking/document.py +65 -4
- agno/knowledge/chunking/fixed.py +4 -1
- agno/knowledge/chunking/markdown.py +102 -11
- agno/knowledge/chunking/recursive.py +2 -2
- agno/knowledge/chunking/semantic.py +130 -48
- agno/knowledge/chunking/strategy.py +18 -0
- agno/knowledge/embedder/azure_openai.py +0 -1
- agno/knowledge/embedder/google.py +1 -1
- agno/knowledge/embedder/mistral.py +1 -1
- agno/knowledge/embedder/nebius.py +1 -1
- agno/knowledge/embedder/openai.py +16 -12
- agno/knowledge/filesystem.py +412 -0
- agno/knowledge/knowledge.py +4261 -1199
- agno/knowledge/protocol.py +134 -0
- agno/knowledge/reader/arxiv_reader.py +3 -2
- agno/knowledge/reader/base.py +9 -7
- agno/knowledge/reader/csv_reader.py +91 -42
- agno/knowledge/reader/docx_reader.py +9 -10
- agno/knowledge/reader/excel_reader.py +225 -0
- agno/knowledge/reader/field_labeled_csv_reader.py +38 -48
- agno/knowledge/reader/firecrawl_reader.py +3 -2
- agno/knowledge/reader/json_reader.py +16 -22
- agno/knowledge/reader/markdown_reader.py +15 -14
- agno/knowledge/reader/pdf_reader.py +33 -28
- agno/knowledge/reader/pptx_reader.py +9 -10
- agno/knowledge/reader/reader_factory.py +135 -1
- agno/knowledge/reader/s3_reader.py +8 -16
- agno/knowledge/reader/tavily_reader.py +3 -3
- agno/knowledge/reader/text_reader.py +15 -14
- agno/knowledge/reader/utils/__init__.py +17 -0
- agno/knowledge/reader/utils/spreadsheet.py +114 -0
- agno/knowledge/reader/web_search_reader.py +8 -65
- agno/knowledge/reader/website_reader.py +16 -13
- agno/knowledge/reader/wikipedia_reader.py +36 -3
- agno/knowledge/reader/youtube_reader.py +3 -2
- agno/knowledge/remote_content/__init__.py +33 -0
- agno/knowledge/remote_content/config.py +266 -0
- agno/knowledge/remote_content/remote_content.py +105 -17
- agno/knowledge/utils.py +76 -22
- agno/learn/__init__.py +71 -0
- agno/learn/config.py +463 -0
- agno/learn/curate.py +185 -0
- agno/learn/machine.py +725 -0
- agno/learn/schemas.py +1114 -0
- agno/learn/stores/__init__.py +38 -0
- agno/learn/stores/decision_log.py +1156 -0
- agno/learn/stores/entity_memory.py +3275 -0
- agno/learn/stores/learned_knowledge.py +1583 -0
- agno/learn/stores/protocol.py +117 -0
- agno/learn/stores/session_context.py +1217 -0
- agno/learn/stores/user_memory.py +1495 -0
- agno/learn/stores/user_profile.py +1220 -0
- agno/learn/utils.py +209 -0
- agno/media.py +22 -6
- agno/memory/__init__.py +14 -1
- agno/memory/manager.py +223 -8
- agno/memory/strategies/__init__.py +15 -0
- agno/memory/strategies/base.py +66 -0
- agno/memory/strategies/summarize.py +196 -0
- agno/memory/strategies/types.py +37 -0
- agno/models/aimlapi/aimlapi.py +17 -0
- agno/models/anthropic/claude.py +434 -59
- agno/models/aws/bedrock.py +121 -20
- agno/models/aws/claude.py +131 -274
- agno/models/azure/ai_foundry.py +10 -6
- agno/models/azure/openai_chat.py +33 -10
- agno/models/base.py +1162 -561
- agno/models/cerebras/cerebras.py +120 -24
- agno/models/cerebras/cerebras_openai.py +21 -2
- agno/models/cohere/chat.py +65 -6
- agno/models/cometapi/cometapi.py +18 -1
- agno/models/dashscope/dashscope.py +2 -3
- agno/models/deepinfra/deepinfra.py +18 -1
- agno/models/deepseek/deepseek.py +69 -3
- agno/models/fireworks/fireworks.py +18 -1
- agno/models/google/gemini.py +959 -89
- agno/models/google/utils.py +22 -0
- agno/models/groq/groq.py +48 -18
- agno/models/huggingface/huggingface.py +17 -6
- agno/models/ibm/watsonx.py +16 -6
- agno/models/internlm/internlm.py +18 -1
- agno/models/langdb/langdb.py +13 -1
- agno/models/litellm/chat.py +88 -9
- agno/models/litellm/litellm_openai.py +18 -1
- agno/models/message.py +24 -5
- agno/models/meta/llama.py +40 -13
- agno/models/meta/llama_openai.py +22 -21
- agno/models/metrics.py +12 -0
- agno/models/mistral/mistral.py +8 -4
- agno/models/n1n/__init__.py +3 -0
- agno/models/n1n/n1n.py +57 -0
- agno/models/nebius/nebius.py +6 -7
- agno/models/nvidia/nvidia.py +20 -3
- agno/models/ollama/__init__.py +2 -0
- agno/models/ollama/chat.py +17 -6
- agno/models/ollama/responses.py +100 -0
- agno/models/openai/__init__.py +2 -0
- agno/models/openai/chat.py +117 -26
- agno/models/openai/open_responses.py +46 -0
- agno/models/openai/responses.py +110 -32
- agno/models/openrouter/__init__.py +2 -0
- agno/models/openrouter/openrouter.py +67 -2
- agno/models/openrouter/responses.py +146 -0
- agno/models/perplexity/perplexity.py +19 -1
- agno/models/portkey/portkey.py +7 -6
- agno/models/requesty/requesty.py +19 -2
- agno/models/response.py +20 -2
- agno/models/sambanova/sambanova.py +20 -3
- agno/models/siliconflow/siliconflow.py +19 -2
- agno/models/together/together.py +20 -3
- agno/models/vercel/v0.py +20 -3
- agno/models/vertexai/claude.py +124 -4
- agno/models/vllm/vllm.py +19 -14
- agno/models/xai/xai.py +19 -2
- agno/os/app.py +467 -137
- agno/os/auth.py +253 -5
- agno/os/config.py +22 -0
- agno/os/interfaces/a2a/a2a.py +7 -6
- agno/os/interfaces/a2a/router.py +635 -26
- agno/os/interfaces/a2a/utils.py +32 -33
- agno/os/interfaces/agui/agui.py +5 -3
- agno/os/interfaces/agui/router.py +26 -16
- agno/os/interfaces/agui/utils.py +97 -57
- agno/os/interfaces/base.py +7 -7
- agno/os/interfaces/slack/router.py +16 -7
- agno/os/interfaces/slack/slack.py +7 -7
- agno/os/interfaces/whatsapp/router.py +35 -7
- agno/os/interfaces/whatsapp/security.py +3 -1
- agno/os/interfaces/whatsapp/whatsapp.py +11 -8
- agno/os/managers.py +326 -0
- agno/os/mcp.py +652 -79
- agno/os/middleware/__init__.py +4 -0
- agno/os/middleware/jwt.py +718 -115
- agno/os/middleware/trailing_slash.py +27 -0
- agno/os/router.py +105 -1558
- agno/os/routers/agents/__init__.py +3 -0
- agno/os/routers/agents/router.py +655 -0
- agno/os/routers/agents/schema.py +288 -0
- agno/os/routers/components/__init__.py +3 -0
- agno/os/routers/components/components.py +475 -0
- agno/os/routers/database.py +155 -0
- agno/os/routers/evals/evals.py +111 -18
- agno/os/routers/evals/schemas.py +38 -5
- agno/os/routers/evals/utils.py +80 -11
- agno/os/routers/health.py +3 -3
- agno/os/routers/knowledge/knowledge.py +284 -35
- agno/os/routers/knowledge/schemas.py +14 -2
- agno/os/routers/memory/memory.py +274 -11
- agno/os/routers/memory/schemas.py +44 -3
- agno/os/routers/metrics/metrics.py +30 -15
- agno/os/routers/metrics/schemas.py +10 -6
- agno/os/routers/registry/__init__.py +3 -0
- agno/os/routers/registry/registry.py +337 -0
- agno/os/routers/session/session.py +143 -14
- agno/os/routers/teams/__init__.py +3 -0
- agno/os/routers/teams/router.py +550 -0
- agno/os/routers/teams/schema.py +280 -0
- agno/os/routers/traces/__init__.py +3 -0
- agno/os/routers/traces/schemas.py +414 -0
- agno/os/routers/traces/traces.py +549 -0
- agno/os/routers/workflows/__init__.py +3 -0
- agno/os/routers/workflows/router.py +757 -0
- agno/os/routers/workflows/schema.py +139 -0
- agno/os/schema.py +157 -584
- agno/os/scopes.py +469 -0
- agno/os/settings.py +3 -0
- agno/os/utils.py +574 -185
- agno/reasoning/anthropic.py +85 -1
- agno/reasoning/azure_ai_foundry.py +93 -1
- agno/reasoning/deepseek.py +102 -2
- agno/reasoning/default.py +6 -7
- agno/reasoning/gemini.py +87 -3
- agno/reasoning/groq.py +109 -2
- agno/reasoning/helpers.py +6 -7
- agno/reasoning/manager.py +1238 -0
- agno/reasoning/ollama.py +93 -1
- agno/reasoning/openai.py +115 -1
- agno/reasoning/vertexai.py +85 -1
- agno/registry/__init__.py +3 -0
- agno/registry/registry.py +68 -0
- agno/remote/__init__.py +3 -0
- agno/remote/base.py +581 -0
- agno/run/__init__.py +2 -4
- agno/run/agent.py +134 -19
- agno/run/base.py +49 -1
- agno/run/cancel.py +65 -52
- agno/run/cancellation_management/__init__.py +9 -0
- agno/run/cancellation_management/base.py +78 -0
- agno/run/cancellation_management/in_memory_cancellation_manager.py +100 -0
- agno/run/cancellation_management/redis_cancellation_manager.py +236 -0
- agno/run/requirement.py +181 -0
- agno/run/team.py +111 -19
- agno/run/workflow.py +2 -1
- agno/session/agent.py +57 -92
- agno/session/summary.py +1 -1
- agno/session/team.py +62 -115
- agno/session/workflow.py +353 -57
- agno/skills/__init__.py +17 -0
- agno/skills/agent_skills.py +377 -0
- agno/skills/errors.py +32 -0
- agno/skills/loaders/__init__.py +4 -0
- agno/skills/loaders/base.py +27 -0
- agno/skills/loaders/local.py +216 -0
- agno/skills/skill.py +65 -0
- agno/skills/utils.py +107 -0
- agno/skills/validator.py +277 -0
- agno/table.py +10 -0
- agno/team/__init__.py +5 -1
- agno/team/remote.py +447 -0
- agno/team/team.py +3769 -2202
- agno/tools/brandfetch.py +27 -18
- agno/tools/browserbase.py +225 -16
- agno/tools/crawl4ai.py +3 -0
- agno/tools/duckduckgo.py +25 -71
- agno/tools/exa.py +0 -21
- agno/tools/file.py +14 -13
- agno/tools/file_generation.py +12 -6
- agno/tools/firecrawl.py +15 -7
- agno/tools/function.py +94 -113
- agno/tools/google_bigquery.py +11 -2
- agno/tools/google_drive.py +4 -3
- agno/tools/knowledge.py +9 -4
- agno/tools/mcp/mcp.py +301 -18
- agno/tools/mcp/multi_mcp.py +269 -14
- agno/tools/mem0.py +11 -10
- agno/tools/memory.py +47 -46
- agno/tools/mlx_transcribe.py +10 -7
- agno/tools/models/nebius.py +5 -5
- agno/tools/models_labs.py +20 -10
- agno/tools/nano_banana.py +151 -0
- agno/tools/parallel.py +0 -7
- agno/tools/postgres.py +76 -36
- agno/tools/python.py +14 -6
- agno/tools/reasoning.py +30 -23
- agno/tools/redshift.py +406 -0
- agno/tools/shopify.py +1519 -0
- agno/tools/spotify.py +919 -0
- agno/tools/tavily.py +4 -1
- agno/tools/toolkit.py +253 -18
- agno/tools/websearch.py +93 -0
- agno/tools/website.py +1 -1
- agno/tools/wikipedia.py +1 -1
- agno/tools/workflow.py +56 -48
- agno/tools/yfinance.py +12 -11
- agno/tracing/__init__.py +12 -0
- agno/tracing/exporter.py +161 -0
- agno/tracing/schemas.py +276 -0
- agno/tracing/setup.py +112 -0
- agno/utils/agent.py +251 -10
- agno/utils/cryptography.py +22 -0
- agno/utils/dttm.py +33 -0
- agno/utils/events.py +264 -7
- agno/utils/hooks.py +111 -3
- agno/utils/http.py +161 -2
- agno/utils/mcp.py +49 -8
- agno/utils/media.py +22 -1
- agno/utils/models/ai_foundry.py +9 -2
- agno/utils/models/claude.py +20 -5
- agno/utils/models/cohere.py +9 -2
- agno/utils/models/llama.py +9 -2
- agno/utils/models/mistral.py +4 -2
- agno/utils/os.py +0 -0
- agno/utils/print_response/agent.py +99 -16
- agno/utils/print_response/team.py +223 -24
- agno/utils/print_response/workflow.py +0 -2
- agno/utils/prompts.py +8 -6
- agno/utils/remote.py +23 -0
- agno/utils/response.py +1 -13
- agno/utils/string.py +91 -2
- agno/utils/team.py +62 -12
- agno/utils/tokens.py +657 -0
- agno/vectordb/base.py +15 -2
- agno/vectordb/cassandra/cassandra.py +1 -1
- agno/vectordb/chroma/__init__.py +2 -1
- agno/vectordb/chroma/chromadb.py +468 -23
- agno/vectordb/clickhouse/clickhousedb.py +1 -1
- agno/vectordb/couchbase/couchbase.py +6 -2
- agno/vectordb/lancedb/lance_db.py +7 -38
- agno/vectordb/lightrag/lightrag.py +7 -6
- agno/vectordb/milvus/milvus.py +118 -84
- agno/vectordb/mongodb/__init__.py +2 -1
- agno/vectordb/mongodb/mongodb.py +14 -31
- agno/vectordb/pgvector/pgvector.py +120 -66
- agno/vectordb/pineconedb/pineconedb.py +2 -19
- agno/vectordb/qdrant/__init__.py +2 -1
- agno/vectordb/qdrant/qdrant.py +33 -56
- agno/vectordb/redis/__init__.py +2 -1
- agno/vectordb/redis/redisdb.py +19 -31
- agno/vectordb/singlestore/singlestore.py +17 -9
- agno/vectordb/surrealdb/surrealdb.py +2 -38
- agno/vectordb/weaviate/__init__.py +2 -1
- agno/vectordb/weaviate/weaviate.py +7 -3
- agno/workflow/__init__.py +5 -1
- agno/workflow/agent.py +2 -2
- agno/workflow/condition.py +12 -10
- agno/workflow/loop.py +28 -9
- agno/workflow/parallel.py +21 -13
- agno/workflow/remote.py +362 -0
- agno/workflow/router.py +12 -9
- agno/workflow/step.py +261 -36
- agno/workflow/steps.py +12 -8
- agno/workflow/types.py +40 -77
- agno/workflow/workflow.py +939 -213
- {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/METADATA +134 -181
- agno-2.4.3.dist-info/RECORD +677 -0
- {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/WHEEL +1 -1
- agno/tools/googlesearch.py +0 -98
- agno/tools/memori.py +0 -339
- agno-2.2.13.dist-info/RECORD +0 -575
- {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/licenses/LICENSE +0 -0
- {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/top_level.txt +0 -0
agno/models/cerebras/cerebras.py
CHANGED
|
@@ -12,6 +12,7 @@ from agno.models.message import Message
|
|
|
12
12
|
from agno.models.metrics import Metrics
|
|
13
13
|
from agno.models.response import ModelResponse
|
|
14
14
|
from agno.run.agent import RunOutput
|
|
15
|
+
from agno.utils.http import get_default_async_client, get_default_sync_client
|
|
15
16
|
from agno.utils.log import log_debug, log_error, log_warning
|
|
16
17
|
|
|
17
18
|
try:
|
|
@@ -96,6 +97,35 @@ class Cerebras(Model):
|
|
|
96
97
|
client_params.update(self.client_params)
|
|
97
98
|
return client_params
|
|
98
99
|
|
|
100
|
+
def _ensure_additional_properties_false(self, schema: Dict[str, Any]) -> None:
|
|
101
|
+
"""
|
|
102
|
+
Recursively ensure all object types have additionalProperties: false.
|
|
103
|
+
Cerebras API requires this for JSON schema validation.
|
|
104
|
+
"""
|
|
105
|
+
if not isinstance(schema, dict):
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
# Set additionalProperties: false for object types
|
|
109
|
+
if schema.get("type") == "object":
|
|
110
|
+
schema["additionalProperties"] = False
|
|
111
|
+
|
|
112
|
+
# Recursively process nested schemas
|
|
113
|
+
if "properties" in schema and isinstance(schema["properties"], dict):
|
|
114
|
+
for prop_schema in schema["properties"].values():
|
|
115
|
+
self._ensure_additional_properties_false(prop_schema)
|
|
116
|
+
|
|
117
|
+
if "items" in schema:
|
|
118
|
+
self._ensure_additional_properties_false(schema["items"])
|
|
119
|
+
|
|
120
|
+
if "$defs" in schema and isinstance(schema["$defs"], dict):
|
|
121
|
+
for def_schema in schema["$defs"].values():
|
|
122
|
+
self._ensure_additional_properties_false(def_schema)
|
|
123
|
+
|
|
124
|
+
for key in ["allOf", "anyOf", "oneOf"]:
|
|
125
|
+
if key in schema and isinstance(schema[key], list):
|
|
126
|
+
for item in schema[key]:
|
|
127
|
+
self._ensure_additional_properties_false(item)
|
|
128
|
+
|
|
99
129
|
def get_client(self) -> CerebrasClient:
|
|
100
130
|
"""
|
|
101
131
|
Returns a Cerebras client.
|
|
@@ -107,11 +137,11 @@ class Cerebras(Model):
|
|
|
107
137
|
return self.client
|
|
108
138
|
|
|
109
139
|
client_params: Dict[str, Any] = self._get_client_params()
|
|
110
|
-
if self.http_client:
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
140
|
+
if self.http_client is not None:
|
|
141
|
+
client_params["http_client"] = self.http_client
|
|
142
|
+
else:
|
|
143
|
+
# Use global sync client when no custom http_client is provided
|
|
144
|
+
client_params["http_client"] = get_default_sync_client()
|
|
115
145
|
self.client = CerebrasClient(**client_params)
|
|
116
146
|
return self.client
|
|
117
147
|
|
|
@@ -129,12 +159,8 @@ class Cerebras(Model):
|
|
|
129
159
|
if self.http_client and isinstance(self.http_client, httpx.AsyncClient):
|
|
130
160
|
client_params["http_client"] = self.http_client
|
|
131
161
|
else:
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
# Create a new async HTTP client with custom limits
|
|
135
|
-
client_params["http_client"] = httpx.AsyncClient(
|
|
136
|
-
limits=httpx.Limits(max_connections=1000, max_keepalive_connections=100)
|
|
137
|
-
)
|
|
162
|
+
# Use global async client when no custom http_client is provided
|
|
163
|
+
client_params["http_client"] = get_default_async_client()
|
|
138
164
|
self.async_client = AsyncCerebrasClient(**client_params)
|
|
139
165
|
return self.async_client
|
|
140
166
|
|
|
@@ -194,8 +220,11 @@ class Cerebras(Model):
|
|
|
194
220
|
):
|
|
195
221
|
# Ensure json_schema has strict parameter set
|
|
196
222
|
schema = response_format["json_schema"]
|
|
197
|
-
if isinstance(schema.get("schema"), dict)
|
|
198
|
-
|
|
223
|
+
if isinstance(schema.get("schema"), dict):
|
|
224
|
+
if "strict" not in schema:
|
|
225
|
+
schema["strict"] = self.strict_output
|
|
226
|
+
# Cerebras requires additionalProperties: false for all object types
|
|
227
|
+
self._ensure_additional_properties_false(schema["schema"])
|
|
199
228
|
|
|
200
229
|
request_params["response_format"] = response_format
|
|
201
230
|
|
|
@@ -215,6 +244,7 @@ class Cerebras(Model):
|
|
|
215
244
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
216
245
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
217
246
|
run_response: Optional[RunOutput] = None,
|
|
247
|
+
compress_tool_results: bool = False,
|
|
218
248
|
) -> ModelResponse:
|
|
219
249
|
"""
|
|
220
250
|
Send a chat completion request to the Cerebras API.
|
|
@@ -231,7 +261,7 @@ class Cerebras(Model):
|
|
|
231
261
|
assistant_message.metrics.start_timer()
|
|
232
262
|
provider_response = self.get_client().chat.completions.create(
|
|
233
263
|
model=self.id,
|
|
234
|
-
messages=[self._format_message(m) for m in messages], # type: ignore
|
|
264
|
+
messages=[self._format_message(m, compress_tool_results) for m in messages], # type: ignore
|
|
235
265
|
**self.get_request_params(response_format=response_format, tools=tools),
|
|
236
266
|
)
|
|
237
267
|
assistant_message.metrics.stop_timer()
|
|
@@ -248,6 +278,7 @@ class Cerebras(Model):
|
|
|
248
278
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
249
279
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
250
280
|
run_response: Optional[RunOutput] = None,
|
|
281
|
+
compress_tool_results: bool = False,
|
|
251
282
|
) -> ModelResponse:
|
|
252
283
|
"""
|
|
253
284
|
Sends an asynchronous chat completion request to the Cerebras API.
|
|
@@ -264,7 +295,7 @@ class Cerebras(Model):
|
|
|
264
295
|
assistant_message.metrics.start_timer()
|
|
265
296
|
provider_response = await self.get_async_client().chat.completions.create(
|
|
266
297
|
model=self.id,
|
|
267
|
-
messages=[self._format_message(m) for m in messages], # type: ignore
|
|
298
|
+
messages=[self._format_message(m, compress_tool_results) for m in messages], # type: ignore
|
|
268
299
|
**self.get_request_params(response_format=response_format, tools=tools),
|
|
269
300
|
)
|
|
270
301
|
assistant_message.metrics.stop_timer()
|
|
@@ -281,6 +312,7 @@ class Cerebras(Model):
|
|
|
281
312
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
282
313
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
283
314
|
run_response: Optional[RunOutput] = None,
|
|
315
|
+
compress_tool_results: bool = False,
|
|
284
316
|
) -> Iterator[ModelResponse]:
|
|
285
317
|
"""
|
|
286
318
|
Send a streaming chat completion request to the Cerebras API.
|
|
@@ -298,7 +330,7 @@ class Cerebras(Model):
|
|
|
298
330
|
|
|
299
331
|
for chunk in self.get_client().chat.completions.create(
|
|
300
332
|
model=self.id,
|
|
301
|
-
messages=[self._format_message(m) for m in messages], # type: ignore
|
|
333
|
+
messages=[self._format_message(m, compress_tool_results) for m in messages], # type: ignore
|
|
302
334
|
stream=True,
|
|
303
335
|
**self.get_request_params(response_format=response_format, tools=tools),
|
|
304
336
|
):
|
|
@@ -314,6 +346,7 @@ class Cerebras(Model):
|
|
|
314
346
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
315
347
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
316
348
|
run_response: Optional[RunOutput] = None,
|
|
349
|
+
compress_tool_results: bool = False,
|
|
317
350
|
) -> AsyncIterator[ModelResponse]:
|
|
318
351
|
"""
|
|
319
352
|
Sends an asynchronous streaming chat completion request to the Cerebras API.
|
|
@@ -331,7 +364,7 @@ class Cerebras(Model):
|
|
|
331
364
|
|
|
332
365
|
async_stream = await self.get_async_client().chat.completions.create(
|
|
333
366
|
model=self.id,
|
|
334
|
-
messages=[self._format_message(m) for m in messages], # type: ignore
|
|
367
|
+
messages=[self._format_message(m, compress_tool_results) for m in messages], # type: ignore
|
|
335
368
|
stream=True,
|
|
336
369
|
**self.get_request_params(response_format=response_format, tools=tools),
|
|
337
370
|
)
|
|
@@ -341,20 +374,27 @@ class Cerebras(Model):
|
|
|
341
374
|
|
|
342
375
|
assistant_message.metrics.stop_timer()
|
|
343
376
|
|
|
344
|
-
def _format_message(self, message: Message) -> Dict[str, Any]:
|
|
377
|
+
def _format_message(self, message: Message, compress_tool_results: bool = False) -> Dict[str, Any]:
|
|
345
378
|
"""
|
|
346
379
|
Format a message into the format expected by the Cerebras API.
|
|
347
380
|
|
|
348
381
|
Args:
|
|
349
382
|
message (Message): The message to format.
|
|
383
|
+
compress_tool_results: Whether to compress tool results.
|
|
350
384
|
|
|
351
385
|
Returns:
|
|
352
386
|
Dict[str, Any]: The formatted message.
|
|
353
387
|
"""
|
|
388
|
+
# Use compressed content for tool messages if compression is active
|
|
389
|
+
if message.role == "tool":
|
|
390
|
+
content = message.get_content(use_compressed_content=compress_tool_results)
|
|
391
|
+
else:
|
|
392
|
+
content = message.content if message.content is not None else ""
|
|
393
|
+
|
|
354
394
|
# Basic message content
|
|
355
395
|
message_dict: Dict[str, Any] = {
|
|
356
396
|
"role": message.role,
|
|
357
|
-
"content":
|
|
397
|
+
"content": content,
|
|
358
398
|
}
|
|
359
399
|
|
|
360
400
|
# Add name if present
|
|
@@ -383,7 +423,7 @@ class Cerebras(Model):
|
|
|
383
423
|
message_dict = {
|
|
384
424
|
"role": "tool",
|
|
385
425
|
"tool_call_id": message.tool_call_id,
|
|
386
|
-
"content":
|
|
426
|
+
"content": content,
|
|
387
427
|
}
|
|
388
428
|
|
|
389
429
|
# Ensure no None values in the message
|
|
@@ -462,18 +502,19 @@ class Cerebras(Model):
|
|
|
462
502
|
if choice_delta.content:
|
|
463
503
|
model_response.content = choice_delta.content
|
|
464
504
|
|
|
465
|
-
# Add tool calls
|
|
505
|
+
# Add tool calls - preserve index for proper aggregation in parse_tool_calls
|
|
466
506
|
if choice_delta.tool_calls:
|
|
467
507
|
model_response.tool_calls = [
|
|
468
508
|
{
|
|
509
|
+
"index": tool_call.index if hasattr(tool_call, "index") else idx,
|
|
469
510
|
"id": tool_call.id,
|
|
470
511
|
"type": tool_call.type,
|
|
471
512
|
"function": {
|
|
472
|
-
"name": tool_call.function.name,
|
|
473
|
-
"arguments": tool_call.function.arguments,
|
|
513
|
+
"name": tool_call.function.name if tool_call.function else None,
|
|
514
|
+
"arguments": tool_call.function.arguments if tool_call.function else None,
|
|
474
515
|
},
|
|
475
516
|
}
|
|
476
|
-
for tool_call in choice_delta.tool_calls
|
|
517
|
+
for idx, tool_call in enumerate(choice_delta.tool_calls)
|
|
477
518
|
]
|
|
478
519
|
|
|
479
520
|
# Add usage metrics
|
|
@@ -482,6 +523,61 @@ class Cerebras(Model):
|
|
|
482
523
|
|
|
483
524
|
return model_response
|
|
484
525
|
|
|
526
|
+
def parse_tool_calls(self, tool_calls_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
527
|
+
"""
|
|
528
|
+
Build complete tool calls from streamed tool call delta data.
|
|
529
|
+
|
|
530
|
+
Cerebras streams tool calls incrementally with partial data in each chunk.
|
|
531
|
+
This method aggregates those chunks by index to produce complete tool calls.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
tool_calls_data: List of tool call deltas from streaming chunks.
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
List[Dict[str, Any]]: List of fully-formed tool call dicts.
|
|
538
|
+
"""
|
|
539
|
+
tool_calls: List[Dict[str, Any]] = []
|
|
540
|
+
|
|
541
|
+
for tool_call_delta in tool_calls_data:
|
|
542
|
+
# Get the index for this tool call (default to 0 if not present)
|
|
543
|
+
index = tool_call_delta.get("index", 0)
|
|
544
|
+
|
|
545
|
+
# Extend the list if needed
|
|
546
|
+
while len(tool_calls) <= index:
|
|
547
|
+
tool_calls.append(
|
|
548
|
+
{
|
|
549
|
+
"id": None,
|
|
550
|
+
"type": None,
|
|
551
|
+
"function": {
|
|
552
|
+
"name": "",
|
|
553
|
+
"arguments": "",
|
|
554
|
+
},
|
|
555
|
+
}
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
tool_call_entry = tool_calls[index]
|
|
559
|
+
|
|
560
|
+
# Update id if present
|
|
561
|
+
if tool_call_delta.get("id"):
|
|
562
|
+
tool_call_entry["id"] = tool_call_delta["id"]
|
|
563
|
+
|
|
564
|
+
# Update type if present
|
|
565
|
+
if tool_call_delta.get("type"):
|
|
566
|
+
tool_call_entry["type"] = tool_call_delta["type"]
|
|
567
|
+
|
|
568
|
+
# Update function name and arguments (concatenate for streaming)
|
|
569
|
+
if tool_call_delta.get("function"):
|
|
570
|
+
func_delta = tool_call_delta["function"]
|
|
571
|
+
if func_delta.get("name"):
|
|
572
|
+
tool_call_entry["function"]["name"] += func_delta["name"]
|
|
573
|
+
if func_delta.get("arguments"):
|
|
574
|
+
tool_call_entry["function"]["arguments"] += func_delta["arguments"]
|
|
575
|
+
|
|
576
|
+
# Filter out any incomplete tool calls (missing id or function name)
|
|
577
|
+
complete_tool_calls = [tc for tc in tool_calls if tc.get("id") and tc.get("function", {}).get("name")]
|
|
578
|
+
|
|
579
|
+
return complete_tool_calls
|
|
580
|
+
|
|
485
581
|
def _get_metrics(self, response_usage: Union[ChatCompletionResponseUsage, ChatChunkResponseUsage]) -> Metrics:
|
|
486
582
|
"""
|
|
487
583
|
Parse the given Cerebras usage into an Agno Metrics object.
|
|
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Type, Union
|
|
|
5
5
|
|
|
6
6
|
from pydantic import BaseModel
|
|
7
7
|
|
|
8
|
+
from agno.exceptions import ModelAuthenticationError
|
|
8
9
|
from agno.models.message import Message
|
|
9
10
|
from agno.models.openai.like import OpenAILike
|
|
10
11
|
from agno.utils.log import log_debug
|
|
@@ -20,6 +21,22 @@ class CerebrasOpenAI(OpenAILike):
|
|
|
20
21
|
base_url: str = "https://api.cerebras.ai/v1"
|
|
21
22
|
api_key: Optional[str] = field(default_factory=lambda: getenv("CEREBRAS_API_KEY", None))
|
|
22
23
|
|
|
24
|
+
def _get_client_params(self) -> Dict[str, Any]:
|
|
25
|
+
"""
|
|
26
|
+
Returns client parameters for API requests, checking for CEREBRAS_API_KEY.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Dict[str, Any]: A dictionary of client parameters for API requests.
|
|
30
|
+
"""
|
|
31
|
+
if not self.api_key:
|
|
32
|
+
self.api_key = getenv("CEREBRAS_API_KEY")
|
|
33
|
+
if not self.api_key:
|
|
34
|
+
raise ModelAuthenticationError(
|
|
35
|
+
message="CEREBRAS_API_KEY not set. Please set the CEREBRAS_API_KEY environment variable.",
|
|
36
|
+
model_name=self.name,
|
|
37
|
+
)
|
|
38
|
+
return super()._get_client_params()
|
|
39
|
+
|
|
23
40
|
def get_request_params(
|
|
24
41
|
self,
|
|
25
42
|
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
@@ -61,7 +78,7 @@ class CerebrasOpenAI(OpenAILike):
|
|
|
61
78
|
log_debug(f"Calling {self.provider} with request parameters: {request_params}", log_level=2)
|
|
62
79
|
return request_params
|
|
63
80
|
|
|
64
|
-
def _format_message(self, message: Message) -> Dict[str, Any]:
|
|
81
|
+
def _format_message(self, message: Message, compress_tool_results: bool = False) -> Dict[str, Any]:
|
|
65
82
|
"""
|
|
66
83
|
Format a message into the format expected by the Cerebras API.
|
|
67
84
|
|
|
@@ -71,6 +88,7 @@ class CerebrasOpenAI(OpenAILike):
|
|
|
71
88
|
Returns:
|
|
72
89
|
Dict[str, Any]: The formatted message.
|
|
73
90
|
"""
|
|
91
|
+
|
|
74
92
|
# Basic message content
|
|
75
93
|
message_dict: Dict[str, Any] = {
|
|
76
94
|
"role": message.role,
|
|
@@ -100,10 +118,11 @@ class CerebrasOpenAI(OpenAILike):
|
|
|
100
118
|
|
|
101
119
|
# Handle tool responses
|
|
102
120
|
if message.role == "tool" and message.tool_call_id:
|
|
121
|
+
content = message.get_content(use_compressed_content=compress_tool_results)
|
|
103
122
|
message_dict = {
|
|
104
123
|
"role": "tool",
|
|
105
124
|
"tool_call_id": message.tool_call_id,
|
|
106
|
-
"content":
|
|
125
|
+
"content": content if message.content is not None else "",
|
|
107
126
|
}
|
|
108
127
|
|
|
109
128
|
# Ensure no None values in the message
|
agno/models/cohere/chat.py
CHANGED
|
@@ -2,6 +2,7 @@ from dataclasses import dataclass
|
|
|
2
2
|
from os import getenv
|
|
3
3
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, Type, Union
|
|
4
4
|
|
|
5
|
+
import httpx
|
|
5
6
|
from pydantic import BaseModel
|
|
6
7
|
|
|
7
8
|
from agno.exceptions import ModelProviderError
|
|
@@ -10,7 +11,8 @@ from agno.models.message import Message
|
|
|
10
11
|
from agno.models.metrics import Metrics
|
|
11
12
|
from agno.models.response import ModelResponse
|
|
12
13
|
from agno.run.agent import RunOutput
|
|
13
|
-
from agno.utils.
|
|
14
|
+
from agno.utils.http import get_default_async_client, get_default_sync_client
|
|
15
|
+
from agno.utils.log import log_debug, log_error, log_warning
|
|
14
16
|
from agno.utils.models.cohere import format_messages
|
|
15
17
|
|
|
16
18
|
try:
|
|
@@ -50,6 +52,7 @@ class Cohere(Model):
|
|
|
50
52
|
# -*- Client parameters
|
|
51
53
|
api_key: Optional[str] = None
|
|
52
54
|
client_params: Optional[Dict[str, Any]] = None
|
|
55
|
+
http_client: Optional[Union[httpx.Client, httpx.AsyncClient]] = None
|
|
53
56
|
# -*- Provide the Cohere client manually
|
|
54
57
|
client: Optional[CohereClient] = None
|
|
55
58
|
async_client: Optional[CohereAsyncClient] = None
|
|
@@ -66,6 +69,17 @@ class Cohere(Model):
|
|
|
66
69
|
|
|
67
70
|
_client_params["api_key"] = self.api_key
|
|
68
71
|
|
|
72
|
+
if self.http_client:
|
|
73
|
+
if isinstance(self.http_client, httpx.Client):
|
|
74
|
+
_client_params["httpx_client"] = self.http_client
|
|
75
|
+
else:
|
|
76
|
+
log_warning("http_client is not an instance of httpx.Client. Using default global httpx.Client.")
|
|
77
|
+
# Use global sync client when user http_client is invalid
|
|
78
|
+
_client_params["httpx_client"] = get_default_sync_client()
|
|
79
|
+
else:
|
|
80
|
+
# Use global sync client when no custom http_client is provided
|
|
81
|
+
_client_params["httpx_client"] = get_default_sync_client()
|
|
82
|
+
|
|
69
83
|
self.client = CohereClient(**_client_params)
|
|
70
84
|
return self.client # type: ignore
|
|
71
85
|
|
|
@@ -78,13 +92,54 @@ class Cohere(Model):
|
|
|
78
92
|
self.api_key = self.api_key or getenv("CO_API_KEY")
|
|
79
93
|
|
|
80
94
|
if not self.api_key:
|
|
81
|
-
|
|
95
|
+
raise ModelProviderError(
|
|
96
|
+
message="CO_API_KEY not set. Please set the CO_API_KEY environment variable.",
|
|
97
|
+
model_name=self.name,
|
|
98
|
+
model_id=self.id,
|
|
99
|
+
)
|
|
82
100
|
|
|
83
101
|
_client_params["api_key"] = self.api_key
|
|
84
102
|
|
|
103
|
+
if self.http_client:
|
|
104
|
+
if isinstance(self.http_client, httpx.AsyncClient):
|
|
105
|
+
_client_params["httpx_client"] = self.http_client
|
|
106
|
+
else:
|
|
107
|
+
log_warning(
|
|
108
|
+
"http_client is not an instance of httpx.AsyncClient. Using default global httpx.AsyncClient."
|
|
109
|
+
)
|
|
110
|
+
# Use global async client when user http_client is invalid
|
|
111
|
+
_client_params["httpx_client"] = get_default_async_client()
|
|
112
|
+
else:
|
|
113
|
+
# Use global async client when no custom http_client is provided
|
|
114
|
+
_client_params["httpx_client"] = get_default_async_client()
|
|
85
115
|
self.async_client = CohereAsyncClient(**_client_params)
|
|
86
116
|
return self.async_client # type: ignore
|
|
87
117
|
|
|
118
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
119
|
+
"""
|
|
120
|
+
Convert the model to a dictionary.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Dict[str, Any]: The dictionary representation of the model.
|
|
124
|
+
"""
|
|
125
|
+
model_dict = super().to_dict()
|
|
126
|
+
model_dict.update(
|
|
127
|
+
{
|
|
128
|
+
"temperature": self.temperature,
|
|
129
|
+
"max_tokens": self.max_tokens,
|
|
130
|
+
"top_k": self.top_k,
|
|
131
|
+
"top_p": self.top_p,
|
|
132
|
+
"seed": self.seed,
|
|
133
|
+
"frequency_penalty": self.frequency_penalty,
|
|
134
|
+
"presence_penalty": self.presence_penalty,
|
|
135
|
+
"logprobs": self.logprobs,
|
|
136
|
+
"strict_tools": self.strict_tools,
|
|
137
|
+
"add_chat_history": self.add_chat_history,
|
|
138
|
+
}
|
|
139
|
+
)
|
|
140
|
+
cleaned_dict = {k: v for k, v in model_dict.items() if v is not None}
|
|
141
|
+
return cleaned_dict
|
|
142
|
+
|
|
88
143
|
def get_request_params(
|
|
89
144
|
self,
|
|
90
145
|
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
@@ -155,6 +210,7 @@ class Cohere(Model):
|
|
|
155
210
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
156
211
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
157
212
|
run_response: Optional[RunOutput] = None,
|
|
213
|
+
compress_tool_results: bool = False,
|
|
158
214
|
) -> ModelResponse:
|
|
159
215
|
"""
|
|
160
216
|
Invoke a non-streamed chat response from the Cohere API.
|
|
@@ -168,7 +224,7 @@ class Cohere(Model):
|
|
|
168
224
|
assistant_message.metrics.start_timer()
|
|
169
225
|
provider_response = self.get_client().chat(
|
|
170
226
|
model=self.id,
|
|
171
|
-
messages=format_messages(messages), # type: ignore
|
|
227
|
+
messages=format_messages(messages, compress_tool_results), # type: ignore
|
|
172
228
|
**request_kwargs,
|
|
173
229
|
) # type: ignore
|
|
174
230
|
assistant_message.metrics.stop_timer()
|
|
@@ -189,6 +245,7 @@ class Cohere(Model):
|
|
|
189
245
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
190
246
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
191
247
|
run_response: Optional[RunOutput] = None,
|
|
248
|
+
compress_tool_results: bool = False,
|
|
192
249
|
) -> Iterator[ModelResponse]:
|
|
193
250
|
"""
|
|
194
251
|
Invoke a streamed chat response from the Cohere API.
|
|
@@ -205,7 +262,7 @@ class Cohere(Model):
|
|
|
205
262
|
|
|
206
263
|
for response in self.get_client().chat_stream(
|
|
207
264
|
model=self.id,
|
|
208
|
-
messages=format_messages(messages), # type: ignore
|
|
265
|
+
messages=format_messages(messages, compress_tool_results), # type: ignore
|
|
209
266
|
**request_kwargs,
|
|
210
267
|
):
|
|
211
268
|
model_response, tool_use = self._parse_provider_response_delta(response, tool_use=tool_use)
|
|
@@ -225,6 +282,7 @@ class Cohere(Model):
|
|
|
225
282
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
226
283
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
227
284
|
run_response: Optional[RunOutput] = None,
|
|
285
|
+
compress_tool_results: bool = False,
|
|
228
286
|
) -> ModelResponse:
|
|
229
287
|
"""
|
|
230
288
|
Asynchronously invoke a non-streamed chat response from the Cohere API.
|
|
@@ -238,7 +296,7 @@ class Cohere(Model):
|
|
|
238
296
|
assistant_message.metrics.start_timer()
|
|
239
297
|
provider_response = await self.get_async_client().chat(
|
|
240
298
|
model=self.id,
|
|
241
|
-
messages=format_messages(messages), # type: ignore
|
|
299
|
+
messages=format_messages(messages, compress_tool_results), # type: ignore
|
|
242
300
|
**request_kwargs,
|
|
243
301
|
)
|
|
244
302
|
assistant_message.metrics.stop_timer()
|
|
@@ -259,6 +317,7 @@ class Cohere(Model):
|
|
|
259
317
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
260
318
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
261
319
|
run_response: Optional[RunOutput] = None,
|
|
320
|
+
compress_tool_results: bool = False,
|
|
262
321
|
) -> AsyncIterator[ModelResponse]:
|
|
263
322
|
"""
|
|
264
323
|
Asynchronously invoke a streamed chat response from the Cohere API.
|
|
@@ -275,7 +334,7 @@ class Cohere(Model):
|
|
|
275
334
|
|
|
276
335
|
async for response in self.get_async_client().chat_stream(
|
|
277
336
|
model=self.id,
|
|
278
|
-
messages=format_messages(messages), # type: ignore
|
|
337
|
+
messages=format_messages(messages, compress_tool_results), # type: ignore
|
|
279
338
|
**request_kwargs,
|
|
280
339
|
):
|
|
281
340
|
model_response, tool_use = self._parse_provider_response_delta(response, tool_use=tool_use)
|
agno/models/cometapi/cometapi.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
2
|
from os import getenv
|
|
3
|
-
from typing import List, Optional
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
4
|
|
|
5
5
|
import httpx
|
|
6
6
|
|
|
7
|
+
from agno.exceptions import ModelAuthenticationError
|
|
7
8
|
from agno.models.openai.like import OpenAILike
|
|
8
9
|
from agno.utils.log import log_debug
|
|
9
10
|
|
|
@@ -26,6 +27,22 @@ class CometAPI(OpenAILike):
|
|
|
26
27
|
api_key: Optional[str] = field(default_factory=lambda: getenv("COMETAPI_KEY"))
|
|
27
28
|
base_url: str = "https://api.cometapi.com/v1"
|
|
28
29
|
|
|
30
|
+
def _get_client_params(self) -> Dict[str, Any]:
|
|
31
|
+
"""
|
|
32
|
+
Returns client parameters for API requests, checking for COMETAPI_KEY.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Dict[str, Any]: A dictionary of client parameters for API requests.
|
|
36
|
+
"""
|
|
37
|
+
if not self.api_key:
|
|
38
|
+
self.api_key = getenv("COMETAPI_KEY")
|
|
39
|
+
if not self.api_key:
|
|
40
|
+
raise ModelAuthenticationError(
|
|
41
|
+
message="COMETAPI_KEY not set. Please set the COMETAPI_KEY environment variable.",
|
|
42
|
+
model_name=self.name,
|
|
43
|
+
)
|
|
44
|
+
return super()._get_client_params()
|
|
45
|
+
|
|
29
46
|
def get_available_models(self) -> List[str]:
|
|
30
47
|
"""
|
|
31
48
|
Fetch available chat models from CometAPI, filtering out non-chat models.
|
|
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Type, Union
|
|
|
4
4
|
|
|
5
5
|
from pydantic import BaseModel
|
|
6
6
|
|
|
7
|
-
from agno.exceptions import
|
|
7
|
+
from agno.exceptions import ModelAuthenticationError
|
|
8
8
|
from agno.models.openai.like import OpenAILike
|
|
9
9
|
|
|
10
10
|
|
|
@@ -43,10 +43,9 @@ class DashScope(OpenAILike):
|
|
|
43
43
|
if not self.api_key:
|
|
44
44
|
self.api_key = getenv("DASHSCOPE_API_KEY")
|
|
45
45
|
if not self.api_key:
|
|
46
|
-
raise
|
|
46
|
+
raise ModelAuthenticationError(
|
|
47
47
|
message="DASHSCOPE_API_KEY not set. Please set the DASHSCOPE_API_KEY environment variable.",
|
|
48
48
|
model_name=self.name,
|
|
49
|
-
model_id=self.id,
|
|
50
49
|
)
|
|
51
50
|
|
|
52
51
|
# Define base client params
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
2
|
from os import getenv
|
|
3
|
-
from typing import Optional
|
|
3
|
+
from typing import Any, Dict, Optional
|
|
4
4
|
|
|
5
|
+
from agno.exceptions import ModelAuthenticationError
|
|
5
6
|
from agno.models.openai.like import OpenAILike
|
|
6
7
|
|
|
7
8
|
|
|
@@ -26,3 +27,19 @@ class DeepInfra(OpenAILike):
|
|
|
26
27
|
base_url: str = "https://api.deepinfra.com/v1/openai"
|
|
27
28
|
|
|
28
29
|
supports_native_structured_outputs: bool = False
|
|
30
|
+
|
|
31
|
+
def _get_client_params(self) -> Dict[str, Any]:
|
|
32
|
+
"""
|
|
33
|
+
Returns client parameters for API requests, checking for DEEPINFRA_API_KEY.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Dict[str, Any]: A dictionary of client parameters for API requests.
|
|
37
|
+
"""
|
|
38
|
+
if not self.api_key:
|
|
39
|
+
self.api_key = getenv("DEEPINFRA_API_KEY")
|
|
40
|
+
if not self.api_key:
|
|
41
|
+
raise ModelAuthenticationError(
|
|
42
|
+
message="DEEPINFRA_API_KEY not set. Please set the DEEPINFRA_API_KEY environment variable.",
|
|
43
|
+
model_name=self.name,
|
|
44
|
+
)
|
|
45
|
+
return super()._get_client_params()
|
agno/models/deepseek/deepseek.py
CHANGED
|
@@ -2,8 +2,11 @@ from dataclasses import dataclass, field
|
|
|
2
2
|
from os import getenv
|
|
3
3
|
from typing import Any, Dict, Optional
|
|
4
4
|
|
|
5
|
-
from agno.exceptions import
|
|
5
|
+
from agno.exceptions import ModelAuthenticationError
|
|
6
|
+
from agno.models.message import Message
|
|
6
7
|
from agno.models.openai.like import OpenAILike
|
|
8
|
+
from agno.utils.log import log_warning
|
|
9
|
+
from agno.utils.openai import _format_file_for_message, audio_to_message, images_to_message
|
|
7
10
|
|
|
8
11
|
|
|
9
12
|
@dataclass
|
|
@@ -35,10 +38,9 @@ class DeepSeek(OpenAILike):
|
|
|
35
38
|
self.api_key = getenv("DEEPSEEK_API_KEY")
|
|
36
39
|
if not self.api_key:
|
|
37
40
|
# Raise error immediately if key is missing
|
|
38
|
-
raise
|
|
41
|
+
raise ModelAuthenticationError(
|
|
39
42
|
message="DEEPSEEK_API_KEY not set. Please set the DEEPSEEK_API_KEY environment variable.",
|
|
40
43
|
model_name=self.name,
|
|
41
|
-
model_id=self.id,
|
|
42
44
|
)
|
|
43
45
|
|
|
44
46
|
# Define base client params
|
|
@@ -59,3 +61,67 @@ class DeepSeek(OpenAILike):
|
|
|
59
61
|
if self.client_params:
|
|
60
62
|
client_params.update(self.client_params)
|
|
61
63
|
return client_params
|
|
64
|
+
|
|
65
|
+
def _format_message(self, message: Message, compress_tool_results: bool = False) -> Dict[str, Any]:
|
|
66
|
+
"""
|
|
67
|
+
Format a message into the format expected by OpenAI.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
message (Message): The message to format.
|
|
71
|
+
compress_tool_results: Whether to compress tool results.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Dict[str, Any]: The formatted message.
|
|
75
|
+
"""
|
|
76
|
+
tool_result = message.get_content(use_compressed_content=compress_tool_results)
|
|
77
|
+
|
|
78
|
+
message_dict: Dict[str, Any] = {
|
|
79
|
+
"role": self.role_map[message.role] if self.role_map else self.default_role_map[message.role],
|
|
80
|
+
"content": tool_result,
|
|
81
|
+
"name": message.name,
|
|
82
|
+
"tool_call_id": message.tool_call_id,
|
|
83
|
+
"tool_calls": message.tool_calls,
|
|
84
|
+
"reasoning_content": message.reasoning_content,
|
|
85
|
+
}
|
|
86
|
+
message_dict = {k: v for k, v in message_dict.items() if v is not None}
|
|
87
|
+
|
|
88
|
+
# Ignore non-string message content
|
|
89
|
+
# because we assume that the images/audio are already added to the message
|
|
90
|
+
if (message.images is not None and len(message.images) > 0) or (
|
|
91
|
+
message.audio is not None and len(message.audio) > 0
|
|
92
|
+
):
|
|
93
|
+
# Ignore non-string message content
|
|
94
|
+
# because we assume that the images/audio are already added to the message
|
|
95
|
+
if isinstance(message.content, str):
|
|
96
|
+
message_dict["content"] = [{"type": "text", "text": message.content}]
|
|
97
|
+
if message.images is not None:
|
|
98
|
+
message_dict["content"].extend(images_to_message(images=message.images))
|
|
99
|
+
|
|
100
|
+
if message.audio is not None:
|
|
101
|
+
message_dict["content"].extend(audio_to_message(audio=message.audio))
|
|
102
|
+
|
|
103
|
+
if message.audio_output is not None:
|
|
104
|
+
message_dict["content"] = ""
|
|
105
|
+
message_dict["audio"] = {"id": message.audio_output.id}
|
|
106
|
+
|
|
107
|
+
if message.videos is not None and len(message.videos) > 0:
|
|
108
|
+
log_warning("Video input is currently unsupported.")
|
|
109
|
+
|
|
110
|
+
if message.files is not None:
|
|
111
|
+
# Ensure content is a list of parts
|
|
112
|
+
content = message_dict.get("content")
|
|
113
|
+
if isinstance(content, str): # wrap existing text
|
|
114
|
+
text = content
|
|
115
|
+
message_dict["content"] = [{"type": "text", "text": text}]
|
|
116
|
+
elif content is None:
|
|
117
|
+
message_dict["content"] = []
|
|
118
|
+
# Insert each file part before text parts
|
|
119
|
+
for file in message.files:
|
|
120
|
+
file_part = _format_file_for_message(file)
|
|
121
|
+
if file_part:
|
|
122
|
+
message_dict["content"].insert(0, file_part)
|
|
123
|
+
|
|
124
|
+
# Manually add the content field even if it is None
|
|
125
|
+
if message.content is None:
|
|
126
|
+
message_dict["content"] = ""
|
|
127
|
+
return message_dict
|