ag2 0.9.10__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ag2 might be problematic. Click here for more details.
- {ag2-0.9.10.dist-info → ag2-0.10.0.dist-info}/METADATA +14 -7
- {ag2-0.9.10.dist-info → ag2-0.10.0.dist-info}/RECORD +42 -24
- autogen/a2a/__init__.py +36 -0
- autogen/a2a/agent_executor.py +105 -0
- autogen/a2a/client.py +280 -0
- autogen/a2a/errors.py +18 -0
- autogen/a2a/httpx_client_factory.py +79 -0
- autogen/a2a/server.py +221 -0
- autogen/a2a/utils.py +165 -0
- autogen/agentchat/__init__.py +3 -0
- autogen/agentchat/agent.py +0 -2
- autogen/agentchat/chat.py +5 -1
- autogen/agentchat/contrib/llava_agent.py +1 -13
- autogen/agentchat/conversable_agent.py +178 -73
- autogen/agentchat/group/group_tool_executor.py +46 -15
- autogen/agentchat/group/guardrails.py +41 -33
- autogen/agentchat/group/multi_agent_chat.py +53 -0
- autogen/agentchat/group/safeguards/api.py +19 -2
- autogen/agentchat/group/safeguards/enforcer.py +134 -40
- autogen/agentchat/groupchat.py +45 -33
- autogen/agentchat/realtime/experimental/realtime_swarm.py +1 -3
- autogen/interop/pydantic_ai/pydantic_ai.py +1 -1
- autogen/llm_config/client.py +3 -2
- autogen/oai/bedrock.py +0 -13
- autogen/oai/client.py +15 -8
- autogen/oai/client_utils.py +30 -0
- autogen/oai/cohere.py +0 -10
- autogen/remote/__init__.py +18 -0
- autogen/remote/agent.py +199 -0
- autogen/remote/agent_service.py +142 -0
- autogen/remote/errors.py +17 -0
- autogen/remote/httpx_client_factory.py +131 -0
- autogen/remote/protocol.py +37 -0
- autogen/remote/retry.py +102 -0
- autogen/remote/runtime.py +96 -0
- autogen/testing/__init__.py +12 -0
- autogen/testing/messages.py +45 -0
- autogen/testing/test_agent.py +111 -0
- autogen/version.py +1 -1
- {ag2-0.9.10.dist-info → ag2-0.10.0.dist-info}/WHEEL +0 -0
- {ag2-0.9.10.dist-info → ag2-0.10.0.dist-info}/licenses/LICENSE +0 -0
- {ag2-0.9.10.dist-info → ag2-0.10.0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -7,7 +7,7 @@ import re
|
|
|
7
7
|
from abc import ABC, abstractmethod
|
|
8
8
|
from typing import TYPE_CHECKING, Any
|
|
9
9
|
|
|
10
|
-
from pydantic import BaseModel, Field
|
|
10
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
11
11
|
|
|
12
12
|
from ...oai.client import OpenAIWrapper
|
|
13
13
|
|
|
@@ -16,32 +16,6 @@ if TYPE_CHECKING:
|
|
|
16
16
|
from .targets.transition_target import TransitionTarget
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
class GuardrailResult(BaseModel):
|
|
20
|
-
"""Represents the outcome of a guardrail check."""
|
|
21
|
-
|
|
22
|
-
activated: bool
|
|
23
|
-
justification: str = Field(default="No justification provided")
|
|
24
|
-
|
|
25
|
-
def __str__(self) -> str:
|
|
26
|
-
return f"Guardrail Result: {self.activated}\nJustification: {self.justification}"
|
|
27
|
-
|
|
28
|
-
@staticmethod
|
|
29
|
-
def parse(text: str) -> "GuardrailResult":
|
|
30
|
-
"""Parses a JSON string into a GuardrailResult object.
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
text (str): The JSON string to parse.
|
|
34
|
-
|
|
35
|
-
Returns:
|
|
36
|
-
GuardrailResult: The parsed GuardrailResult object.
|
|
37
|
-
"""
|
|
38
|
-
try:
|
|
39
|
-
data = json.loads(text)
|
|
40
|
-
return GuardrailResult(**data)
|
|
41
|
-
except (json.JSONDecodeError, ValueError) as e:
|
|
42
|
-
raise ValueError(f"Failed to parse GuardrailResult from text: {text}") from e
|
|
43
|
-
|
|
44
|
-
|
|
45
19
|
class Guardrail(ABC):
|
|
46
20
|
"""Abstract base class for guardrails."""
|
|
47
21
|
|
|
@@ -59,7 +33,7 @@ class Guardrail(ABC):
|
|
|
59
33
|
def check(
|
|
60
34
|
self,
|
|
61
35
|
context: str | list[dict[str, Any]],
|
|
62
|
-
) -> GuardrailResult:
|
|
36
|
+
) -> "GuardrailResult":
|
|
63
37
|
"""Checks the text against the guardrail and returns a GuardrailResult.
|
|
64
38
|
|
|
65
39
|
Args:
|
|
@@ -99,7 +73,7 @@ You will activate the guardrail only if the condition is met.
|
|
|
99
73
|
def check(
|
|
100
74
|
self,
|
|
101
75
|
context: str | list[dict[str, Any]],
|
|
102
|
-
) -> GuardrailResult:
|
|
76
|
+
) -> "GuardrailResult":
|
|
103
77
|
"""Checks the context against the guardrail using an LLM.
|
|
104
78
|
|
|
105
79
|
Args:
|
|
@@ -120,7 +94,7 @@ You will activate the guardrail only if the condition is met.
|
|
|
120
94
|
raise ValueError("Context must be a string or a list of messages.")
|
|
121
95
|
# Call the LLM with the check messages
|
|
122
96
|
response = self.client.create(messages=check_messages)
|
|
123
|
-
return GuardrailResult.parse(response.choices[0].message.content) # type: ignore
|
|
97
|
+
return GuardrailResult.parse(response.choices[0].message.content, guardrail=self) # type: ignore
|
|
124
98
|
|
|
125
99
|
|
|
126
100
|
class RegexGuardrail(Guardrail):
|
|
@@ -143,7 +117,7 @@ class RegexGuardrail(Guardrail):
|
|
|
143
117
|
def check(
|
|
144
118
|
self,
|
|
145
119
|
context: str | list[dict[str, Any]],
|
|
146
|
-
) -> GuardrailResult:
|
|
120
|
+
) -> "GuardrailResult":
|
|
147
121
|
"""Checks the context against the guardrail using a regular expression.
|
|
148
122
|
|
|
149
123
|
Args:
|
|
@@ -167,5 +141,39 @@ class RegexGuardrail(Guardrail):
|
|
|
167
141
|
if match:
|
|
168
142
|
activated = True
|
|
169
143
|
justification = f"Match found -> {match.group(0)}"
|
|
170
|
-
return GuardrailResult(activated=activated, justification=justification)
|
|
171
|
-
return GuardrailResult(activated=False, justification="No match found in the context.")
|
|
144
|
+
return GuardrailResult(activated=activated, justification=justification, guardrail=self)
|
|
145
|
+
return GuardrailResult(activated=False, justification="No match found in the context.", guardrail=self)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class GuardrailResult(BaseModel):
|
|
149
|
+
"""Represents the outcome of a guardrail check."""
|
|
150
|
+
|
|
151
|
+
activated: bool
|
|
152
|
+
guardrail: Guardrail
|
|
153
|
+
justification: str = Field(default="No justification provided")
|
|
154
|
+
|
|
155
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
156
|
+
|
|
157
|
+
def __str__(self) -> str:
|
|
158
|
+
return f"Guardrail Result: {self.activated}\nJustification: {self.justification}"
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def reply(self) -> str:
|
|
162
|
+
return f"{self.guardrail.activation_message}\nJustification: {self.justification}"
|
|
163
|
+
|
|
164
|
+
@staticmethod
|
|
165
|
+
def parse(text: str, guardrail: "Guardrail") -> "GuardrailResult":
|
|
166
|
+
"""Parses a JSON string into a GuardrailResult object.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
text (str): The JSON string to parse.
|
|
170
|
+
guardrail (Guardrail): The guardrail that the result is for.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
GuardrailResult: The parsed GuardrailResult object.
|
|
174
|
+
"""
|
|
175
|
+
try:
|
|
176
|
+
data = json.loads(text)
|
|
177
|
+
return GuardrailResult(**data, guardrail=guardrail)
|
|
178
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
179
|
+
raise ValueError(f"Failed to parse GuardrailResult from text: {text}") from e
|
|
@@ -11,6 +11,7 @@ from ...events.agent_events import ErrorEvent, RunCompletionEvent
|
|
|
11
11
|
from ...io.base import IOStream
|
|
12
12
|
from ...io.run_response import AsyncRunResponse, AsyncRunResponseProtocol, RunResponse, RunResponseProtocol
|
|
13
13
|
from ...io.thread_io_stream import AsyncThreadIOStream, ThreadIOStream
|
|
14
|
+
from ...llm_config import LLMConfig
|
|
14
15
|
from ..chat import ChatResult
|
|
15
16
|
from .context_variables import ContextVariables
|
|
16
17
|
from .group_utils import cleanup_temp_user_messages
|
|
@@ -32,6 +33,9 @@ def initiate_group_chat(
|
|
|
32
33
|
pattern: "Pattern",
|
|
33
34
|
messages: list[dict[str, Any]] | str,
|
|
34
35
|
max_rounds: int = 20,
|
|
36
|
+
safeguard_policy: dict[str, Any] | str | None = None,
|
|
37
|
+
safeguard_llm_config: LLMConfig | None = None,
|
|
38
|
+
mask_llm_config: LLMConfig | None = None,
|
|
35
39
|
) -> tuple[ChatResult, ContextVariables, "Agent"]:
|
|
36
40
|
"""Initialize and run a group chat using a pattern for configuration.
|
|
37
41
|
|
|
@@ -39,6 +43,9 @@ def initiate_group_chat(
|
|
|
39
43
|
pattern: Pattern object that encapsulates the chat configuration.
|
|
40
44
|
messages: Initial message(s).
|
|
41
45
|
max_rounds: Maximum number of conversation rounds.
|
|
46
|
+
safeguard_policy: Optional safeguard policy dict or path to JSON file.
|
|
47
|
+
safeguard_llm_config: Optional LLM configuration for safeguard checks.
|
|
48
|
+
mask_llm_config: Optional LLM configuration for masking.
|
|
42
49
|
|
|
43
50
|
Returns:
|
|
44
51
|
ChatResult: Conversations chat history.
|
|
@@ -66,6 +73,17 @@ def initiate_group_chat(
|
|
|
66
73
|
messages=messages,
|
|
67
74
|
)
|
|
68
75
|
|
|
76
|
+
# Apply safeguards if provided
|
|
77
|
+
if safeguard_policy:
|
|
78
|
+
from .safeguards import apply_safeguard_policy
|
|
79
|
+
|
|
80
|
+
apply_safeguard_policy(
|
|
81
|
+
groupchat_manager=manager,
|
|
82
|
+
policy=safeguard_policy,
|
|
83
|
+
safeguard_llm_config=safeguard_llm_config,
|
|
84
|
+
mask_llm_config=mask_llm_config,
|
|
85
|
+
)
|
|
86
|
+
|
|
69
87
|
# Start or resume the conversation
|
|
70
88
|
if len(processed_messages) > 1:
|
|
71
89
|
last_agent, last_message = manager.resume(messages=processed_messages)
|
|
@@ -94,6 +112,9 @@ async def a_initiate_group_chat(
|
|
|
94
112
|
pattern: "Pattern",
|
|
95
113
|
messages: list[dict[str, Any]] | str,
|
|
96
114
|
max_rounds: int = 20,
|
|
115
|
+
safeguard_policy: dict[str, Any] | str | None = None,
|
|
116
|
+
safeguard_llm_config: LLMConfig | None = None,
|
|
117
|
+
mask_llm_config: LLMConfig | None = None,
|
|
97
118
|
) -> tuple[ChatResult, ContextVariables, "Agent"]:
|
|
98
119
|
"""Initialize and run a group chat using a pattern for configuration, asynchronously.
|
|
99
120
|
|
|
@@ -101,6 +122,9 @@ async def a_initiate_group_chat(
|
|
|
101
122
|
pattern: Pattern object that encapsulates the chat configuration.
|
|
102
123
|
messages: Initial message(s).
|
|
103
124
|
max_rounds: Maximum number of conversation rounds.
|
|
125
|
+
safeguard_policy: Optional safeguard policy dict or path to JSON file.
|
|
126
|
+
safeguard_llm_config: Optional LLM configuration for safeguard checks.
|
|
127
|
+
mask_llm_config: Optional LLM configuration for masking.
|
|
104
128
|
|
|
105
129
|
Returns:
|
|
106
130
|
ChatResult: Conversations chat history.
|
|
@@ -128,6 +152,17 @@ async def a_initiate_group_chat(
|
|
|
128
152
|
messages=messages,
|
|
129
153
|
)
|
|
130
154
|
|
|
155
|
+
# Apply safeguards if provided
|
|
156
|
+
if safeguard_policy:
|
|
157
|
+
from .safeguards import apply_safeguard_policy
|
|
158
|
+
|
|
159
|
+
apply_safeguard_policy(
|
|
160
|
+
groupchat_manager=manager,
|
|
161
|
+
policy=safeguard_policy,
|
|
162
|
+
safeguard_llm_config=safeguard_llm_config,
|
|
163
|
+
mask_llm_config=mask_llm_config,
|
|
164
|
+
)
|
|
165
|
+
|
|
131
166
|
# Start or resume the conversation
|
|
132
167
|
if len(processed_messages) > 1:
|
|
133
168
|
last_agent, last_message = await manager.a_resume(messages=processed_messages)
|
|
@@ -156,6 +191,9 @@ def run_group_chat(
|
|
|
156
191
|
pattern: "Pattern",
|
|
157
192
|
messages: list[dict[str, Any]] | str,
|
|
158
193
|
max_rounds: int = 20,
|
|
194
|
+
safeguard_policy: dict[str, Any] | str | None = None,
|
|
195
|
+
safeguard_llm_config: LLMConfig | None = None,
|
|
196
|
+
mask_llm_config: LLMConfig | None = None,
|
|
159
197
|
) -> RunResponseProtocol:
|
|
160
198
|
iostream = ThreadIOStream()
|
|
161
199
|
# todo: add agents
|
|
@@ -165,6 +203,9 @@ def run_group_chat(
|
|
|
165
203
|
pattern: "Pattern" = pattern,
|
|
166
204
|
messages: list[dict[str, Any]] | str = messages,
|
|
167
205
|
max_rounds: int = max_rounds,
|
|
206
|
+
safeguard_policy: dict[str, Any] | str | None = safeguard_policy,
|
|
207
|
+
safeguard_llm_config: LLMConfig | None = safeguard_llm_config,
|
|
208
|
+
mask_llm_config: LLMConfig | None = mask_llm_config,
|
|
168
209
|
iostream: ThreadIOStream = iostream,
|
|
169
210
|
response: RunResponse = response,
|
|
170
211
|
) -> None:
|
|
@@ -174,6 +215,9 @@ def run_group_chat(
|
|
|
174
215
|
pattern=pattern,
|
|
175
216
|
messages=messages,
|
|
176
217
|
max_rounds=max_rounds,
|
|
218
|
+
safeguard_policy=safeguard_policy,
|
|
219
|
+
safeguard_llm_config=safeguard_llm_config,
|
|
220
|
+
mask_llm_config=mask_llm_config,
|
|
177
221
|
)
|
|
178
222
|
|
|
179
223
|
IOStream.get_default().send(
|
|
@@ -200,6 +244,9 @@ async def a_run_group_chat(
|
|
|
200
244
|
pattern: "Pattern",
|
|
201
245
|
messages: list[dict[str, Any]] | str,
|
|
202
246
|
max_rounds: int = 20,
|
|
247
|
+
safeguard_policy: dict[str, Any] | str | None = None,
|
|
248
|
+
safeguard_llm_config: LLMConfig | None = None,
|
|
249
|
+
mask_llm_config: LLMConfig | None = None,
|
|
203
250
|
) -> AsyncRunResponseProtocol:
|
|
204
251
|
iostream = AsyncThreadIOStream()
|
|
205
252
|
# todo: add agents
|
|
@@ -209,6 +256,9 @@ async def a_run_group_chat(
|
|
|
209
256
|
pattern: "Pattern" = pattern,
|
|
210
257
|
messages: list[dict[str, Any]] | str = messages,
|
|
211
258
|
max_rounds: int = max_rounds,
|
|
259
|
+
safeguard_policy: dict[str, Any] | str | None = safeguard_policy,
|
|
260
|
+
safeguard_llm_config: LLMConfig | None = safeguard_llm_config,
|
|
261
|
+
mask_llm_config: LLMConfig | None = mask_llm_config,
|
|
212
262
|
iostream: AsyncThreadIOStream = iostream,
|
|
213
263
|
response: AsyncRunResponse = response,
|
|
214
264
|
) -> None:
|
|
@@ -218,6 +268,9 @@ async def a_run_group_chat(
|
|
|
218
268
|
pattern=pattern,
|
|
219
269
|
messages=messages,
|
|
220
270
|
max_rounds=max_rounds,
|
|
271
|
+
safeguard_policy=safeguard_policy,
|
|
272
|
+
safeguard_llm_config=safeguard_llm_config,
|
|
273
|
+
mask_llm_config=mask_llm_config,
|
|
221
274
|
)
|
|
222
275
|
|
|
223
276
|
IOStream.get_default().send(
|
|
@@ -158,6 +158,8 @@ def apply_safeguard_policy(
|
|
|
158
158
|
policy=policy,
|
|
159
159
|
safeguard_llm_config=safeguard_llm_config,
|
|
160
160
|
mask_llm_config=mask_llm_config,
|
|
161
|
+
groupchat_manager=groupchat_manager,
|
|
162
|
+
agents=agents,
|
|
161
163
|
)
|
|
162
164
|
|
|
163
165
|
# Determine which agents to apply safeguards to
|
|
@@ -168,8 +170,14 @@ def apply_safeguard_policy(
|
|
|
168
170
|
if not isinstance(groupchat_manager, GroupChatManager):
|
|
169
171
|
raise ValueError("groupchat_manager must be an instance of GroupChatManager")
|
|
170
172
|
|
|
171
|
-
target_agents.extend([
|
|
172
|
-
|
|
173
|
+
target_agents.extend([
|
|
174
|
+
agent
|
|
175
|
+
for agent in groupchat_manager.groupchat.agents
|
|
176
|
+
if hasattr(agent, "hook_lists") and agent.name != "_Group_Tool_Executor"
|
|
177
|
+
])
|
|
178
|
+
all_agent_names = [
|
|
179
|
+
agent.name for agent in groupchat_manager.groupchat.agents if agent.name != "_Group_Tool_Executor"
|
|
180
|
+
]
|
|
173
181
|
all_agent_names.append(groupchat_manager.name)
|
|
174
182
|
|
|
175
183
|
# Register inter-agent guardrails with the groupchat
|
|
@@ -221,4 +229,13 @@ def apply_safeguard_policy(
|
|
|
221
229
|
f"Agent {agent.name} does not support hooks. Please ensure it inherits from ConversableAgent."
|
|
222
230
|
)
|
|
223
231
|
|
|
232
|
+
# Apply hooks to GroupToolExecutor if it exists (for GroupChat scenarios)
|
|
233
|
+
if groupchat_manager and enforcer.group_tool_executor and hasattr(enforcer.group_tool_executor, "hook_lists"):
|
|
234
|
+
# Create hooks for GroupToolExecutor - it needs tool interaction hooks
|
|
235
|
+
# since it's the one actually executing tools in GroupChat
|
|
236
|
+
hooks = enforcer.create_agent_hooks(enforcer.group_tool_executor.name)
|
|
237
|
+
for hook_name, hook_func in hooks.items():
|
|
238
|
+
if hook_name in enforcer.group_tool_executor.hook_lists:
|
|
239
|
+
enforcer.group_tool_executor.hook_lists[hook_name].append(hook_func)
|
|
240
|
+
|
|
224
241
|
return enforcer
|
|
@@ -9,8 +9,11 @@ import re
|
|
|
9
9
|
from collections.abc import Callable
|
|
10
10
|
from typing import Any
|
|
11
11
|
|
|
12
|
+
from ....code_utils import content_str
|
|
12
13
|
from ....io.base import IOStream
|
|
13
14
|
from ....llm_config import LLMConfig
|
|
15
|
+
from ...conversable_agent import ConversableAgent
|
|
16
|
+
from ...groupchat import GroupChatManager
|
|
14
17
|
from ..guardrails import LLMGuardrail, RegexGuardrail
|
|
15
18
|
from ..targets.transition_target import TransitionTarget
|
|
16
19
|
from .events import SafeguardEvent
|
|
@@ -19,11 +22,22 @@ from .events import SafeguardEvent
|
|
|
19
22
|
class SafeguardEnforcer:
|
|
20
23
|
"""Main safeguard enforcer - executes safeguard policies"""
|
|
21
24
|
|
|
25
|
+
@staticmethod
|
|
26
|
+
def _stringify_content(value: Any) -> str:
|
|
27
|
+
if isinstance(value, (str, list)) or value is None:
|
|
28
|
+
try:
|
|
29
|
+
return content_str(value)
|
|
30
|
+
except (TypeError, ValueError, AssertionError):
|
|
31
|
+
pass
|
|
32
|
+
return "" if value is None else str(value)
|
|
33
|
+
|
|
22
34
|
def __init__(
|
|
23
35
|
self,
|
|
24
36
|
policy: dict[str, Any] | str,
|
|
25
37
|
safeguard_llm_config: LLMConfig | dict[str, Any] | None = None,
|
|
26
38
|
mask_llm_config: LLMConfig | dict[str, Any] | None = None,
|
|
39
|
+
groupchat_manager: GroupChatManager | None = None,
|
|
40
|
+
agents: list[ConversableAgent] | None = None,
|
|
27
41
|
):
|
|
28
42
|
"""Initialize the safeguard enforcer.
|
|
29
43
|
|
|
@@ -31,10 +45,20 @@ class SafeguardEnforcer:
|
|
|
31
45
|
policy: Safeguard policy dict or path to JSON file
|
|
32
46
|
safeguard_llm_config: LLM configuration for safeguard checks
|
|
33
47
|
mask_llm_config: LLM configuration for masking
|
|
48
|
+
groupchat_manager: GroupChat manager instance for group chat scenarios
|
|
49
|
+
agents: List of conversable agents to apply safeguards to
|
|
34
50
|
"""
|
|
35
51
|
self.policy = self._load_policy(policy)
|
|
36
52
|
self.safeguard_llm_config = safeguard_llm_config
|
|
37
53
|
self.mask_llm_config = mask_llm_config
|
|
54
|
+
self.groupchat_manager = groupchat_manager
|
|
55
|
+
self.agents = agents
|
|
56
|
+
self.group_tool_executor = None
|
|
57
|
+
if self.groupchat_manager:
|
|
58
|
+
for agent in self.groupchat_manager.groupchat.agents:
|
|
59
|
+
if agent.name == "_Group_Tool_Executor":
|
|
60
|
+
self.group_tool_executor = agent # type: ignore[assignment]
|
|
61
|
+
break
|
|
38
62
|
|
|
39
63
|
# Validate policy format before proceeding
|
|
40
64
|
self._validate_policy()
|
|
@@ -351,18 +375,21 @@ class SafeguardEnforcer:
|
|
|
351
375
|
hooks = {}
|
|
352
376
|
|
|
353
377
|
# Check if we have any tool interaction rules that apply to this agent
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
for rule in self.environment_rules
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
rule
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
378
|
+
if agent_name == "_Group_Tool_Executor":
|
|
379
|
+
# group tool executor is running all tools, so we need to check all tool interaction rules
|
|
380
|
+
agent_tool_rules = [rule for rule in self.environment_rules if rule["type"] == "tool_interaction"]
|
|
381
|
+
else:
|
|
382
|
+
agent_tool_rules = [
|
|
383
|
+
rule
|
|
384
|
+
for rule in self.environment_rules
|
|
385
|
+
if rule["type"] == "tool_interaction"
|
|
386
|
+
and (
|
|
387
|
+
rule.get("message_destination") == agent_name
|
|
388
|
+
or rule.get("message_source") == agent_name
|
|
389
|
+
or rule.get("agent_name") == agent_name
|
|
390
|
+
or "message_destination" not in rule
|
|
391
|
+
)
|
|
392
|
+
]
|
|
366
393
|
if agent_tool_rules:
|
|
367
394
|
|
|
368
395
|
def tool_input_hook(tool_input: dict[str, Any]) -> dict[str, Any] | None:
|
|
@@ -624,7 +651,7 @@ class SafeguardEnforcer:
|
|
|
624
651
|
# Handle tool_calls (like in tool inputs)
|
|
625
652
|
elif "tool_calls" in blocked_content and blocked_content["tool_calls"]:
|
|
626
653
|
blocked_content["tool_calls"] = [
|
|
627
|
-
{**tool_call, "function": {**tool_call["function"], "arguments": block_msg}}
|
|
654
|
+
{**tool_call, "function": {**tool_call["function"], "arguments": json.dumps({"error": block_msg})}}
|
|
628
655
|
for tool_call in blocked_content["tool_calls"]
|
|
629
656
|
]
|
|
630
657
|
# Handle regular content
|
|
@@ -649,7 +676,10 @@ class SafeguardEnforcer:
|
|
|
649
676
|
blocked_item["content"] = block_msg
|
|
650
677
|
if "tool_calls" in blocked_item:
|
|
651
678
|
blocked_item["tool_calls"] = [
|
|
652
|
-
{
|
|
679
|
+
{
|
|
680
|
+
**tool_call,
|
|
681
|
+
"function": {**tool_call["function"], "arguments": json.dumps({"error": block_msg})},
|
|
682
|
+
}
|
|
653
683
|
for tool_call in blocked_item["tool_calls"]
|
|
654
684
|
]
|
|
655
685
|
if "tool_responses" in blocked_item:
|
|
@@ -675,9 +705,12 @@ class SafeguardEnforcer:
|
|
|
675
705
|
# Handle tool_responses
|
|
676
706
|
if "tool_responses" in masked_content and masked_content["tool_responses"]:
|
|
677
707
|
if "content" in masked_content:
|
|
678
|
-
masked_content["content"] = mask_func(
|
|
708
|
+
masked_content["content"] = mask_func(self._stringify_content(masked_content.get("content")))
|
|
679
709
|
masked_content["tool_responses"] = [
|
|
680
|
-
{
|
|
710
|
+
{
|
|
711
|
+
**response,
|
|
712
|
+
"content": mask_func(self._stringify_content(response.get("content"))),
|
|
713
|
+
}
|
|
681
714
|
for response in masked_content["tool_responses"]
|
|
682
715
|
]
|
|
683
716
|
# Handle tool_calls
|
|
@@ -687,17 +720,17 @@ class SafeguardEnforcer:
|
|
|
687
720
|
**tool_call,
|
|
688
721
|
"function": {
|
|
689
722
|
**tool_call["function"],
|
|
690
|
-
"arguments": mask_func(
|
|
723
|
+
"arguments": mask_func(self._stringify_content(tool_call["function"].get("arguments"))),
|
|
691
724
|
},
|
|
692
725
|
}
|
|
693
726
|
for tool_call in masked_content["tool_calls"]
|
|
694
727
|
]
|
|
695
728
|
# Handle regular content
|
|
696
729
|
elif "content" in masked_content:
|
|
697
|
-
masked_content["content"] = mask_func(
|
|
730
|
+
masked_content["content"] = mask_func(self._stringify_content(masked_content.get("content")))
|
|
698
731
|
# Handle arguments
|
|
699
732
|
elif "arguments" in masked_content:
|
|
700
|
-
masked_content["arguments"] = mask_func(
|
|
733
|
+
masked_content["arguments"] = mask_func(self._stringify_content(masked_content.get("arguments")))
|
|
701
734
|
|
|
702
735
|
return masked_content
|
|
703
736
|
|
|
@@ -708,39 +741,64 @@ class SafeguardEnforcer:
|
|
|
708
741
|
if isinstance(item, dict):
|
|
709
742
|
masked_item = item.copy()
|
|
710
743
|
if "content" in masked_item:
|
|
711
|
-
masked_item["content"] = mask_func(
|
|
744
|
+
masked_item["content"] = mask_func(self._stringify_content(masked_item.get("content")))
|
|
712
745
|
if "tool_calls" in masked_item:
|
|
713
746
|
masked_item["tool_calls"] = [
|
|
714
747
|
{
|
|
715
748
|
**tool_call,
|
|
716
749
|
"function": {
|
|
717
750
|
**tool_call["function"],
|
|
718
|
-
"arguments": mask_func(
|
|
751
|
+
"arguments": mask_func(
|
|
752
|
+
self._stringify_content(tool_call["function"].get("arguments"))
|
|
753
|
+
),
|
|
719
754
|
},
|
|
720
755
|
}
|
|
721
756
|
for tool_call in masked_item["tool_calls"]
|
|
722
757
|
]
|
|
723
758
|
if "tool_responses" in masked_item:
|
|
724
759
|
masked_item["tool_responses"] = [
|
|
725
|
-
{
|
|
760
|
+
{
|
|
761
|
+
**response,
|
|
762
|
+
"content": mask_func(self._stringify_content(response.get("content"))),
|
|
763
|
+
}
|
|
726
764
|
for response in masked_item["tool_responses"]
|
|
727
765
|
]
|
|
728
766
|
masked_list.append(masked_item)
|
|
729
767
|
else:
|
|
730
768
|
# For non-dict items, wrap the masked content in a dict
|
|
731
|
-
masked_item_content: str = mask_func(
|
|
769
|
+
masked_item_content: str = mask_func(self._stringify_content(item))
|
|
732
770
|
masked_list.append({"content": masked_item_content, "role": "function"})
|
|
733
771
|
return masked_list
|
|
734
772
|
|
|
735
773
|
else:
|
|
736
774
|
# String content
|
|
737
|
-
return mask_func(
|
|
775
|
+
return mask_func(self._stringify_content(content))
|
|
738
776
|
|
|
739
777
|
def _check_inter_agent_communication(
|
|
740
778
|
self, sender_name: str, recipient_name: str, message: str | dict[str, Any]
|
|
741
779
|
) -> str | dict[str, Any]:
|
|
742
780
|
"""Check inter-agent communication."""
|
|
743
|
-
|
|
781
|
+
if isinstance(message, dict):
|
|
782
|
+
if "tool_calls" in message and isinstance(message["tool_calls"], list):
|
|
783
|
+
# Extract arguments from all tool calls and combine them
|
|
784
|
+
tool_args = []
|
|
785
|
+
for tool_call in message["tool_calls"]:
|
|
786
|
+
if "function" in tool_call and "arguments" in tool_call["function"]:
|
|
787
|
+
tool_args.append(tool_call["function"]["arguments"])
|
|
788
|
+
content_to_check = " | ".join(tool_args) if tool_args else ""
|
|
789
|
+
elif "tool_responses" in message and isinstance(message["tool_responses"], list):
|
|
790
|
+
# Extract content from all tool responses and combine them
|
|
791
|
+
tool_contents = []
|
|
792
|
+
for tool_response in message["tool_responses"]:
|
|
793
|
+
if "content" in tool_response:
|
|
794
|
+
tool_contents.append(str(tool_response["content"]))
|
|
795
|
+
content_to_check = " | ".join(tool_contents) if tool_contents else ""
|
|
796
|
+
else:
|
|
797
|
+
content_to_check = str(message.get("content", ""))
|
|
798
|
+
elif isinstance(message, str):
|
|
799
|
+
content_to_check = message
|
|
800
|
+
else:
|
|
801
|
+
raise ValueError("Message must be a dictionary or a string")
|
|
744
802
|
|
|
745
803
|
for rule in self.inter_agent_rules:
|
|
746
804
|
if rule["type"] == "agent_transition":
|
|
@@ -750,7 +808,7 @@ class SafeguardEnforcer:
|
|
|
750
808
|
|
|
751
809
|
if source_match and target_match:
|
|
752
810
|
# Prepare content preview
|
|
753
|
-
content_preview =
|
|
811
|
+
content_preview = str(content_to_check)[:100] + ("..." if len(str(content_to_check)) > 100 else "")
|
|
754
812
|
|
|
755
813
|
# Use guardrail if available
|
|
756
814
|
if "guardrail" in rule and rule["guardrail"]:
|
|
@@ -766,7 +824,7 @@ class SafeguardEnforcer:
|
|
|
766
824
|
)
|
|
767
825
|
|
|
768
826
|
try:
|
|
769
|
-
result = rule["guardrail"].check(
|
|
827
|
+
result = rule["guardrail"].check(content_to_check)
|
|
770
828
|
if result.activated:
|
|
771
829
|
self._send_safeguard_event(
|
|
772
830
|
event_type="violation",
|
|
@@ -812,7 +870,7 @@ class SafeguardEnforcer:
|
|
|
812
870
|
# action=rule.get('action', 'N/A'),
|
|
813
871
|
content_preview=content_preview,
|
|
814
872
|
)
|
|
815
|
-
is_violation, explanation = self._check_regex_violation(
|
|
873
|
+
is_violation, explanation = self._check_regex_violation(content_to_check, rule["pattern"])
|
|
816
874
|
if is_violation:
|
|
817
875
|
result_value = self._apply_action(
|
|
818
876
|
action=rule["action"],
|
|
@@ -847,11 +905,11 @@ class SafeguardEnforcer:
|
|
|
847
905
|
)
|
|
848
906
|
if "custom_prompt" in rule:
|
|
849
907
|
is_violation, explanation = self._check_llm_violation(
|
|
850
|
-
|
|
908
|
+
content_to_check, custom_prompt=rule["custom_prompt"]
|
|
851
909
|
)
|
|
852
910
|
else:
|
|
853
911
|
is_violation, explanation = self._check_llm_violation(
|
|
854
|
-
|
|
912
|
+
content_to_check, disallow_items=rule["disallow"]
|
|
855
913
|
)
|
|
856
914
|
|
|
857
915
|
if is_violation:
|
|
@@ -971,12 +1029,20 @@ class SafeguardEnforcer:
|
|
|
971
1029
|
# Extract tool name from data
|
|
972
1030
|
tool_name = data.get("name", data.get("tool_name", ""))
|
|
973
1031
|
|
|
1032
|
+
# Resolve the actual agent name if this is GroupToolExecutor
|
|
1033
|
+
actual_agent_name = agent_name
|
|
1034
|
+
if agent_name == "_Group_Tool_Executor" and self.group_tool_executor:
|
|
1035
|
+
# Get the original tool caller from GroupToolExecutor
|
|
1036
|
+
originator = self.group_tool_executor.get_tool_call_originator() # type: ignore[attr-defined]
|
|
1037
|
+
if originator:
|
|
1038
|
+
actual_agent_name = originator
|
|
1039
|
+
|
|
974
1040
|
# Determine source/destination based on direction
|
|
975
1041
|
if direction == "output":
|
|
976
|
-
source_name, dest_name = tool_name,
|
|
1042
|
+
source_name, dest_name = tool_name, actual_agent_name
|
|
977
1043
|
content = str(data.get("content", ""))
|
|
978
1044
|
else: # input
|
|
979
|
-
source_name, dest_name =
|
|
1045
|
+
source_name, dest_name = actual_agent_name, tool_name
|
|
980
1046
|
content = str(data.get("arguments", ""))
|
|
981
1047
|
|
|
982
1048
|
result = self._check_interaction(
|
|
@@ -985,7 +1051,7 @@ class SafeguardEnforcer:
|
|
|
985
1051
|
dest_name=dest_name,
|
|
986
1052
|
content=content,
|
|
987
1053
|
data=data,
|
|
988
|
-
context_info=f"{
|
|
1054
|
+
context_info=f"{actual_agent_name} <-> {tool_name} ({direction})",
|
|
989
1055
|
)
|
|
990
1056
|
|
|
991
1057
|
if result is not None:
|
|
@@ -1050,15 +1116,43 @@ class SafeguardEnforcer:
|
|
|
1050
1116
|
Returns:
|
|
1051
1117
|
Optional replacement message if a safeguard triggers, None otherwise
|
|
1052
1118
|
"""
|
|
1053
|
-
#
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1119
|
+
# Handle GroupToolExecutor transparency for safeguards
|
|
1120
|
+
if src_agent_name == "_Group_Tool_Executor":
|
|
1121
|
+
actual_src_agent_name = self._resolve_tool_executor_source(src_agent_name, self.group_tool_executor)
|
|
1122
|
+
else:
|
|
1123
|
+
actual_src_agent_name = src_agent_name
|
|
1124
|
+
|
|
1125
|
+
# Store original message for comparison
|
|
1126
|
+
original_message = message_content
|
|
1057
1127
|
|
|
1058
|
-
result = self._check_inter_agent_communication(
|
|
1128
|
+
result = self._check_inter_agent_communication(actual_src_agent_name, dst_agent_name, message_content)
|
|
1059
1129
|
|
|
1060
|
-
if result
|
|
1061
|
-
|
|
1130
|
+
# Check if the result is different from the original
|
|
1131
|
+
if result != original_message:
|
|
1062
1132
|
return result
|
|
1063
1133
|
|
|
1064
1134
|
return None
|
|
1135
|
+
|
|
1136
|
+
def _resolve_tool_executor_source(self, src_agent_name: str, tool_executor: Any = None) -> str:
|
|
1137
|
+
"""Resolve the actual source agent when GroupToolExecutor is involved.
|
|
1138
|
+
|
|
1139
|
+
When src_agent_name is "_Group_Tool_Executor", get the original agent who called the tool.
|
|
1140
|
+
|
|
1141
|
+
Args:
|
|
1142
|
+
src_agent_name: The source agent name from the communication
|
|
1143
|
+
tool_executor: GroupToolExecutor instance for getting originator
|
|
1144
|
+
|
|
1145
|
+
Returns:
|
|
1146
|
+
The actual source agent name (original tool caller for tool responses)
|
|
1147
|
+
"""
|
|
1148
|
+
if src_agent_name != "_Group_Tool_Executor":
|
|
1149
|
+
return src_agent_name
|
|
1150
|
+
|
|
1151
|
+
# Handle GroupToolExecutor - get the original tool caller
|
|
1152
|
+
if tool_executor and hasattr(tool_executor, "get_tool_call_originator"):
|
|
1153
|
+
originator = tool_executor.get_tool_call_originator()
|
|
1154
|
+
if originator:
|
|
1155
|
+
return originator # type: ignore[no-any-return]
|
|
1156
|
+
|
|
1157
|
+
# Fallback: Could not determine original caller
|
|
1158
|
+
return "tool_executor"
|