agno 2.2.13__py3-none-any.whl → 2.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/__init__.py +6 -0
- agno/agent/agent.py +5252 -3145
- agno/agent/remote.py +525 -0
- agno/api/api.py +2 -0
- agno/client/__init__.py +3 -0
- agno/client/a2a/__init__.py +10 -0
- agno/client/a2a/client.py +554 -0
- agno/client/a2a/schemas.py +112 -0
- agno/client/a2a/utils.py +369 -0
- agno/client/os.py +2669 -0
- agno/compression/__init__.py +3 -0
- agno/compression/manager.py +247 -0
- agno/culture/manager.py +2 -2
- agno/db/base.py +927 -6
- agno/db/dynamo/dynamo.py +788 -2
- agno/db/dynamo/schemas.py +128 -0
- agno/db/dynamo/utils.py +26 -3
- agno/db/firestore/firestore.py +674 -50
- agno/db/firestore/schemas.py +41 -0
- agno/db/firestore/utils.py +25 -10
- agno/db/gcs_json/gcs_json_db.py +506 -3
- agno/db/gcs_json/utils.py +14 -2
- agno/db/in_memory/in_memory_db.py +203 -4
- agno/db/in_memory/utils.py +14 -2
- agno/db/json/json_db.py +498 -2
- agno/db/json/utils.py +14 -2
- agno/db/migrations/manager.py +199 -0
- agno/db/migrations/utils.py +19 -0
- agno/db/migrations/v1_to_v2.py +54 -16
- agno/db/migrations/versions/__init__.py +0 -0
- agno/db/migrations/versions/v2_3_0.py +977 -0
- agno/db/mongo/async_mongo.py +1013 -39
- agno/db/mongo/mongo.py +684 -4
- agno/db/mongo/schemas.py +48 -0
- agno/db/mongo/utils.py +17 -0
- agno/db/mysql/__init__.py +2 -1
- agno/db/mysql/async_mysql.py +2958 -0
- agno/db/mysql/mysql.py +722 -53
- agno/db/mysql/schemas.py +77 -11
- agno/db/mysql/utils.py +151 -8
- agno/db/postgres/async_postgres.py +1254 -137
- agno/db/postgres/postgres.py +2316 -93
- agno/db/postgres/schemas.py +153 -21
- agno/db/postgres/utils.py +22 -7
- agno/db/redis/redis.py +531 -3
- agno/db/redis/schemas.py +36 -0
- agno/db/redis/utils.py +31 -15
- agno/db/schemas/evals.py +1 -0
- agno/db/schemas/memory.py +20 -9
- agno/db/singlestore/schemas.py +70 -1
- agno/db/singlestore/singlestore.py +737 -74
- agno/db/singlestore/utils.py +13 -3
- agno/db/sqlite/async_sqlite.py +1069 -89
- agno/db/sqlite/schemas.py +133 -1
- agno/db/sqlite/sqlite.py +2203 -165
- agno/db/sqlite/utils.py +21 -11
- agno/db/surrealdb/models.py +25 -0
- agno/db/surrealdb/surrealdb.py +603 -1
- agno/db/utils.py +60 -0
- agno/eval/__init__.py +26 -3
- agno/eval/accuracy.py +25 -12
- agno/eval/agent_as_judge.py +871 -0
- agno/eval/base.py +29 -0
- agno/eval/performance.py +10 -4
- agno/eval/reliability.py +22 -13
- agno/eval/utils.py +2 -1
- agno/exceptions.py +42 -0
- agno/hooks/__init__.py +3 -0
- agno/hooks/decorator.py +164 -0
- agno/integrations/discord/client.py +13 -2
- agno/knowledge/__init__.py +4 -0
- agno/knowledge/chunking/code.py +90 -0
- agno/knowledge/chunking/document.py +65 -4
- agno/knowledge/chunking/fixed.py +4 -1
- agno/knowledge/chunking/markdown.py +102 -11
- agno/knowledge/chunking/recursive.py +2 -2
- agno/knowledge/chunking/semantic.py +130 -48
- agno/knowledge/chunking/strategy.py +18 -0
- agno/knowledge/embedder/azure_openai.py +0 -1
- agno/knowledge/embedder/google.py +1 -1
- agno/knowledge/embedder/mistral.py +1 -1
- agno/knowledge/embedder/nebius.py +1 -1
- agno/knowledge/embedder/openai.py +16 -12
- agno/knowledge/filesystem.py +412 -0
- agno/knowledge/knowledge.py +4261 -1199
- agno/knowledge/protocol.py +134 -0
- agno/knowledge/reader/arxiv_reader.py +3 -2
- agno/knowledge/reader/base.py +9 -7
- agno/knowledge/reader/csv_reader.py +91 -42
- agno/knowledge/reader/docx_reader.py +9 -10
- agno/knowledge/reader/excel_reader.py +225 -0
- agno/knowledge/reader/field_labeled_csv_reader.py +38 -48
- agno/knowledge/reader/firecrawl_reader.py +3 -2
- agno/knowledge/reader/json_reader.py +16 -22
- agno/knowledge/reader/markdown_reader.py +15 -14
- agno/knowledge/reader/pdf_reader.py +33 -28
- agno/knowledge/reader/pptx_reader.py +9 -10
- agno/knowledge/reader/reader_factory.py +135 -1
- agno/knowledge/reader/s3_reader.py +8 -16
- agno/knowledge/reader/tavily_reader.py +3 -3
- agno/knowledge/reader/text_reader.py +15 -14
- agno/knowledge/reader/utils/__init__.py +17 -0
- agno/knowledge/reader/utils/spreadsheet.py +114 -0
- agno/knowledge/reader/web_search_reader.py +8 -65
- agno/knowledge/reader/website_reader.py +16 -13
- agno/knowledge/reader/wikipedia_reader.py +36 -3
- agno/knowledge/reader/youtube_reader.py +3 -2
- agno/knowledge/remote_content/__init__.py +33 -0
- agno/knowledge/remote_content/config.py +266 -0
- agno/knowledge/remote_content/remote_content.py +105 -17
- agno/knowledge/utils.py +76 -22
- agno/learn/__init__.py +71 -0
- agno/learn/config.py +463 -0
- agno/learn/curate.py +185 -0
- agno/learn/machine.py +725 -0
- agno/learn/schemas.py +1114 -0
- agno/learn/stores/__init__.py +38 -0
- agno/learn/stores/decision_log.py +1156 -0
- agno/learn/stores/entity_memory.py +3275 -0
- agno/learn/stores/learned_knowledge.py +1583 -0
- agno/learn/stores/protocol.py +117 -0
- agno/learn/stores/session_context.py +1217 -0
- agno/learn/stores/user_memory.py +1495 -0
- agno/learn/stores/user_profile.py +1220 -0
- agno/learn/utils.py +209 -0
- agno/media.py +22 -6
- agno/memory/__init__.py +14 -1
- agno/memory/manager.py +223 -8
- agno/memory/strategies/__init__.py +15 -0
- agno/memory/strategies/base.py +66 -0
- agno/memory/strategies/summarize.py +196 -0
- agno/memory/strategies/types.py +37 -0
- agno/models/aimlapi/aimlapi.py +17 -0
- agno/models/anthropic/claude.py +434 -59
- agno/models/aws/bedrock.py +121 -20
- agno/models/aws/claude.py +131 -274
- agno/models/azure/ai_foundry.py +10 -6
- agno/models/azure/openai_chat.py +33 -10
- agno/models/base.py +1162 -561
- agno/models/cerebras/cerebras.py +120 -24
- agno/models/cerebras/cerebras_openai.py +21 -2
- agno/models/cohere/chat.py +65 -6
- agno/models/cometapi/cometapi.py +18 -1
- agno/models/dashscope/dashscope.py +2 -3
- agno/models/deepinfra/deepinfra.py +18 -1
- agno/models/deepseek/deepseek.py +69 -3
- agno/models/fireworks/fireworks.py +18 -1
- agno/models/google/gemini.py +959 -89
- agno/models/google/utils.py +22 -0
- agno/models/groq/groq.py +48 -18
- agno/models/huggingface/huggingface.py +17 -6
- agno/models/ibm/watsonx.py +16 -6
- agno/models/internlm/internlm.py +18 -1
- agno/models/langdb/langdb.py +13 -1
- agno/models/litellm/chat.py +88 -9
- agno/models/litellm/litellm_openai.py +18 -1
- agno/models/message.py +24 -5
- agno/models/meta/llama.py +40 -13
- agno/models/meta/llama_openai.py +22 -21
- agno/models/metrics.py +12 -0
- agno/models/mistral/mistral.py +8 -4
- agno/models/n1n/__init__.py +3 -0
- agno/models/n1n/n1n.py +57 -0
- agno/models/nebius/nebius.py +6 -7
- agno/models/nvidia/nvidia.py +20 -3
- agno/models/ollama/__init__.py +2 -0
- agno/models/ollama/chat.py +17 -6
- agno/models/ollama/responses.py +100 -0
- agno/models/openai/__init__.py +2 -0
- agno/models/openai/chat.py +117 -26
- agno/models/openai/open_responses.py +46 -0
- agno/models/openai/responses.py +110 -32
- agno/models/openrouter/__init__.py +2 -0
- agno/models/openrouter/openrouter.py +67 -2
- agno/models/openrouter/responses.py +146 -0
- agno/models/perplexity/perplexity.py +19 -1
- agno/models/portkey/portkey.py +7 -6
- agno/models/requesty/requesty.py +19 -2
- agno/models/response.py +20 -2
- agno/models/sambanova/sambanova.py +20 -3
- agno/models/siliconflow/siliconflow.py +19 -2
- agno/models/together/together.py +20 -3
- agno/models/vercel/v0.py +20 -3
- agno/models/vertexai/claude.py +124 -4
- agno/models/vllm/vllm.py +19 -14
- agno/models/xai/xai.py +19 -2
- agno/os/app.py +467 -137
- agno/os/auth.py +253 -5
- agno/os/config.py +22 -0
- agno/os/interfaces/a2a/a2a.py +7 -6
- agno/os/interfaces/a2a/router.py +635 -26
- agno/os/interfaces/a2a/utils.py +32 -33
- agno/os/interfaces/agui/agui.py +5 -3
- agno/os/interfaces/agui/router.py +26 -16
- agno/os/interfaces/agui/utils.py +97 -57
- agno/os/interfaces/base.py +7 -7
- agno/os/interfaces/slack/router.py +16 -7
- agno/os/interfaces/slack/slack.py +7 -7
- agno/os/interfaces/whatsapp/router.py +35 -7
- agno/os/interfaces/whatsapp/security.py +3 -1
- agno/os/interfaces/whatsapp/whatsapp.py +11 -8
- agno/os/managers.py +326 -0
- agno/os/mcp.py +652 -79
- agno/os/middleware/__init__.py +4 -0
- agno/os/middleware/jwt.py +718 -115
- agno/os/middleware/trailing_slash.py +27 -0
- agno/os/router.py +105 -1558
- agno/os/routers/agents/__init__.py +3 -0
- agno/os/routers/agents/router.py +655 -0
- agno/os/routers/agents/schema.py +288 -0
- agno/os/routers/components/__init__.py +3 -0
- agno/os/routers/components/components.py +475 -0
- agno/os/routers/database.py +155 -0
- agno/os/routers/evals/evals.py +111 -18
- agno/os/routers/evals/schemas.py +38 -5
- agno/os/routers/evals/utils.py +80 -11
- agno/os/routers/health.py +3 -3
- agno/os/routers/knowledge/knowledge.py +284 -35
- agno/os/routers/knowledge/schemas.py +14 -2
- agno/os/routers/memory/memory.py +274 -11
- agno/os/routers/memory/schemas.py +44 -3
- agno/os/routers/metrics/metrics.py +30 -15
- agno/os/routers/metrics/schemas.py +10 -6
- agno/os/routers/registry/__init__.py +3 -0
- agno/os/routers/registry/registry.py +337 -0
- agno/os/routers/session/session.py +143 -14
- agno/os/routers/teams/__init__.py +3 -0
- agno/os/routers/teams/router.py +550 -0
- agno/os/routers/teams/schema.py +280 -0
- agno/os/routers/traces/__init__.py +3 -0
- agno/os/routers/traces/schemas.py +414 -0
- agno/os/routers/traces/traces.py +549 -0
- agno/os/routers/workflows/__init__.py +3 -0
- agno/os/routers/workflows/router.py +757 -0
- agno/os/routers/workflows/schema.py +139 -0
- agno/os/schema.py +157 -584
- agno/os/scopes.py +469 -0
- agno/os/settings.py +3 -0
- agno/os/utils.py +574 -185
- agno/reasoning/anthropic.py +85 -1
- agno/reasoning/azure_ai_foundry.py +93 -1
- agno/reasoning/deepseek.py +102 -2
- agno/reasoning/default.py +6 -7
- agno/reasoning/gemini.py +87 -3
- agno/reasoning/groq.py +109 -2
- agno/reasoning/helpers.py +6 -7
- agno/reasoning/manager.py +1238 -0
- agno/reasoning/ollama.py +93 -1
- agno/reasoning/openai.py +115 -1
- agno/reasoning/vertexai.py +85 -1
- agno/registry/__init__.py +3 -0
- agno/registry/registry.py +68 -0
- agno/remote/__init__.py +3 -0
- agno/remote/base.py +581 -0
- agno/run/__init__.py +2 -4
- agno/run/agent.py +134 -19
- agno/run/base.py +49 -1
- agno/run/cancel.py +65 -52
- agno/run/cancellation_management/__init__.py +9 -0
- agno/run/cancellation_management/base.py +78 -0
- agno/run/cancellation_management/in_memory_cancellation_manager.py +100 -0
- agno/run/cancellation_management/redis_cancellation_manager.py +236 -0
- agno/run/requirement.py +181 -0
- agno/run/team.py +111 -19
- agno/run/workflow.py +2 -1
- agno/session/agent.py +57 -92
- agno/session/summary.py +1 -1
- agno/session/team.py +62 -115
- agno/session/workflow.py +353 -57
- agno/skills/__init__.py +17 -0
- agno/skills/agent_skills.py +377 -0
- agno/skills/errors.py +32 -0
- agno/skills/loaders/__init__.py +4 -0
- agno/skills/loaders/base.py +27 -0
- agno/skills/loaders/local.py +216 -0
- agno/skills/skill.py +65 -0
- agno/skills/utils.py +107 -0
- agno/skills/validator.py +277 -0
- agno/table.py +10 -0
- agno/team/__init__.py +5 -1
- agno/team/remote.py +447 -0
- agno/team/team.py +3769 -2202
- agno/tools/brandfetch.py +27 -18
- agno/tools/browserbase.py +225 -16
- agno/tools/crawl4ai.py +3 -0
- agno/tools/duckduckgo.py +25 -71
- agno/tools/exa.py +0 -21
- agno/tools/file.py +14 -13
- agno/tools/file_generation.py +12 -6
- agno/tools/firecrawl.py +15 -7
- agno/tools/function.py +94 -113
- agno/tools/google_bigquery.py +11 -2
- agno/tools/google_drive.py +4 -3
- agno/tools/knowledge.py +9 -4
- agno/tools/mcp/mcp.py +301 -18
- agno/tools/mcp/multi_mcp.py +269 -14
- agno/tools/mem0.py +11 -10
- agno/tools/memory.py +47 -46
- agno/tools/mlx_transcribe.py +10 -7
- agno/tools/models/nebius.py +5 -5
- agno/tools/models_labs.py +20 -10
- agno/tools/nano_banana.py +151 -0
- agno/tools/parallel.py +0 -7
- agno/tools/postgres.py +76 -36
- agno/tools/python.py +14 -6
- agno/tools/reasoning.py +30 -23
- agno/tools/redshift.py +406 -0
- agno/tools/shopify.py +1519 -0
- agno/tools/spotify.py +919 -0
- agno/tools/tavily.py +4 -1
- agno/tools/toolkit.py +253 -18
- agno/tools/websearch.py +93 -0
- agno/tools/website.py +1 -1
- agno/tools/wikipedia.py +1 -1
- agno/tools/workflow.py +56 -48
- agno/tools/yfinance.py +12 -11
- agno/tracing/__init__.py +12 -0
- agno/tracing/exporter.py +161 -0
- agno/tracing/schemas.py +276 -0
- agno/tracing/setup.py +112 -0
- agno/utils/agent.py +251 -10
- agno/utils/cryptography.py +22 -0
- agno/utils/dttm.py +33 -0
- agno/utils/events.py +264 -7
- agno/utils/hooks.py +111 -3
- agno/utils/http.py +161 -2
- agno/utils/mcp.py +49 -8
- agno/utils/media.py +22 -1
- agno/utils/models/ai_foundry.py +9 -2
- agno/utils/models/claude.py +20 -5
- agno/utils/models/cohere.py +9 -2
- agno/utils/models/llama.py +9 -2
- agno/utils/models/mistral.py +4 -2
- agno/utils/os.py +0 -0
- agno/utils/print_response/agent.py +99 -16
- agno/utils/print_response/team.py +223 -24
- agno/utils/print_response/workflow.py +0 -2
- agno/utils/prompts.py +8 -6
- agno/utils/remote.py +23 -0
- agno/utils/response.py +1 -13
- agno/utils/string.py +91 -2
- agno/utils/team.py +62 -12
- agno/utils/tokens.py +657 -0
- agno/vectordb/base.py +15 -2
- agno/vectordb/cassandra/cassandra.py +1 -1
- agno/vectordb/chroma/__init__.py +2 -1
- agno/vectordb/chroma/chromadb.py +468 -23
- agno/vectordb/clickhouse/clickhousedb.py +1 -1
- agno/vectordb/couchbase/couchbase.py +6 -2
- agno/vectordb/lancedb/lance_db.py +7 -38
- agno/vectordb/lightrag/lightrag.py +7 -6
- agno/vectordb/milvus/milvus.py +118 -84
- agno/vectordb/mongodb/__init__.py +2 -1
- agno/vectordb/mongodb/mongodb.py +14 -31
- agno/vectordb/pgvector/pgvector.py +120 -66
- agno/vectordb/pineconedb/pineconedb.py +2 -19
- agno/vectordb/qdrant/__init__.py +2 -1
- agno/vectordb/qdrant/qdrant.py +33 -56
- agno/vectordb/redis/__init__.py +2 -1
- agno/vectordb/redis/redisdb.py +19 -31
- agno/vectordb/singlestore/singlestore.py +17 -9
- agno/vectordb/surrealdb/surrealdb.py +2 -38
- agno/vectordb/weaviate/__init__.py +2 -1
- agno/vectordb/weaviate/weaviate.py +7 -3
- agno/workflow/__init__.py +5 -1
- agno/workflow/agent.py +2 -2
- agno/workflow/condition.py +12 -10
- agno/workflow/loop.py +28 -9
- agno/workflow/parallel.py +21 -13
- agno/workflow/remote.py +362 -0
- agno/workflow/router.py +12 -9
- agno/workflow/step.py +261 -36
- agno/workflow/steps.py +12 -8
- agno/workflow/types.py +40 -77
- agno/workflow/workflow.py +939 -213
- {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/METADATA +134 -181
- agno-2.4.3.dist-info/RECORD +677 -0
- {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/WHEEL +1 -1
- agno/tools/googlesearch.py +0 -98
- agno/tools/memori.py +0 -339
- agno-2.2.13.dist-info/RECORD +0 -575
- {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/licenses/LICENSE +0 -0
- {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class GeminiFinishReason(Enum):
|
|
5
|
+
"""Gemini API finish reasons"""
|
|
6
|
+
|
|
7
|
+
STOP = "STOP"
|
|
8
|
+
MAX_TOKENS = "MAX_TOKENS"
|
|
9
|
+
SAFETY = "SAFETY"
|
|
10
|
+
RECITATION = "RECITATION"
|
|
11
|
+
MALFORMED_FUNCTION_CALL = "MALFORMED_FUNCTION_CALL"
|
|
12
|
+
OTHER = "OTHER"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Guidance message used to retry a Gemini invocation after a MALFORMED_FUNCTION_CALL error
|
|
16
|
+
MALFORMED_FUNCTION_CALL_GUIDANCE = """The previous function call was malformed. Please try again with a valid function call.
|
|
17
|
+
|
|
18
|
+
Guidelines:
|
|
19
|
+
- Generate the function call JSON directly, do not generate code
|
|
20
|
+
- Use the function name exactly as defined (no namespace prefixes like 'default_api.')
|
|
21
|
+
- Ensure all required parameters are provided with correct types
|
|
22
|
+
"""
|
agno/models/groq/groq.py
CHANGED
|
@@ -6,12 +6,13 @@ from typing import Any, Dict, Iterator, List, Optional, Type, Union
|
|
|
6
6
|
import httpx
|
|
7
7
|
from pydantic import BaseModel
|
|
8
8
|
|
|
9
|
-
from agno.exceptions import ModelProviderError
|
|
9
|
+
from agno.exceptions import ModelAuthenticationError, ModelProviderError
|
|
10
10
|
from agno.models.base import Model
|
|
11
11
|
from agno.models.message import Message
|
|
12
12
|
from agno.models.metrics import Metrics
|
|
13
13
|
from agno.models.response import ModelResponse
|
|
14
14
|
from agno.run.agent import RunOutput
|
|
15
|
+
from agno.utils.http import get_default_async_client, get_default_sync_client
|
|
15
16
|
from agno.utils.log import log_debug, log_error, log_warning
|
|
16
17
|
from agno.utils.openai import images_to_message
|
|
17
18
|
|
|
@@ -73,7 +74,10 @@ class Groq(Model):
|
|
|
73
74
|
if not self.api_key:
|
|
74
75
|
self.api_key = getenv("GROQ_API_KEY")
|
|
75
76
|
if not self.api_key:
|
|
76
|
-
|
|
77
|
+
raise ModelAuthenticationError(
|
|
78
|
+
message="GROQ_API_KEY not set. Please set the GROQ_API_KEY environment variable.",
|
|
79
|
+
model_name=self.name,
|
|
80
|
+
)
|
|
77
81
|
|
|
78
82
|
# Define base client params
|
|
79
83
|
base_params = {
|
|
@@ -93,7 +97,7 @@ class Groq(Model):
|
|
|
93
97
|
|
|
94
98
|
def get_client(self) -> GroqClient:
|
|
95
99
|
"""
|
|
96
|
-
Returns a Groq client.
|
|
100
|
+
Returns a Groq client. Caches the client to avoid recreating it on every request.
|
|
97
101
|
|
|
98
102
|
Returns:
|
|
99
103
|
GroqClient: An instance of the Groq client.
|
|
@@ -103,14 +107,22 @@ class Groq(Model):
|
|
|
103
107
|
|
|
104
108
|
client_params: Dict[str, Any] = self._get_client_params()
|
|
105
109
|
if self.http_client is not None:
|
|
106
|
-
|
|
110
|
+
if isinstance(self.http_client, httpx.Client):
|
|
111
|
+
client_params["http_client"] = self.http_client
|
|
112
|
+
else:
|
|
113
|
+
log_warning("http_client is not an instance of httpx.Client. Using default global httpx.Client.")
|
|
114
|
+
# Use global sync client when user http_client is invalid
|
|
115
|
+
client_params["http_client"] = get_default_sync_client()
|
|
116
|
+
else:
|
|
117
|
+
# Use global sync client when no custom http_client is provided
|
|
118
|
+
client_params["http_client"] = get_default_sync_client()
|
|
107
119
|
|
|
108
120
|
self.client = GroqClient(**client_params)
|
|
109
121
|
return self.client
|
|
110
122
|
|
|
111
123
|
def get_async_client(self) -> AsyncGroqClient:
|
|
112
124
|
"""
|
|
113
|
-
Returns an asynchronous Groq client.
|
|
125
|
+
Returns an asynchronous Groq client. Caches the client to avoid recreating it on every request.
|
|
114
126
|
|
|
115
127
|
Returns:
|
|
116
128
|
AsyncGroqClient: An instance of the asynchronous Groq client.
|
|
@@ -119,15 +131,20 @@ class Groq(Model):
|
|
|
119
131
|
return self.async_client
|
|
120
132
|
|
|
121
133
|
client_params: Dict[str, Any] = self._get_client_params()
|
|
122
|
-
if self.http_client
|
|
123
|
-
|
|
134
|
+
if self.http_client:
|
|
135
|
+
if isinstance(self.http_client, httpx.AsyncClient):
|
|
136
|
+
client_params["http_client"] = self.http_client
|
|
137
|
+
else:
|
|
138
|
+
log_warning(
|
|
139
|
+
"http_client is not an instance of httpx.AsyncClient. Using default global httpx.AsyncClient."
|
|
140
|
+
)
|
|
141
|
+
# Use global async client when user http_client is invalid
|
|
142
|
+
client_params["http_client"] = get_default_async_client()
|
|
124
143
|
else:
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
limits=httpx.Limits(max_connections=1000, max_keepalive_connections=100)
|
|
130
|
-
)
|
|
144
|
+
# Use global async client when no custom http_client is provided
|
|
145
|
+
client_params["http_client"] = get_default_async_client()
|
|
146
|
+
|
|
147
|
+
# Create and cache the client
|
|
131
148
|
self.async_client = AsyncGroqClient(**client_params)
|
|
132
149
|
return self.async_client
|
|
133
150
|
|
|
@@ -207,19 +224,28 @@ class Groq(Model):
|
|
|
207
224
|
self,
|
|
208
225
|
message: Message,
|
|
209
226
|
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
227
|
+
compress_tool_results: bool = False,
|
|
210
228
|
) -> Dict[str, Any]:
|
|
211
229
|
"""
|
|
212
230
|
Format a message into the format expected by Groq.
|
|
213
231
|
|
|
214
232
|
Args:
|
|
215
233
|
message (Message): The message to format.
|
|
234
|
+
response_format: Optional response format specification.
|
|
235
|
+
compress_tool_results: Whether to compress tool results.
|
|
216
236
|
|
|
217
237
|
Returns:
|
|
218
238
|
Dict[str, Any]: The formatted message.
|
|
219
239
|
"""
|
|
240
|
+
# Use compressed content for tool messages if compression is active
|
|
241
|
+
if message.role == "tool":
|
|
242
|
+
content = message.get_content(use_compressed_content=compress_tool_results)
|
|
243
|
+
else:
|
|
244
|
+
content = message.content
|
|
245
|
+
|
|
220
246
|
message_dict: Dict[str, Any] = {
|
|
221
247
|
"role": message.role,
|
|
222
|
-
"content":
|
|
248
|
+
"content": content,
|
|
223
249
|
"name": message.name,
|
|
224
250
|
"tool_call_id": message.tool_call_id,
|
|
225
251
|
"tool_calls": message.tool_calls,
|
|
@@ -262,6 +288,7 @@ class Groq(Model):
|
|
|
262
288
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
263
289
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
264
290
|
run_response: Optional[RunOutput] = None,
|
|
291
|
+
compress_tool_results: bool = False,
|
|
265
292
|
) -> ModelResponse:
|
|
266
293
|
"""
|
|
267
294
|
Send a chat completion request to the Groq API.
|
|
@@ -273,7 +300,7 @@ class Groq(Model):
|
|
|
273
300
|
assistant_message.metrics.start_timer()
|
|
274
301
|
provider_response = self.get_client().chat.completions.create(
|
|
275
302
|
model=self.id,
|
|
276
|
-
messages=[self.format_message(m) for m in messages], # type: ignore
|
|
303
|
+
messages=[self.format_message(m, response_format, compress_tool_results) for m in messages], # type: ignore
|
|
277
304
|
**self.get_request_params(response_format=response_format, tools=tools, tool_choice=tool_choice),
|
|
278
305
|
)
|
|
279
306
|
assistant_message.metrics.stop_timer()
|
|
@@ -302,6 +329,7 @@ class Groq(Model):
|
|
|
302
329
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
303
330
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
304
331
|
run_response: Optional[RunOutput] = None,
|
|
332
|
+
compress_tool_results: bool = False,
|
|
305
333
|
) -> ModelResponse:
|
|
306
334
|
"""
|
|
307
335
|
Sends an asynchronous chat completion request to the Groq API.
|
|
@@ -313,7 +341,7 @@ class Groq(Model):
|
|
|
313
341
|
assistant_message.metrics.start_timer()
|
|
314
342
|
response = await self.get_async_client().chat.completions.create(
|
|
315
343
|
model=self.id,
|
|
316
|
-
messages=[self.format_message(m) for m in messages], # type: ignore
|
|
344
|
+
messages=[self.format_message(m, response_format, compress_tool_results) for m in messages], # type: ignore
|
|
317
345
|
**self.get_request_params(response_format=response_format, tools=tools, tool_choice=tool_choice),
|
|
318
346
|
)
|
|
319
347
|
assistant_message.metrics.stop_timer()
|
|
@@ -342,6 +370,7 @@ class Groq(Model):
|
|
|
342
370
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
343
371
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
344
372
|
run_response: Optional[RunOutput] = None,
|
|
373
|
+
compress_tool_results: bool = False,
|
|
345
374
|
) -> Iterator[ModelResponse]:
|
|
346
375
|
"""
|
|
347
376
|
Send a streaming chat completion request to the Groq API.
|
|
@@ -354,7 +383,7 @@ class Groq(Model):
|
|
|
354
383
|
|
|
355
384
|
for chunk in self.get_client().chat.completions.create(
|
|
356
385
|
model=self.id,
|
|
357
|
-
messages=[self.format_message(m) for m in messages], # type: ignore
|
|
386
|
+
messages=[self.format_message(m, response_format, compress_tool_results) for m in messages], # type: ignore
|
|
358
387
|
stream=True,
|
|
359
388
|
**self.get_request_params(response_format=response_format, tools=tools, tool_choice=tool_choice),
|
|
360
389
|
):
|
|
@@ -382,6 +411,7 @@ class Groq(Model):
|
|
|
382
411
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
383
412
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
384
413
|
run_response: Optional[RunOutput] = None,
|
|
414
|
+
compress_tool_results: bool = False,
|
|
385
415
|
) -> AsyncIterator[ModelResponse]:
|
|
386
416
|
"""
|
|
387
417
|
Sends an asynchronous streaming chat completion request to the Groq API.
|
|
@@ -395,7 +425,7 @@ class Groq(Model):
|
|
|
395
425
|
|
|
396
426
|
async_stream = await self.get_async_client().chat.completions.create(
|
|
397
427
|
model=self.id,
|
|
398
|
-
messages=[self.format_message(m) for m in messages], # type: ignore
|
|
428
|
+
messages=[self.format_message(m, response_format, compress_tool_results) for m in messages], # type: ignore
|
|
399
429
|
stream=True,
|
|
400
430
|
**self.get_request_params(response_format=response_format, tools=tools, tool_choice=tool_choice),
|
|
401
431
|
)
|
|
@@ -191,19 +191,26 @@ class HuggingFace(Model):
|
|
|
191
191
|
cleaned_dict = {k: v for k, v in _dict.items() if v is not None}
|
|
192
192
|
return cleaned_dict
|
|
193
193
|
|
|
194
|
-
def _format_message(self, message: Message) -> Dict[str, Any]:
|
|
194
|
+
def _format_message(self, message: Message, compress_tool_results: bool = False) -> Dict[str, Any]:
|
|
195
195
|
"""
|
|
196
196
|
Format a message into the format expected by HuggingFace.
|
|
197
197
|
|
|
198
198
|
Args:
|
|
199
199
|
message (Message): The message to format.
|
|
200
|
+
compress_tool_results: Whether to compress tool results.
|
|
200
201
|
|
|
201
202
|
Returns:
|
|
202
203
|
Dict[str, Any]: The formatted message.
|
|
203
204
|
"""
|
|
205
|
+
# Use compressed content for tool messages if compression is active
|
|
206
|
+
if message.role == "tool":
|
|
207
|
+
content = message.get_content(use_compressed_content=compress_tool_results)
|
|
208
|
+
else:
|
|
209
|
+
content = message.content if message.content is not None else ""
|
|
210
|
+
|
|
204
211
|
message_dict: Dict[str, Any] = {
|
|
205
212
|
"role": message.role,
|
|
206
|
-
"content":
|
|
213
|
+
"content": content,
|
|
207
214
|
"name": message.name or message.tool_name,
|
|
208
215
|
"tool_call_id": message.tool_call_id,
|
|
209
216
|
"tool_calls": message.tool_calls,
|
|
@@ -236,6 +243,7 @@ class HuggingFace(Model):
|
|
|
236
243
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
237
244
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
238
245
|
run_response: Optional[RunOutput] = None,
|
|
246
|
+
compress_tool_results: bool = False,
|
|
239
247
|
) -> ModelResponse:
|
|
240
248
|
"""
|
|
241
249
|
Send a chat completion request to the HuggingFace Hub.
|
|
@@ -247,7 +255,7 @@ class HuggingFace(Model):
|
|
|
247
255
|
assistant_message.metrics.start_timer()
|
|
248
256
|
provider_response = self.get_client().chat.completions.create(
|
|
249
257
|
model=self.id,
|
|
250
|
-
messages=[self._format_message(m) for m in messages],
|
|
258
|
+
messages=[self._format_message(m, compress_tool_results) for m in messages],
|
|
251
259
|
**self.get_request_params(tools=tools, tool_choice=tool_choice),
|
|
252
260
|
)
|
|
253
261
|
assistant_message.metrics.stop_timer()
|
|
@@ -269,6 +277,7 @@ class HuggingFace(Model):
|
|
|
269
277
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
270
278
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
271
279
|
run_response: Optional[RunOutput] = None,
|
|
280
|
+
compress_tool_results: bool = False,
|
|
272
281
|
) -> ModelResponse:
|
|
273
282
|
"""
|
|
274
283
|
Sends an asynchronous chat completion request to the HuggingFace Hub Inference.
|
|
@@ -280,7 +289,7 @@ class HuggingFace(Model):
|
|
|
280
289
|
assistant_message.metrics.start_timer()
|
|
281
290
|
provider_response = await self.get_async_client().chat.completions.create(
|
|
282
291
|
model=self.id,
|
|
283
|
-
messages=[self._format_message(m) for m in messages],
|
|
292
|
+
messages=[self._format_message(m, compress_tool_results) for m in messages],
|
|
284
293
|
**self.get_request_params(tools=tools, tool_choice=tool_choice),
|
|
285
294
|
)
|
|
286
295
|
assistant_message.metrics.stop_timer()
|
|
@@ -302,6 +311,7 @@ class HuggingFace(Model):
|
|
|
302
311
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
303
312
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
304
313
|
run_response: Optional[RunOutput] = None,
|
|
314
|
+
compress_tool_results: bool = False,
|
|
305
315
|
) -> Iterator[ModelResponse]:
|
|
306
316
|
"""
|
|
307
317
|
Send a streaming chat completion request to the HuggingFace API.
|
|
@@ -314,7 +324,7 @@ class HuggingFace(Model):
|
|
|
314
324
|
|
|
315
325
|
stream = self.get_client().chat.completions.create(
|
|
316
326
|
model=self.id,
|
|
317
|
-
messages=[self._format_message(m) for m in messages],
|
|
327
|
+
messages=[self._format_message(m, compress_tool_results) for m in messages],
|
|
318
328
|
stream=True,
|
|
319
329
|
stream_options=ChatCompletionInputStreamOptions(include_usage=True), # type: ignore
|
|
320
330
|
**self.get_request_params(tools=tools, tool_choice=tool_choice),
|
|
@@ -340,6 +350,7 @@ class HuggingFace(Model):
|
|
|
340
350
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
341
351
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
342
352
|
run_response: Optional[RunOutput] = None,
|
|
353
|
+
compress_tool_results: bool = False,
|
|
343
354
|
) -> AsyncIterator[Any]:
|
|
344
355
|
"""
|
|
345
356
|
Sends an asynchronous streaming chat completion request to the HuggingFace API.
|
|
@@ -351,7 +362,7 @@ class HuggingFace(Model):
|
|
|
351
362
|
assistant_message.metrics.start_timer()
|
|
352
363
|
provider_response = await self.get_async_client().chat.completions.create(
|
|
353
364
|
model=self.id,
|
|
354
|
-
messages=[self._format_message(m) for m in messages],
|
|
365
|
+
messages=[self._format_message(m, compress_tool_results) for m in messages],
|
|
355
366
|
stream=True,
|
|
356
367
|
stream_options=ChatCompletionInputStreamOptions(include_usage=True), # type: ignore
|
|
357
368
|
**self.get_request_params(tools=tools, tool_choice=tool_choice),
|
agno/models/ibm/watsonx.py
CHANGED
|
@@ -129,12 +129,13 @@ class WatsonX(Model):
|
|
|
129
129
|
log_debug(f"Calling {self.provider} with request parameters: {request_params}", log_level=2)
|
|
130
130
|
return request_params
|
|
131
131
|
|
|
132
|
-
def _format_message(self, message: Message) -> Dict[str, Any]:
|
|
132
|
+
def _format_message(self, message: Message, compress_tool_results: bool = False) -> Dict[str, Any]:
|
|
133
133
|
"""
|
|
134
134
|
Format a message into the format expected by WatsonX.
|
|
135
135
|
|
|
136
136
|
Args:
|
|
137
137
|
message (Message): The message to format.
|
|
138
|
+
compress_tool_results: Whether to compress tool results.
|
|
138
139
|
|
|
139
140
|
Returns:
|
|
140
141
|
Dict[str, Any]: The formatted message.
|
|
@@ -151,7 +152,12 @@ class WatsonX(Model):
|
|
|
151
152
|
if message.videos is not None and len(message.videos) > 0:
|
|
152
153
|
log_warning("Video input is currently unsupported.")
|
|
153
154
|
|
|
154
|
-
|
|
155
|
+
message_dict = message.to_dict()
|
|
156
|
+
|
|
157
|
+
# Use compressed content for tool messages if compression is active
|
|
158
|
+
if message.role == "tool" and compress_tool_results:
|
|
159
|
+
message_dict["content"] = message.get_content(use_compressed_content=True)
|
|
160
|
+
return message_dict
|
|
155
161
|
|
|
156
162
|
def invoke(
|
|
157
163
|
self,
|
|
@@ -161,6 +167,7 @@ class WatsonX(Model):
|
|
|
161
167
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
162
168
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
163
169
|
run_response: Optional[RunOutput] = None,
|
|
170
|
+
compress_tool_results: bool = False,
|
|
164
171
|
) -> ModelResponse:
|
|
165
172
|
"""
|
|
166
173
|
Send a chat completion request to the WatsonX API.
|
|
@@ -171,7 +178,7 @@ class WatsonX(Model):
|
|
|
171
178
|
|
|
172
179
|
client = self.get_client()
|
|
173
180
|
|
|
174
|
-
formatted_messages = [self._format_message(m) for m in messages]
|
|
181
|
+
formatted_messages = [self._format_message(m, compress_tool_results) for m in messages]
|
|
175
182
|
request_params = self.get_request_params(
|
|
176
183
|
response_format=response_format, tools=tools, tool_choice=tool_choice
|
|
177
184
|
)
|
|
@@ -196,6 +203,7 @@ class WatsonX(Model):
|
|
|
196
203
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
197
204
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
198
205
|
run_response: Optional[RunOutput] = None,
|
|
206
|
+
compress_tool_results: bool = False,
|
|
199
207
|
) -> Any:
|
|
200
208
|
"""
|
|
201
209
|
Sends an asynchronous chat completion request to the WatsonX API.
|
|
@@ -205,7 +213,7 @@ class WatsonX(Model):
|
|
|
205
213
|
run_response.metrics.set_time_to_first_token()
|
|
206
214
|
|
|
207
215
|
client = self.get_client()
|
|
208
|
-
formatted_messages = [self._format_message(m) for m in messages]
|
|
216
|
+
formatted_messages = [self._format_message(m, compress_tool_results) for m in messages]
|
|
209
217
|
|
|
210
218
|
request_params = self.get_request_params(
|
|
211
219
|
response_format=response_format, tools=tools, tool_choice=tool_choice
|
|
@@ -231,13 +239,14 @@ class WatsonX(Model):
|
|
|
231
239
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
232
240
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
233
241
|
run_response: Optional[RunOutput] = None,
|
|
242
|
+
compress_tool_results: bool = False,
|
|
234
243
|
) -> Iterator[ModelResponse]:
|
|
235
244
|
"""
|
|
236
245
|
Send a streaming chat completion request to the WatsonX API.
|
|
237
246
|
"""
|
|
238
247
|
try:
|
|
239
248
|
client = self.get_client()
|
|
240
|
-
formatted_messages = [self._format_message(m) for m in messages]
|
|
249
|
+
formatted_messages = [self._format_message(m, compress_tool_results) for m in messages]
|
|
241
250
|
|
|
242
251
|
request_params = self.get_request_params(
|
|
243
252
|
response_format=response_format, tools=tools, tool_choice=tool_choice
|
|
@@ -265,6 +274,7 @@ class WatsonX(Model):
|
|
|
265
274
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
266
275
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
267
276
|
run_response: Optional[RunOutput] = None,
|
|
277
|
+
compress_tool_results: bool = False,
|
|
268
278
|
) -> AsyncIterator[ModelResponse]:
|
|
269
279
|
"""
|
|
270
280
|
Sends an asynchronous streaming chat completion request to the WatsonX API.
|
|
@@ -274,7 +284,7 @@ class WatsonX(Model):
|
|
|
274
284
|
run_response.metrics.set_time_to_first_token()
|
|
275
285
|
|
|
276
286
|
client = self.get_client()
|
|
277
|
-
formatted_messages = [self._format_message(m) for m in messages]
|
|
287
|
+
formatted_messages = [self._format_message(m, compress_tool_results) for m in messages]
|
|
278
288
|
|
|
279
289
|
# Get parameters for chat
|
|
280
290
|
request_params = self.get_request_params(
|
agno/models/internlm/internlm.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
2
|
from os import getenv
|
|
3
|
-
from typing import Optional
|
|
3
|
+
from typing import Any, Dict, Optional
|
|
4
4
|
|
|
5
|
+
from agno.exceptions import ModelAuthenticationError
|
|
5
6
|
from agno.models.openai.like import OpenAILike
|
|
6
7
|
|
|
7
8
|
|
|
@@ -24,3 +25,19 @@ class InternLM(OpenAILike):
|
|
|
24
25
|
|
|
25
26
|
api_key: Optional[str] = field(default_factory=lambda: getenv("INTERNLM_API_KEY"))
|
|
26
27
|
base_url: Optional[str] = "https://internlm-chat.intern-ai.org.cn/puyu/api/v1/chat/completions"
|
|
28
|
+
|
|
29
|
+
def _get_client_params(self) -> Dict[str, Any]:
|
|
30
|
+
"""
|
|
31
|
+
Returns client parameters for API requests, checking for INTERNLM_API_KEY.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Dict[str, Any]: A dictionary of client parameters for API requests.
|
|
35
|
+
"""
|
|
36
|
+
if not self.api_key:
|
|
37
|
+
self.api_key = getenv("INTERNLM_API_KEY")
|
|
38
|
+
if not self.api_key:
|
|
39
|
+
raise ModelAuthenticationError(
|
|
40
|
+
message="INTERNLM_API_KEY not set. Please set the INTERNLM_API_KEY environment variable.",
|
|
41
|
+
model_name=self.name,
|
|
42
|
+
)
|
|
43
|
+
return super()._get_client_params()
|
agno/models/langdb/langdb.py
CHANGED
|
@@ -2,6 +2,7 @@ from dataclasses import dataclass, field
|
|
|
2
2
|
from os import getenv
|
|
3
3
|
from typing import Any, Dict, Optional
|
|
4
4
|
|
|
5
|
+
from agno.exceptions import ModelAuthenticationError
|
|
5
6
|
from agno.models.openai.like import OpenAILike
|
|
6
7
|
|
|
7
8
|
|
|
@@ -32,8 +33,19 @@ class LangDB(OpenAILike):
|
|
|
32
33
|
default_headers: Optional[dict] = None
|
|
33
34
|
|
|
34
35
|
def _get_client_params(self) -> Dict[str, Any]:
|
|
36
|
+
if not self.api_key:
|
|
37
|
+
self.api_key = getenv("LANGDB_API_KEY")
|
|
38
|
+
if not self.api_key:
|
|
39
|
+
raise ModelAuthenticationError(
|
|
40
|
+
message="LANGDB_API_KEY not set. Please set the LANGDB_API_KEY environment variable.",
|
|
41
|
+
model_name=self.name,
|
|
42
|
+
)
|
|
43
|
+
|
|
35
44
|
if not self.project_id:
|
|
36
|
-
raise
|
|
45
|
+
raise ModelAuthenticationError(
|
|
46
|
+
message="LANGDB_PROJECT_ID not set. Please set the LANGDB_PROJECT_ID environment variable.",
|
|
47
|
+
model_name=self.name,
|
|
48
|
+
)
|
|
37
49
|
|
|
38
50
|
if not self.base_url:
|
|
39
51
|
self.base_url = f"{self.base_host_url}/{self.project_id}/v1"
|
agno/models/litellm/chat.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import copy
|
|
1
2
|
import json
|
|
2
3
|
from dataclasses import dataclass
|
|
3
4
|
from os import getenv
|
|
@@ -10,8 +11,10 @@ from agno.models.message import Message
|
|
|
10
11
|
from agno.models.metrics import Metrics
|
|
11
12
|
from agno.models.response import ModelResponse
|
|
12
13
|
from agno.run.agent import RunOutput
|
|
14
|
+
from agno.tools.function import Function
|
|
13
15
|
from agno.utils.log import log_debug, log_error, log_warning
|
|
14
16
|
from agno.utils.openai import _format_file_for_message, audio_to_message, images_to_message
|
|
17
|
+
from agno.utils.tokens import count_schema_tokens
|
|
15
18
|
|
|
16
19
|
try:
|
|
17
20
|
import litellm
|
|
@@ -46,10 +49,18 @@ class LiteLLM(Model):
|
|
|
46
49
|
|
|
47
50
|
client: Optional[Any] = None
|
|
48
51
|
|
|
52
|
+
# Store the original client to preserve it across copies (e.g., for Router instances)
|
|
53
|
+
_original_client: Optional[Any] = None
|
|
54
|
+
|
|
49
55
|
def __post_init__(self):
|
|
50
56
|
"""Initialize the model after the dataclass initialization."""
|
|
51
57
|
super().__post_init__()
|
|
52
58
|
|
|
59
|
+
# Store the original client if provided (e.g., Router instance)
|
|
60
|
+
# This ensures the client is preserved when the model is copied for background tasks
|
|
61
|
+
if self.client is not None and self._original_client is None:
|
|
62
|
+
self._original_client = self.client
|
|
63
|
+
|
|
53
64
|
# Set up API key from environment variable if not already set
|
|
54
65
|
if not self.client and not self.api_key:
|
|
55
66
|
self.api_key = getenv("LITELLM_API_KEY")
|
|
@@ -57,8 +68,8 @@ class LiteLLM(Model):
|
|
|
57
68
|
# Check for other present valid keys, e.g. OPENAI_API_KEY if self.id is an OpenAI model
|
|
58
69
|
env_validation = validate_environment(model=self.id, api_base=self.api_base)
|
|
59
70
|
if not env_validation.get("keys_in_environment"):
|
|
60
|
-
|
|
61
|
-
"
|
|
71
|
+
log_error(
|
|
72
|
+
"LITELLM_API_KEY not set. Please set the LITELLM_API_KEY or other valid environment variables."
|
|
62
73
|
)
|
|
63
74
|
|
|
64
75
|
def get_client(self) -> Any:
|
|
@@ -68,17 +79,52 @@ class LiteLLM(Model):
|
|
|
68
79
|
Returns:
|
|
69
80
|
Any: An instance of the LiteLLM client.
|
|
70
81
|
"""
|
|
82
|
+
# First check if we have a current client
|
|
71
83
|
if self.client is not None:
|
|
72
84
|
return self.client
|
|
73
85
|
|
|
86
|
+
# Check if we have an original client (e.g., Router) that was preserved
|
|
87
|
+
# This handles the case where the model was copied for background tasks
|
|
88
|
+
if self._original_client is not None:
|
|
89
|
+
self.client = self._original_client
|
|
90
|
+
return self.client
|
|
91
|
+
|
|
74
92
|
self.client = litellm
|
|
75
93
|
return self.client
|
|
76
94
|
|
|
77
|
-
def
|
|
95
|
+
def __deepcopy__(self, memo: Dict[int, Any]) -> "LiteLLM":
|
|
96
|
+
"""
|
|
97
|
+
Custom deepcopy to preserve the client (e.g., Router) across copies.
|
|
98
|
+
|
|
99
|
+
This is needed because when the model is copied for background tasks
|
|
100
|
+
(memory, summarization), the client reference needs to be preserved.
|
|
101
|
+
"""
|
|
102
|
+
# Create a shallow copy first
|
|
103
|
+
cls = self.__class__
|
|
104
|
+
result = cls.__new__(cls)
|
|
105
|
+
memo[id(self)] = result
|
|
106
|
+
|
|
107
|
+
# Copy all attributes, but keep the same client reference
|
|
108
|
+
for k, v in self.__dict__.items():
|
|
109
|
+
if k in ("client", "_original_client"):
|
|
110
|
+
# Keep the same client reference (don't deepcopy Router instances)
|
|
111
|
+
setattr(result, k, v)
|
|
112
|
+
else:
|
|
113
|
+
setattr(result, k, copy.deepcopy(v, memo))
|
|
114
|
+
|
|
115
|
+
return result
|
|
116
|
+
|
|
117
|
+
def _format_messages(self, messages: List[Message], compress_tool_results: bool = False) -> List[Dict[str, Any]]:
|
|
78
118
|
"""Format messages for LiteLLM API."""
|
|
79
119
|
formatted_messages = []
|
|
80
120
|
for m in messages:
|
|
81
|
-
|
|
121
|
+
# Use compressed content for tool messages if compression is active
|
|
122
|
+
if m.role == "tool":
|
|
123
|
+
content = m.get_content(use_compressed_content=compress_tool_results)
|
|
124
|
+
else:
|
|
125
|
+
content = m.content if m.content is not None else ""
|
|
126
|
+
|
|
127
|
+
msg = {"role": m.role, "content": content}
|
|
82
128
|
|
|
83
129
|
# Handle media
|
|
84
130
|
if (m.images is not None and len(m.images) > 0) or (m.audio is not None and len(m.audio) > 0):
|
|
@@ -98,7 +144,7 @@ class LiteLLM(Model):
|
|
|
98
144
|
if isinstance(msg["content"], str):
|
|
99
145
|
content_list = [{"type": "text", "text": msg["content"]}]
|
|
100
146
|
else:
|
|
101
|
-
content_list = msg["content"]
|
|
147
|
+
content_list = msg["content"] if isinstance(msg["content"], list) else []
|
|
102
148
|
for file in m.files:
|
|
103
149
|
file_part = _format_file_for_message(file)
|
|
104
150
|
if file_part:
|
|
@@ -186,10 +232,11 @@ class LiteLLM(Model):
|
|
|
186
232
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
187
233
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
188
234
|
run_response: Optional[RunOutput] = None,
|
|
235
|
+
compress_tool_results: bool = False,
|
|
189
236
|
) -> ModelResponse:
|
|
190
237
|
"""Sends a chat completion request to the LiteLLM API."""
|
|
191
238
|
completion_kwargs = self.get_request_params(tools=tools)
|
|
192
|
-
completion_kwargs["messages"] = self._format_messages(messages)
|
|
239
|
+
completion_kwargs["messages"] = self._format_messages(messages, compress_tool_results)
|
|
193
240
|
|
|
194
241
|
if run_response and run_response.metrics:
|
|
195
242
|
run_response.metrics.set_time_to_first_token()
|
|
@@ -211,10 +258,11 @@ class LiteLLM(Model):
|
|
|
211
258
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
212
259
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
213
260
|
run_response: Optional[RunOutput] = None,
|
|
261
|
+
compress_tool_results: bool = False,
|
|
214
262
|
) -> Iterator[ModelResponse]:
|
|
215
263
|
"""Sends a streaming chat completion request to the LiteLLM API."""
|
|
216
264
|
completion_kwargs = self.get_request_params(tools=tools)
|
|
217
|
-
completion_kwargs["messages"] = self._format_messages(messages)
|
|
265
|
+
completion_kwargs["messages"] = self._format_messages(messages, compress_tool_results)
|
|
218
266
|
completion_kwargs["stream"] = True
|
|
219
267
|
completion_kwargs["stream_options"] = {"include_usage": True}
|
|
220
268
|
|
|
@@ -236,10 +284,11 @@ class LiteLLM(Model):
|
|
|
236
284
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
237
285
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
238
286
|
run_response: Optional[RunOutput] = None,
|
|
287
|
+
compress_tool_results: bool = False,
|
|
239
288
|
) -> ModelResponse:
|
|
240
289
|
"""Sends an asynchronous chat completion request to the LiteLLM API."""
|
|
241
290
|
completion_kwargs = self.get_request_params(tools=tools)
|
|
242
|
-
completion_kwargs["messages"] = self._format_messages(messages)
|
|
291
|
+
completion_kwargs["messages"] = self._format_messages(messages, compress_tool_results)
|
|
243
292
|
|
|
244
293
|
if run_response and run_response.metrics:
|
|
245
294
|
run_response.metrics.set_time_to_first_token()
|
|
@@ -261,10 +310,11 @@ class LiteLLM(Model):
|
|
|
261
310
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
262
311
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
263
312
|
run_response: Optional[RunOutput] = None,
|
|
313
|
+
compress_tool_results: bool = False,
|
|
264
314
|
) -> AsyncIterator[ModelResponse]:
|
|
265
315
|
"""Sends an asynchronous streaming chat request to the LiteLLM API."""
|
|
266
316
|
completion_kwargs = self.get_request_params(tools=tools)
|
|
267
|
-
completion_kwargs["messages"] = self._format_messages(messages)
|
|
317
|
+
completion_kwargs["messages"] = self._format_messages(messages, compress_tool_results)
|
|
268
318
|
completion_kwargs["stream"] = True
|
|
269
319
|
completion_kwargs["stream_options"] = {"include_usage": True}
|
|
270
320
|
|
|
@@ -295,6 +345,9 @@ class LiteLLM(Model):
|
|
|
295
345
|
if response_message.content is not None:
|
|
296
346
|
model_response.content = response_message.content
|
|
297
347
|
|
|
348
|
+
if hasattr(response_message, "reasoning_content") and response_message.reasoning_content is not None:
|
|
349
|
+
model_response.reasoning_content = response_message.reasoning_content
|
|
350
|
+
|
|
298
351
|
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
|
|
299
352
|
model_response.tool_calls = []
|
|
300
353
|
for tool_call in response_message.tool_calls:
|
|
@@ -322,6 +375,9 @@ class LiteLLM(Model):
|
|
|
322
375
|
if hasattr(choice_delta, "content") and choice_delta.content is not None:
|
|
323
376
|
model_response.content = choice_delta.content
|
|
324
377
|
|
|
378
|
+
if hasattr(choice_delta, "reasoning_content") and choice_delta.reasoning_content is not None:
|
|
379
|
+
model_response.reasoning_content = choice_delta.reasoning_content
|
|
380
|
+
|
|
325
381
|
if hasattr(choice_delta, "tool_calls") and choice_delta.tool_calls:
|
|
326
382
|
processed_tool_calls = []
|
|
327
383
|
for tool_call in choice_delta.tool_calls:
|
|
@@ -466,3 +522,26 @@ class LiteLLM(Model):
|
|
|
466
522
|
metrics.total_tokens = metrics.input_tokens + metrics.output_tokens
|
|
467
523
|
|
|
468
524
|
return metrics
|
|
525
|
+
|
|
526
|
+
def count_tokens(
|
|
527
|
+
self,
|
|
528
|
+
messages: List[Message],
|
|
529
|
+
tools: Optional[List[Union[Function, Dict[str, Any]]]] = None,
|
|
530
|
+
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
531
|
+
) -> int:
|
|
532
|
+
formatted_messages = self._format_messages(messages, compress_tool_results=True)
|
|
533
|
+
formatted_tools = self._format_tools(tools) if tools else None
|
|
534
|
+
tokens = litellm.token_counter(
|
|
535
|
+
model=self.id,
|
|
536
|
+
messages=formatted_messages,
|
|
537
|
+
tools=formatted_tools, # type: ignore
|
|
538
|
+
)
|
|
539
|
+
return tokens + count_schema_tokens(response_format, self.id)
|
|
540
|
+
|
|
541
|
+
async def acount_tokens(
|
|
542
|
+
self,
|
|
543
|
+
messages: List[Message],
|
|
544
|
+
tools: Optional[List[Union[Function, Dict[str, Any]]]] = None,
|
|
545
|
+
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
546
|
+
) -> int:
|
|
547
|
+
return self.count_tokens(messages, tools, response_format)
|