ag2 0.9.7__py3-none-any.whl → 0.9.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ag2 might be problematic. Click here for more details.
- {ag2-0.9.7.dist-info → ag2-0.9.9.dist-info}/METADATA +102 -75
- ag2-0.9.9.dist-info/RECORD +387 -0
- autogen/__init__.py +1 -2
- autogen/_website/generate_api_references.py +4 -5
- autogen/_website/generate_mkdocs.py +9 -15
- autogen/_website/notebook_processor.py +13 -14
- autogen/_website/process_notebooks.py +10 -10
- autogen/_website/utils.py +5 -4
- autogen/agentchat/agent.py +13 -13
- autogen/agentchat/assistant_agent.py +7 -6
- autogen/agentchat/contrib/agent_eval/agent_eval.py +3 -3
- autogen/agentchat/contrib/agent_eval/critic_agent.py +3 -3
- autogen/agentchat/contrib/agent_eval/quantifier_agent.py +3 -3
- autogen/agentchat/contrib/agent_eval/subcritic_agent.py +3 -3
- autogen/agentchat/contrib/agent_optimizer.py +3 -3
- autogen/agentchat/contrib/capabilities/generate_images.py +11 -11
- autogen/agentchat/contrib/capabilities/teachability.py +15 -15
- autogen/agentchat/contrib/capabilities/transforms.py +17 -18
- autogen/agentchat/contrib/capabilities/transforms_util.py +5 -5
- autogen/agentchat/contrib/capabilities/vision_capability.py +4 -3
- autogen/agentchat/contrib/captainagent/agent_builder.py +30 -30
- autogen/agentchat/contrib/captainagent/captainagent.py +22 -21
- autogen/agentchat/contrib/captainagent/tool_retriever.py +2 -3
- autogen/agentchat/contrib/gpt_assistant_agent.py +9 -9
- autogen/agentchat/contrib/graph_rag/document.py +3 -3
- autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +3 -3
- autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py +6 -6
- autogen/agentchat/contrib/graph_rag/graph_query_engine.py +3 -3
- autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py +5 -11
- autogen/agentchat/contrib/graph_rag/neo4j_graph_rag_capability.py +6 -6
- autogen/agentchat/contrib/graph_rag/neo4j_native_graph_query_engine.py +7 -7
- autogen/agentchat/contrib/graph_rag/neo4j_native_graph_rag_capability.py +6 -6
- autogen/agentchat/contrib/img_utils.py +1 -1
- autogen/agentchat/contrib/llamaindex_conversable_agent.py +11 -11
- autogen/agentchat/contrib/llava_agent.py +18 -4
- autogen/agentchat/contrib/math_user_proxy_agent.py +11 -11
- autogen/agentchat/contrib/multimodal_conversable_agent.py +8 -8
- autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +6 -5
- autogen/agentchat/contrib/rag/chromadb_query_engine.py +22 -26
- autogen/agentchat/contrib/rag/llamaindex_query_engine.py +14 -17
- autogen/agentchat/contrib/rag/mongodb_query_engine.py +27 -37
- autogen/agentchat/contrib/rag/query_engine.py +7 -5
- autogen/agentchat/contrib/retrieve_assistant_agent.py +5 -5
- autogen/agentchat/contrib/retrieve_user_proxy_agent.py +8 -7
- autogen/agentchat/contrib/society_of_mind_agent.py +15 -14
- autogen/agentchat/contrib/swarm_agent.py +76 -98
- autogen/agentchat/contrib/text_analyzer_agent.py +7 -7
- autogen/agentchat/contrib/vectordb/base.py +10 -18
- autogen/agentchat/contrib/vectordb/chromadb.py +2 -1
- autogen/agentchat/contrib/vectordb/couchbase.py +18 -20
- autogen/agentchat/contrib/vectordb/mongodb.py +6 -5
- autogen/agentchat/contrib/vectordb/pgvectordb.py +40 -41
- autogen/agentchat/contrib/vectordb/qdrant.py +5 -5
- autogen/agentchat/contrib/web_surfer.py +20 -19
- autogen/agentchat/conversable_agent.py +292 -290
- autogen/agentchat/group/context_str.py +1 -3
- autogen/agentchat/group/context_variables.py +15 -25
- autogen/agentchat/group/group_tool_executor.py +10 -10
- autogen/agentchat/group/group_utils.py +15 -15
- autogen/agentchat/group/guardrails.py +7 -7
- autogen/agentchat/group/handoffs.py +19 -36
- autogen/agentchat/group/multi_agent_chat.py +7 -7
- autogen/agentchat/group/on_condition.py +4 -7
- autogen/agentchat/group/on_context_condition.py +4 -7
- autogen/agentchat/group/patterns/auto.py +8 -7
- autogen/agentchat/group/patterns/manual.py +7 -6
- autogen/agentchat/group/patterns/pattern.py +13 -12
- autogen/agentchat/group/patterns/random.py +3 -3
- autogen/agentchat/group/patterns/round_robin.py +3 -3
- autogen/agentchat/group/reply_result.py +2 -4
- autogen/agentchat/group/speaker_selection_result.py +5 -5
- autogen/agentchat/group/targets/group_chat_target.py +7 -6
- autogen/agentchat/group/targets/group_manager_target.py +4 -4
- autogen/agentchat/group/targets/transition_target.py +2 -1
- autogen/agentchat/groupchat.py +60 -63
- autogen/agentchat/realtime/experimental/audio_adapters/twilio_audio_adapter.py +4 -4
- autogen/agentchat/realtime/experimental/audio_adapters/websocket_audio_adapter.py +4 -4
- autogen/agentchat/realtime/experimental/clients/gemini/client.py +7 -7
- autogen/agentchat/realtime/experimental/clients/oai/base_client.py +8 -8
- autogen/agentchat/realtime/experimental/clients/oai/rtc_client.py +6 -6
- autogen/agentchat/realtime/experimental/clients/realtime_client.py +10 -9
- autogen/agentchat/realtime/experimental/realtime_agent.py +10 -9
- autogen/agentchat/realtime/experimental/realtime_observer.py +3 -3
- autogen/agentchat/realtime/experimental/realtime_swarm.py +44 -44
- autogen/agentchat/user_proxy_agent.py +10 -9
- autogen/agentchat/utils.py +3 -3
- autogen/agents/contrib/time/time_reply_agent.py +6 -5
- autogen/agents/contrib/time/time_tool_agent.py +2 -1
- autogen/agents/experimental/deep_research/deep_research.py +3 -3
- autogen/agents/experimental/discord/discord.py +2 -2
- autogen/agents/experimental/document_agent/chroma_query_engine.py +29 -44
- autogen/agents/experimental/document_agent/docling_doc_ingest_agent.py +9 -14
- autogen/agents/experimental/document_agent/document_agent.py +15 -16
- autogen/agents/experimental/document_agent/document_conditions.py +3 -3
- autogen/agents/experimental/document_agent/document_utils.py +5 -9
- autogen/agents/experimental/document_agent/inmemory_query_engine.py +14 -20
- autogen/agents/experimental/document_agent/parser_utils.py +4 -4
- autogen/agents/experimental/document_agent/url_utils.py +14 -23
- autogen/agents/experimental/reasoning/reasoning_agent.py +33 -33
- autogen/agents/experimental/slack/slack.py +2 -2
- autogen/agents/experimental/telegram/telegram.py +2 -3
- autogen/agents/experimental/websurfer/websurfer.py +4 -4
- autogen/agents/experimental/wikipedia/wikipedia.py +5 -7
- autogen/browser_utils.py +8 -8
- autogen/cache/abstract_cache_base.py +5 -5
- autogen/cache/cache.py +12 -12
- autogen/cache/cache_factory.py +4 -4
- autogen/cache/cosmos_db_cache.py +9 -9
- autogen/cache/disk_cache.py +6 -6
- autogen/cache/in_memory_cache.py +4 -4
- autogen/cache/redis_cache.py +4 -4
- autogen/code_utils.py +18 -18
- autogen/coding/base.py +6 -6
- autogen/coding/docker_commandline_code_executor.py +9 -9
- autogen/coding/func_with_reqs.py +7 -6
- autogen/coding/jupyter/base.py +3 -3
- autogen/coding/jupyter/docker_jupyter_server.py +3 -4
- autogen/coding/jupyter/import_utils.py +3 -3
- autogen/coding/jupyter/jupyter_client.py +5 -5
- autogen/coding/jupyter/jupyter_code_executor.py +3 -4
- autogen/coding/jupyter/local_jupyter_server.py +2 -6
- autogen/coding/local_commandline_code_executor.py +8 -7
- autogen/coding/markdown_code_extractor.py +1 -2
- autogen/coding/utils.py +1 -2
- autogen/doc_utils.py +3 -2
- autogen/environments/docker_python_environment.py +19 -29
- autogen/environments/python_environment.py +8 -17
- autogen/environments/system_python_environment.py +3 -4
- autogen/environments/venv_python_environment.py +8 -12
- autogen/environments/working_directory.py +1 -2
- autogen/events/agent_events.py +106 -109
- autogen/events/base_event.py +6 -5
- autogen/events/client_events.py +15 -14
- autogen/events/helpers.py +1 -1
- autogen/events/print_event.py +4 -5
- autogen/fast_depends/_compat.py +10 -15
- autogen/fast_depends/core/build.py +17 -36
- autogen/fast_depends/core/model.py +64 -113
- autogen/fast_depends/dependencies/model.py +2 -1
- autogen/fast_depends/dependencies/provider.py +3 -2
- autogen/fast_depends/library/model.py +4 -4
- autogen/fast_depends/schema.py +7 -7
- autogen/fast_depends/use.py +17 -25
- autogen/fast_depends/utils.py +10 -30
- autogen/formatting_utils.py +6 -6
- autogen/graph_utils.py +1 -4
- autogen/import_utils.py +38 -27
- autogen/interop/crewai/crewai.py +2 -2
- autogen/interop/interoperable.py +2 -2
- autogen/interop/langchain/langchain_chat_model_factory.py +3 -2
- autogen/interop/langchain/langchain_tool.py +2 -6
- autogen/interop/litellm/litellm_config_factory.py +6 -7
- autogen/interop/pydantic_ai/pydantic_ai.py +4 -7
- autogen/interop/registry.py +2 -1
- autogen/io/base.py +5 -5
- autogen/io/run_response.py +33 -32
- autogen/io/websockets.py +6 -5
- autogen/json_utils.py +1 -2
- autogen/llm_config/__init__.py +11 -0
- autogen/llm_config/client.py +58 -0
- autogen/llm_config/config.py +384 -0
- autogen/llm_config/entry.py +154 -0
- autogen/logger/base_logger.py +4 -3
- autogen/logger/file_logger.py +2 -1
- autogen/logger/logger_factory.py +2 -2
- autogen/logger/logger_utils.py +2 -2
- autogen/logger/sqlite_logger.py +2 -1
- autogen/math_utils.py +4 -5
- autogen/mcp/__main__.py +6 -6
- autogen/mcp/helpers.py +4 -4
- autogen/mcp/mcp_client.py +170 -29
- autogen/mcp/mcp_proxy/fastapi_code_generator_helpers.py +3 -4
- autogen/mcp/mcp_proxy/mcp_proxy.py +23 -26
- autogen/mcp/mcp_proxy/operation_grouping.py +4 -5
- autogen/mcp/mcp_proxy/operation_renaming.py +6 -10
- autogen/mcp/mcp_proxy/security.py +2 -3
- autogen/messages/agent_messages.py +96 -98
- autogen/messages/base_message.py +6 -5
- autogen/messages/client_messages.py +15 -14
- autogen/messages/print_message.py +4 -5
- autogen/oai/__init__.py +1 -2
- autogen/oai/anthropic.py +42 -41
- autogen/oai/bedrock.py +68 -57
- autogen/oai/cerebras.py +26 -25
- autogen/oai/client.py +113 -139
- autogen/oai/client_utils.py +3 -3
- autogen/oai/cohere.py +34 -11
- autogen/oai/gemini.py +39 -17
- autogen/oai/gemini_types.py +11 -12
- autogen/oai/groq.py +22 -10
- autogen/oai/mistral.py +17 -11
- autogen/oai/oai_models/__init__.py +14 -2
- autogen/oai/oai_models/_models.py +2 -2
- autogen/oai/oai_models/chat_completion.py +13 -14
- autogen/oai/oai_models/chat_completion_message.py +11 -9
- autogen/oai/oai_models/chat_completion_message_tool_call.py +26 -3
- autogen/oai/oai_models/chat_completion_token_logprob.py +3 -4
- autogen/oai/oai_models/completion_usage.py +8 -9
- autogen/oai/ollama.py +19 -9
- autogen/oai/openai_responses.py +40 -17
- autogen/oai/openai_utils.py +48 -38
- autogen/oai/together.py +29 -14
- autogen/retrieve_utils.py +6 -7
- autogen/runtime_logging.py +5 -4
- autogen/token_count_utils.py +7 -4
- autogen/tools/contrib/time/time.py +0 -1
- autogen/tools/dependency_injection.py +5 -6
- autogen/tools/experimental/browser_use/browser_use.py +10 -10
- autogen/tools/experimental/code_execution/python_code_execution.py +5 -7
- autogen/tools/experimental/crawl4ai/crawl4ai.py +12 -15
- autogen/tools/experimental/deep_research/deep_research.py +9 -8
- autogen/tools/experimental/duckduckgo/duckduckgo_search.py +5 -11
- autogen/tools/experimental/firecrawl/firecrawl_tool.py +98 -115
- autogen/tools/experimental/google/authentication/credentials_local_provider.py +1 -1
- autogen/tools/experimental/google/drive/drive_functions.py +4 -4
- autogen/tools/experimental/google/drive/toolkit.py +5 -5
- autogen/tools/experimental/google_search/google_search.py +5 -5
- autogen/tools/experimental/google_search/youtube_search.py +5 -5
- autogen/tools/experimental/messageplatform/discord/discord.py +8 -12
- autogen/tools/experimental/messageplatform/slack/slack.py +14 -20
- autogen/tools/experimental/messageplatform/telegram/telegram.py +8 -12
- autogen/tools/experimental/perplexity/perplexity_search.py +18 -29
- autogen/tools/experimental/reliable/reliable.py +68 -74
- autogen/tools/experimental/searxng/searxng_search.py +20 -19
- autogen/tools/experimental/tavily/tavily_search.py +12 -19
- autogen/tools/experimental/web_search_preview/web_search_preview.py +13 -7
- autogen/tools/experimental/wikipedia/wikipedia.py +7 -10
- autogen/tools/function_utils.py +7 -7
- autogen/tools/tool.py +8 -6
- autogen/types.py +2 -2
- autogen/version.py +1 -1
- ag2-0.9.7.dist-info/RECORD +0 -421
- autogen/llm_config.py +0 -385
- {ag2-0.9.7.dist-info → ag2-0.9.9.dist-info}/WHEEL +0 -0
- {ag2-0.9.7.dist-info → ag2-0.9.9.dist-info}/licenses/LICENSE +0 -0
- {ag2-0.9.7.dist-info → ag2-0.9.9.dist-info}/licenses/NOTICE.md +0 -0
autogen/oai/client.py
CHANGED
|
@@ -13,10 +13,11 @@ import re
|
|
|
13
13
|
import sys
|
|
14
14
|
import uuid
|
|
15
15
|
import warnings
|
|
16
|
+
from collections.abc import Callable
|
|
16
17
|
from functools import lru_cache
|
|
17
|
-
from typing import Any,
|
|
18
|
+
from typing import Any, Literal
|
|
18
19
|
|
|
19
|
-
from pydantic import BaseModel, Field, HttpUrl
|
|
20
|
+
from pydantic import BaseModel, Field, HttpUrl
|
|
20
21
|
from pydantic.type_adapter import TypeAdapter
|
|
21
22
|
|
|
22
23
|
from ..cache import Cache
|
|
@@ -25,7 +26,8 @@ from ..events.client_events import StreamEvent, UsageSummaryEvent
|
|
|
25
26
|
from ..exception_utils import ModelToolNotSupportedError
|
|
26
27
|
from ..import_utils import optional_import_block, require_optional_import
|
|
27
28
|
from ..io.base import IOStream
|
|
28
|
-
from ..llm_config import
|
|
29
|
+
from ..llm_config import ModelClient
|
|
30
|
+
from ..llm_config.entry import LLMConfigEntry, LLMConfigEntryDict
|
|
29
31
|
from ..logger.logger_utils import get_current_ts
|
|
30
32
|
from ..runtime_logging import log_chat_completion, log_new_client, log_new_wrapper, logging_enabled
|
|
31
33
|
from ..token_count_utils import count_token
|
|
@@ -58,7 +60,7 @@ if openai_result.is_successful:
|
|
|
58
60
|
ERROR = None
|
|
59
61
|
from openai.lib._pydantic import _ensure_strict_json_schema
|
|
60
62
|
else:
|
|
61
|
-
ERROR:
|
|
63
|
+
ERROR: ImportError | None = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
|
|
62
64
|
|
|
63
65
|
# OpenAI = object
|
|
64
66
|
# AzureOpenAI = object
|
|
@@ -73,7 +75,7 @@ with optional_import_block() as cerebras_result:
|
|
|
73
75
|
from .cerebras import CerebrasClient
|
|
74
76
|
|
|
75
77
|
if cerebras_result.is_successful:
|
|
76
|
-
cerebras_import_exception:
|
|
78
|
+
cerebras_import_exception: ImportError | None = None
|
|
77
79
|
else:
|
|
78
80
|
cerebras_AuthenticationError = cerebras_InternalServerError = cerebras_RateLimitError = Exception # noqa: N816
|
|
79
81
|
cerebras_import_exception = ImportError("cerebras_cloud_sdk not found")
|
|
@@ -87,7 +89,7 @@ with optional_import_block() as gemini_result:
|
|
|
87
89
|
from .gemini import GeminiClient
|
|
88
90
|
|
|
89
91
|
if gemini_result.is_successful:
|
|
90
|
-
gemini_import_exception:
|
|
92
|
+
gemini_import_exception: ImportError | None = None
|
|
91
93
|
else:
|
|
92
94
|
gemini_InternalServerError = gemini_ResourceExhausted = Exception # noqa: N816
|
|
93
95
|
gemini_import_exception = ImportError("google-genai not found")
|
|
@@ -101,7 +103,7 @@ with optional_import_block() as anthropic_result:
|
|
|
101
103
|
from .anthropic import AnthropicClient
|
|
102
104
|
|
|
103
105
|
if anthropic_result.is_successful:
|
|
104
|
-
anthropic_import_exception:
|
|
106
|
+
anthropic_import_exception: ImportError | None = None
|
|
105
107
|
else:
|
|
106
108
|
anthorpic_InternalServerError = anthorpic_RateLimitError = Exception # noqa: N816
|
|
107
109
|
anthropic_import_exception = ImportError("anthropic not found")
|
|
@@ -115,7 +117,7 @@ with optional_import_block() as mistral_result:
|
|
|
115
117
|
from .mistral import MistralAIClient
|
|
116
118
|
|
|
117
119
|
if mistral_result.is_successful:
|
|
118
|
-
mistral_import_exception:
|
|
120
|
+
mistral_import_exception: ImportError | None = None
|
|
119
121
|
else:
|
|
120
122
|
mistral_SDKError = mistral_HTTPValidationError = Exception # noqa: N816
|
|
121
123
|
mistral_import_exception = ImportError("mistralai not found")
|
|
@@ -126,7 +128,7 @@ with optional_import_block() as together_result:
|
|
|
126
128
|
from .together import TogetherClient
|
|
127
129
|
|
|
128
130
|
if together_result.is_successful:
|
|
129
|
-
together_import_exception:
|
|
131
|
+
together_import_exception: ImportError | None = None
|
|
130
132
|
else:
|
|
131
133
|
together_TogetherException = Exception # noqa: N816
|
|
132
134
|
together_import_exception = ImportError("together not found")
|
|
@@ -141,7 +143,7 @@ with optional_import_block() as groq_result:
|
|
|
141
143
|
from .groq import GroqClient
|
|
142
144
|
|
|
143
145
|
if groq_result.is_successful:
|
|
144
|
-
groq_import_exception:
|
|
146
|
+
groq_import_exception: ImportError | None = None
|
|
145
147
|
else:
|
|
146
148
|
groq_InternalServerError = groq_RateLimitError = groq_APIConnectionError = Exception # noqa: N816
|
|
147
149
|
groq_import_exception = ImportError("groq not found")
|
|
@@ -156,7 +158,7 @@ with optional_import_block() as cohere_result:
|
|
|
156
158
|
from .cohere import CohereClient
|
|
157
159
|
|
|
158
160
|
if cohere_result.is_successful:
|
|
159
|
-
cohere_import_exception:
|
|
161
|
+
cohere_import_exception: ImportError | None = None
|
|
160
162
|
else:
|
|
161
163
|
cohere_InternalServerError = cohere_TooManyRequestsError = cohere_ServiceUnavailableError = Exception # noqa: N816
|
|
162
164
|
cohere_import_exception = ImportError("cohere not found")
|
|
@@ -170,7 +172,7 @@ with optional_import_block() as ollama_result:
|
|
|
170
172
|
from .ollama import OllamaClient
|
|
171
173
|
|
|
172
174
|
if ollama_result.is_successful:
|
|
173
|
-
ollama_import_exception:
|
|
175
|
+
ollama_import_exception: ImportError | None = None
|
|
174
176
|
else:
|
|
175
177
|
ollama_RequestError = ollama_ResponseError = Exception # noqa: N816
|
|
176
178
|
ollama_import_exception = ImportError("ollama not found")
|
|
@@ -184,7 +186,7 @@ with optional_import_block() as bedrock_result:
|
|
|
184
186
|
from .bedrock import BedrockClient
|
|
185
187
|
|
|
186
188
|
if bedrock_result.is_successful:
|
|
187
|
-
bedrock_import_exception:
|
|
189
|
+
bedrock_import_exception: ImportError | None = None
|
|
188
190
|
else:
|
|
189
191
|
bedrock_BotoCoreError = bedrock_ClientError = Exception # noqa: N816
|
|
190
192
|
bedrock_import_exception = ImportError("botocore not found")
|
|
@@ -212,6 +214,7 @@ OPENAI_FALLBACK_KWARGS = {
|
|
|
212
214
|
"default_query",
|
|
213
215
|
"http_client",
|
|
214
216
|
"_strict_response_validation",
|
|
217
|
+
"webhook_secret",
|
|
215
218
|
}
|
|
216
219
|
|
|
217
220
|
AOPENAI_FALLBACK_KWARGS = {
|
|
@@ -231,124 +234,103 @@ AOPENAI_FALLBACK_KWARGS = {
|
|
|
231
234
|
"_strict_response_validation",
|
|
232
235
|
"base_url",
|
|
233
236
|
"project",
|
|
237
|
+
"webhook_secret",
|
|
234
238
|
}
|
|
235
239
|
|
|
236
240
|
|
|
237
241
|
@lru_cache(maxsize=128)
|
|
238
|
-
def log_cache_seed_value(cache_seed_value:
|
|
242
|
+
def log_cache_seed_value(cache_seed_value: str | int, client: ModelClient) -> None:
|
|
239
243
|
logger.debug(f"Using cache with seed value {cache_seed_value} for client {client.__class__.__name__}")
|
|
240
244
|
|
|
241
245
|
|
|
242
|
-
|
|
246
|
+
class OpenAIEntryDict(LLMConfigEntryDict, total=False):
|
|
247
|
+
api_type: Literal["openai"]
|
|
248
|
+
|
|
249
|
+
price: list[float] | None
|
|
250
|
+
tool_choice: Literal["none", "auto", "required"] | None
|
|
251
|
+
user: str | None
|
|
252
|
+
stream: bool
|
|
253
|
+
verbosity: Literal["low", "medium", "high"] | None
|
|
254
|
+
extra_body: dict[str, Any] | None
|
|
255
|
+
reasoning_effort: Literal["low", "minimal", "medium", "high"] | None
|
|
256
|
+
max_completion_tokens: int | None
|
|
257
|
+
|
|
258
|
+
|
|
243
259
|
class OpenAILLMConfigEntry(LLMConfigEntry):
|
|
244
260
|
api_type: Literal["openai"] = "openai"
|
|
245
|
-
top_p: Optional[float] = None
|
|
246
|
-
price: Optional[list[float]] = Field(default=None, min_length=2, max_length=2)
|
|
247
|
-
tool_choice: Optional[Literal["none", "auto", "required"]] = None
|
|
248
|
-
user: Optional[str] = None
|
|
249
261
|
|
|
250
|
-
|
|
262
|
+
price: list[float] | None = Field(default=None, min_length=2, max_length=2)
|
|
263
|
+
tool_choice: Literal["none", "auto", "required"] | None = None
|
|
264
|
+
user: str | None = None
|
|
265
|
+
stream: bool = False
|
|
266
|
+
verbosity: Literal["low", "medium", "high"] | None = None
|
|
267
|
+
# The extra_body parameter flows from OpenAILLMConfigEntry to the LLM request through this path:
|
|
251
268
|
# 1. Config Definition: extra_body is defined in OpenAILLMConfigEntry (autogen/oai/client.py:248)
|
|
252
269
|
# 2. Parameter Classification: It's classified as an OpenAI client parameter (not AG2-specific) via the openai_kwargs property (autogen/oai/client.py:752-758)
|
|
253
270
|
# 3. Request Separation: In _separate_create_config() (autogen/oai/client.py:842), extra_body goes into create_config since it's not in the extra_kwargs set.
|
|
254
271
|
# 4. API Call: The create_config becomes params and gets passed directly to OpenAI's create() method via **params (autogen/oai/client.py:551,658)
|
|
255
|
-
extra_body:
|
|
272
|
+
extra_body: dict[str, Any] | None = (
|
|
256
273
|
None # For VLLM - See here: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters
|
|
257
274
|
)
|
|
258
275
|
# reasoning models - see: https://platform.openai.com/docs/api-reference/chat/create#chat-create-reasoning_effort
|
|
259
|
-
reasoning_effort:
|
|
260
|
-
max_completion_tokens:
|
|
276
|
+
reasoning_effort: Literal["low", "minimal", "medium", "high"] | None = None
|
|
277
|
+
max_completion_tokens: int | None = None
|
|
261
278
|
|
|
262
|
-
def create_client(self) ->
|
|
279
|
+
def create_client(self) -> ModelClient:
|
|
263
280
|
raise NotImplementedError("create_client method must be implemented in the derived class.")
|
|
264
281
|
|
|
265
282
|
|
|
266
|
-
|
|
283
|
+
class AzureOpenAIEntryDict(LLMConfigEntryDict, total=False):
|
|
284
|
+
api_type: Literal["azure"]
|
|
285
|
+
|
|
286
|
+
azure_ad_token_provider: str | Callable[[], str] | None
|
|
287
|
+
stream: bool
|
|
288
|
+
tool_choice: Literal["none", "auto", "required"] | None
|
|
289
|
+
user: str | None
|
|
290
|
+
reasoning_effort: Literal["low", "medium", "high"] | None
|
|
291
|
+
max_completion_tokens: int | None
|
|
292
|
+
|
|
293
|
+
|
|
267
294
|
class AzureOpenAILLMConfigEntry(LLMConfigEntry):
|
|
268
295
|
api_type: Literal["azure"] = "azure"
|
|
269
|
-
|
|
270
|
-
azure_ad_token_provider:
|
|
271
|
-
|
|
272
|
-
|
|
296
|
+
|
|
297
|
+
azure_ad_token_provider: str | Callable[[], str] | None = None
|
|
298
|
+
stream: bool = False
|
|
299
|
+
tool_choice: Literal["none", "auto", "required"] | None = None
|
|
300
|
+
user: str | None = None
|
|
273
301
|
# reasoning models - see:
|
|
274
302
|
# - https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/reasoning
|
|
275
303
|
# - https://learn.microsoft.com/en-us/azure/ai-services/openai/reference-preview
|
|
276
|
-
reasoning_effort:
|
|
277
|
-
max_completion_tokens:
|
|
304
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = None
|
|
305
|
+
max_completion_tokens: int | None = None
|
|
278
306
|
|
|
279
|
-
def create_client(self) ->
|
|
307
|
+
def create_client(self) -> ModelClient:
|
|
280
308
|
raise NotImplementedError
|
|
281
309
|
|
|
282
310
|
|
|
283
|
-
|
|
311
|
+
class DeepSeekEntyDict(LLMConfigEntryDict, total=False):
|
|
312
|
+
api_type: Literal["deepseek"]
|
|
313
|
+
|
|
314
|
+
base_url: HttpUrl
|
|
315
|
+
stream: bool
|
|
316
|
+
tool_choice: Literal["none", "auto", "required"] | None
|
|
317
|
+
|
|
318
|
+
|
|
284
319
|
class DeepSeekLLMConfigEntry(LLMConfigEntry):
|
|
285
320
|
api_type: Literal["deepseek"] = "deepseek"
|
|
286
|
-
|
|
287
|
-
temperature: float = Field(
|
|
321
|
+
|
|
322
|
+
temperature: float | None = Field(default=None, ge=0.0, le=1.0)
|
|
323
|
+
top_p: float | None = Field(None, ge=0.0, le=1.0)
|
|
288
324
|
max_tokens: int = Field(8192, ge=1, le=8192)
|
|
289
|
-
top_p: Optional[float] = Field(None, ge=0.0, le=1.0)
|
|
290
|
-
tool_choice: Optional[Literal["none", "auto", "required"]] = None
|
|
291
325
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
if v is not None and info.data.get("temperature") is not None:
|
|
296
|
-
raise ValueError("temperature and top_p cannot be set at the same time.")
|
|
297
|
-
return v
|
|
326
|
+
base_url: HttpUrl = HttpUrl("https://api.deepseek.com/v1")
|
|
327
|
+
stream: bool = False
|
|
328
|
+
tool_choice: Literal["none", "auto", "required"] | None = None
|
|
298
329
|
|
|
299
330
|
def create_client(self) -> None: # type: ignore [override]
|
|
300
331
|
raise NotImplementedError("DeepSeekLLMConfigEntry.create_client is not implemented.")
|
|
301
332
|
|
|
302
333
|
|
|
303
|
-
@export_module("autogen")
|
|
304
|
-
class ModelClient(Protocol):
|
|
305
|
-
"""A client class must implement the following methods:
|
|
306
|
-
- create must return a response object that implements the ModelClientResponseProtocol
|
|
307
|
-
- cost must return the cost of the response
|
|
308
|
-
- get_usage must return a dict with the following keys:
|
|
309
|
-
- prompt_tokens
|
|
310
|
-
- completion_tokens
|
|
311
|
-
- total_tokens
|
|
312
|
-
- cost
|
|
313
|
-
- model
|
|
314
|
-
|
|
315
|
-
This class is used to create a client that can be used by OpenAIWrapper.
|
|
316
|
-
The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed.
|
|
317
|
-
The message_retrieval method must be implemented to return a list of str or a list of messages from the response.
|
|
318
|
-
"""
|
|
319
|
-
|
|
320
|
-
RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]
|
|
321
|
-
|
|
322
|
-
class ModelClientResponseProtocol(Protocol):
|
|
323
|
-
class Choice(Protocol):
|
|
324
|
-
class Message(Protocol):
|
|
325
|
-
content: Optional[str] | Optional[dict[str, Any]]
|
|
326
|
-
|
|
327
|
-
message: Message
|
|
328
|
-
|
|
329
|
-
choices: list[Choice]
|
|
330
|
-
model: str
|
|
331
|
-
|
|
332
|
-
def create(self, params: dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover
|
|
333
|
-
|
|
334
|
-
def message_retrieval(
|
|
335
|
-
self, response: ModelClientResponseProtocol
|
|
336
|
-
) -> Union[list[str], list[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
|
|
337
|
-
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
|
|
338
|
-
|
|
339
|
-
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
|
|
340
|
-
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
|
|
341
|
-
"""
|
|
342
|
-
... # pragma: no cover
|
|
343
|
-
|
|
344
|
-
def cost(self, response: ModelClientResponseProtocol) -> float: ... # pragma: no cover
|
|
345
|
-
|
|
346
|
-
@staticmethod
|
|
347
|
-
def get_usage(response: ModelClientResponseProtocol) -> dict:
|
|
348
|
-
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
|
|
349
|
-
... # pragma: no cover
|
|
350
|
-
|
|
351
|
-
|
|
352
334
|
class PlaceHolderClient:
|
|
353
335
|
def __init__(self, config):
|
|
354
336
|
self.config = config
|
|
@@ -358,9 +340,7 @@ class PlaceHolderClient:
|
|
|
358
340
|
class OpenAIClient:
|
|
359
341
|
"""Follows the Client protocol and wraps the OpenAI client."""
|
|
360
342
|
|
|
361
|
-
def __init__(
|
|
362
|
-
self, client: Union[OpenAI, AzureOpenAI], response_format: Union[BaseModel, dict[str, Any], None] = None
|
|
363
|
-
):
|
|
343
|
+
def __init__(self, client: OpenAI | AzureOpenAI, response_format: BaseModel | dict[str, Any] | None = None):
|
|
364
344
|
self._oai_client = client
|
|
365
345
|
self.response_format = response_format
|
|
366
346
|
if (
|
|
@@ -372,9 +352,7 @@ class OpenAIClient:
|
|
|
372
352
|
"The API key specified is not a valid OpenAI format; it won't work with the OpenAI-hosted model."
|
|
373
353
|
)
|
|
374
354
|
|
|
375
|
-
def message_retrieval(
|
|
376
|
-
self, response: Union[ChatCompletion, Completion]
|
|
377
|
-
) -> Union[list[str], list[ChatCompletionMessage]]:
|
|
355
|
+
def message_retrieval(self, response: ChatCompletion | Completion) -> list[str] | list[ChatCompletionMessage]:
|
|
378
356
|
"""Retrieve the messages from the response.
|
|
379
357
|
|
|
380
358
|
Args:
|
|
@@ -511,7 +489,10 @@ class OpenAIClient:
|
|
|
511
489
|
if "stream" in kwargs:
|
|
512
490
|
kwargs.pop("stream")
|
|
513
491
|
|
|
514
|
-
if
|
|
492
|
+
if (
|
|
493
|
+
isinstance(kwargs["response_format"], dict)
|
|
494
|
+
and kwargs["response_format"].get("type") != "json_object"
|
|
495
|
+
):
|
|
515
496
|
kwargs["response_format"] = {
|
|
516
497
|
"type": "json_schema",
|
|
517
498
|
"json_schema": {
|
|
@@ -550,8 +531,8 @@ class OpenAIClient:
|
|
|
550
531
|
completion_tokens = 0
|
|
551
532
|
|
|
552
533
|
# Prepare for potential function call
|
|
553
|
-
full_function_call:
|
|
554
|
-
full_tool_calls:
|
|
534
|
+
full_function_call: dict[str, Any] | None = None
|
|
535
|
+
full_tool_calls: list[dict[str, Any] | None] | None = None
|
|
555
536
|
|
|
556
537
|
# Send the chat completion request to OpenAI's API and process the response in chunks
|
|
557
538
|
for chunk in create_or_parse(**params):
|
|
@@ -678,9 +659,9 @@ class OpenAIClient:
|
|
|
678
659
|
# Unsupported parameters
|
|
679
660
|
unsupported_params = [
|
|
680
661
|
"temperature",
|
|
662
|
+
"top_p",
|
|
681
663
|
"frequency_penalty",
|
|
682
664
|
"presence_penalty",
|
|
683
|
-
"top_p",
|
|
684
665
|
"logprobs",
|
|
685
666
|
"top_logprobs",
|
|
686
667
|
"logit_bias",
|
|
@@ -706,7 +687,7 @@ class OpenAIClient:
|
|
|
706
687
|
msg["role"] = "user"
|
|
707
688
|
msg["content"] = f"System message: {msg['content']}"
|
|
708
689
|
|
|
709
|
-
def cost(self, response:
|
|
690
|
+
def cost(self, response: ChatCompletion | Completion) -> float:
|
|
710
691
|
"""Calculate the cost of the response."""
|
|
711
692
|
model = response.model
|
|
712
693
|
if model not in OAI_PRICE1K:
|
|
@@ -727,7 +708,7 @@ class OpenAIClient:
|
|
|
727
708
|
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator]
|
|
728
709
|
|
|
729
710
|
@staticmethod
|
|
730
|
-
def get_usage(response:
|
|
711
|
+
def get_usage(response: ChatCompletion | Completion) -> dict:
|
|
731
712
|
return {
|
|
732
713
|
"prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
|
|
733
714
|
"completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
|
|
@@ -763,13 +744,13 @@ class OpenAIWrapper:
|
|
|
763
744
|
else:
|
|
764
745
|
return OPENAI_FALLBACK_KWARGS | AOPENAI_FALLBACK_KWARGS
|
|
765
746
|
|
|
766
|
-
total_usage_summary:
|
|
767
|
-
actual_usage_summary:
|
|
747
|
+
total_usage_summary: dict[str, Any] | None = None
|
|
748
|
+
actual_usage_summary: dict[str, Any] | None = None
|
|
768
749
|
|
|
769
750
|
def __init__(
|
|
770
751
|
self,
|
|
771
752
|
*,
|
|
772
|
-
config_list:
|
|
753
|
+
config_list: list[dict[str, Any]] | None = None,
|
|
773
754
|
**base_config: Any,
|
|
774
755
|
):
|
|
775
756
|
"""Initialize the OpenAIWrapper.
|
|
@@ -851,17 +832,6 @@ class OpenAIWrapper:
|
|
|
851
832
|
|
|
852
833
|
def _configure_azure_openai(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None:
|
|
853
834
|
openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model"))
|
|
854
|
-
if openai_config["azure_deployment"] is not None:
|
|
855
|
-
# Preserve dots for specific model versions that require them
|
|
856
|
-
deployment_name = openai_config["azure_deployment"]
|
|
857
|
-
if deployment_name in [
|
|
858
|
-
"gpt-4.1"
|
|
859
|
-
]: # Add more as needed, Whitelist approach so as to not break existing deployments
|
|
860
|
-
# Keep the deployment name as-is for these specific models
|
|
861
|
-
pass
|
|
862
|
-
else:
|
|
863
|
-
# Remove dots for all other models (maintain existing behavior)
|
|
864
|
-
openai_config["azure_deployment"] = deployment_name.replace(".", "")
|
|
865
835
|
openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None))
|
|
866
836
|
|
|
867
837
|
# Create a default Azure token provider if requested
|
|
@@ -890,6 +860,13 @@ class OpenAIWrapper:
|
|
|
890
860
|
if key in config:
|
|
891
861
|
openai_config[key] = config[key]
|
|
892
862
|
|
|
863
|
+
def _configure_openai_config_for_gemini(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None:
|
|
864
|
+
"""Update openai_config with additional gemini genai configs."""
|
|
865
|
+
optional_keys = ["proxy"]
|
|
866
|
+
for key in optional_keys:
|
|
867
|
+
if key in config:
|
|
868
|
+
openai_config[key] = config[key]
|
|
869
|
+
|
|
893
870
|
def _register_default_client(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None:
|
|
894
871
|
"""Create a client with the given config to override openai_config,
|
|
895
872
|
after removing extra kwargs.
|
|
@@ -915,7 +892,7 @@ class OpenAIWrapper:
|
|
|
915
892
|
if api_type is not None and api_type.startswith("azure"):
|
|
916
893
|
|
|
917
894
|
@require_optional_import("openai>=1.66.2", "openai")
|
|
918
|
-
def create_azure_openai_client() ->
|
|
895
|
+
def create_azure_openai_client() -> AzureOpenAI:
|
|
919
896
|
self._configure_azure_openai(config, openai_config)
|
|
920
897
|
client = AzureOpenAI(**openai_config)
|
|
921
898
|
self._clients.append(OpenAIClient(client, response_format=response_format))
|
|
@@ -930,6 +907,7 @@ class OpenAIWrapper:
|
|
|
930
907
|
elif api_type is not None and api_type.startswith("google"):
|
|
931
908
|
if gemini_import_exception:
|
|
932
909
|
raise ImportError("Please install `google-genai` and 'vertexai' to use Google's API.")
|
|
910
|
+
self._configure_openai_config_for_gemini(config, openai_config)
|
|
933
911
|
client = GeminiClient(response_format=response_format, **openai_config)
|
|
934
912
|
self._clients.append(client)
|
|
935
913
|
elif api_type is not None and api_type.startswith("anthropic"):
|
|
@@ -975,7 +953,7 @@ class OpenAIWrapper:
|
|
|
975
953
|
elif api_type is not None and api_type.startswith("responses"):
|
|
976
954
|
# OpenAI Responses API (stateful). Reuse the same OpenAI SDK but call the `/responses` endpoint via the new client.
|
|
977
955
|
@require_optional_import("openai>=1.66.2", "openai")
|
|
978
|
-
def create_responses_client() ->
|
|
956
|
+
def create_responses_client() -> OpenAI:
|
|
979
957
|
client = OpenAI(**openai_config)
|
|
980
958
|
self._clients.append(OpenAIResponsesClient(client, response_format=response_format))
|
|
981
959
|
return client
|
|
@@ -984,7 +962,7 @@ class OpenAIWrapper:
|
|
|
984
962
|
else:
|
|
985
963
|
|
|
986
964
|
@require_optional_import("openai>=1.66.2", "openai")
|
|
987
|
-
def create_openai_client() ->
|
|
965
|
+
def create_openai_client() -> OpenAI:
|
|
988
966
|
client = OpenAI(**openai_config)
|
|
989
967
|
self._clients.append(OpenAIClient(client, response_format))
|
|
990
968
|
return client
|
|
@@ -1025,10 +1003,10 @@ class OpenAIWrapper:
|
|
|
1025
1003
|
@classmethod
|
|
1026
1004
|
def instantiate(
|
|
1027
1005
|
cls,
|
|
1028
|
-
template:
|
|
1029
|
-
context:
|
|
1030
|
-
allow_format_str_template:
|
|
1031
|
-
) ->
|
|
1006
|
+
template: str | Callable[[dict[str, Any]], str] | None,
|
|
1007
|
+
context: dict[str, Any] | None = None,
|
|
1008
|
+
allow_format_str_template: bool | None = False,
|
|
1009
|
+
) -> str | None:
|
|
1032
1010
|
if not context or template is None:
|
|
1033
1011
|
return template # type: ignore [return-value]
|
|
1034
1012
|
if isinstance(template, str):
|
|
@@ -1038,8 +1016,8 @@ class OpenAIWrapper:
|
|
|
1038
1016
|
def _construct_create_params(self, create_config: dict[str, Any], extra_kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
1039
1017
|
"""Prime the create_config with additional_kwargs."""
|
|
1040
1018
|
# Validate the config
|
|
1041
|
-
prompt:
|
|
1042
|
-
messages:
|
|
1019
|
+
prompt: str | None = create_config.get("prompt")
|
|
1020
|
+
messages: list[dict[str, Any]] | None = create_config.get("messages")
|
|
1043
1021
|
if (prompt is None) == (messages is None):
|
|
1044
1022
|
raise ValueError("Either prompt or messages should be in create config but not both.")
|
|
1045
1023
|
context = extra_kwargs.get("context")
|
|
@@ -1106,9 +1084,6 @@ class OpenAIWrapper:
|
|
|
1106
1084
|
full_config = {**config, **self._config_list[i]}
|
|
1107
1085
|
# separate the config into create_config and extra_kwargs
|
|
1108
1086
|
create_config, extra_kwargs = self._separate_create_config(full_config)
|
|
1109
|
-
api_type = extra_kwargs.get("api_type")
|
|
1110
|
-
if api_type and api_type.startswith("azure") and "model" in create_config:
|
|
1111
|
-
create_config["model"] = create_config["model"].replace(".", "")
|
|
1112
1087
|
# construct the create params
|
|
1113
1088
|
params = self._construct_create_params(create_config, extra_kwargs)
|
|
1114
1089
|
# get the cache_seed, filter_func and context
|
|
@@ -1336,8 +1311,8 @@ class OpenAIWrapper:
|
|
|
1336
1311
|
|
|
1337
1312
|
@staticmethod
|
|
1338
1313
|
def _update_function_call_from_chunk(
|
|
1339
|
-
function_call_chunk:
|
|
1340
|
-
full_function_call:
|
|
1314
|
+
function_call_chunk: ChoiceDeltaToolCallFunction | ChoiceDeltaFunctionCall,
|
|
1315
|
+
full_function_call: dict[str, Any] | None,
|
|
1341
1316
|
completion_tokens: int,
|
|
1342
1317
|
) -> tuple[dict[str, Any], int]:
|
|
1343
1318
|
"""Update the function call from the chunk.
|
|
@@ -1368,7 +1343,7 @@ class OpenAIWrapper:
|
|
|
1368
1343
|
@staticmethod
|
|
1369
1344
|
def _update_tool_calls_from_chunk(
|
|
1370
1345
|
tool_calls_chunk: ChoiceDeltaToolCall,
|
|
1371
|
-
full_tool_call:
|
|
1346
|
+
full_tool_call: dict[str, Any] | None,
|
|
1372
1347
|
completion_tokens: int,
|
|
1373
1348
|
) -> tuple[dict[str, Any], int]:
|
|
1374
1349
|
"""Update the tool call from the chunk.
|
|
@@ -1442,7 +1417,7 @@ class OpenAIWrapper:
|
|
|
1442
1417
|
if actual_usage is not None:
|
|
1443
1418
|
self.actual_usage_summary = update_usage(self.actual_usage_summary, actual_usage)
|
|
1444
1419
|
|
|
1445
|
-
def print_usage_summary(self, mode:
|
|
1420
|
+
def print_usage_summary(self, mode: str | list[str] = ["actual", "total"]) -> None:
|
|
1446
1421
|
"""Print the usage summary."""
|
|
1447
1422
|
iostream = IOStream.get_default()
|
|
1448
1423
|
|
|
@@ -1470,7 +1445,7 @@ class OpenAIWrapper:
|
|
|
1470
1445
|
@classmethod
|
|
1471
1446
|
def extract_text_or_completion_object(
|
|
1472
1447
|
cls, response: ModelClient.ModelClientResponseProtocol
|
|
1473
|
-
) ->
|
|
1448
|
+
) -> list[str] | list[ModelClient.ModelClientResponseProtocol.Choice.Message]:
|
|
1474
1449
|
"""Extract the text or ChatCompletion objects from a completion or chat response.
|
|
1475
1450
|
|
|
1476
1451
|
Args:
|
|
@@ -1487,7 +1462,6 @@ class OpenAIWrapper:
|
|
|
1487
1462
|
# -----------------------------------------------------------------------------
|
|
1488
1463
|
|
|
1489
1464
|
|
|
1490
|
-
@register_llm_config
|
|
1491
1465
|
class OpenAIResponsesLLMConfigEntry(OpenAILLMConfigEntry):
|
|
1492
1466
|
"""LLMConfig entry for the OpenAI Responses API (stateful, tool-enabled).
|
|
1493
1467
|
|
|
@@ -1507,8 +1481,8 @@ class OpenAIResponsesLLMConfigEntry(OpenAILLMConfigEntry):
|
|
|
1507
1481
|
"""
|
|
1508
1482
|
|
|
1509
1483
|
api_type: Literal["responses"] = "responses"
|
|
1510
|
-
tool_choice:
|
|
1511
|
-
built_in_tools:
|
|
1484
|
+
tool_choice: Literal["none", "auto", "required"] | None = "auto"
|
|
1485
|
+
built_in_tools: list[str] | None = None
|
|
1512
1486
|
|
|
1513
|
-
def create_client(self) ->
|
|
1487
|
+
def create_client(self) -> ModelClient: # pragma: no cover
|
|
1514
1488
|
raise NotImplementedError("Handled via OpenAIWrapper._register_default_client")
|
autogen/oai/client_utils.py
CHANGED
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
import logging
|
|
10
10
|
import warnings
|
|
11
|
-
from typing import Any,
|
|
11
|
+
from typing import Any, Protocol, runtime_checkable
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
@runtime_checkable
|
|
@@ -24,8 +24,8 @@ def validate_parameter(
|
|
|
24
24
|
allowed_types: tuple[Any, ...],
|
|
25
25
|
allow_None: bool, # noqa: N803
|
|
26
26
|
default_value: Any,
|
|
27
|
-
numerical_bound:
|
|
28
|
-
allowed_values:
|
|
27
|
+
numerical_bound: tuple[float | None, float | None] | None,
|
|
28
|
+
allowed_values: list[Any] | None,
|
|
29
29
|
) -> Any:
|
|
30
30
|
"""Validates a given config parameter, checking its type, values, and setting defaults
|
|
31
31
|
Parameters:
|
autogen/oai/cohere.py
CHANGED
|
@@ -34,14 +34,15 @@ import os
|
|
|
34
34
|
import sys
|
|
35
35
|
import time
|
|
36
36
|
import warnings
|
|
37
|
-
from typing import Any, Literal
|
|
37
|
+
from typing import Any, Literal
|
|
38
38
|
|
|
39
39
|
from pydantic import BaseModel, Field
|
|
40
|
+
from typing_extensions import Unpack
|
|
40
41
|
|
|
41
42
|
from autogen.oai.client_utils import FormatterProtocol, logging_formatter, validate_parameter
|
|
42
43
|
|
|
43
44
|
from ..import_utils import optional_import_block, require_optional_import
|
|
44
|
-
from ..llm_config import LLMConfigEntry,
|
|
45
|
+
from ..llm_config.entry import LLMConfigEntry, LLMConfigEntryDict
|
|
45
46
|
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
|
|
46
47
|
|
|
47
48
|
with optional_import_block():
|
|
@@ -66,20 +67,30 @@ COHERE_PRICING_1K = {
|
|
|
66
67
|
}
|
|
67
68
|
|
|
68
69
|
|
|
69
|
-
|
|
70
|
+
class CohereEntryDict(LLMConfigEntryDict, total=False):
|
|
71
|
+
api_type: Literal["cohere"]
|
|
72
|
+
|
|
73
|
+
k: int
|
|
74
|
+
seed: int | None
|
|
75
|
+
frequency_penalty: float
|
|
76
|
+
presence_penalty: float
|
|
77
|
+
client_name: str | None
|
|
78
|
+
strict_tools: bool
|
|
79
|
+
stream: bool
|
|
80
|
+
tool_choice: Literal["NONE", "REQUIRED"] | None
|
|
81
|
+
|
|
82
|
+
|
|
70
83
|
class CohereLLMConfigEntry(LLMConfigEntry):
|
|
71
84
|
api_type: Literal["cohere"] = "cohere"
|
|
72
|
-
|
|
73
|
-
max_tokens: Optional[int] = Field(default=None, ge=0)
|
|
85
|
+
|
|
74
86
|
k: int = Field(default=0, ge=0, le=500)
|
|
75
|
-
|
|
76
|
-
seed: Optional[int] = None
|
|
87
|
+
seed: int | None = None
|
|
77
88
|
frequency_penalty: float = Field(default=0, ge=0, le=1)
|
|
78
89
|
presence_penalty: float = Field(default=0, ge=0, le=1)
|
|
79
|
-
client_name:
|
|
90
|
+
client_name: str | None = None
|
|
80
91
|
strict_tools: bool = False
|
|
81
92
|
stream: bool = False
|
|
82
|
-
tool_choice:
|
|
93
|
+
tool_choice: Literal["NONE", "REQUIRED"] | None = None
|
|
83
94
|
|
|
84
95
|
def create_client(self):
|
|
85
96
|
raise NotImplementedError("CohereLLMConfigEntry.create_client is not implemented.")
|
|
@@ -88,7 +99,7 @@ class CohereLLMConfigEntry(LLMConfigEntry):
|
|
|
88
99
|
class CohereClient:
|
|
89
100
|
"""Client for Cohere's API."""
|
|
90
101
|
|
|
91
|
-
def __init__(self, **kwargs):
|
|
102
|
+
def __init__(self, **kwargs: Unpack[CohereEntryDict]):
|
|
92
103
|
"""Requires api_key or environment variable to be set
|
|
93
104
|
|
|
94
105
|
Args:
|
|
@@ -104,7 +115,7 @@ class CohereClient:
|
|
|
104
115
|
)
|
|
105
116
|
|
|
106
117
|
# Store the response format, if provided (for structured outputs)
|
|
107
|
-
self._response_format:
|
|
118
|
+
self._response_format: type[BaseModel] | None = None
|
|
108
119
|
|
|
109
120
|
def message_retrieval(self, response) -> list:
|
|
110
121
|
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
|
|
@@ -203,7 +214,17 @@ class CohereClient:
|
|
|
203
214
|
if "k" in params:
|
|
204
215
|
cohere_params["k"] = validate_parameter(params, "k", int, False, 0, (0, 500), None)
|
|
205
216
|
|
|
217
|
+
if "top_p" in params:
|
|
218
|
+
cohere_params["p"] = validate_parameter(params, "top_p", (int, float), False, 0.75, (0.01, 0.99), None)
|
|
219
|
+
|
|
206
220
|
if "p" in params:
|
|
221
|
+
warnings.warn(
|
|
222
|
+
(
|
|
223
|
+
"parameter 'p' is deprecated, use 'top_p' instead for consistency with OpenAI API spec. "
|
|
224
|
+
"Scheduled for removal in 0.10.0 version."
|
|
225
|
+
),
|
|
226
|
+
DeprecationWarning,
|
|
227
|
+
)
|
|
207
228
|
cohere_params["p"] = validate_parameter(params, "p", (int, float), False, 0.75, (0.01, 0.99), None)
|
|
208
229
|
|
|
209
230
|
if "seed" in params:
|
|
@@ -402,8 +423,10 @@ class CohereClient:
|
|
|
402
423
|
|
|
403
424
|
def _convert_json_response(self, response: str) -> Any:
|
|
404
425
|
"""Extract and validate JSON response from the output for structured outputs.
|
|
426
|
+
|
|
405
427
|
Args:
|
|
406
428
|
response (str): The response from the API.
|
|
429
|
+
|
|
407
430
|
Returns:
|
|
408
431
|
Any: The parsed JSON response.
|
|
409
432
|
"""
|