ag2 0.9.10__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ag2 might be problematic. Click here for more details.

Files changed (42) hide show
  1. {ag2-0.9.10.dist-info → ag2-0.10.0.dist-info}/METADATA +14 -7
  2. {ag2-0.9.10.dist-info → ag2-0.10.0.dist-info}/RECORD +42 -24
  3. autogen/a2a/__init__.py +36 -0
  4. autogen/a2a/agent_executor.py +105 -0
  5. autogen/a2a/client.py +280 -0
  6. autogen/a2a/errors.py +18 -0
  7. autogen/a2a/httpx_client_factory.py +79 -0
  8. autogen/a2a/server.py +221 -0
  9. autogen/a2a/utils.py +165 -0
  10. autogen/agentchat/__init__.py +3 -0
  11. autogen/agentchat/agent.py +0 -2
  12. autogen/agentchat/chat.py +5 -1
  13. autogen/agentchat/contrib/llava_agent.py +1 -13
  14. autogen/agentchat/conversable_agent.py +178 -73
  15. autogen/agentchat/group/group_tool_executor.py +46 -15
  16. autogen/agentchat/group/guardrails.py +41 -33
  17. autogen/agentchat/group/multi_agent_chat.py +53 -0
  18. autogen/agentchat/group/safeguards/api.py +19 -2
  19. autogen/agentchat/group/safeguards/enforcer.py +134 -40
  20. autogen/agentchat/groupchat.py +45 -33
  21. autogen/agentchat/realtime/experimental/realtime_swarm.py +1 -3
  22. autogen/interop/pydantic_ai/pydantic_ai.py +1 -1
  23. autogen/llm_config/client.py +3 -2
  24. autogen/oai/bedrock.py +0 -13
  25. autogen/oai/client.py +15 -8
  26. autogen/oai/client_utils.py +30 -0
  27. autogen/oai/cohere.py +0 -10
  28. autogen/remote/__init__.py +18 -0
  29. autogen/remote/agent.py +199 -0
  30. autogen/remote/agent_service.py +142 -0
  31. autogen/remote/errors.py +17 -0
  32. autogen/remote/httpx_client_factory.py +131 -0
  33. autogen/remote/protocol.py +37 -0
  34. autogen/remote/retry.py +102 -0
  35. autogen/remote/runtime.py +96 -0
  36. autogen/testing/__init__.py +12 -0
  37. autogen/testing/messages.py +45 -0
  38. autogen/testing/test_agent.py +111 -0
  39. autogen/version.py +1 -1
  40. {ag2-0.9.10.dist-info → ag2-0.10.0.dist-info}/WHEEL +0 -0
  41. {ag2-0.9.10.dist-info → ag2-0.10.0.dist-info}/licenses/LICENSE +0 -0
  42. {ag2-0.9.10.dist-info → ag2-0.10.0.dist-info}/licenses/NOTICE.md +0 -0
@@ -7,7 +7,7 @@ import re
7
7
  from abc import ABC, abstractmethod
8
8
  from typing import TYPE_CHECKING, Any
9
9
 
10
- from pydantic import BaseModel, Field
10
+ from pydantic import BaseModel, ConfigDict, Field
11
11
 
12
12
  from ...oai.client import OpenAIWrapper
13
13
 
@@ -16,32 +16,6 @@ if TYPE_CHECKING:
16
16
  from .targets.transition_target import TransitionTarget
17
17
 
18
18
 
19
- class GuardrailResult(BaseModel):
20
- """Represents the outcome of a guardrail check."""
21
-
22
- activated: bool
23
- justification: str = Field(default="No justification provided")
24
-
25
- def __str__(self) -> str:
26
- return f"Guardrail Result: {self.activated}\nJustification: {self.justification}"
27
-
28
- @staticmethod
29
- def parse(text: str) -> "GuardrailResult":
30
- """Parses a JSON string into a GuardrailResult object.
31
-
32
- Args:
33
- text (str): The JSON string to parse.
34
-
35
- Returns:
36
- GuardrailResult: The parsed GuardrailResult object.
37
- """
38
- try:
39
- data = json.loads(text)
40
- return GuardrailResult(**data)
41
- except (json.JSONDecodeError, ValueError) as e:
42
- raise ValueError(f"Failed to parse GuardrailResult from text: {text}") from e
43
-
44
-
45
19
  class Guardrail(ABC):
46
20
  """Abstract base class for guardrails."""
47
21
 
@@ -59,7 +33,7 @@ class Guardrail(ABC):
59
33
  def check(
60
34
  self,
61
35
  context: str | list[dict[str, Any]],
62
- ) -> GuardrailResult:
36
+ ) -> "GuardrailResult":
63
37
  """Checks the text against the guardrail and returns a GuardrailResult.
64
38
 
65
39
  Args:
@@ -99,7 +73,7 @@ You will activate the guardrail only if the condition is met.
99
73
  def check(
100
74
  self,
101
75
  context: str | list[dict[str, Any]],
102
- ) -> GuardrailResult:
76
+ ) -> "GuardrailResult":
103
77
  """Checks the context against the guardrail using an LLM.
104
78
 
105
79
  Args:
@@ -120,7 +94,7 @@ You will activate the guardrail only if the condition is met.
120
94
  raise ValueError("Context must be a string or a list of messages.")
121
95
  # Call the LLM with the check messages
122
96
  response = self.client.create(messages=check_messages)
123
- return GuardrailResult.parse(response.choices[0].message.content) # type: ignore
97
+ return GuardrailResult.parse(response.choices[0].message.content, guardrail=self) # type: ignore
124
98
 
125
99
 
126
100
  class RegexGuardrail(Guardrail):
@@ -143,7 +117,7 @@ class RegexGuardrail(Guardrail):
143
117
  def check(
144
118
  self,
145
119
  context: str | list[dict[str, Any]],
146
- ) -> GuardrailResult:
120
+ ) -> "GuardrailResult":
147
121
  """Checks the context against the guardrail using a regular expression.
148
122
 
149
123
  Args:
@@ -167,5 +141,39 @@ class RegexGuardrail(Guardrail):
167
141
  if match:
168
142
  activated = True
169
143
  justification = f"Match found -> {match.group(0)}"
170
- return GuardrailResult(activated=activated, justification=justification)
171
- return GuardrailResult(activated=False, justification="No match found in the context.")
144
+ return GuardrailResult(activated=activated, justification=justification, guardrail=self)
145
+ return GuardrailResult(activated=False, justification="No match found in the context.", guardrail=self)
146
+
147
+
148
+ class GuardrailResult(BaseModel):
149
+ """Represents the outcome of a guardrail check."""
150
+
151
+ activated: bool
152
+ guardrail: Guardrail
153
+ justification: str = Field(default="No justification provided")
154
+
155
+ model_config = ConfigDict(arbitrary_types_allowed=True)
156
+
157
+ def __str__(self) -> str:
158
+ return f"Guardrail Result: {self.activated}\nJustification: {self.justification}"
159
+
160
+ @property
161
+ def reply(self) -> str:
162
+ return f"{self.guardrail.activation_message}\nJustification: {self.justification}"
163
+
164
+ @staticmethod
165
+ def parse(text: str, guardrail: "Guardrail") -> "GuardrailResult":
166
+ """Parses a JSON string into a GuardrailResult object.
167
+
168
+ Args:
169
+ text (str): The JSON string to parse.
170
+ guardrail (Guardrail): The guardrail that the result is for.
171
+
172
+ Returns:
173
+ GuardrailResult: The parsed GuardrailResult object.
174
+ """
175
+ try:
176
+ data = json.loads(text)
177
+ return GuardrailResult(**data, guardrail=guardrail)
178
+ except (json.JSONDecodeError, ValueError) as e:
179
+ raise ValueError(f"Failed to parse GuardrailResult from text: {text}") from e
@@ -11,6 +11,7 @@ from ...events.agent_events import ErrorEvent, RunCompletionEvent
11
11
  from ...io.base import IOStream
12
12
  from ...io.run_response import AsyncRunResponse, AsyncRunResponseProtocol, RunResponse, RunResponseProtocol
13
13
  from ...io.thread_io_stream import AsyncThreadIOStream, ThreadIOStream
14
+ from ...llm_config import LLMConfig
14
15
  from ..chat import ChatResult
15
16
  from .context_variables import ContextVariables
16
17
  from .group_utils import cleanup_temp_user_messages
@@ -32,6 +33,9 @@ def initiate_group_chat(
32
33
  pattern: "Pattern",
33
34
  messages: list[dict[str, Any]] | str,
34
35
  max_rounds: int = 20,
36
+ safeguard_policy: dict[str, Any] | str | None = None,
37
+ safeguard_llm_config: LLMConfig | None = None,
38
+ mask_llm_config: LLMConfig | None = None,
35
39
  ) -> tuple[ChatResult, ContextVariables, "Agent"]:
36
40
  """Initialize and run a group chat using a pattern for configuration.
37
41
 
@@ -39,6 +43,9 @@ def initiate_group_chat(
39
43
  pattern: Pattern object that encapsulates the chat configuration.
40
44
  messages: Initial message(s).
41
45
  max_rounds: Maximum number of conversation rounds.
46
+ safeguard_policy: Optional safeguard policy dict or path to JSON file.
47
+ safeguard_llm_config: Optional LLM configuration for safeguard checks.
48
+ mask_llm_config: Optional LLM configuration for masking.
42
49
 
43
50
  Returns:
44
51
  ChatResult: Conversations chat history.
@@ -66,6 +73,17 @@ def initiate_group_chat(
66
73
  messages=messages,
67
74
  )
68
75
 
76
+ # Apply safeguards if provided
77
+ if safeguard_policy:
78
+ from .safeguards import apply_safeguard_policy
79
+
80
+ apply_safeguard_policy(
81
+ groupchat_manager=manager,
82
+ policy=safeguard_policy,
83
+ safeguard_llm_config=safeguard_llm_config,
84
+ mask_llm_config=mask_llm_config,
85
+ )
86
+
69
87
  # Start or resume the conversation
70
88
  if len(processed_messages) > 1:
71
89
  last_agent, last_message = manager.resume(messages=processed_messages)
@@ -94,6 +112,9 @@ async def a_initiate_group_chat(
94
112
  pattern: "Pattern",
95
113
  messages: list[dict[str, Any]] | str,
96
114
  max_rounds: int = 20,
115
+ safeguard_policy: dict[str, Any] | str | None = None,
116
+ safeguard_llm_config: LLMConfig | None = None,
117
+ mask_llm_config: LLMConfig | None = None,
97
118
  ) -> tuple[ChatResult, ContextVariables, "Agent"]:
98
119
  """Initialize and run a group chat using a pattern for configuration, asynchronously.
99
120
 
@@ -101,6 +122,9 @@ async def a_initiate_group_chat(
101
122
  pattern: Pattern object that encapsulates the chat configuration.
102
123
  messages: Initial message(s).
103
124
  max_rounds: Maximum number of conversation rounds.
125
+ safeguard_policy: Optional safeguard policy dict or path to JSON file.
126
+ safeguard_llm_config: Optional LLM configuration for safeguard checks.
127
+ mask_llm_config: Optional LLM configuration for masking.
104
128
 
105
129
  Returns:
106
130
  ChatResult: Conversations chat history.
@@ -128,6 +152,17 @@ async def a_initiate_group_chat(
128
152
  messages=messages,
129
153
  )
130
154
 
155
+ # Apply safeguards if provided
156
+ if safeguard_policy:
157
+ from .safeguards import apply_safeguard_policy
158
+
159
+ apply_safeguard_policy(
160
+ groupchat_manager=manager,
161
+ policy=safeguard_policy,
162
+ safeguard_llm_config=safeguard_llm_config,
163
+ mask_llm_config=mask_llm_config,
164
+ )
165
+
131
166
  # Start or resume the conversation
132
167
  if len(processed_messages) > 1:
133
168
  last_agent, last_message = await manager.a_resume(messages=processed_messages)
@@ -156,6 +191,9 @@ def run_group_chat(
156
191
  pattern: "Pattern",
157
192
  messages: list[dict[str, Any]] | str,
158
193
  max_rounds: int = 20,
194
+ safeguard_policy: dict[str, Any] | str | None = None,
195
+ safeguard_llm_config: LLMConfig | None = None,
196
+ mask_llm_config: LLMConfig | None = None,
159
197
  ) -> RunResponseProtocol:
160
198
  iostream = ThreadIOStream()
161
199
  # todo: add agents
@@ -165,6 +203,9 @@ def run_group_chat(
165
203
  pattern: "Pattern" = pattern,
166
204
  messages: list[dict[str, Any]] | str = messages,
167
205
  max_rounds: int = max_rounds,
206
+ safeguard_policy: dict[str, Any] | str | None = safeguard_policy,
207
+ safeguard_llm_config: LLMConfig | None = safeguard_llm_config,
208
+ mask_llm_config: LLMConfig | None = mask_llm_config,
168
209
  iostream: ThreadIOStream = iostream,
169
210
  response: RunResponse = response,
170
211
  ) -> None:
@@ -174,6 +215,9 @@ def run_group_chat(
174
215
  pattern=pattern,
175
216
  messages=messages,
176
217
  max_rounds=max_rounds,
218
+ safeguard_policy=safeguard_policy,
219
+ safeguard_llm_config=safeguard_llm_config,
220
+ mask_llm_config=mask_llm_config,
177
221
  )
178
222
 
179
223
  IOStream.get_default().send(
@@ -200,6 +244,9 @@ async def a_run_group_chat(
200
244
  pattern: "Pattern",
201
245
  messages: list[dict[str, Any]] | str,
202
246
  max_rounds: int = 20,
247
+ safeguard_policy: dict[str, Any] | str | None = None,
248
+ safeguard_llm_config: LLMConfig | None = None,
249
+ mask_llm_config: LLMConfig | None = None,
203
250
  ) -> AsyncRunResponseProtocol:
204
251
  iostream = AsyncThreadIOStream()
205
252
  # todo: add agents
@@ -209,6 +256,9 @@ async def a_run_group_chat(
209
256
  pattern: "Pattern" = pattern,
210
257
  messages: list[dict[str, Any]] | str = messages,
211
258
  max_rounds: int = max_rounds,
259
+ safeguard_policy: dict[str, Any] | str | None = safeguard_policy,
260
+ safeguard_llm_config: LLMConfig | None = safeguard_llm_config,
261
+ mask_llm_config: LLMConfig | None = mask_llm_config,
212
262
  iostream: AsyncThreadIOStream = iostream,
213
263
  response: AsyncRunResponse = response,
214
264
  ) -> None:
@@ -218,6 +268,9 @@ async def a_run_group_chat(
218
268
  pattern=pattern,
219
269
  messages=messages,
220
270
  max_rounds=max_rounds,
271
+ safeguard_policy=safeguard_policy,
272
+ safeguard_llm_config=safeguard_llm_config,
273
+ mask_llm_config=mask_llm_config,
221
274
  )
222
275
 
223
276
  IOStream.get_default().send(
@@ -158,6 +158,8 @@ def apply_safeguard_policy(
158
158
  policy=policy,
159
159
  safeguard_llm_config=safeguard_llm_config,
160
160
  mask_llm_config=mask_llm_config,
161
+ groupchat_manager=groupchat_manager,
162
+ agents=agents,
161
163
  )
162
164
 
163
165
  # Determine which agents to apply safeguards to
@@ -168,8 +170,14 @@ def apply_safeguard_policy(
168
170
  if not isinstance(groupchat_manager, GroupChatManager):
169
171
  raise ValueError("groupchat_manager must be an instance of GroupChatManager")
170
172
 
171
- target_agents.extend([agent for agent in groupchat_manager.groupchat.agents if hasattr(agent, "hook_lists")])
172
- all_agent_names = [agent.name for agent in groupchat_manager.groupchat.agents]
173
+ target_agents.extend([
174
+ agent
175
+ for agent in groupchat_manager.groupchat.agents
176
+ if hasattr(agent, "hook_lists") and agent.name != "_Group_Tool_Executor"
177
+ ])
178
+ all_agent_names = [
179
+ agent.name for agent in groupchat_manager.groupchat.agents if agent.name != "_Group_Tool_Executor"
180
+ ]
173
181
  all_agent_names.append(groupchat_manager.name)
174
182
 
175
183
  # Register inter-agent guardrails with the groupchat
@@ -221,4 +229,13 @@ def apply_safeguard_policy(
221
229
  f"Agent {agent.name} does not support hooks. Please ensure it inherits from ConversableAgent."
222
230
  )
223
231
 
232
+ # Apply hooks to GroupToolExecutor if it exists (for GroupChat scenarios)
233
+ if groupchat_manager and enforcer.group_tool_executor and hasattr(enforcer.group_tool_executor, "hook_lists"):
234
+ # Create hooks for GroupToolExecutor - it needs tool interaction hooks
235
+ # since it's the one actually executing tools in GroupChat
236
+ hooks = enforcer.create_agent_hooks(enforcer.group_tool_executor.name)
237
+ for hook_name, hook_func in hooks.items():
238
+ if hook_name in enforcer.group_tool_executor.hook_lists:
239
+ enforcer.group_tool_executor.hook_lists[hook_name].append(hook_func)
240
+
224
241
  return enforcer
@@ -9,8 +9,11 @@ import re
9
9
  from collections.abc import Callable
10
10
  from typing import Any
11
11
 
12
+ from ....code_utils import content_str
12
13
  from ....io.base import IOStream
13
14
  from ....llm_config import LLMConfig
15
+ from ...conversable_agent import ConversableAgent
16
+ from ...groupchat import GroupChatManager
14
17
  from ..guardrails import LLMGuardrail, RegexGuardrail
15
18
  from ..targets.transition_target import TransitionTarget
16
19
  from .events import SafeguardEvent
@@ -19,11 +22,22 @@ from .events import SafeguardEvent
19
22
  class SafeguardEnforcer:
20
23
  """Main safeguard enforcer - executes safeguard policies"""
21
24
 
25
+ @staticmethod
26
+ def _stringify_content(value: Any) -> str:
27
+ if isinstance(value, (str, list)) or value is None:
28
+ try:
29
+ return content_str(value)
30
+ except (TypeError, ValueError, AssertionError):
31
+ pass
32
+ return "" if value is None else str(value)
33
+
22
34
  def __init__(
23
35
  self,
24
36
  policy: dict[str, Any] | str,
25
37
  safeguard_llm_config: LLMConfig | dict[str, Any] | None = None,
26
38
  mask_llm_config: LLMConfig | dict[str, Any] | None = None,
39
+ groupchat_manager: GroupChatManager | None = None,
40
+ agents: list[ConversableAgent] | None = None,
27
41
  ):
28
42
  """Initialize the safeguard enforcer.
29
43
 
@@ -31,10 +45,20 @@ class SafeguardEnforcer:
31
45
  policy: Safeguard policy dict or path to JSON file
32
46
  safeguard_llm_config: LLM configuration for safeguard checks
33
47
  mask_llm_config: LLM configuration for masking
48
+ groupchat_manager: GroupChat manager instance for group chat scenarios
49
+ agents: List of conversable agents to apply safeguards to
34
50
  """
35
51
  self.policy = self._load_policy(policy)
36
52
  self.safeguard_llm_config = safeguard_llm_config
37
53
  self.mask_llm_config = mask_llm_config
54
+ self.groupchat_manager = groupchat_manager
55
+ self.agents = agents
56
+ self.group_tool_executor = None
57
+ if self.groupchat_manager:
58
+ for agent in self.groupchat_manager.groupchat.agents:
59
+ if agent.name == "_Group_Tool_Executor":
60
+ self.group_tool_executor = agent # type: ignore[assignment]
61
+ break
38
62
 
39
63
  # Validate policy format before proceeding
40
64
  self._validate_policy()
@@ -351,18 +375,21 @@ class SafeguardEnforcer:
351
375
  hooks = {}
352
376
 
353
377
  # Check if we have any tool interaction rules that apply to this agent
354
- agent_tool_rules = [
355
- rule
356
- for rule in self.environment_rules
357
- if rule["type"] == "tool_interaction"
358
- and (
359
- rule.get("message_destination") == agent_name
360
- or rule.get("message_source") == agent_name
361
- or rule.get("agent_name") == agent_name
362
- or "message_destination" not in rule
363
- )
364
- ] # Simple pattern rules apply to all
365
-
378
+ if agent_name == "_Group_Tool_Executor":
379
+ # group tool executor is running all tools, so we need to check all tool interaction rules
380
+ agent_tool_rules = [rule for rule in self.environment_rules if rule["type"] == "tool_interaction"]
381
+ else:
382
+ agent_tool_rules = [
383
+ rule
384
+ for rule in self.environment_rules
385
+ if rule["type"] == "tool_interaction"
386
+ and (
387
+ rule.get("message_destination") == agent_name
388
+ or rule.get("message_source") == agent_name
389
+ or rule.get("agent_name") == agent_name
390
+ or "message_destination" not in rule
391
+ )
392
+ ]
366
393
  if agent_tool_rules:
367
394
 
368
395
  def tool_input_hook(tool_input: dict[str, Any]) -> dict[str, Any] | None:
@@ -624,7 +651,7 @@ class SafeguardEnforcer:
624
651
  # Handle tool_calls (like in tool inputs)
625
652
  elif "tool_calls" in blocked_content and blocked_content["tool_calls"]:
626
653
  blocked_content["tool_calls"] = [
627
- {**tool_call, "function": {**tool_call["function"], "arguments": block_msg}}
654
+ {**tool_call, "function": {**tool_call["function"], "arguments": json.dumps({"error": block_msg})}}
628
655
  for tool_call in blocked_content["tool_calls"]
629
656
  ]
630
657
  # Handle regular content
@@ -649,7 +676,10 @@ class SafeguardEnforcer:
649
676
  blocked_item["content"] = block_msg
650
677
  if "tool_calls" in blocked_item:
651
678
  blocked_item["tool_calls"] = [
652
- {**tool_call, "function": {**tool_call["function"], "arguments": block_msg}}
679
+ {
680
+ **tool_call,
681
+ "function": {**tool_call["function"], "arguments": json.dumps({"error": block_msg})},
682
+ }
653
683
  for tool_call in blocked_item["tool_calls"]
654
684
  ]
655
685
  if "tool_responses" in blocked_item:
@@ -675,9 +705,12 @@ class SafeguardEnforcer:
675
705
  # Handle tool_responses
676
706
  if "tool_responses" in masked_content and masked_content["tool_responses"]:
677
707
  if "content" in masked_content:
678
- masked_content["content"] = mask_func(str(masked_content["content"]))
708
+ masked_content["content"] = mask_func(self._stringify_content(masked_content.get("content")))
679
709
  masked_content["tool_responses"] = [
680
- {**response, "content": mask_func(str(response.get("content", "")))}
710
+ {
711
+ **response,
712
+ "content": mask_func(self._stringify_content(response.get("content"))),
713
+ }
681
714
  for response in masked_content["tool_responses"]
682
715
  ]
683
716
  # Handle tool_calls
@@ -687,17 +720,17 @@ class SafeguardEnforcer:
687
720
  **tool_call,
688
721
  "function": {
689
722
  **tool_call["function"],
690
- "arguments": mask_func(str(tool_call["function"].get("arguments", ""))),
723
+ "arguments": mask_func(self._stringify_content(tool_call["function"].get("arguments"))),
691
724
  },
692
725
  }
693
726
  for tool_call in masked_content["tool_calls"]
694
727
  ]
695
728
  # Handle regular content
696
729
  elif "content" in masked_content:
697
- masked_content["content"] = mask_func(str(masked_content["content"]))
730
+ masked_content["content"] = mask_func(self._stringify_content(masked_content.get("content")))
698
731
  # Handle arguments
699
732
  elif "arguments" in masked_content:
700
- masked_content["arguments"] = mask_func(str(masked_content["arguments"]))
733
+ masked_content["arguments"] = mask_func(self._stringify_content(masked_content.get("arguments")))
701
734
 
702
735
  return masked_content
703
736
 
@@ -708,39 +741,64 @@ class SafeguardEnforcer:
708
741
  if isinstance(item, dict):
709
742
  masked_item = item.copy()
710
743
  if "content" in masked_item:
711
- masked_item["content"] = mask_func(str(masked_item["content"]))
744
+ masked_item["content"] = mask_func(self._stringify_content(masked_item.get("content")))
712
745
  if "tool_calls" in masked_item:
713
746
  masked_item["tool_calls"] = [
714
747
  {
715
748
  **tool_call,
716
749
  "function": {
717
750
  **tool_call["function"],
718
- "arguments": mask_func(str(tool_call["function"].get("arguments", ""))),
751
+ "arguments": mask_func(
752
+ self._stringify_content(tool_call["function"].get("arguments"))
753
+ ),
719
754
  },
720
755
  }
721
756
  for tool_call in masked_item["tool_calls"]
722
757
  ]
723
758
  if "tool_responses" in masked_item:
724
759
  masked_item["tool_responses"] = [
725
- {**response, "content": mask_func(str(response.get("content", "")))}
760
+ {
761
+ **response,
762
+ "content": mask_func(self._stringify_content(response.get("content"))),
763
+ }
726
764
  for response in masked_item["tool_responses"]
727
765
  ]
728
766
  masked_list.append(masked_item)
729
767
  else:
730
768
  # For non-dict items, wrap the masked content in a dict
731
- masked_item_content: str = mask_func(str(item))
769
+ masked_item_content: str = mask_func(self._stringify_content(item))
732
770
  masked_list.append({"content": masked_item_content, "role": "function"})
733
771
  return masked_list
734
772
 
735
773
  else:
736
774
  # String content
737
- return mask_func(str(content))
775
+ return mask_func(self._stringify_content(content))
738
776
 
739
777
  def _check_inter_agent_communication(
740
778
  self, sender_name: str, recipient_name: str, message: str | dict[str, Any]
741
779
  ) -> str | dict[str, Any]:
742
780
  """Check inter-agent communication."""
743
- content = message.get("content", "") if isinstance(message, dict) else str(message)
781
+ if isinstance(message, dict):
782
+ if "tool_calls" in message and isinstance(message["tool_calls"], list):
783
+ # Extract arguments from all tool calls and combine them
784
+ tool_args = []
785
+ for tool_call in message["tool_calls"]:
786
+ if "function" in tool_call and "arguments" in tool_call["function"]:
787
+ tool_args.append(tool_call["function"]["arguments"])
788
+ content_to_check = " | ".join(tool_args) if tool_args else ""
789
+ elif "tool_responses" in message and isinstance(message["tool_responses"], list):
790
+ # Extract content from all tool responses and combine them
791
+ tool_contents = []
792
+ for tool_response in message["tool_responses"]:
793
+ if "content" in tool_response:
794
+ tool_contents.append(str(tool_response["content"]))
795
+ content_to_check = " | ".join(tool_contents) if tool_contents else ""
796
+ else:
797
+ content_to_check = str(message.get("content", ""))
798
+ elif isinstance(message, str):
799
+ content_to_check = message
800
+ else:
801
+ raise ValueError("Message must be a dictionary or a string")
744
802
 
745
803
  for rule in self.inter_agent_rules:
746
804
  if rule["type"] == "agent_transition":
@@ -750,7 +808,7 @@ class SafeguardEnforcer:
750
808
 
751
809
  if source_match and target_match:
752
810
  # Prepare content preview
753
- content_preview = content[:100] + ("..." if len(content) > 100 else "")
811
+ content_preview = str(content_to_check)[:100] + ("..." if len(str(content_to_check)) > 100 else "")
754
812
 
755
813
  # Use guardrail if available
756
814
  if "guardrail" in rule and rule["guardrail"]:
@@ -766,7 +824,7 @@ class SafeguardEnforcer:
766
824
  )
767
825
 
768
826
  try:
769
- result = rule["guardrail"].check(content)
827
+ result = rule["guardrail"].check(content_to_check)
770
828
  if result.activated:
771
829
  self._send_safeguard_event(
772
830
  event_type="violation",
@@ -812,7 +870,7 @@ class SafeguardEnforcer:
812
870
  # action=rule.get('action', 'N/A'),
813
871
  content_preview=content_preview,
814
872
  )
815
- is_violation, explanation = self._check_regex_violation(content, rule["pattern"])
873
+ is_violation, explanation = self._check_regex_violation(content_to_check, rule["pattern"])
816
874
  if is_violation:
817
875
  result_value = self._apply_action(
818
876
  action=rule["action"],
@@ -847,11 +905,11 @@ class SafeguardEnforcer:
847
905
  )
848
906
  if "custom_prompt" in rule:
849
907
  is_violation, explanation = self._check_llm_violation(
850
- content, custom_prompt=rule["custom_prompt"]
908
+ content_to_check, custom_prompt=rule["custom_prompt"]
851
909
  )
852
910
  else:
853
911
  is_violation, explanation = self._check_llm_violation(
854
- content, disallow_items=rule["disallow"]
912
+ content_to_check, disallow_items=rule["disallow"]
855
913
  )
856
914
 
857
915
  if is_violation:
@@ -971,12 +1029,20 @@ class SafeguardEnforcer:
971
1029
  # Extract tool name from data
972
1030
  tool_name = data.get("name", data.get("tool_name", ""))
973
1031
 
1032
+ # Resolve the actual agent name if this is GroupToolExecutor
1033
+ actual_agent_name = agent_name
1034
+ if agent_name == "_Group_Tool_Executor" and self.group_tool_executor:
1035
+ # Get the original tool caller from GroupToolExecutor
1036
+ originator = self.group_tool_executor.get_tool_call_originator() # type: ignore[attr-defined]
1037
+ if originator:
1038
+ actual_agent_name = originator
1039
+
974
1040
  # Determine source/destination based on direction
975
1041
  if direction == "output":
976
- source_name, dest_name = tool_name, agent_name
1042
+ source_name, dest_name = tool_name, actual_agent_name
977
1043
  content = str(data.get("content", ""))
978
1044
  else: # input
979
- source_name, dest_name = agent_name, tool_name
1045
+ source_name, dest_name = actual_agent_name, tool_name
980
1046
  content = str(data.get("arguments", ""))
981
1047
 
982
1048
  result = self._check_interaction(
@@ -985,7 +1051,7 @@ class SafeguardEnforcer:
985
1051
  dest_name=dest_name,
986
1052
  content=content,
987
1053
  data=data,
988
- context_info=f"{agent_name} <-> {tool_name} ({direction})",
1054
+ context_info=f"{actual_agent_name} <-> {tool_name} ({direction})",
989
1055
  )
990
1056
 
991
1057
  if result is not None:
@@ -1050,15 +1116,43 @@ class SafeguardEnforcer:
1050
1116
  Returns:
1051
1117
  Optional replacement message if a safeguard triggers, None otherwise
1052
1118
  """
1053
- # Store original content for comparison
1054
- original_content = (
1055
- message_content.get("content", "") if isinstance(message_content, dict) else str(message_content)
1056
- )
1119
+ # Handle GroupToolExecutor transparency for safeguards
1120
+ if src_agent_name == "_Group_Tool_Executor":
1121
+ actual_src_agent_name = self._resolve_tool_executor_source(src_agent_name, self.group_tool_executor)
1122
+ else:
1123
+ actual_src_agent_name = src_agent_name
1124
+
1125
+ # Store original message for comparison
1126
+ original_message = message_content
1057
1127
 
1058
- result = self._check_inter_agent_communication(src_agent_name, dst_agent_name, message_content)
1128
+ result = self._check_inter_agent_communication(actual_src_agent_name, dst_agent_name, message_content)
1059
1129
 
1060
- if result != original_content:
1061
- # Return the complete modified message structure to preserve tool_calls/tool_responses pairing
1130
+ # Check if the result is different from the original
1131
+ if result != original_message:
1062
1132
  return result
1063
1133
 
1064
1134
  return None
1135
+
1136
+ def _resolve_tool_executor_source(self, src_agent_name: str, tool_executor: Any = None) -> str:
1137
+ """Resolve the actual source agent when GroupToolExecutor is involved.
1138
+
1139
+ When src_agent_name is "_Group_Tool_Executor", get the original agent who called the tool.
1140
+
1141
+ Args:
1142
+ src_agent_name: The source agent name from the communication
1143
+ tool_executor: GroupToolExecutor instance for getting originator
1144
+
1145
+ Returns:
1146
+ The actual source agent name (original tool caller for tool responses)
1147
+ """
1148
+ if src_agent_name != "_Group_Tool_Executor":
1149
+ return src_agent_name
1150
+
1151
+ # Handle GroupToolExecutor - get the original tool caller
1152
+ if tool_executor and hasattr(tool_executor, "get_tool_call_originator"):
1153
+ originator = tool_executor.get_tool_call_originator()
1154
+ if originator:
1155
+ return originator # type: ignore[no-any-return]
1156
+
1157
+ # Fallback: Could not determine original caller
1158
+ return "tool_executor"