ag2 0.9.9__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ag2 might be problematic. Click here for more details.

Files changed (113) hide show
  1. {ag2-0.9.9.dist-info → ag2-0.10.0.dist-info}/METADATA +243 -214
  2. {ag2-0.9.9.dist-info → ag2-0.10.0.dist-info}/RECORD +113 -87
  3. autogen/_website/generate_mkdocs.py +3 -3
  4. autogen/_website/notebook_processor.py +1 -1
  5. autogen/_website/utils.py +1 -1
  6. autogen/a2a/__init__.py +36 -0
  7. autogen/a2a/agent_executor.py +105 -0
  8. autogen/a2a/client.py +280 -0
  9. autogen/a2a/errors.py +18 -0
  10. autogen/a2a/httpx_client_factory.py +79 -0
  11. autogen/a2a/server.py +221 -0
  12. autogen/a2a/utils.py +165 -0
  13. autogen/agentchat/__init__.py +3 -0
  14. autogen/agentchat/agent.py +0 -2
  15. autogen/agentchat/assistant_agent.py +15 -15
  16. autogen/agentchat/chat.py +57 -41
  17. autogen/agentchat/contrib/agent_eval/criterion.py +1 -1
  18. autogen/agentchat/contrib/capabilities/text_compressors.py +5 -5
  19. autogen/agentchat/contrib/capabilities/tools_capability.py +1 -1
  20. autogen/agentchat/contrib/capabilities/transforms.py +1 -1
  21. autogen/agentchat/contrib/captainagent/agent_builder.py +1 -1
  22. autogen/agentchat/contrib/captainagent/captainagent.py +20 -19
  23. autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +2 -5
  24. autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +5 -5
  25. autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py +18 -17
  26. autogen/agentchat/contrib/llava_agent.py +1 -13
  27. autogen/agentchat/contrib/rag/mongodb_query_engine.py +2 -2
  28. autogen/agentchat/contrib/rag/query_engine.py +11 -11
  29. autogen/agentchat/contrib/retrieve_assistant_agent.py +3 -0
  30. autogen/agentchat/contrib/swarm_agent.py +3 -2
  31. autogen/agentchat/contrib/vectordb/couchbase.py +1 -1
  32. autogen/agentchat/contrib/vectordb/mongodb.py +1 -1
  33. autogen/agentchat/contrib/web_surfer.py +1 -1
  34. autogen/agentchat/conversable_agent.py +359 -150
  35. autogen/agentchat/group/context_expression.py +21 -21
  36. autogen/agentchat/group/group_tool_executor.py +46 -15
  37. autogen/agentchat/group/guardrails.py +41 -33
  38. autogen/agentchat/group/handoffs.py +11 -11
  39. autogen/agentchat/group/multi_agent_chat.py +56 -2
  40. autogen/agentchat/group/on_condition.py +11 -11
  41. autogen/agentchat/group/safeguards/__init__.py +21 -0
  42. autogen/agentchat/group/safeguards/api.py +241 -0
  43. autogen/agentchat/group/safeguards/enforcer.py +1158 -0
  44. autogen/agentchat/group/safeguards/events.py +119 -0
  45. autogen/agentchat/group/safeguards/validator.py +435 -0
  46. autogen/agentchat/groupchat.py +102 -49
  47. autogen/agentchat/realtime/experimental/clients/realtime_client.py +2 -2
  48. autogen/agentchat/realtime/experimental/function_observer.py +2 -3
  49. autogen/agentchat/realtime/experimental/realtime_agent.py +2 -3
  50. autogen/agentchat/realtime/experimental/realtime_swarm.py +22 -13
  51. autogen/agentchat/user_proxy_agent.py +55 -53
  52. autogen/agents/experimental/document_agent/document_agent.py +1 -10
  53. autogen/agents/experimental/document_agent/parser_utils.py +5 -1
  54. autogen/browser_utils.py +4 -4
  55. autogen/cache/abstract_cache_base.py +2 -6
  56. autogen/cache/disk_cache.py +1 -6
  57. autogen/cache/in_memory_cache.py +2 -6
  58. autogen/cache/redis_cache.py +1 -5
  59. autogen/coding/__init__.py +10 -2
  60. autogen/coding/base.py +2 -1
  61. autogen/coding/docker_commandline_code_executor.py +1 -6
  62. autogen/coding/factory.py +9 -0
  63. autogen/coding/jupyter/docker_jupyter_server.py +1 -7
  64. autogen/coding/jupyter/jupyter_client.py +2 -9
  65. autogen/coding/jupyter/jupyter_code_executor.py +2 -7
  66. autogen/coding/jupyter/local_jupyter_server.py +2 -6
  67. autogen/coding/local_commandline_code_executor.py +0 -65
  68. autogen/coding/yepcode_code_executor.py +197 -0
  69. autogen/environments/docker_python_environment.py +3 -3
  70. autogen/environments/system_python_environment.py +5 -5
  71. autogen/environments/venv_python_environment.py +5 -5
  72. autogen/events/agent_events.py +1 -1
  73. autogen/events/client_events.py +1 -1
  74. autogen/fast_depends/utils.py +10 -0
  75. autogen/graph_utils.py +5 -7
  76. autogen/import_utils.py +3 -1
  77. autogen/interop/pydantic_ai/pydantic_ai.py +8 -5
  78. autogen/io/processors/console_event_processor.py +8 -3
  79. autogen/llm_config/client.py +3 -2
  80. autogen/llm_config/config.py +168 -91
  81. autogen/llm_config/entry.py +38 -26
  82. autogen/llm_config/types.py +35 -0
  83. autogen/llm_config/utils.py +223 -0
  84. autogen/mcp/mcp_proxy/operation_grouping.py +48 -39
  85. autogen/messages/agent_messages.py +1 -1
  86. autogen/messages/client_messages.py +1 -1
  87. autogen/oai/__init__.py +8 -1
  88. autogen/oai/bedrock.py +0 -13
  89. autogen/oai/client.py +25 -11
  90. autogen/oai/client_utils.py +31 -1
  91. autogen/oai/cohere.py +4 -14
  92. autogen/oai/gemini.py +4 -6
  93. autogen/oai/gemini_types.py +1 -0
  94. autogen/oai/openai_utils.py +44 -115
  95. autogen/remote/__init__.py +18 -0
  96. autogen/remote/agent.py +199 -0
  97. autogen/remote/agent_service.py +142 -0
  98. autogen/remote/errors.py +17 -0
  99. autogen/remote/httpx_client_factory.py +131 -0
  100. autogen/remote/protocol.py +37 -0
  101. autogen/remote/retry.py +102 -0
  102. autogen/remote/runtime.py +96 -0
  103. autogen/testing/__init__.py +12 -0
  104. autogen/testing/messages.py +45 -0
  105. autogen/testing/test_agent.py +111 -0
  106. autogen/tools/dependency_injection.py +4 -8
  107. autogen/tools/experimental/reliable/reliable.py +3 -2
  108. autogen/tools/experimental/web_search_preview/web_search_preview.py +1 -1
  109. autogen/tools/function_utils.py +2 -1
  110. autogen/version.py +1 -1
  111. {ag2-0.9.9.dist-info → ag2-0.10.0.dist-info}/WHEEL +0 -0
  112. {ag2-0.9.9.dist-info → ag2-0.10.0.dist-info}/licenses/LICENSE +0 -0
  113. {ag2-0.9.9.dist-info → ag2-0.10.0.dist-info}/licenses/NOTICE.md +0 -0
@@ -0,0 +1,131 @@
1
+ # Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import ssl
6
+ import typing
7
+ from typing import Protocol
8
+
9
+ from httpx._client import AsyncClient, Client, EventHook
10
+ from httpx._config import DEFAULT_LIMITS, DEFAULT_MAX_REDIRECTS, DEFAULT_TIMEOUT_CONFIG, Limits
11
+ from httpx._transports.base import AsyncBaseTransport
12
+ from httpx._types import AuthTypes, CertTypes, CookieTypes, HeaderTypes, ProxyTypes, QueryParamTypes, TimeoutTypes
13
+ from httpx._urls import URL
14
+
15
+ from autogen.doc_utils import export_module
16
+
17
+
18
+ class ClientFactory(Protocol):
19
+ def __call__(self) -> AsyncClient: ...
20
+
21
+ def make_sync(self) -> Client: ...
22
+
23
+
24
+ @export_module("autogen.a2a")
25
+ class HttpxClientFactory(ClientFactory):
26
+ """
27
+ An asynchronous HTTP client factory, with connection pooling, HTTP/2, redirects,
28
+ cookie persistence, etc.
29
+
30
+ It can be shared between tasks.
31
+
32
+ Usage:
33
+
34
+ ```python
35
+ >>> factory = HttpxClientFactory()
36
+ >>> async with factory() as client:
37
+ >>> response = await client.get('https://example.org')
38
+ ```
39
+
40
+ **Parameters:**
41
+
42
+ * **auth** - *(optional)* An authentication class to use when sending
43
+ requests.
44
+ * **params** - *(optional)* Query parameters to include in request URLs, as
45
+ a string, dictionary, or sequence of two-tuples.
46
+ * **headers** - *(optional)* Dictionary of HTTP headers to include when
47
+ sending requests.
48
+ * **cookies** - *(optional)* Dictionary of Cookie items to include when
49
+ sending requests.
50
+ * **verify** - *(optional)* Either `True` to use an SSL context with the
51
+ default CA bundle, `False` to disable verification, or an instance of
52
+ `ssl.SSLContext` to use a custom context.
53
+ * **http2** - *(optional)* A boolean indicating if HTTP/2 support should be
54
+ enabled. Defaults to `False`.
55
+ * **proxy** - *(optional)* A proxy URL where all the traffic should be routed.
56
+ * **timeout** - *(optional)* The timeout configuration to use when sending
57
+ requests.
58
+ * **limits** - *(optional)* The limits configuration to use.
59
+ * **max_redirects** - *(optional)* The maximum number of redirect responses
60
+ that should be followed.
61
+ * **base_url** - *(optional)* A URL to use as the base when building
62
+ request URLs.
63
+ * **transport** - *(optional)* A transport class to use for sending requests
64
+ over the network.
65
+ * **trust_env** - *(optional)* Enables or disables usage of environment
66
+ variables for configuration.
67
+ * **default_encoding** - *(optional)* The default encoding to use for decoding
68
+ response text, if no charset information is included in a response Content-Type
69
+ header. Set to a callable for automatic character set detection. Default: "utf-8".
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ *,
75
+ auth: AuthTypes | None = None,
76
+ params: QueryParamTypes | None = None,
77
+ headers: HeaderTypes | None = None,
78
+ cookies: CookieTypes | None = None,
79
+ verify: ssl.SSLContext | str | bool = True,
80
+ cert: CertTypes | None = None,
81
+ http1: bool = True,
82
+ http2: bool = False,
83
+ proxy: ProxyTypes | None = None,
84
+ mounts: None | (typing.Mapping[str, AsyncBaseTransport | None]) = None,
85
+ timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
86
+ follow_redirects: bool = False,
87
+ limits: Limits = DEFAULT_LIMITS,
88
+ max_redirects: int = DEFAULT_MAX_REDIRECTS,
89
+ event_hooks: None | (typing.Mapping[str, list[EventHook]]) = None,
90
+ base_url: URL | str = "",
91
+ transport: AsyncBaseTransport | None = None,
92
+ trust_env: bool = True,
93
+ default_encoding: str | typing.Callable[[bytes], str] = "utf-8",
94
+ **kwargs: typing.Any,
95
+ ) -> None:
96
+ self.options = {
97
+ "auth": auth,
98
+ "params": params,
99
+ "headers": headers,
100
+ "cookies": cookies,
101
+ "verify": verify,
102
+ "cert": cert,
103
+ "http1": http1,
104
+ "http2": http2,
105
+ "proxy": proxy,
106
+ "mounts": mounts,
107
+ "timeout": timeout,
108
+ "follow_redirects": follow_redirects,
109
+ "limits": limits,
110
+ "max_redirects": max_redirects,
111
+ "event_hooks": event_hooks,
112
+ "base_url": base_url,
113
+ "transport": transport,
114
+ "trust_env": trust_env,
115
+ "default_encoding": default_encoding,
116
+ **kwargs,
117
+ }
118
+
119
+ def __call__(self) -> AsyncClient:
120
+ return AsyncClient(**self.options)
121
+
122
+ def make_sync(self) -> Client:
123
+ return Client(**self.options)
124
+
125
+
126
+ class EmptyClientFactory(ClientFactory):
127
+ def __call__(self) -> AsyncClient:
128
+ return AsyncClient(timeout=30.0)
129
+
130
+ def make_sync(self) -> Client:
131
+ return Client(timeout=30.0)
@@ -0,0 +1,37 @@
1
+ # Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ from typing import Any, Protocol
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class AgentBusMessage(BaseModel):
10
+ messages: list[dict[str, Any]]
11
+ context: dict[str, Any] | None = None
12
+
13
+
14
+ class RequestMessage(AgentBusMessage):
15
+ client_tools: list[dict[str, Any]] = Field(default_factory=list)
16
+
17
+ @property
18
+ def client_tool_names(self) -> set[str]:
19
+ return get_tool_names(self.client_tools)
20
+
21
+
22
+ class ResponseMessage(AgentBusMessage):
23
+ pass
24
+
25
+
26
+ class RemoteService(Protocol):
27
+ """Interface to make AgentBus compatible with non AG2 systems."""
28
+
29
+ name: str
30
+
31
+ async def __call__(self, state: RequestMessage) -> ResponseMessage | None:
32
+ """Executable that consumes Conversation State and returns a new state."""
33
+ ...
34
+
35
+
36
+ def get_tool_names(tools: list[dict[str, Any]]) -> set[str]:
37
+ return set(filter(bool, (tool.get("function", {}).get("name", "") for tool in tools)))
@@ -0,0 +1,102 @@
1
+ # Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ import time
5
+ from types import TracebackType
6
+ from typing import Protocol
7
+
8
+ import anyio
9
+
10
+
11
+ class RetryPolicyManager(Protocol):
12
+ def __enter__(self) -> None:
13
+ pass
14
+
15
+ async def __aenter__(self) -> None:
16
+ pass
17
+
18
+ def __exit__(
19
+ self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
20
+ ) -> None | bool:
21
+ pass
22
+
23
+ async def __aexit__(
24
+ self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
25
+ ) -> None | bool:
26
+ pass
27
+
28
+
29
+ class RetryPolicy(Protocol):
30
+ def __call__(self) -> RetryPolicyManager: ...
31
+
32
+
33
+ class SleepRetryPolicy(RetryPolicy):
34
+ def __init__(self, retry_interval: float = 10.0, retry_count: int = 3) -> None:
35
+ self.retry_interval = retry_interval
36
+ self.retry_count = retry_count
37
+
38
+ def __call__(self) -> RetryPolicyManager:
39
+ return _SleepRetryPolicy(self.retry_interval, self.retry_count)
40
+
41
+
42
+ class _SleepRetryPolicy(RetryPolicyManager):
43
+ def __init__(self, retry_interval: float = 10.0, retry_count: int = 3) -> None:
44
+ self.retry_interval = retry_interval
45
+ self.retry_count = retry_count
46
+ self.errors_count = 0
47
+
48
+ def __enter__(self) -> None:
49
+ pass
50
+
51
+ async def __aenter__(self) -> None:
52
+ pass
53
+
54
+ def __exit__(
55
+ self,
56
+ exc_type: type[BaseException] | None,
57
+ exc_value: BaseException | None,
58
+ traceback: TracebackType | None,
59
+ ) -> None | bool:
60
+ if exc_type is not None:
61
+ self.errors_count += 1
62
+ should_suppress = self.errors_count < self.retry_count
63
+ time.sleep(self.retry_interval)
64
+ return should_suppress
65
+ return None
66
+
67
+ async def __aexit__(
68
+ self,
69
+ exc_type: type[BaseException] | None,
70
+ exc_value: BaseException | None,
71
+ traceback: TracebackType | None,
72
+ ) -> None | bool:
73
+ if exc_type is not None:
74
+ self.errors_count += 1
75
+ should_suppress = self.errors_count < self.retry_count
76
+ await anyio.sleep(self.retry_interval)
77
+ return should_suppress
78
+ return None
79
+
80
+
81
+ class NoRetryPolicy(RetryPolicyManager):
82
+ def __enter__(self) -> None:
83
+ pass
84
+
85
+ async def __aenter__(self) -> None:
86
+ pass
87
+
88
+ async def __aexit__(
89
+ self,
90
+ exc_type: type[BaseException] | None,
91
+ exc_value: BaseException | None,
92
+ traceback: TracebackType | None,
93
+ ) -> None | bool:
94
+ pass
95
+
96
+ def __exit__(
97
+ self,
98
+ exc_type: type[BaseException] | None,
99
+ exc_value: BaseException | None,
100
+ traceback: TracebackType | None,
101
+ ) -> None | bool:
102
+ pass
@@ -0,0 +1,96 @@
1
+ # Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import asyncio
6
+ from collections.abc import Awaitable, Callable, Iterable, MutableMapping
7
+ from itertools import chain
8
+ from typing import Any
9
+ from uuid import UUID, uuid4
10
+
11
+ from fastapi import FastAPI, HTTPException, Response, status
12
+
13
+ from autogen.agentchat import ConversableAgent
14
+
15
+ from .agent_service import AgentService
16
+ from .protocol import RemoteService, RequestMessage, ResponseMessage
17
+
18
+
19
+ class HTTPAgentBus:
20
+ def __init__(
21
+ self,
22
+ agents: Iterable[ConversableAgent] = (),
23
+ *,
24
+ long_polling_interval: float = 10.0,
25
+ additional_services: Iterable[RemoteService] = (),
26
+ ) -> None:
27
+ """Create HTTPAgentBus runtime.
28
+
29
+ Makes the passed agents capable of processing remote calls.
30
+
31
+ Args:
32
+ agents: Agents to register as remote services.
33
+ long_polling_interval: Timeout to respond on task status calls for long-living executions.
34
+ Should be less than clients' HTTP request timeout.
35
+ additional_services: Additional services to register.
36
+ """
37
+ self.app = FastAPI()
38
+
39
+ for service in chain(map(AgentService, agents), additional_services):
40
+ register_agent_endpoints(
41
+ app=self.app,
42
+ service=service,
43
+ long_polling_interval=long_polling_interval,
44
+ )
45
+
46
+ async def __call__(
47
+ self,
48
+ scope: MutableMapping[str, Any],
49
+ receive: Callable[[], Awaitable[MutableMapping[str, Any]]],
50
+ send: Callable[[MutableMapping[str, Any]], Awaitable[None]],
51
+ ) -> None:
52
+ """ASGI interface."""
53
+ await self.app(scope, receive, send)
54
+
55
+
56
+ def register_agent_endpoints(
57
+ app: FastAPI,
58
+ service: RemoteService,
59
+ long_polling_interval: float,
60
+ ) -> None:
61
+ tasks: dict[UUID, asyncio.Task[ResponseMessage | None]] = {}
62
+
63
+ @app.get(f"/{service.name}" + "/{task_id}", response_model=ResponseMessage | None)
64
+ async def remote_call_result(task_id: UUID) -> Response | ResponseMessage | None:
65
+ if task_id not in tasks:
66
+ raise HTTPException(
67
+ detail=f"`{task_id}` task not found",
68
+ status_code=status.HTTP_404_NOT_FOUND,
69
+ )
70
+
71
+ task = tasks[task_id]
72
+
73
+ await asyncio.wait(
74
+ (task, asyncio.create_task(asyncio.sleep(long_polling_interval))),
75
+ return_when=asyncio.FIRST_COMPLETED,
76
+ )
77
+
78
+ if not task.done():
79
+ return Response(status_code=status.HTTP_425_TOO_EARLY)
80
+
81
+ try:
82
+ reply = task.result() # Task inner errors raising here
83
+ finally:
84
+ # TODO: how to clear hanged tasks?
85
+ tasks.pop(task_id, None)
86
+
87
+ if reply is None:
88
+ return Response(status_code=status.HTTP_204_NO_CONTENT)
89
+
90
+ return reply
91
+
92
+ @app.post(f"/{service.name}", status_code=status.HTTP_202_ACCEPTED)
93
+ async def remote_call_starter(state: RequestMessage) -> UUID:
94
+ task, task_id = asyncio.create_task(service(state)), uuid4()
95
+ tasks[task_id] = task
96
+ return task_id
@@ -0,0 +1,12 @@
1
+ # Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from .messages import ToolCall, tools_message
6
+ from .test_agent import TestAgent
7
+
8
+ __all__ = (
9
+ "TestAgent",
10
+ "ToolCall",
11
+ "tools_message",
12
+ )
@@ -0,0 +1,45 @@
1
+ # Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from typing import Any
6
+ from uuid import uuid4
7
+
8
+ from pydantic_core import to_json
9
+
10
+ from autogen.events.agent_events import FunctionCall
11
+ from autogen.events.agent_events import ToolCall as RawToolCall
12
+
13
+
14
+ class ToolCall:
15
+ """Represents a tool call with a specified tool name and arguments.
16
+
17
+ Args:
18
+ tool_name: Tool name to call. Tool should be rigestered in Agent you send message.
19
+ arguments: keyword arguments to pass to the tool.
20
+ """
21
+
22
+ def __init__(self, tool_name: str, /, **arguments: Any) -> None:
23
+ self.tool_message = RawToolCall(
24
+ id=f"call_{uuid4()}",
25
+ type="function",
26
+ function=FunctionCall(name=tool_name, arguments=to_json(arguments).decode()),
27
+ ).model_dump()
28
+
29
+ def to_message(self) -> dict[str, Any]:
30
+ """Convert the tool call to a message format suitable for API calls.
31
+
32
+ Returns:
33
+ A dictionary containing the tool call in message format,
34
+ ready to be used in API requests or message queues.
35
+ """
36
+ return tools_message(self)
37
+
38
+
39
+ def tools_message(*tool_calls: ToolCall) -> dict[str, Any]:
40
+ """Convert multiple tool calls into a message format suitable for API calls.
41
+
42
+ Args:
43
+ *tool_calls: One or more ToolCall objects to convert.
44
+ """
45
+ return {"content": None, "tool_calls": [c.tool_message for c in tool_calls]}
@@ -0,0 +1,111 @@
1
+ # Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from collections.abc import Iterable
6
+ from dataclasses import dataclass
7
+ from types import TracebackType
8
+ from typing import Any, TypedDict
9
+
10
+ from autogen import ConversableAgent, ModelClient
11
+
12
+
13
+ class TestAgent:
14
+ """A context manager for testing ConversableAgent instances with predefined messages.
15
+
16
+ This class allows you to temporarily replace an agent's LLM client with a fake client
17
+ that returns predefined messages. It's useful for testing agent behavior without
18
+ making actual API calls.
19
+
20
+ Attributes:
21
+ agent (ConversableAgent): The agent to be tested.
22
+ messages (Iterable[str | dict[str, Any]]): An iterable of messages to be returned by the fake client.
23
+ suppress_messages_end (bool): Whether to suppress StopIteration exceptions from the fake client.
24
+
25
+ Example:
26
+ >>> with TestAgent(agent, ["Hello", "How are you?"]) as test_agent:
27
+ ... # Agent will respond with "Hello" then "How are you?"
28
+ ... pass
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ agent: ConversableAgent,
34
+ messages: Iterable[str | dict[str, Any]] = (),
35
+ *,
36
+ suppress_messages_end: bool = False,
37
+ ) -> None:
38
+ self.agent = agent
39
+
40
+ self.__original_human_input = self.agent.human_input_mode
41
+
42
+ self.__original_client = agent.client
43
+ self.__fake_client = FakeClient(messages)
44
+
45
+ self.suppress_messages_end = suppress_messages_end
46
+
47
+ def __enter__(self) -> None:
48
+ self.agent.human_input_mode = "NEVER"
49
+
50
+ self.__original_client = self.agent.client
51
+ self.agent.client = self.__fake_client # type: ignore[assignment]
52
+ return None
53
+
54
+ def __exit__(
55
+ self,
56
+ exc_type: type[BaseException] | None,
57
+ exc_value: BaseException | None,
58
+ traceback: TracebackType | None,
59
+ ) -> None | bool:
60
+ self.agent.human_input_mode = self.__original_human_input
61
+
62
+ self.agent.client = self.__original_client
63
+
64
+ if isinstance(exc_value, StopIteration):
65
+ # suppress fake client iterator ending
66
+ return self.suppress_messages_end
67
+ return None
68
+
69
+
70
+ class FakeClient:
71
+ def __init__(self, messages: Iterable[str | dict[str, Any]]) -> None:
72
+ # do not unpack messages to allow endless generators pass
73
+ self.choice_iterator = iter(map(convert_fake_message, messages))
74
+
75
+ self.total_usage_summary = None
76
+ self.actual_usage_summary = None
77
+
78
+ def create(self, **params: Any) -> ModelClient.ModelClientResponseProtocol:
79
+ choice = next(self.choice_iterator)
80
+ return FakeClientResponse(choices=[choice])
81
+
82
+ def extract_text_or_completion_object(
83
+ self,
84
+ response: "FakeClientResponse",
85
+ ) -> list[str] | list["FakeMessage"]:
86
+ return response.message_retrieval_function()
87
+
88
+
89
+ def convert_fake_message(message: str | dict[str, Any]) -> "FakeChoice":
90
+ if isinstance(message, str):
91
+ return FakeChoice({"content": message})
92
+ else:
93
+ return FakeChoice({"role": "assistant", **message}) # type: ignore[typeddict-item]
94
+
95
+
96
+ class FakeMessage(TypedDict):
97
+ content: str | dict[str, Any]
98
+
99
+
100
+ @dataclass
101
+ class FakeChoice(ModelClient.ModelClientResponseProtocol.Choice):
102
+ message: FakeMessage # type: ignore[assignment]
103
+
104
+
105
+ @dataclass
106
+ class FakeClientResponse(ModelClient.ModelClientResponseProtocol):
107
+ choices: list[FakeChoice]
108
+ model: str = "fake-model"
109
+
110
+ def message_retrieval_function(self) -> list[str] | list[FakeMessage]:
111
+ return [c.message for c in self.choices]
@@ -15,6 +15,7 @@ from ..doc_utils import export_module
15
15
  from ..fast_depends import Depends as FastDepends
16
16
  from ..fast_depends import inject
17
17
  from ..fast_depends.dependencies import model
18
+ from ..fast_depends.utils import is_coroutine_callable
18
19
 
19
20
  if TYPE_CHECKING:
20
21
  from ..agentchat.conversable_agent import ConversableAgent
@@ -140,10 +141,7 @@ def remove_params(func: Callable[..., Any], sig: inspect.Signature, params: Iter
140
141
 
141
142
 
142
143
  def _remove_injected_params_from_signature(func: Callable[..., Any]) -> Callable[..., Any]:
143
- # This is a workaround for Python 3.9+ where staticmethod.__func__ is accessible
144
- if sys.version_info >= (3, 9) and isinstance(func, staticmethod) and hasattr(func, "__func__"):
145
- func = _fix_staticmethod(func)
146
-
144
+ func = _fix_staticmethod(func)
147
145
  sig = inspect.signature(func)
148
146
  params_to_remove = [p.name for p in sig.parameters.values() if _is_context_param(p) or _is_depends_param(p)]
149
147
  remove_params(func, sig, params_to_remove)
@@ -205,7 +203,7 @@ def _fix_staticmethod(f: Callable[..., Any]) -> Callable[..., Any]:
205
203
 
206
204
 
207
205
  def _set_return_annotation_to_any(f: Callable[..., Any]) -> Callable[..., Any]:
208
- if inspect.iscoroutinefunction(f):
206
+ if is_coroutine_callable(f):
209
207
 
210
208
  @functools.wraps(f)
211
209
  async def _a_wrapped_func(*args: Any, **kwargs: Any) -> Any:
@@ -242,9 +240,7 @@ def inject_params(f: Callable[..., Any]) -> Callable[..., Any]:
242
240
  The modified function with injected dependencies and updated signature.
243
241
  """
244
242
  # This is a workaround for Python 3.9+ where staticmethod.__func__ is accessible
245
- if sys.version_info >= (3, 9) and isinstance(f, staticmethod) and hasattr(f, "__func__"):
246
- f = _fix_staticmethod(f)
247
-
243
+ f = _fix_staticmethod(f)
248
244
  f = _string_metadata_to_description_field(f)
249
245
  f = _set_return_annotation_to_any(f)
250
246
  f = inject(f)
@@ -26,6 +26,7 @@ from ....agentchat.group import AgentTarget, ReplyResult, TerminateTarget
26
26
  from ....agentchat.group.context_variables import ContextVariables
27
27
  from ....agentchat.group.patterns import DefaultPattern
28
28
  from ....doc_utils import export_module
29
+ from ....fast_depends.utils import is_coroutine_callable
29
30
  from ....llm_config import LLMConfig
30
31
  from ....tools.dependency_injection import Field as AG2Field
31
32
  from ....tools.tool import Tool
@@ -375,7 +376,7 @@ def reliable_function_wrapper(
375
376
  Adds 'hypothesis' and 'context_variables' keyword-only arguments.
376
377
  Returns a ReplyResult targeting the validator.
377
378
  """
378
- is_original_func_async = inspect.iscoroutinefunction(tool_function)
379
+ is_original_func_async = is_coroutine_callable(tool_function)
379
380
  tool_sig = inspect.signature(tool_function)
380
381
  wrapper_func: Callable[..., Any] # Declare type for wrapper_func
381
382
 
@@ -653,7 +654,7 @@ class ReliableTool(Tool):
653
654
  Example: `ground_truth=["The API rate limit is 10 requests per minute.", "User preference: only show results from the last 7 days."]`
654
655
  """
655
656
  self._original_func, original_name, original_description = self._extract_func_details(func_or_tool)
656
- self._is_original_func_async = inspect.iscoroutinefunction(self._original_func)
657
+ self._is_original_func_async = is_coroutine_callable(self._original_func)
657
658
 
658
659
  self._runner_llm_config = ConversableAgent._validate_llm_config(runner_llm_config)
659
660
  if self._runner_llm_config is False:
@@ -49,7 +49,7 @@ class WebSearchPreviewTool(Tool):
49
49
  The default is `None`, which means the text will be returned as a string.
50
50
  """
51
51
  self.web_search_tool_param = WebSearchToolParam(
52
- type="web_search_preview",
52
+ type="web_search",
53
53
  search_context_size=search_context_size,
54
54
  user_location=UserLocation(**user_location) if user_location else None, # type: ignore[typeddict-item]
55
55
  )
@@ -17,6 +17,7 @@ from pydantic import __version__ as pydantic_version
17
17
  from pydantic.json_schema import JsonSchemaValue
18
18
 
19
19
  from ..doc_utils import export_module
20
+ from ..fast_depends.utils import is_coroutine_callable
20
21
  from .dependency_injection import Field as AG2Field
21
22
 
22
23
  if parse(pydantic_version) < parse("2.10.2"):
@@ -381,7 +382,7 @@ def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]:
381
382
  # call the original function
382
383
  return await func(*args, **kwargs)
383
384
 
384
- if inspect.iscoroutinefunction(func):
385
+ if is_coroutine_callable(func):
385
386
  return _a_load_parameters_if_needed
386
387
  else:
387
388
  return _load_parameters_if_needed
autogen/version.py CHANGED
@@ -4,4 +4,4 @@
4
4
 
5
5
  __all__ = ["__version__"]
6
6
 
7
- __version__ = "0.9.9"
7
+ __version__ = "0.10.0"
File without changes