ag2 0.9.6__py3-none-any.whl → 0.9.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ag2 might be problematic. Click here for more details.
- {ag2-0.9.6.dist-info → ag2-0.9.8.post1.dist-info}/METADATA +102 -75
- ag2-0.9.8.post1.dist-info/RECORD +387 -0
- autogen/__init__.py +1 -2
- autogen/_website/generate_api_references.py +4 -5
- autogen/_website/generate_mkdocs.py +9 -15
- autogen/_website/notebook_processor.py +13 -14
- autogen/_website/process_notebooks.py +10 -10
- autogen/_website/utils.py +5 -4
- autogen/agentchat/agent.py +13 -13
- autogen/agentchat/assistant_agent.py +7 -6
- autogen/agentchat/contrib/agent_eval/agent_eval.py +3 -3
- autogen/agentchat/contrib/agent_eval/critic_agent.py +3 -3
- autogen/agentchat/contrib/agent_eval/quantifier_agent.py +3 -3
- autogen/agentchat/contrib/agent_eval/subcritic_agent.py +3 -3
- autogen/agentchat/contrib/agent_optimizer.py +3 -3
- autogen/agentchat/contrib/capabilities/generate_images.py +11 -11
- autogen/agentchat/contrib/capabilities/teachability.py +15 -15
- autogen/agentchat/contrib/capabilities/transforms.py +17 -18
- autogen/agentchat/contrib/capabilities/transforms_util.py +5 -5
- autogen/agentchat/contrib/capabilities/vision_capability.py +4 -3
- autogen/agentchat/contrib/captainagent/agent_builder.py +30 -30
- autogen/agentchat/contrib/captainagent/captainagent.py +22 -21
- autogen/agentchat/contrib/captainagent/tool_retriever.py +2 -3
- autogen/agentchat/contrib/gpt_assistant_agent.py +9 -9
- autogen/agentchat/contrib/graph_rag/document.py +3 -3
- autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +3 -3
- autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py +6 -6
- autogen/agentchat/contrib/graph_rag/graph_query_engine.py +3 -3
- autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py +5 -11
- autogen/agentchat/contrib/graph_rag/neo4j_graph_rag_capability.py +6 -6
- autogen/agentchat/contrib/graph_rag/neo4j_native_graph_query_engine.py +7 -7
- autogen/agentchat/contrib/graph_rag/neo4j_native_graph_rag_capability.py +6 -6
- autogen/agentchat/contrib/img_utils.py +1 -1
- autogen/agentchat/contrib/llamaindex_conversable_agent.py +11 -11
- autogen/agentchat/contrib/llava_agent.py +18 -4
- autogen/agentchat/contrib/math_user_proxy_agent.py +11 -11
- autogen/agentchat/contrib/multimodal_conversable_agent.py +8 -8
- autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +6 -5
- autogen/agentchat/contrib/rag/chromadb_query_engine.py +22 -26
- autogen/agentchat/contrib/rag/llamaindex_query_engine.py +14 -17
- autogen/agentchat/contrib/rag/mongodb_query_engine.py +27 -37
- autogen/agentchat/contrib/rag/query_engine.py +7 -5
- autogen/agentchat/contrib/retrieve_assistant_agent.py +5 -5
- autogen/agentchat/contrib/retrieve_user_proxy_agent.py +8 -7
- autogen/agentchat/contrib/society_of_mind_agent.py +15 -14
- autogen/agentchat/contrib/swarm_agent.py +76 -98
- autogen/agentchat/contrib/text_analyzer_agent.py +7 -7
- autogen/agentchat/contrib/vectordb/base.py +10 -18
- autogen/agentchat/contrib/vectordb/chromadb.py +2 -1
- autogen/agentchat/contrib/vectordb/couchbase.py +18 -20
- autogen/agentchat/contrib/vectordb/mongodb.py +6 -5
- autogen/agentchat/contrib/vectordb/pgvectordb.py +40 -41
- autogen/agentchat/contrib/vectordb/qdrant.py +5 -5
- autogen/agentchat/contrib/web_surfer.py +20 -19
- autogen/agentchat/conversable_agent.py +311 -295
- autogen/agentchat/group/context_str.py +1 -3
- autogen/agentchat/group/context_variables.py +15 -25
- autogen/agentchat/group/group_tool_executor.py +10 -10
- autogen/agentchat/group/group_utils.py +15 -15
- autogen/agentchat/group/guardrails.py +7 -7
- autogen/agentchat/group/handoffs.py +19 -36
- autogen/agentchat/group/multi_agent_chat.py +7 -7
- autogen/agentchat/group/on_condition.py +4 -7
- autogen/agentchat/group/on_context_condition.py +4 -7
- autogen/agentchat/group/patterns/auto.py +8 -7
- autogen/agentchat/group/patterns/manual.py +7 -6
- autogen/agentchat/group/patterns/pattern.py +13 -12
- autogen/agentchat/group/patterns/random.py +3 -3
- autogen/agentchat/group/patterns/round_robin.py +3 -3
- autogen/agentchat/group/reply_result.py +2 -4
- autogen/agentchat/group/speaker_selection_result.py +5 -5
- autogen/agentchat/group/targets/group_chat_target.py +7 -6
- autogen/agentchat/group/targets/group_manager_target.py +4 -4
- autogen/agentchat/group/targets/transition_target.py +2 -1
- autogen/agentchat/groupchat.py +58 -61
- autogen/agentchat/realtime/experimental/audio_adapters/twilio_audio_adapter.py +4 -4
- autogen/agentchat/realtime/experimental/audio_adapters/websocket_audio_adapter.py +4 -4
- autogen/agentchat/realtime/experimental/clients/gemini/client.py +7 -7
- autogen/agentchat/realtime/experimental/clients/oai/base_client.py +8 -8
- autogen/agentchat/realtime/experimental/clients/oai/rtc_client.py +6 -6
- autogen/agentchat/realtime/experimental/clients/realtime_client.py +10 -9
- autogen/agentchat/realtime/experimental/realtime_agent.py +10 -9
- autogen/agentchat/realtime/experimental/realtime_observer.py +3 -3
- autogen/agentchat/realtime/experimental/realtime_swarm.py +44 -44
- autogen/agentchat/user_proxy_agent.py +10 -9
- autogen/agentchat/utils.py +3 -3
- autogen/agents/contrib/time/time_reply_agent.py +6 -5
- autogen/agents/contrib/time/time_tool_agent.py +2 -1
- autogen/agents/experimental/deep_research/deep_research.py +3 -3
- autogen/agents/experimental/discord/discord.py +2 -2
- autogen/agents/experimental/document_agent/chroma_query_engine.py +29 -44
- autogen/agents/experimental/document_agent/docling_doc_ingest_agent.py +9 -14
- autogen/agents/experimental/document_agent/document_agent.py +15 -16
- autogen/agents/experimental/document_agent/document_conditions.py +3 -3
- autogen/agents/experimental/document_agent/document_utils.py +5 -9
- autogen/agents/experimental/document_agent/inmemory_query_engine.py +14 -20
- autogen/agents/experimental/document_agent/parser_utils.py +4 -4
- autogen/agents/experimental/document_agent/url_utils.py +14 -23
- autogen/agents/experimental/reasoning/reasoning_agent.py +33 -33
- autogen/agents/experimental/slack/slack.py +2 -2
- autogen/agents/experimental/telegram/telegram.py +2 -3
- autogen/agents/experimental/websurfer/websurfer.py +4 -4
- autogen/agents/experimental/wikipedia/wikipedia.py +5 -7
- autogen/browser_utils.py +8 -8
- autogen/cache/abstract_cache_base.py +5 -5
- autogen/cache/cache.py +12 -12
- autogen/cache/cache_factory.py +4 -4
- autogen/cache/cosmos_db_cache.py +9 -9
- autogen/cache/disk_cache.py +6 -6
- autogen/cache/in_memory_cache.py +4 -4
- autogen/cache/redis_cache.py +4 -4
- autogen/code_utils.py +18 -18
- autogen/coding/base.py +6 -6
- autogen/coding/docker_commandline_code_executor.py +9 -9
- autogen/coding/func_with_reqs.py +7 -6
- autogen/coding/jupyter/base.py +3 -3
- autogen/coding/jupyter/docker_jupyter_server.py +3 -4
- autogen/coding/jupyter/import_utils.py +3 -3
- autogen/coding/jupyter/jupyter_client.py +5 -5
- autogen/coding/jupyter/jupyter_code_executor.py +3 -4
- autogen/coding/jupyter/local_jupyter_server.py +2 -6
- autogen/coding/local_commandline_code_executor.py +8 -7
- autogen/coding/markdown_code_extractor.py +1 -2
- autogen/coding/utils.py +1 -2
- autogen/doc_utils.py +3 -2
- autogen/environments/docker_python_environment.py +19 -29
- autogen/environments/python_environment.py +8 -17
- autogen/environments/system_python_environment.py +3 -4
- autogen/environments/venv_python_environment.py +8 -12
- autogen/environments/working_directory.py +1 -2
- autogen/events/agent_events.py +106 -109
- autogen/events/base_event.py +6 -5
- autogen/events/client_events.py +15 -14
- autogen/events/helpers.py +1 -1
- autogen/events/print_event.py +4 -5
- autogen/fast_depends/_compat.py +10 -15
- autogen/fast_depends/core/build.py +17 -36
- autogen/fast_depends/core/model.py +64 -113
- autogen/fast_depends/dependencies/model.py +2 -1
- autogen/fast_depends/dependencies/provider.py +3 -2
- autogen/fast_depends/library/model.py +4 -4
- autogen/fast_depends/schema.py +7 -7
- autogen/fast_depends/use.py +17 -25
- autogen/fast_depends/utils.py +10 -30
- autogen/formatting_utils.py +6 -6
- autogen/graph_utils.py +1 -4
- autogen/import_utils.py +13 -13
- autogen/interop/crewai/crewai.py +2 -2
- autogen/interop/interoperable.py +2 -2
- autogen/interop/langchain/langchain_chat_model_factory.py +3 -2
- autogen/interop/langchain/langchain_tool.py +2 -6
- autogen/interop/litellm/litellm_config_factory.py +6 -7
- autogen/interop/pydantic_ai/pydantic_ai.py +4 -7
- autogen/interop/registry.py +2 -1
- autogen/io/base.py +5 -5
- autogen/io/run_response.py +33 -32
- autogen/io/websockets.py +6 -5
- autogen/json_utils.py +1 -2
- autogen/llm_config/__init__.py +11 -0
- autogen/llm_config/client.py +58 -0
- autogen/llm_config/config.py +384 -0
- autogen/llm_config/entry.py +154 -0
- autogen/logger/base_logger.py +4 -3
- autogen/logger/file_logger.py +2 -1
- autogen/logger/logger_factory.py +2 -2
- autogen/logger/logger_utils.py +2 -2
- autogen/logger/sqlite_logger.py +3 -2
- autogen/math_utils.py +4 -5
- autogen/mcp/__main__.py +6 -6
- autogen/mcp/helpers.py +4 -4
- autogen/mcp/mcp_client.py +170 -29
- autogen/mcp/mcp_proxy/fastapi_code_generator_helpers.py +3 -4
- autogen/mcp/mcp_proxy/mcp_proxy.py +23 -26
- autogen/mcp/mcp_proxy/operation_grouping.py +4 -5
- autogen/mcp/mcp_proxy/operation_renaming.py +6 -10
- autogen/mcp/mcp_proxy/security.py +2 -3
- autogen/messages/agent_messages.py +96 -98
- autogen/messages/base_message.py +6 -5
- autogen/messages/client_messages.py +15 -14
- autogen/messages/print_message.py +4 -5
- autogen/oai/__init__.py +1 -2
- autogen/oai/anthropic.py +42 -41
- autogen/oai/bedrock.py +68 -57
- autogen/oai/cerebras.py +26 -25
- autogen/oai/client.py +118 -138
- autogen/oai/client_utils.py +3 -3
- autogen/oai/cohere.py +34 -11
- autogen/oai/gemini.py +40 -17
- autogen/oai/gemini_types.py +11 -12
- autogen/oai/groq.py +22 -10
- autogen/oai/mistral.py +17 -11
- autogen/oai/oai_models/__init__.py +14 -2
- autogen/oai/oai_models/_models.py +2 -2
- autogen/oai/oai_models/chat_completion.py +13 -14
- autogen/oai/oai_models/chat_completion_message.py +11 -9
- autogen/oai/oai_models/chat_completion_message_tool_call.py +26 -3
- autogen/oai/oai_models/chat_completion_token_logprob.py +3 -4
- autogen/oai/oai_models/completion_usage.py +8 -9
- autogen/oai/ollama.py +22 -10
- autogen/oai/openai_responses.py +40 -17
- autogen/oai/openai_utils.py +159 -85
- autogen/oai/together.py +29 -14
- autogen/retrieve_utils.py +6 -7
- autogen/runtime_logging.py +5 -4
- autogen/token_count_utils.py +7 -4
- autogen/tools/contrib/time/time.py +0 -1
- autogen/tools/dependency_injection.py +5 -6
- autogen/tools/experimental/browser_use/browser_use.py +10 -10
- autogen/tools/experimental/code_execution/python_code_execution.py +5 -7
- autogen/tools/experimental/crawl4ai/crawl4ai.py +12 -15
- autogen/tools/experimental/deep_research/deep_research.py +9 -8
- autogen/tools/experimental/duckduckgo/duckduckgo_search.py +5 -11
- autogen/tools/experimental/firecrawl/firecrawl_tool.py +98 -115
- autogen/tools/experimental/google/authentication/credentials_local_provider.py +1 -1
- autogen/tools/experimental/google/drive/drive_functions.py +4 -4
- autogen/tools/experimental/google/drive/toolkit.py +5 -5
- autogen/tools/experimental/google_search/google_search.py +5 -5
- autogen/tools/experimental/google_search/youtube_search.py +5 -5
- autogen/tools/experimental/messageplatform/discord/discord.py +8 -12
- autogen/tools/experimental/messageplatform/slack/slack.py +14 -20
- autogen/tools/experimental/messageplatform/telegram/telegram.py +8 -12
- autogen/tools/experimental/perplexity/perplexity_search.py +18 -29
- autogen/tools/experimental/reliable/reliable.py +68 -74
- autogen/tools/experimental/searxng/searxng_search.py +20 -19
- autogen/tools/experimental/tavily/tavily_search.py +12 -19
- autogen/tools/experimental/web_search_preview/web_search_preview.py +13 -7
- autogen/tools/experimental/wikipedia/wikipedia.py +7 -10
- autogen/tools/function_utils.py +7 -7
- autogen/tools/tool.py +6 -5
- autogen/types.py +2 -2
- autogen/version.py +1 -1
- ag2-0.9.6.dist-info/RECORD +0 -421
- autogen/llm_config.py +0 -385
- {ag2-0.9.6.dist-info → ag2-0.9.8.post1.dist-info}/WHEEL +0 -0
- {ag2-0.9.6.dist-info → ag2-0.9.8.post1.dist-info}/licenses/LICENSE +0 -0
- {ag2-0.9.6.dist-info → ag2-0.9.8.post1.dist-info}/licenses/NOTICE.md +0 -0
autogen/oai/client.py
CHANGED
|
@@ -13,10 +13,11 @@ import re
|
|
|
13
13
|
import sys
|
|
14
14
|
import uuid
|
|
15
15
|
import warnings
|
|
16
|
+
from collections.abc import Callable
|
|
16
17
|
from functools import lru_cache
|
|
17
|
-
from typing import Any,
|
|
18
|
+
from typing import Any, Literal
|
|
18
19
|
|
|
19
|
-
from pydantic import BaseModel, Field, HttpUrl
|
|
20
|
+
from pydantic import BaseModel, Field, HttpUrl
|
|
20
21
|
from pydantic.type_adapter import TypeAdapter
|
|
21
22
|
|
|
22
23
|
from ..cache import Cache
|
|
@@ -25,7 +26,8 @@ from ..events.client_events import StreamEvent, UsageSummaryEvent
|
|
|
25
26
|
from ..exception_utils import ModelToolNotSupportedError
|
|
26
27
|
from ..import_utils import optional_import_block, require_optional_import
|
|
27
28
|
from ..io.base import IOStream
|
|
28
|
-
from ..llm_config import
|
|
29
|
+
from ..llm_config import ModelClient
|
|
30
|
+
from ..llm_config.entry import LLMConfigEntry, LLMConfigEntryDict
|
|
29
31
|
from ..logger.logger_utils import get_current_ts
|
|
30
32
|
from ..runtime_logging import log_chat_completion, log_new_client, log_new_wrapper, logging_enabled
|
|
31
33
|
from ..token_count_utils import count_token
|
|
@@ -58,7 +60,7 @@ if openai_result.is_successful:
|
|
|
58
60
|
ERROR = None
|
|
59
61
|
from openai.lib._pydantic import _ensure_strict_json_schema
|
|
60
62
|
else:
|
|
61
|
-
ERROR:
|
|
63
|
+
ERROR: ImportError | None = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
|
|
62
64
|
|
|
63
65
|
# OpenAI = object
|
|
64
66
|
# AzureOpenAI = object
|
|
@@ -73,7 +75,7 @@ with optional_import_block() as cerebras_result:
|
|
|
73
75
|
from .cerebras import CerebrasClient
|
|
74
76
|
|
|
75
77
|
if cerebras_result.is_successful:
|
|
76
|
-
cerebras_import_exception:
|
|
78
|
+
cerebras_import_exception: ImportError | None = None
|
|
77
79
|
else:
|
|
78
80
|
cerebras_AuthenticationError = cerebras_InternalServerError = cerebras_RateLimitError = Exception # noqa: N816
|
|
79
81
|
cerebras_import_exception = ImportError("cerebras_cloud_sdk not found")
|
|
@@ -87,7 +89,7 @@ with optional_import_block() as gemini_result:
|
|
|
87
89
|
from .gemini import GeminiClient
|
|
88
90
|
|
|
89
91
|
if gemini_result.is_successful:
|
|
90
|
-
gemini_import_exception:
|
|
92
|
+
gemini_import_exception: ImportError | None = None
|
|
91
93
|
else:
|
|
92
94
|
gemini_InternalServerError = gemini_ResourceExhausted = Exception # noqa: N816
|
|
93
95
|
gemini_import_exception = ImportError("google-genai not found")
|
|
@@ -101,7 +103,7 @@ with optional_import_block() as anthropic_result:
|
|
|
101
103
|
from .anthropic import AnthropicClient
|
|
102
104
|
|
|
103
105
|
if anthropic_result.is_successful:
|
|
104
|
-
anthropic_import_exception:
|
|
106
|
+
anthropic_import_exception: ImportError | None = None
|
|
105
107
|
else:
|
|
106
108
|
anthorpic_InternalServerError = anthorpic_RateLimitError = Exception # noqa: N816
|
|
107
109
|
anthropic_import_exception = ImportError("anthropic not found")
|
|
@@ -115,7 +117,7 @@ with optional_import_block() as mistral_result:
|
|
|
115
117
|
from .mistral import MistralAIClient
|
|
116
118
|
|
|
117
119
|
if mistral_result.is_successful:
|
|
118
|
-
mistral_import_exception:
|
|
120
|
+
mistral_import_exception: ImportError | None = None
|
|
119
121
|
else:
|
|
120
122
|
mistral_SDKError = mistral_HTTPValidationError = Exception # noqa: N816
|
|
121
123
|
mistral_import_exception = ImportError("mistralai not found")
|
|
@@ -126,7 +128,7 @@ with optional_import_block() as together_result:
|
|
|
126
128
|
from .together import TogetherClient
|
|
127
129
|
|
|
128
130
|
if together_result.is_successful:
|
|
129
|
-
together_import_exception:
|
|
131
|
+
together_import_exception: ImportError | None = None
|
|
130
132
|
else:
|
|
131
133
|
together_TogetherException = Exception # noqa: N816
|
|
132
134
|
together_import_exception = ImportError("together not found")
|
|
@@ -141,7 +143,7 @@ with optional_import_block() as groq_result:
|
|
|
141
143
|
from .groq import GroqClient
|
|
142
144
|
|
|
143
145
|
if groq_result.is_successful:
|
|
144
|
-
groq_import_exception:
|
|
146
|
+
groq_import_exception: ImportError | None = None
|
|
145
147
|
else:
|
|
146
148
|
groq_InternalServerError = groq_RateLimitError = groq_APIConnectionError = Exception # noqa: N816
|
|
147
149
|
groq_import_exception = ImportError("groq not found")
|
|
@@ -156,7 +158,7 @@ with optional_import_block() as cohere_result:
|
|
|
156
158
|
from .cohere import CohereClient
|
|
157
159
|
|
|
158
160
|
if cohere_result.is_successful:
|
|
159
|
-
cohere_import_exception:
|
|
161
|
+
cohere_import_exception: ImportError | None = None
|
|
160
162
|
else:
|
|
161
163
|
cohere_InternalServerError = cohere_TooManyRequestsError = cohere_ServiceUnavailableError = Exception # noqa: N816
|
|
162
164
|
cohere_import_exception = ImportError("cohere not found")
|
|
@@ -170,7 +172,7 @@ with optional_import_block() as ollama_result:
|
|
|
170
172
|
from .ollama import OllamaClient
|
|
171
173
|
|
|
172
174
|
if ollama_result.is_successful:
|
|
173
|
-
ollama_import_exception:
|
|
175
|
+
ollama_import_exception: ImportError | None = None
|
|
174
176
|
else:
|
|
175
177
|
ollama_RequestError = ollama_ResponseError = Exception # noqa: N816
|
|
176
178
|
ollama_import_exception = ImportError("ollama not found")
|
|
@@ -184,7 +186,7 @@ with optional_import_block() as bedrock_result:
|
|
|
184
186
|
from .bedrock import BedrockClient
|
|
185
187
|
|
|
186
188
|
if bedrock_result.is_successful:
|
|
187
|
-
bedrock_import_exception:
|
|
189
|
+
bedrock_import_exception: ImportError | None = None
|
|
188
190
|
else:
|
|
189
191
|
bedrock_BotoCoreError = bedrock_ClientError = Exception # noqa: N816
|
|
190
192
|
bedrock_import_exception = ImportError("botocore not found")
|
|
@@ -212,6 +214,7 @@ OPENAI_FALLBACK_KWARGS = {
|
|
|
212
214
|
"default_query",
|
|
213
215
|
"http_client",
|
|
214
216
|
"_strict_response_validation",
|
|
217
|
+
"webhook_secret",
|
|
215
218
|
}
|
|
216
219
|
|
|
217
220
|
AOPENAI_FALLBACK_KWARGS = {
|
|
@@ -231,118 +234,103 @@ AOPENAI_FALLBACK_KWARGS = {
|
|
|
231
234
|
"_strict_response_validation",
|
|
232
235
|
"base_url",
|
|
233
236
|
"project",
|
|
237
|
+
"webhook_secret",
|
|
234
238
|
}
|
|
235
239
|
|
|
236
240
|
|
|
237
241
|
@lru_cache(maxsize=128)
|
|
238
|
-
def log_cache_seed_value(cache_seed_value:
|
|
242
|
+
def log_cache_seed_value(cache_seed_value: str | int, client: ModelClient) -> None:
|
|
239
243
|
logger.debug(f"Using cache with seed value {cache_seed_value} for client {client.__class__.__name__}")
|
|
240
244
|
|
|
241
245
|
|
|
242
|
-
|
|
246
|
+
class OpenAIEntryDict(LLMConfigEntryDict, total=False):
|
|
247
|
+
api_type: Literal["openai"]
|
|
248
|
+
|
|
249
|
+
price: list[float] | None
|
|
250
|
+
tool_choice: Literal["none", "auto", "required"] | None
|
|
251
|
+
user: str | None
|
|
252
|
+
stream: bool
|
|
253
|
+
verbosity: Literal["low", "medium", "high"] | None
|
|
254
|
+
extra_body: dict[str, Any] | None
|
|
255
|
+
reasoning_effort: Literal["low", "minimal", "medium", "high"] | None
|
|
256
|
+
max_completion_tokens: int | None
|
|
257
|
+
|
|
258
|
+
|
|
243
259
|
class OpenAILLMConfigEntry(LLMConfigEntry):
|
|
244
260
|
api_type: Literal["openai"] = "openai"
|
|
245
|
-
|
|
246
|
-
price:
|
|
247
|
-
tool_choice:
|
|
248
|
-
user:
|
|
249
|
-
|
|
261
|
+
|
|
262
|
+
price: list[float] | None = Field(default=None, min_length=2, max_length=2)
|
|
263
|
+
tool_choice: Literal["none", "auto", "required"] | None = None
|
|
264
|
+
user: str | None = None
|
|
265
|
+
stream: bool = False
|
|
266
|
+
verbosity: Literal["low", "medium", "high"] | None = None
|
|
267
|
+
# The extra_body parameter flows from OpenAILLMConfigEntry to the LLM request through this path:
|
|
268
|
+
# 1. Config Definition: extra_body is defined in OpenAILLMConfigEntry (autogen/oai/client.py:248)
|
|
269
|
+
# 2. Parameter Classification: It's classified as an OpenAI client parameter (not AG2-specific) via the openai_kwargs property (autogen/oai/client.py:752-758)
|
|
270
|
+
# 3. Request Separation: In _separate_create_config() (autogen/oai/client.py:842), extra_body goes into create_config since it's not in the extra_kwargs set.
|
|
271
|
+
# 4. API Call: The create_config becomes params and gets passed directly to OpenAI's create() method via **params (autogen/oai/client.py:551,658)
|
|
272
|
+
extra_body: dict[str, Any] | None = (
|
|
250
273
|
None # For VLLM - See here: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters
|
|
251
274
|
)
|
|
252
275
|
# reasoning models - see: https://platform.openai.com/docs/api-reference/chat/create#chat-create-reasoning_effort
|
|
253
|
-
reasoning_effort:
|
|
254
|
-
max_completion_tokens:
|
|
276
|
+
reasoning_effort: Literal["low", "minimal", "medium", "high"] | None = None
|
|
277
|
+
max_completion_tokens: int | None = None
|
|
255
278
|
|
|
256
|
-
def create_client(self) ->
|
|
279
|
+
def create_client(self) -> ModelClient:
|
|
257
280
|
raise NotImplementedError("create_client method must be implemented in the derived class.")
|
|
258
281
|
|
|
259
282
|
|
|
260
|
-
|
|
283
|
+
class AzureOpenAIEntryDict(LLMConfigEntryDict, total=False):
|
|
284
|
+
api_type: Literal["azure"]
|
|
285
|
+
|
|
286
|
+
azure_ad_token_provider: str | Callable[[], str] | None
|
|
287
|
+
stream: bool
|
|
288
|
+
tool_choice: Literal["none", "auto", "required"] | None
|
|
289
|
+
user: str | None
|
|
290
|
+
reasoning_effort: Literal["low", "medium", "high"] | None
|
|
291
|
+
max_completion_tokens: int | None
|
|
292
|
+
|
|
293
|
+
|
|
261
294
|
class AzureOpenAILLMConfigEntry(LLMConfigEntry):
|
|
262
295
|
api_type: Literal["azure"] = "azure"
|
|
263
|
-
|
|
264
|
-
azure_ad_token_provider:
|
|
265
|
-
|
|
266
|
-
|
|
296
|
+
|
|
297
|
+
azure_ad_token_provider: str | Callable[[], str] | None = None
|
|
298
|
+
stream: bool = False
|
|
299
|
+
tool_choice: Literal["none", "auto", "required"] | None = None
|
|
300
|
+
user: str | None = None
|
|
267
301
|
# reasoning models - see:
|
|
268
302
|
# - https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/reasoning
|
|
269
303
|
# - https://learn.microsoft.com/en-us/azure/ai-services/openai/reference-preview
|
|
270
|
-
reasoning_effort:
|
|
271
|
-
max_completion_tokens:
|
|
304
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = None
|
|
305
|
+
max_completion_tokens: int | None = None
|
|
272
306
|
|
|
273
|
-
def create_client(self) ->
|
|
307
|
+
def create_client(self) -> ModelClient:
|
|
274
308
|
raise NotImplementedError
|
|
275
309
|
|
|
276
310
|
|
|
277
|
-
|
|
311
|
+
class DeepSeekEntyDict(LLMConfigEntryDict, total=False):
|
|
312
|
+
api_type: Literal["deepseek"]
|
|
313
|
+
|
|
314
|
+
base_url: HttpUrl
|
|
315
|
+
stream: bool
|
|
316
|
+
tool_choice: Literal["none", "auto", "required"] | None
|
|
317
|
+
|
|
318
|
+
|
|
278
319
|
class DeepSeekLLMConfigEntry(LLMConfigEntry):
|
|
279
320
|
api_type: Literal["deepseek"] = "deepseek"
|
|
280
|
-
|
|
281
|
-
temperature: float = Field(
|
|
321
|
+
|
|
322
|
+
temperature: float | None = Field(default=None, ge=0.0, le=1.0)
|
|
323
|
+
top_p: float | None = Field(None, ge=0.0, le=1.0)
|
|
282
324
|
max_tokens: int = Field(8192, ge=1, le=8192)
|
|
283
|
-
top_p: Optional[float] = Field(None, ge=0.0, le=1.0)
|
|
284
|
-
tool_choice: Optional[Literal["none", "auto", "required"]] = None
|
|
285
325
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
if v is not None and info.data.get("temperature") is not None:
|
|
290
|
-
raise ValueError("temperature and top_p cannot be set at the same time.")
|
|
291
|
-
return v
|
|
326
|
+
base_url: HttpUrl = HttpUrl("https://api.deepseek.com/v1")
|
|
327
|
+
stream: bool = False
|
|
328
|
+
tool_choice: Literal["none", "auto", "required"] | None = None
|
|
292
329
|
|
|
293
330
|
def create_client(self) -> None: # type: ignore [override]
|
|
294
331
|
raise NotImplementedError("DeepSeekLLMConfigEntry.create_client is not implemented.")
|
|
295
332
|
|
|
296
333
|
|
|
297
|
-
@export_module("autogen")
|
|
298
|
-
class ModelClient(Protocol):
|
|
299
|
-
"""A client class must implement the following methods:
|
|
300
|
-
- create must return a response object that implements the ModelClientResponseProtocol
|
|
301
|
-
- cost must return the cost of the response
|
|
302
|
-
- get_usage must return a dict with the following keys:
|
|
303
|
-
- prompt_tokens
|
|
304
|
-
- completion_tokens
|
|
305
|
-
- total_tokens
|
|
306
|
-
- cost
|
|
307
|
-
- model
|
|
308
|
-
|
|
309
|
-
This class is used to create a client that can be used by OpenAIWrapper.
|
|
310
|
-
The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed.
|
|
311
|
-
The message_retrieval method must be implemented to return a list of str or a list of messages from the response.
|
|
312
|
-
"""
|
|
313
|
-
|
|
314
|
-
RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]
|
|
315
|
-
|
|
316
|
-
class ModelClientResponseProtocol(Protocol):
|
|
317
|
-
class Choice(Protocol):
|
|
318
|
-
class Message(Protocol):
|
|
319
|
-
content: Optional[str] | Optional[dict[str, Any]]
|
|
320
|
-
|
|
321
|
-
message: Message
|
|
322
|
-
|
|
323
|
-
choices: list[Choice]
|
|
324
|
-
model: str
|
|
325
|
-
|
|
326
|
-
def create(self, params: dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover
|
|
327
|
-
|
|
328
|
-
def message_retrieval(
|
|
329
|
-
self, response: ModelClientResponseProtocol
|
|
330
|
-
) -> Union[list[str], list[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
|
|
331
|
-
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
|
|
332
|
-
|
|
333
|
-
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
|
|
334
|
-
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
|
|
335
|
-
"""
|
|
336
|
-
... # pragma: no cover
|
|
337
|
-
|
|
338
|
-
def cost(self, response: ModelClientResponseProtocol) -> float: ... # pragma: no cover
|
|
339
|
-
|
|
340
|
-
@staticmethod
|
|
341
|
-
def get_usage(response: ModelClientResponseProtocol) -> dict:
|
|
342
|
-
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
|
|
343
|
-
... # pragma: no cover
|
|
344
|
-
|
|
345
|
-
|
|
346
334
|
class PlaceHolderClient:
|
|
347
335
|
def __init__(self, config):
|
|
348
336
|
self.config = config
|
|
@@ -352,9 +340,7 @@ class PlaceHolderClient:
|
|
|
352
340
|
class OpenAIClient:
|
|
353
341
|
"""Follows the Client protocol and wraps the OpenAI client."""
|
|
354
342
|
|
|
355
|
-
def __init__(
|
|
356
|
-
self, client: Union[OpenAI, AzureOpenAI], response_format: Union[BaseModel, dict[str, Any], None] = None
|
|
357
|
-
):
|
|
343
|
+
def __init__(self, client: OpenAI | AzureOpenAI, response_format: BaseModel | dict[str, Any] | None = None):
|
|
358
344
|
self._oai_client = client
|
|
359
345
|
self.response_format = response_format
|
|
360
346
|
if (
|
|
@@ -366,9 +352,7 @@ class OpenAIClient:
|
|
|
366
352
|
"The API key specified is not a valid OpenAI format; it won't work with the OpenAI-hosted model."
|
|
367
353
|
)
|
|
368
354
|
|
|
369
|
-
def message_retrieval(
|
|
370
|
-
self, response: Union[ChatCompletion, Completion]
|
|
371
|
-
) -> Union[list[str], list[ChatCompletionMessage]]:
|
|
355
|
+
def message_retrieval(self, response: ChatCompletion | Completion) -> list[str] | list[ChatCompletionMessage]:
|
|
372
356
|
"""Retrieve the messages from the response.
|
|
373
357
|
|
|
374
358
|
Args:
|
|
@@ -505,7 +489,10 @@ class OpenAIClient:
|
|
|
505
489
|
if "stream" in kwargs:
|
|
506
490
|
kwargs.pop("stream")
|
|
507
491
|
|
|
508
|
-
if
|
|
492
|
+
if (
|
|
493
|
+
isinstance(kwargs["response_format"], dict)
|
|
494
|
+
and kwargs["response_format"].get("type") != "json_object"
|
|
495
|
+
):
|
|
509
496
|
kwargs["response_format"] = {
|
|
510
497
|
"type": "json_schema",
|
|
511
498
|
"json_schema": {
|
|
@@ -544,8 +531,8 @@ class OpenAIClient:
|
|
|
544
531
|
completion_tokens = 0
|
|
545
532
|
|
|
546
533
|
# Prepare for potential function call
|
|
547
|
-
full_function_call:
|
|
548
|
-
full_tool_calls:
|
|
534
|
+
full_function_call: dict[str, Any] | None = None
|
|
535
|
+
full_tool_calls: list[dict[str, Any] | None] | None = None
|
|
549
536
|
|
|
550
537
|
# Send the chat completion request to OpenAI's API and process the response in chunks
|
|
551
538
|
for chunk in create_or_parse(**params):
|
|
@@ -672,9 +659,9 @@ class OpenAIClient:
|
|
|
672
659
|
# Unsupported parameters
|
|
673
660
|
unsupported_params = [
|
|
674
661
|
"temperature",
|
|
662
|
+
"top_p",
|
|
675
663
|
"frequency_penalty",
|
|
676
664
|
"presence_penalty",
|
|
677
|
-
"top_p",
|
|
678
665
|
"logprobs",
|
|
679
666
|
"top_logprobs",
|
|
680
667
|
"logit_bias",
|
|
@@ -700,7 +687,7 @@ class OpenAIClient:
|
|
|
700
687
|
msg["role"] = "user"
|
|
701
688
|
msg["content"] = f"System message: {msg['content']}"
|
|
702
689
|
|
|
703
|
-
def cost(self, response:
|
|
690
|
+
def cost(self, response: ChatCompletion | Completion) -> float:
|
|
704
691
|
"""Calculate the cost of the response."""
|
|
705
692
|
model = response.model
|
|
706
693
|
if model not in OAI_PRICE1K:
|
|
@@ -721,7 +708,7 @@ class OpenAIClient:
|
|
|
721
708
|
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator]
|
|
722
709
|
|
|
723
710
|
@staticmethod
|
|
724
|
-
def get_usage(response:
|
|
711
|
+
def get_usage(response: ChatCompletion | Completion) -> dict:
|
|
725
712
|
return {
|
|
726
713
|
"prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
|
|
727
714
|
"completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
|
|
@@ -757,13 +744,13 @@ class OpenAIWrapper:
|
|
|
757
744
|
else:
|
|
758
745
|
return OPENAI_FALLBACK_KWARGS | AOPENAI_FALLBACK_KWARGS
|
|
759
746
|
|
|
760
|
-
total_usage_summary:
|
|
761
|
-
actual_usage_summary:
|
|
747
|
+
total_usage_summary: dict[str, Any] | None = None
|
|
748
|
+
actual_usage_summary: dict[str, Any] | None = None
|
|
762
749
|
|
|
763
750
|
def __init__(
|
|
764
751
|
self,
|
|
765
752
|
*,
|
|
766
|
-
config_list:
|
|
753
|
+
config_list: list[dict[str, Any]] | None = None,
|
|
767
754
|
**base_config: Any,
|
|
768
755
|
):
|
|
769
756
|
"""Initialize the OpenAIWrapper.
|
|
@@ -845,17 +832,6 @@ class OpenAIWrapper:
|
|
|
845
832
|
|
|
846
833
|
def _configure_azure_openai(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None:
|
|
847
834
|
openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model"))
|
|
848
|
-
if openai_config["azure_deployment"] is not None:
|
|
849
|
-
# Preserve dots for specific model versions that require them
|
|
850
|
-
deployment_name = openai_config["azure_deployment"]
|
|
851
|
-
if deployment_name in [
|
|
852
|
-
"gpt-4.1"
|
|
853
|
-
]: # Add more as needed, Whitelist approach so as to not break existing deployments
|
|
854
|
-
# Keep the deployment name as-is for these specific models
|
|
855
|
-
pass
|
|
856
|
-
else:
|
|
857
|
-
# Remove dots for all other models (maintain existing behavior)
|
|
858
|
-
openai_config["azure_deployment"] = deployment_name.replace(".", "")
|
|
859
835
|
openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None))
|
|
860
836
|
|
|
861
837
|
# Create a default Azure token provider if requested
|
|
@@ -884,6 +860,13 @@ class OpenAIWrapper:
|
|
|
884
860
|
if key in config:
|
|
885
861
|
openai_config[key] = config[key]
|
|
886
862
|
|
|
863
|
+
def _configure_openai_config_for_gemini(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None:
|
|
864
|
+
"""Update openai_config with additional gemini genai configs."""
|
|
865
|
+
optional_keys = ["proxy"]
|
|
866
|
+
for key in optional_keys:
|
|
867
|
+
if key in config:
|
|
868
|
+
openai_config[key] = config[key]
|
|
869
|
+
|
|
887
870
|
def _register_default_client(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None:
|
|
888
871
|
"""Create a client with the given config to override openai_config,
|
|
889
872
|
after removing extra kwargs.
|
|
@@ -909,7 +892,7 @@ class OpenAIWrapper:
|
|
|
909
892
|
if api_type is not None and api_type.startswith("azure"):
|
|
910
893
|
|
|
911
894
|
@require_optional_import("openai>=1.66.2", "openai")
|
|
912
|
-
def create_azure_openai_client() ->
|
|
895
|
+
def create_azure_openai_client() -> AzureOpenAI:
|
|
913
896
|
self._configure_azure_openai(config, openai_config)
|
|
914
897
|
client = AzureOpenAI(**openai_config)
|
|
915
898
|
self._clients.append(OpenAIClient(client, response_format=response_format))
|
|
@@ -924,6 +907,7 @@ class OpenAIWrapper:
|
|
|
924
907
|
elif api_type is not None and api_type.startswith("google"):
|
|
925
908
|
if gemini_import_exception:
|
|
926
909
|
raise ImportError("Please install `google-genai` and 'vertexai' to use Google's API.")
|
|
910
|
+
self._configure_openai_config_for_gemini(config, openai_config)
|
|
927
911
|
client = GeminiClient(response_format=response_format, **openai_config)
|
|
928
912
|
self._clients.append(client)
|
|
929
913
|
elif api_type is not None and api_type.startswith("anthropic"):
|
|
@@ -969,7 +953,7 @@ class OpenAIWrapper:
|
|
|
969
953
|
elif api_type is not None and api_type.startswith("responses"):
|
|
970
954
|
# OpenAI Responses API (stateful). Reuse the same OpenAI SDK but call the `/responses` endpoint via the new client.
|
|
971
955
|
@require_optional_import("openai>=1.66.2", "openai")
|
|
972
|
-
def create_responses_client() ->
|
|
956
|
+
def create_responses_client() -> OpenAI:
|
|
973
957
|
client = OpenAI(**openai_config)
|
|
974
958
|
self._clients.append(OpenAIResponsesClient(client, response_format=response_format))
|
|
975
959
|
return client
|
|
@@ -978,7 +962,7 @@ class OpenAIWrapper:
|
|
|
978
962
|
else:
|
|
979
963
|
|
|
980
964
|
@require_optional_import("openai>=1.66.2", "openai")
|
|
981
|
-
def create_openai_client() ->
|
|
965
|
+
def create_openai_client() -> OpenAI:
|
|
982
966
|
client = OpenAI(**openai_config)
|
|
983
967
|
self._clients.append(OpenAIClient(client, response_format))
|
|
984
968
|
return client
|
|
@@ -1019,10 +1003,10 @@ class OpenAIWrapper:
|
|
|
1019
1003
|
@classmethod
|
|
1020
1004
|
def instantiate(
|
|
1021
1005
|
cls,
|
|
1022
|
-
template:
|
|
1023
|
-
context:
|
|
1024
|
-
allow_format_str_template:
|
|
1025
|
-
) ->
|
|
1006
|
+
template: str | Callable[[dict[str, Any]], str] | None,
|
|
1007
|
+
context: dict[str, Any] | None = None,
|
|
1008
|
+
allow_format_str_template: bool | None = False,
|
|
1009
|
+
) -> str | None:
|
|
1026
1010
|
if not context or template is None:
|
|
1027
1011
|
return template # type: ignore [return-value]
|
|
1028
1012
|
if isinstance(template, str):
|
|
@@ -1032,8 +1016,8 @@ class OpenAIWrapper:
|
|
|
1032
1016
|
def _construct_create_params(self, create_config: dict[str, Any], extra_kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
1033
1017
|
"""Prime the create_config with additional_kwargs."""
|
|
1034
1018
|
# Validate the config
|
|
1035
|
-
prompt:
|
|
1036
|
-
messages:
|
|
1019
|
+
prompt: str | None = create_config.get("prompt")
|
|
1020
|
+
messages: list[dict[str, Any]] | None = create_config.get("messages")
|
|
1037
1021
|
if (prompt is None) == (messages is None):
|
|
1038
1022
|
raise ValueError("Either prompt or messages should be in create config but not both.")
|
|
1039
1023
|
context = extra_kwargs.get("context")
|
|
@@ -1100,9 +1084,6 @@ class OpenAIWrapper:
|
|
|
1100
1084
|
full_config = {**config, **self._config_list[i]}
|
|
1101
1085
|
# separate the config into create_config and extra_kwargs
|
|
1102
1086
|
create_config, extra_kwargs = self._separate_create_config(full_config)
|
|
1103
|
-
api_type = extra_kwargs.get("api_type")
|
|
1104
|
-
if api_type and api_type.startswith("azure") and "model" in create_config:
|
|
1105
|
-
create_config["model"] = create_config["model"].replace(".", "")
|
|
1106
1087
|
# construct the create params
|
|
1107
1088
|
params = self._construct_create_params(create_config, extra_kwargs)
|
|
1108
1089
|
# get the cache_seed, filter_func and context
|
|
@@ -1330,8 +1311,8 @@ class OpenAIWrapper:
|
|
|
1330
1311
|
|
|
1331
1312
|
@staticmethod
|
|
1332
1313
|
def _update_function_call_from_chunk(
|
|
1333
|
-
function_call_chunk:
|
|
1334
|
-
full_function_call:
|
|
1314
|
+
function_call_chunk: ChoiceDeltaToolCallFunction | ChoiceDeltaFunctionCall,
|
|
1315
|
+
full_function_call: dict[str, Any] | None,
|
|
1335
1316
|
completion_tokens: int,
|
|
1336
1317
|
) -> tuple[dict[str, Any], int]:
|
|
1337
1318
|
"""Update the function call from the chunk.
|
|
@@ -1362,7 +1343,7 @@ class OpenAIWrapper:
|
|
|
1362
1343
|
@staticmethod
|
|
1363
1344
|
def _update_tool_calls_from_chunk(
|
|
1364
1345
|
tool_calls_chunk: ChoiceDeltaToolCall,
|
|
1365
|
-
full_tool_call:
|
|
1346
|
+
full_tool_call: dict[str, Any] | None,
|
|
1366
1347
|
completion_tokens: int,
|
|
1367
1348
|
) -> tuple[dict[str, Any], int]:
|
|
1368
1349
|
"""Update the tool call from the chunk.
|
|
@@ -1436,7 +1417,7 @@ class OpenAIWrapper:
|
|
|
1436
1417
|
if actual_usage is not None:
|
|
1437
1418
|
self.actual_usage_summary = update_usage(self.actual_usage_summary, actual_usage)
|
|
1438
1419
|
|
|
1439
|
-
def print_usage_summary(self, mode:
|
|
1420
|
+
def print_usage_summary(self, mode: str | list[str] = ["actual", "total"]) -> None:
|
|
1440
1421
|
"""Print the usage summary."""
|
|
1441
1422
|
iostream = IOStream.get_default()
|
|
1442
1423
|
|
|
@@ -1464,7 +1445,7 @@ class OpenAIWrapper:
|
|
|
1464
1445
|
@classmethod
|
|
1465
1446
|
def extract_text_or_completion_object(
|
|
1466
1447
|
cls, response: ModelClient.ModelClientResponseProtocol
|
|
1467
|
-
) ->
|
|
1448
|
+
) -> list[str] | list[ModelClient.ModelClientResponseProtocol.Choice.Message]:
|
|
1468
1449
|
"""Extract the text or ChatCompletion objects from a completion or chat response.
|
|
1469
1450
|
|
|
1470
1451
|
Args:
|
|
@@ -1481,7 +1462,6 @@ class OpenAIWrapper:
|
|
|
1481
1462
|
# -----------------------------------------------------------------------------
|
|
1482
1463
|
|
|
1483
1464
|
|
|
1484
|
-
@register_llm_config
|
|
1485
1465
|
class OpenAIResponsesLLMConfigEntry(OpenAILLMConfigEntry):
|
|
1486
1466
|
"""LLMConfig entry for the OpenAI Responses API (stateful, tool-enabled).
|
|
1487
1467
|
|
|
@@ -1501,8 +1481,8 @@ class OpenAIResponsesLLMConfigEntry(OpenAILLMConfigEntry):
|
|
|
1501
1481
|
"""
|
|
1502
1482
|
|
|
1503
1483
|
api_type: Literal["responses"] = "responses"
|
|
1504
|
-
tool_choice:
|
|
1505
|
-
built_in_tools:
|
|
1484
|
+
tool_choice: Literal["none", "auto", "required"] | None = "auto"
|
|
1485
|
+
built_in_tools: list[str] | None = None
|
|
1506
1486
|
|
|
1507
|
-
def create_client(self) ->
|
|
1487
|
+
def create_client(self) -> ModelClient: # pragma: no cover
|
|
1508
1488
|
raise NotImplementedError("Handled via OpenAIWrapper._register_default_client")
|
autogen/oai/client_utils.py
CHANGED
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
import logging
|
|
10
10
|
import warnings
|
|
11
|
-
from typing import Any,
|
|
11
|
+
from typing import Any, Protocol, runtime_checkable
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
@runtime_checkable
|
|
@@ -24,8 +24,8 @@ def validate_parameter(
|
|
|
24
24
|
allowed_types: tuple[Any, ...],
|
|
25
25
|
allow_None: bool, # noqa: N803
|
|
26
26
|
default_value: Any,
|
|
27
|
-
numerical_bound:
|
|
28
|
-
allowed_values:
|
|
27
|
+
numerical_bound: tuple[float | None, float | None] | None,
|
|
28
|
+
allowed_values: list[Any] | None,
|
|
29
29
|
) -> Any:
|
|
30
30
|
"""Validates a given config parameter, checking its type, values, and setting defaults
|
|
31
31
|
Parameters:
|
autogen/oai/cohere.py
CHANGED
|
@@ -34,14 +34,15 @@ import os
|
|
|
34
34
|
import sys
|
|
35
35
|
import time
|
|
36
36
|
import warnings
|
|
37
|
-
from typing import Any, Literal
|
|
37
|
+
from typing import Any, Literal
|
|
38
38
|
|
|
39
39
|
from pydantic import BaseModel, Field
|
|
40
|
+
from typing_extensions import Unpack
|
|
40
41
|
|
|
41
42
|
from autogen.oai.client_utils import FormatterProtocol, logging_formatter, validate_parameter
|
|
42
43
|
|
|
43
44
|
from ..import_utils import optional_import_block, require_optional_import
|
|
44
|
-
from ..llm_config import LLMConfigEntry,
|
|
45
|
+
from ..llm_config.entry import LLMConfigEntry, LLMConfigEntryDict
|
|
45
46
|
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage
|
|
46
47
|
|
|
47
48
|
with optional_import_block():
|
|
@@ -66,20 +67,30 @@ COHERE_PRICING_1K = {
|
|
|
66
67
|
}
|
|
67
68
|
|
|
68
69
|
|
|
69
|
-
|
|
70
|
+
class CohereEntryDict(LLMConfigEntryDict, total=False):
|
|
71
|
+
api_type: Literal["cohere"]
|
|
72
|
+
|
|
73
|
+
k: int
|
|
74
|
+
seed: int | None
|
|
75
|
+
frequency_penalty: float
|
|
76
|
+
presence_penalty: float
|
|
77
|
+
client_name: str | None
|
|
78
|
+
strict_tools: bool
|
|
79
|
+
stream: bool
|
|
80
|
+
tool_choice: Literal["NONE", "REQUIRED"] | None
|
|
81
|
+
|
|
82
|
+
|
|
70
83
|
class CohereLLMConfigEntry(LLMConfigEntry):
|
|
71
84
|
api_type: Literal["cohere"] = "cohere"
|
|
72
|
-
|
|
73
|
-
max_tokens: Optional[int] = Field(default=None, ge=0)
|
|
85
|
+
|
|
74
86
|
k: int = Field(default=0, ge=0, le=500)
|
|
75
|
-
|
|
76
|
-
seed: Optional[int] = None
|
|
87
|
+
seed: int | None = None
|
|
77
88
|
frequency_penalty: float = Field(default=0, ge=0, le=1)
|
|
78
89
|
presence_penalty: float = Field(default=0, ge=0, le=1)
|
|
79
|
-
client_name:
|
|
90
|
+
client_name: str | None = None
|
|
80
91
|
strict_tools: bool = False
|
|
81
92
|
stream: bool = False
|
|
82
|
-
tool_choice:
|
|
93
|
+
tool_choice: Literal["NONE", "REQUIRED"] | None = None
|
|
83
94
|
|
|
84
95
|
def create_client(self):
|
|
85
96
|
raise NotImplementedError("CohereLLMConfigEntry.create_client is not implemented.")
|
|
@@ -88,7 +99,7 @@ class CohereLLMConfigEntry(LLMConfigEntry):
|
|
|
88
99
|
class CohereClient:
|
|
89
100
|
"""Client for Cohere's API."""
|
|
90
101
|
|
|
91
|
-
def __init__(self, **kwargs):
|
|
102
|
+
def __init__(self, **kwargs: Unpack[CohereEntryDict]):
|
|
92
103
|
"""Requires api_key or environment variable to be set
|
|
93
104
|
|
|
94
105
|
Args:
|
|
@@ -104,7 +115,7 @@ class CohereClient:
|
|
|
104
115
|
)
|
|
105
116
|
|
|
106
117
|
# Store the response format, if provided (for structured outputs)
|
|
107
|
-
self._response_format:
|
|
118
|
+
self._response_format: type[BaseModel] | None = None
|
|
108
119
|
|
|
109
120
|
def message_retrieval(self, response) -> list:
|
|
110
121
|
"""Retrieve and return a list of strings or a list of Choice.Message from the response.
|
|
@@ -203,7 +214,17 @@ class CohereClient:
|
|
|
203
214
|
if "k" in params:
|
|
204
215
|
cohere_params["k"] = validate_parameter(params, "k", int, False, 0, (0, 500), None)
|
|
205
216
|
|
|
217
|
+
if "top_p" in params:
|
|
218
|
+
cohere_params["p"] = validate_parameter(params, "top_p", (int, float), False, 0.75, (0.01, 0.99), None)
|
|
219
|
+
|
|
206
220
|
if "p" in params:
|
|
221
|
+
warnings.warn(
|
|
222
|
+
(
|
|
223
|
+
"parameter 'p' is deprecated, use 'top_p' instead for consistency with OpenAI API spec. "
|
|
224
|
+
"Scheduled for removal in 0.10.0 version."
|
|
225
|
+
),
|
|
226
|
+
DeprecationWarning,
|
|
227
|
+
)
|
|
207
228
|
cohere_params["p"] = validate_parameter(params, "p", (int, float), False, 0.75, (0.01, 0.99), None)
|
|
208
229
|
|
|
209
230
|
if "seed" in params:
|
|
@@ -402,8 +423,10 @@ class CohereClient:
|
|
|
402
423
|
|
|
403
424
|
def _convert_json_response(self, response: str) -> Any:
|
|
404
425
|
"""Extract and validate JSON response from the output for structured outputs.
|
|
426
|
+
|
|
405
427
|
Args:
|
|
406
428
|
response (str): The response from the API.
|
|
429
|
+
|
|
407
430
|
Returns:
|
|
408
431
|
Any: The parsed JSON response.
|
|
409
432
|
"""
|