ag2 0.9.8.post1__py3-none-any.whl → 0.9.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ag2 might be problematic. Click here for more details.
- {ag2-0.9.8.post1.dist-info → ag2-0.9.10.dist-info}/METADATA +232 -210
- {ag2-0.9.8.post1.dist-info → ag2-0.9.10.dist-info}/RECORD +88 -80
- autogen/_website/generate_mkdocs.py +3 -3
- autogen/_website/notebook_processor.py +1 -1
- autogen/_website/utils.py +1 -1
- autogen/agentchat/assistant_agent.py +15 -15
- autogen/agentchat/chat.py +52 -40
- autogen/agentchat/contrib/agent_eval/criterion.py +1 -1
- autogen/agentchat/contrib/capabilities/text_compressors.py +5 -5
- autogen/agentchat/contrib/capabilities/tools_capability.py +1 -1
- autogen/agentchat/contrib/capabilities/transforms.py +1 -1
- autogen/agentchat/contrib/captainagent/agent_builder.py +1 -1
- autogen/agentchat/contrib/captainagent/captainagent.py +20 -19
- autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +2 -5
- autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +5 -5
- autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py +18 -17
- autogen/agentchat/contrib/rag/mongodb_query_engine.py +2 -2
- autogen/agentchat/contrib/rag/query_engine.py +11 -11
- autogen/agentchat/contrib/retrieve_assistant_agent.py +3 -0
- autogen/agentchat/contrib/swarm_agent.py +3 -2
- autogen/agentchat/contrib/vectordb/couchbase.py +1 -1
- autogen/agentchat/contrib/vectordb/mongodb.py +1 -1
- autogen/agentchat/contrib/web_surfer.py +1 -1
- autogen/agentchat/conversable_agent.py +184 -80
- autogen/agentchat/group/context_expression.py +21 -21
- autogen/agentchat/group/handoffs.py +11 -11
- autogen/agentchat/group/multi_agent_chat.py +3 -2
- autogen/agentchat/group/on_condition.py +11 -11
- autogen/agentchat/group/safeguards/__init__.py +21 -0
- autogen/agentchat/group/safeguards/api.py +224 -0
- autogen/agentchat/group/safeguards/enforcer.py +1064 -0
- autogen/agentchat/group/safeguards/events.py +119 -0
- autogen/agentchat/group/safeguards/validator.py +435 -0
- autogen/agentchat/groupchat.py +60 -19
- autogen/agentchat/realtime/experimental/clients/realtime_client.py +2 -2
- autogen/agentchat/realtime/experimental/function_observer.py +2 -3
- autogen/agentchat/realtime/experimental/realtime_agent.py +2 -3
- autogen/agentchat/realtime/experimental/realtime_swarm.py +21 -10
- autogen/agentchat/user_proxy_agent.py +55 -53
- autogen/agents/experimental/document_agent/document_agent.py +1 -10
- autogen/agents/experimental/document_agent/parser_utils.py +5 -1
- autogen/browser_utils.py +4 -4
- autogen/cache/abstract_cache_base.py +2 -6
- autogen/cache/disk_cache.py +1 -6
- autogen/cache/in_memory_cache.py +2 -6
- autogen/cache/redis_cache.py +1 -5
- autogen/coding/__init__.py +10 -2
- autogen/coding/base.py +2 -1
- autogen/coding/docker_commandline_code_executor.py +1 -6
- autogen/coding/factory.py +9 -0
- autogen/coding/jupyter/docker_jupyter_server.py +1 -7
- autogen/coding/jupyter/jupyter_client.py +2 -9
- autogen/coding/jupyter/jupyter_code_executor.py +2 -7
- autogen/coding/jupyter/local_jupyter_server.py +2 -6
- autogen/coding/local_commandline_code_executor.py +0 -65
- autogen/coding/yepcode_code_executor.py +197 -0
- autogen/environments/docker_python_environment.py +3 -3
- autogen/environments/system_python_environment.py +5 -5
- autogen/environments/venv_python_environment.py +5 -5
- autogen/events/agent_events.py +1 -1
- autogen/events/client_events.py +1 -1
- autogen/fast_depends/utils.py +10 -0
- autogen/graph_utils.py +5 -7
- autogen/import_utils.py +28 -15
- autogen/interop/pydantic_ai/pydantic_ai.py +8 -5
- autogen/io/processors/console_event_processor.py +8 -3
- autogen/llm_config/config.py +168 -91
- autogen/llm_config/entry.py +38 -26
- autogen/llm_config/types.py +35 -0
- autogen/llm_config/utils.py +223 -0
- autogen/mcp/mcp_proxy/operation_grouping.py +48 -39
- autogen/messages/agent_messages.py +1 -1
- autogen/messages/client_messages.py +1 -1
- autogen/oai/__init__.py +8 -1
- autogen/oai/client.py +10 -3
- autogen/oai/client_utils.py +1 -1
- autogen/oai/cohere.py +4 -4
- autogen/oai/gemini.py +4 -6
- autogen/oai/gemini_types.py +1 -0
- autogen/oai/openai_utils.py +44 -115
- autogen/tools/dependency_injection.py +4 -8
- autogen/tools/experimental/reliable/reliable.py +3 -2
- autogen/tools/experimental/web_search_preview/web_search_preview.py +1 -1
- autogen/tools/function_utils.py +2 -1
- autogen/version.py +1 -1
- {ag2-0.9.8.post1.dist-info → ag2-0.9.10.dist-info}/WHEEL +0 -0
- {ag2-0.9.8.post1.dist-info → ag2-0.9.10.dist-info}/licenses/LICENSE +0 -0
- {ag2-0.9.8.post1.dist-info → ag2-0.9.10.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,1064 @@
|
|
|
1
|
+
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import re
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from ....io.base import IOStream
|
|
13
|
+
from ....llm_config import LLMConfig
|
|
14
|
+
from ..guardrails import LLMGuardrail, RegexGuardrail
|
|
15
|
+
from ..targets.transition_target import TransitionTarget
|
|
16
|
+
from .events import SafeguardEvent
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SafeguardEnforcer:
|
|
20
|
+
"""Main safeguard enforcer - executes safeguard policies"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
policy: dict[str, Any] | str,
|
|
25
|
+
safeguard_llm_config: LLMConfig | dict[str, Any] | None = None,
|
|
26
|
+
mask_llm_config: LLMConfig | dict[str, Any] | None = None,
|
|
27
|
+
):
|
|
28
|
+
"""Initialize the safeguard enforcer.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
policy: Safeguard policy dict or path to JSON file
|
|
32
|
+
safeguard_llm_config: LLM configuration for safeguard checks
|
|
33
|
+
mask_llm_config: LLM configuration for masking
|
|
34
|
+
"""
|
|
35
|
+
self.policy = self._load_policy(policy)
|
|
36
|
+
self.safeguard_llm_config = safeguard_llm_config
|
|
37
|
+
self.mask_llm_config = mask_llm_config
|
|
38
|
+
|
|
39
|
+
# Validate policy format before proceeding
|
|
40
|
+
self._validate_policy()
|
|
41
|
+
|
|
42
|
+
# Create mask agent for content masking
|
|
43
|
+
if self.mask_llm_config:
|
|
44
|
+
from ...conversable_agent import ConversableAgent
|
|
45
|
+
|
|
46
|
+
self.mask_agent = ConversableAgent(
|
|
47
|
+
name="mask_agent",
|
|
48
|
+
system_message="You are a agent responsible for masking sensitive information.",
|
|
49
|
+
llm_config=self.mask_llm_config,
|
|
50
|
+
human_input_mode="NEVER",
|
|
51
|
+
max_consecutive_auto_reply=1,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Parse safeguard rules
|
|
55
|
+
self.inter_agent_rules = self._parse_inter_agent_rules()
|
|
56
|
+
self.environment_rules = self._parse_environment_rules()
|
|
57
|
+
|
|
58
|
+
# Send load event
|
|
59
|
+
self._send_safeguard_event(
|
|
60
|
+
event_type="load",
|
|
61
|
+
message=f"Loaded {len(self.inter_agent_rules)} inter-agent and {len(self.environment_rules)} environment safeguard rules",
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def _send_safeguard_event(
|
|
65
|
+
self,
|
|
66
|
+
event_type: str,
|
|
67
|
+
message: str,
|
|
68
|
+
source_agent: str | None = None,
|
|
69
|
+
target_agent: str | None = None,
|
|
70
|
+
guardrail_type: str | None = None,
|
|
71
|
+
action: str | None = None,
|
|
72
|
+
content_preview: str | None = None,
|
|
73
|
+
) -> None:
|
|
74
|
+
"""Send a safeguard event to the IOStream."""
|
|
75
|
+
iostream = IOStream.get_default()
|
|
76
|
+
event = SafeguardEvent(
|
|
77
|
+
event_type=event_type,
|
|
78
|
+
message=message,
|
|
79
|
+
source_agent=source_agent,
|
|
80
|
+
target_agent=target_agent,
|
|
81
|
+
guardrail_type=guardrail_type,
|
|
82
|
+
action=action,
|
|
83
|
+
content_preview=content_preview,
|
|
84
|
+
)
|
|
85
|
+
iostream.send(event)
|
|
86
|
+
|
|
87
|
+
def _load_policy(self, policy: dict[str, Any] | str) -> dict[str, Any]:
|
|
88
|
+
"""Load policy from file or use provided dict."""
|
|
89
|
+
if isinstance(policy, str):
|
|
90
|
+
with open(policy) as f:
|
|
91
|
+
result: dict[str, Any] = json.load(f)
|
|
92
|
+
return result
|
|
93
|
+
return policy
|
|
94
|
+
|
|
95
|
+
def _validate_policy(self) -> None:
|
|
96
|
+
"""Validate policy format and syntax."""
|
|
97
|
+
from .validator import SafeguardValidator
|
|
98
|
+
|
|
99
|
+
validator = SafeguardValidator(self.policy)
|
|
100
|
+
validator.validate_policy_structure()
|
|
101
|
+
|
|
102
|
+
def _parse_inter_agent_rules(self) -> list[dict[str, Any]]:
|
|
103
|
+
"""Parse inter-agent safeguard rules from policy."""
|
|
104
|
+
rules = []
|
|
105
|
+
inter_agent = self.policy.get("inter_agent_safeguards", {})
|
|
106
|
+
|
|
107
|
+
# Agent transitions
|
|
108
|
+
for rule in inter_agent.get("agent_transitions", []):
|
|
109
|
+
# Create guardrail based on check_method
|
|
110
|
+
check_method = rule.get("check_method", "regex")
|
|
111
|
+
guardrail: LLMGuardrail | RegexGuardrail | None = None
|
|
112
|
+
action = rule.get("violation_response", rule.get("action", "block")) # Support both field names
|
|
113
|
+
|
|
114
|
+
if check_method == "llm":
|
|
115
|
+
if not self.safeguard_llm_config:
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"safeguard_llm_config is required for LLM-based guardrail: {rule['message_source']} -> {rule['message_destination']}"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Handle different LLM check types
|
|
121
|
+
if "custom_prompt" in rule:
|
|
122
|
+
# Custom prompt for LLM guardrail
|
|
123
|
+
condition = rule["custom_prompt"]
|
|
124
|
+
|
|
125
|
+
elif "disallow_item" in rule:
|
|
126
|
+
# Disallow items for LLM guardrail
|
|
127
|
+
condition = f"Check if this content contains any of these disallowed categories: {', '.join(rule['disallow_item'])}"
|
|
128
|
+
|
|
129
|
+
else:
|
|
130
|
+
raise ValueError(
|
|
131
|
+
f"Either custom_prompt or disallow_item must be provided for LLM guardrail: {rule['message_source']} -> {rule['message_destination']}"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Create LLM guardrail - handle dict config by converting to LLMConfig
|
|
135
|
+
llm_config = self.safeguard_llm_config
|
|
136
|
+
if isinstance(llm_config, dict):
|
|
137
|
+
llm_config = LLMConfig(config_list=[llm_config])
|
|
138
|
+
|
|
139
|
+
guardrail = LLMGuardrail(
|
|
140
|
+
name=f"llm_guard_{rule['message_source']}_{rule['message_destination']}",
|
|
141
|
+
condition=condition,
|
|
142
|
+
target=TransitionTarget(),
|
|
143
|
+
llm_config=llm_config,
|
|
144
|
+
activation_message=rule.get("activation_message", "LLM detected violation"),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
elif check_method == "regex":
|
|
148
|
+
if "pattern" in rule:
|
|
149
|
+
# Regex pattern guardrail
|
|
150
|
+
guardrail = RegexGuardrail(
|
|
151
|
+
name=f"regex_guard_{rule['message_source']}_{rule['message_destination']}",
|
|
152
|
+
condition=rule["pattern"],
|
|
153
|
+
target=TransitionTarget(),
|
|
154
|
+
activation_message=rule.get("activation_message", "Regex pattern matched"),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Add rule with guardrail
|
|
158
|
+
parsed_rule = {
|
|
159
|
+
"type": "agent_transition",
|
|
160
|
+
"source": rule["message_source"],
|
|
161
|
+
"target": rule["message_destination"],
|
|
162
|
+
"action": action,
|
|
163
|
+
"guardrail": guardrail,
|
|
164
|
+
"activation_message": rule.get("activation_message", "Content blocked by safeguard"),
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
# Keep legacy fields for backward compatibility
|
|
168
|
+
if "disallow_item" in rule:
|
|
169
|
+
parsed_rule["disallow"] = rule["disallow_item"]
|
|
170
|
+
if "pattern" in rule:
|
|
171
|
+
parsed_rule["pattern"] = rule["pattern"]
|
|
172
|
+
if "custom_prompt" in rule:
|
|
173
|
+
parsed_rule["custom_prompt"] = rule["custom_prompt"]
|
|
174
|
+
|
|
175
|
+
rules.append(parsed_rule)
|
|
176
|
+
|
|
177
|
+
# Groupchat message check
|
|
178
|
+
if "groupchat_message_check" in inter_agent:
|
|
179
|
+
rule = inter_agent["groupchat_message_check"]
|
|
180
|
+
rules.append({
|
|
181
|
+
"type": "groupchat_message",
|
|
182
|
+
"source": "*",
|
|
183
|
+
"target": "*",
|
|
184
|
+
"action": rule.get("pet_action", "block"),
|
|
185
|
+
"disallow": rule.get("disallow_item", []),
|
|
186
|
+
})
|
|
187
|
+
|
|
188
|
+
return rules
|
|
189
|
+
|
|
190
|
+
def _parse_environment_rules(self) -> list[dict[str, Any]]:
|
|
191
|
+
"""Parse agent-environment safeguard rules from policy."""
|
|
192
|
+
rules = []
|
|
193
|
+
env_rules = self.policy.get("agent_environment_safeguards", {})
|
|
194
|
+
|
|
195
|
+
# Tool interaction rules
|
|
196
|
+
for rule in env_rules.get("tool_interaction", []):
|
|
197
|
+
check_method = rule.get("check_method", "regex") # default to regex for backward compatibility
|
|
198
|
+
action = rule.get("violation_response", rule.get("action", "block"))
|
|
199
|
+
|
|
200
|
+
if check_method == "llm":
|
|
201
|
+
# LLM-based tool interaction rule - requires message_source/message_destination
|
|
202
|
+
if "message_source" not in rule or "message_destination" not in rule:
|
|
203
|
+
raise ValueError(
|
|
204
|
+
"tool_interaction with check_method 'llm' must have 'message_source' and 'message_destination'"
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
parsed_rule = {
|
|
208
|
+
"type": "tool_interaction",
|
|
209
|
+
"message_source": rule["message_source"],
|
|
210
|
+
"message_destination": rule["message_destination"],
|
|
211
|
+
"check_method": "llm",
|
|
212
|
+
"action": action,
|
|
213
|
+
"activation_message": rule.get("activation_message", "LLM blocked tool output"),
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
# Add LLM-specific parameters
|
|
217
|
+
if "custom_prompt" in rule:
|
|
218
|
+
parsed_rule["custom_prompt"] = rule["custom_prompt"]
|
|
219
|
+
elif "disallow_item" in rule:
|
|
220
|
+
parsed_rule["disallow"] = rule["disallow_item"]
|
|
221
|
+
|
|
222
|
+
rules.append(parsed_rule)
|
|
223
|
+
|
|
224
|
+
elif check_method == "regex":
|
|
225
|
+
# Regex pattern-based rule - now requires message_source/message_destination
|
|
226
|
+
if "message_source" not in rule or "message_destination" not in rule:
|
|
227
|
+
raise ValueError(
|
|
228
|
+
"tool_interaction with check_method 'regex' must have 'message_source' and 'message_destination'"
|
|
229
|
+
)
|
|
230
|
+
if "pattern" not in rule:
|
|
231
|
+
raise ValueError("tool_interaction with check_method 'regex' must have 'pattern'")
|
|
232
|
+
|
|
233
|
+
rules.append({
|
|
234
|
+
"type": "tool_interaction",
|
|
235
|
+
"message_source": rule["message_source"],
|
|
236
|
+
"message_destination": rule["message_destination"],
|
|
237
|
+
"check_method": "regex",
|
|
238
|
+
"pattern": rule["pattern"],
|
|
239
|
+
"action": action,
|
|
240
|
+
"activation_message": rule.get("activation_message", "Content blocked by safeguard"),
|
|
241
|
+
})
|
|
242
|
+
else:
|
|
243
|
+
raise ValueError(
|
|
244
|
+
"tool_interaction rule must have check_method 'llm' or 'regex' with appropriate parameters"
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# LLM interaction rules
|
|
248
|
+
for rule in env_rules.get("llm_interaction", []):
|
|
249
|
+
check_method = rule.get("check_method", "regex") # default to regex for backward compatibility
|
|
250
|
+
action = rule.get("action", "block")
|
|
251
|
+
|
|
252
|
+
# All llm_interaction rules now require message_source/message_destination
|
|
253
|
+
if "message_source" not in rule or "message_destination" not in rule:
|
|
254
|
+
raise ValueError("llm_interaction rule must have 'message_source' and 'message_destination'")
|
|
255
|
+
|
|
256
|
+
if check_method == "llm":
|
|
257
|
+
# LLM-based LLM interaction rule
|
|
258
|
+
parsed_rule = {
|
|
259
|
+
"type": "llm_interaction",
|
|
260
|
+
"message_source": rule["message_source"],
|
|
261
|
+
"message_destination": rule["message_destination"],
|
|
262
|
+
"check_method": "llm",
|
|
263
|
+
"action": action,
|
|
264
|
+
"activation_message": rule.get("activation_message", "LLM blocked content"),
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
# Add LLM-specific parameters
|
|
268
|
+
if "custom_prompt" in rule:
|
|
269
|
+
parsed_rule["custom_prompt"] = rule["custom_prompt"]
|
|
270
|
+
elif "disallow_item" in rule:
|
|
271
|
+
parsed_rule["disallow_item"] = rule["disallow_item"]
|
|
272
|
+
else:
|
|
273
|
+
raise ValueError(
|
|
274
|
+
"llm_interaction with check_method 'llm' must have either 'custom_prompt' or 'disallow_item'"
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
rules.append(parsed_rule)
|
|
278
|
+
|
|
279
|
+
elif check_method == "regex":
|
|
280
|
+
# Regex-based LLM interaction rule
|
|
281
|
+
if "pattern" not in rule:
|
|
282
|
+
raise ValueError("llm_interaction with check_method 'regex' must have 'pattern'")
|
|
283
|
+
|
|
284
|
+
rules.append({
|
|
285
|
+
"type": "llm_interaction",
|
|
286
|
+
"message_source": rule["message_source"],
|
|
287
|
+
"message_destination": rule["message_destination"],
|
|
288
|
+
"check_method": "regex",
|
|
289
|
+
"pattern": rule["pattern"],
|
|
290
|
+
"action": action,
|
|
291
|
+
"activation_message": rule.get("activation_message", "Content blocked by safeguard"),
|
|
292
|
+
})
|
|
293
|
+
else:
|
|
294
|
+
raise ValueError(
|
|
295
|
+
"llm_interaction rule must have check_method 'llm' or 'regex' with appropriate parameters"
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# User interaction rules
|
|
299
|
+
for rule in env_rules.get("user_interaction", []):
|
|
300
|
+
check_method = rule.get("check_method", "llm") # default to llm for backward compatibility
|
|
301
|
+
action = rule.get("action", "block")
|
|
302
|
+
|
|
303
|
+
# All user_interaction rules now require message_source/message_destination
|
|
304
|
+
if "message_source" not in rule or "message_destination" not in rule:
|
|
305
|
+
raise ValueError("user_interaction rule must have 'message_source' and 'message_destination'")
|
|
306
|
+
|
|
307
|
+
if check_method == "llm":
|
|
308
|
+
# LLM-based user interaction rule
|
|
309
|
+
parsed_rule = {
|
|
310
|
+
"type": "user_interaction",
|
|
311
|
+
"message_source": rule["message_source"],
|
|
312
|
+
"message_destination": rule["message_destination"],
|
|
313
|
+
"check_method": "llm",
|
|
314
|
+
"action": action,
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
# Add LLM-specific parameters
|
|
318
|
+
if "custom_prompt" in rule:
|
|
319
|
+
parsed_rule["custom_prompt"] = rule["custom_prompt"]
|
|
320
|
+
elif "disallow_item" in rule:
|
|
321
|
+
parsed_rule["disallow_item"] = rule["disallow_item"]
|
|
322
|
+
else:
|
|
323
|
+
raise ValueError(
|
|
324
|
+
"user_interaction with check_method 'llm' must have either 'custom_prompt' or 'disallow_item'"
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
rules.append(parsed_rule)
|
|
328
|
+
|
|
329
|
+
elif check_method == "regex":
|
|
330
|
+
# Regex-based user interaction rule
|
|
331
|
+
if "pattern" not in rule:
|
|
332
|
+
raise ValueError("user_interaction with check_method 'regex' must have 'pattern'")
|
|
333
|
+
|
|
334
|
+
rules.append({
|
|
335
|
+
"type": "user_interaction",
|
|
336
|
+
"message_source": rule["message_source"],
|
|
337
|
+
"message_destination": rule["message_destination"],
|
|
338
|
+
"check_method": "regex",
|
|
339
|
+
"pattern": rule["pattern"],
|
|
340
|
+
"action": action,
|
|
341
|
+
})
|
|
342
|
+
else:
|
|
343
|
+
raise ValueError(
|
|
344
|
+
"user_interaction rule must have check_method 'llm' or 'regex' with appropriate parameters"
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
return rules
|
|
348
|
+
|
|
349
|
+
def create_agent_hooks(self, agent_name: str) -> dict[str, Callable[..., Any]]:
|
|
350
|
+
"""Create hook functions for a specific agent, only for rule types that exist."""
|
|
351
|
+
hooks = {}
|
|
352
|
+
|
|
353
|
+
# Check if we have any tool interaction rules that apply to this agent
|
|
354
|
+
agent_tool_rules = [
|
|
355
|
+
rule
|
|
356
|
+
for rule in self.environment_rules
|
|
357
|
+
if rule["type"] == "tool_interaction"
|
|
358
|
+
and (
|
|
359
|
+
rule.get("message_destination") == agent_name
|
|
360
|
+
or rule.get("message_source") == agent_name
|
|
361
|
+
or rule.get("agent_name") == agent_name
|
|
362
|
+
or "message_destination" not in rule
|
|
363
|
+
)
|
|
364
|
+
] # Simple pattern rules apply to all
|
|
365
|
+
|
|
366
|
+
if agent_tool_rules:
|
|
367
|
+
|
|
368
|
+
def tool_input_hook(tool_input: dict[str, Any]) -> dict[str, Any] | None:
|
|
369
|
+
result = self._check_tool_interaction(agent_name, tool_input, "input")
|
|
370
|
+
return result if result is not None else tool_input
|
|
371
|
+
|
|
372
|
+
def tool_output_hook(tool_input: dict[str, Any]) -> dict[str, Any] | None:
|
|
373
|
+
result = self._check_tool_interaction(agent_name, tool_input, "output")
|
|
374
|
+
return result if result is not None else tool_input
|
|
375
|
+
|
|
376
|
+
hooks["safeguard_tool_inputs"] = tool_input_hook
|
|
377
|
+
hooks["safeguard_tool_outputs"] = tool_output_hook
|
|
378
|
+
|
|
379
|
+
# Check if we have any LLM interaction rules that apply to this agent
|
|
380
|
+
agent_llm_rules = [
|
|
381
|
+
rule
|
|
382
|
+
for rule in self.environment_rules
|
|
383
|
+
if rule["type"] == "llm_interaction"
|
|
384
|
+
and (
|
|
385
|
+
rule.get("message_destination") == agent_name
|
|
386
|
+
or rule.get("message_source") == agent_name
|
|
387
|
+
or rule.get("agent_name") == agent_name
|
|
388
|
+
or "message_destination" not in rule
|
|
389
|
+
)
|
|
390
|
+
] # Simple pattern rules apply to all
|
|
391
|
+
|
|
392
|
+
if agent_llm_rules:
|
|
393
|
+
|
|
394
|
+
def llm_input_hook(tool_input: dict[str, Any]) -> dict[str, Any] | None:
|
|
395
|
+
# Extract messages from the data structure if needed
|
|
396
|
+
messages = tool_input if isinstance(tool_input, list) else tool_input.get("messages", tool_input)
|
|
397
|
+
result = self._check_llm_interaction(agent_name, messages, "input")
|
|
398
|
+
if isinstance(result, list) and isinstance(tool_input, dict) and "messages" in tool_input:
|
|
399
|
+
return {**tool_input, "messages": result}
|
|
400
|
+
elif isinstance(result, dict):
|
|
401
|
+
return result
|
|
402
|
+
elif result is not None and not isinstance(result, dict):
|
|
403
|
+
# Convert string or other types to dict format
|
|
404
|
+
return {"content": str(result), "role": "function"}
|
|
405
|
+
elif result is not None and isinstance(result, dict) and result != tool_input:
|
|
406
|
+
# Return the modified dict result
|
|
407
|
+
return result
|
|
408
|
+
return tool_input
|
|
409
|
+
|
|
410
|
+
def llm_output_hook(tool_input: dict[str, Any]) -> dict[str, Any] | None:
|
|
411
|
+
result = self._check_llm_interaction(agent_name, tool_input, "output")
|
|
412
|
+
if isinstance(result, dict):
|
|
413
|
+
return result
|
|
414
|
+
elif result is not None and not isinstance(result, dict):
|
|
415
|
+
# Convert string or other types to dict format
|
|
416
|
+
return {"content": str(result), "role": "function"}
|
|
417
|
+
elif result is not None and isinstance(result, dict) and result != tool_input:
|
|
418
|
+
# Return the modified dict result
|
|
419
|
+
return result
|
|
420
|
+
return tool_input
|
|
421
|
+
|
|
422
|
+
hooks["safeguard_llm_inputs"] = llm_input_hook
|
|
423
|
+
hooks["safeguard_llm_outputs"] = llm_output_hook
|
|
424
|
+
|
|
425
|
+
# Check if we have any user interaction rules that apply to this agent
|
|
426
|
+
agent_user_rules = [
|
|
427
|
+
rule
|
|
428
|
+
for rule in self.environment_rules
|
|
429
|
+
if rule["type"] == "user_interaction" and rule.get("message_destination") == agent_name
|
|
430
|
+
]
|
|
431
|
+
|
|
432
|
+
if agent_user_rules:
|
|
433
|
+
|
|
434
|
+
def human_input_hook(tool_input: dict[str, Any]) -> dict[str, Any] | None:
|
|
435
|
+
# Extract human input from data structure
|
|
436
|
+
human_input = tool_input.get("content", str(tool_input))
|
|
437
|
+
result = self._check_user_interaction(agent_name, human_input)
|
|
438
|
+
if result != human_input and isinstance(tool_input, dict):
|
|
439
|
+
return {**tool_input, "content": result}
|
|
440
|
+
return tool_input if result == human_input else {"content": result}
|
|
441
|
+
|
|
442
|
+
hooks["safeguard_human_inputs"] = human_input_hook
|
|
443
|
+
|
|
444
|
+
# Check if we have any inter-agent rules that apply to this agent
|
|
445
|
+
# Note: For group chats, inter-agent communication is handled by GroupChat._run_inter_agent_guardrails()
|
|
446
|
+
# But for direct agent-to-agent communication, we need the process_message_before_send hook
|
|
447
|
+
agent_inter_rules = [
|
|
448
|
+
rule
|
|
449
|
+
for rule in self.inter_agent_rules
|
|
450
|
+
if (
|
|
451
|
+
rule.get("source") == agent_name
|
|
452
|
+
or rule.get("target") == agent_name
|
|
453
|
+
or rule.get("source") == "*"
|
|
454
|
+
or rule.get("target") == "*"
|
|
455
|
+
)
|
|
456
|
+
]
|
|
457
|
+
|
|
458
|
+
if agent_inter_rules:
|
|
459
|
+
|
|
460
|
+
def message_before_send_hook(
|
|
461
|
+
sender: Any, message: dict[str, Any] | str, recipient: Any, silent: Any = None
|
|
462
|
+
) -> dict[str, Any] | str:
|
|
463
|
+
_ = silent # Unused parameter
|
|
464
|
+
result = self._check_inter_agent_communication(sender.name, recipient.name, message)
|
|
465
|
+
return result
|
|
466
|
+
|
|
467
|
+
hooks["process_message_before_send"] = message_before_send_hook # type: ignore[assignment]
|
|
468
|
+
|
|
469
|
+
return hooks
|
|
470
|
+
|
|
471
|
+
def _check_llm_violation(
|
|
472
|
+
self, content: str, disallow_items: list[str] | None = None, custom_prompt: str | None = None
|
|
473
|
+
) -> tuple[bool, str]:
|
|
474
|
+
"""Check if content violates LLM-based safeguard conditions."""
|
|
475
|
+
if not self.safeguard_llm_config:
|
|
476
|
+
raise ValueError("safeguard_llm_config is required for LLM-based safeguard checks")
|
|
477
|
+
|
|
478
|
+
# Determine condition based on available parameters
|
|
479
|
+
if custom_prompt:
|
|
480
|
+
condition = custom_prompt
|
|
481
|
+
elif disallow_items:
|
|
482
|
+
condition = (
|
|
483
|
+
f"Check if this content contains any of these disallowed categories: {', '.join(disallow_items)}"
|
|
484
|
+
)
|
|
485
|
+
else:
|
|
486
|
+
raise ValueError("Either custom_prompt or disallow_items must be provided")
|
|
487
|
+
|
|
488
|
+
# Create LLM guardrail for checking
|
|
489
|
+
# Handle dict config by converting to LLMConfig
|
|
490
|
+
llm_config = self.safeguard_llm_config
|
|
491
|
+
if isinstance(llm_config, dict):
|
|
492
|
+
llm_config = LLMConfig(config_list=[llm_config])
|
|
493
|
+
|
|
494
|
+
from ..targets.transition_target import TransitionTarget
|
|
495
|
+
|
|
496
|
+
guardrail = LLMGuardrail(
|
|
497
|
+
name="temp_safeguard_check",
|
|
498
|
+
condition=condition,
|
|
499
|
+
target=TransitionTarget(),
|
|
500
|
+
llm_config=llm_config,
|
|
501
|
+
activation_message="Content violates safeguard conditions",
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
try:
|
|
505
|
+
result = guardrail.check(content)
|
|
506
|
+
return result.activated, result.justification
|
|
507
|
+
except Exception as e:
|
|
508
|
+
raise RuntimeError(f"Safeguard check failed: {e}")
|
|
509
|
+
|
|
510
|
+
def _check_regex_violation(self, content: str, pattern: str) -> tuple[bool, str]:
|
|
511
|
+
"""Check if content matches a regex pattern."""
|
|
512
|
+
try:
|
|
513
|
+
if re.search(pattern, content, re.IGNORECASE):
|
|
514
|
+
return True, f"Content matched pattern: {pattern}"
|
|
515
|
+
except re.error as e:
|
|
516
|
+
raise ValueError(f"Invalid regex pattern '{pattern}': {e}")
|
|
517
|
+
|
|
518
|
+
return False, "No pattern match"
|
|
519
|
+
|
|
520
|
+
def _apply_action(
|
|
521
|
+
self,
|
|
522
|
+
action: str,
|
|
523
|
+
content: str | dict[str, Any] | list[Any],
|
|
524
|
+
disallow_items: list[str],
|
|
525
|
+
explanation: str,
|
|
526
|
+
custom_message: str | None = None,
|
|
527
|
+
pattern: str | None = None,
|
|
528
|
+
guardrail_type: str | None = None,
|
|
529
|
+
source_agent: str | None = None,
|
|
530
|
+
target_agent: str | None = None,
|
|
531
|
+
content_preview: str | None = None,
|
|
532
|
+
) -> str | dict[str, Any] | list[Any]:
|
|
533
|
+
"""Apply the specified action to content."""
|
|
534
|
+
message = custom_message or explanation
|
|
535
|
+
|
|
536
|
+
if action == "block":
|
|
537
|
+
self._send_safeguard_event(
|
|
538
|
+
event_type="action",
|
|
539
|
+
message=f"BLOCKED: {message}",
|
|
540
|
+
action="block",
|
|
541
|
+
source_agent=source_agent,
|
|
542
|
+
target_agent=target_agent,
|
|
543
|
+
content_preview=content_preview,
|
|
544
|
+
)
|
|
545
|
+
return self._handle_blocked_content(content, message)
|
|
546
|
+
elif action == "mask":
|
|
547
|
+
self._send_safeguard_event(
|
|
548
|
+
event_type="action",
|
|
549
|
+
message=f"MASKED: {message}",
|
|
550
|
+
action="mask",
|
|
551
|
+
source_agent=source_agent,
|
|
552
|
+
target_agent=target_agent,
|
|
553
|
+
content_preview=content_preview,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
def mask_func(text: str) -> str:
|
|
557
|
+
return self._mask_content(text, disallow_items, explanation, pattern)
|
|
558
|
+
|
|
559
|
+
return self._handle_masked_content(content, mask_func)
|
|
560
|
+
elif action == "warning":
|
|
561
|
+
self._send_safeguard_event(
|
|
562
|
+
event_type="action",
|
|
563
|
+
message=f"WARNING: {message}",
|
|
564
|
+
action="warning",
|
|
565
|
+
source_agent=source_agent,
|
|
566
|
+
target_agent=target_agent,
|
|
567
|
+
content_preview=content_preview,
|
|
568
|
+
)
|
|
569
|
+
return content
|
|
570
|
+
else:
|
|
571
|
+
return content
|
|
572
|
+
|
|
573
|
+
def _mask_content(
|
|
574
|
+
self, content: str, disallow_items: list[str], explanation: str, pattern: str | None = None
|
|
575
|
+
) -> str:
|
|
576
|
+
"""Mask sensitive content using LLM, pattern-based, or simple replacement."""
|
|
577
|
+
# If we have a specific pattern from a regex guardrail, use it first
|
|
578
|
+
if pattern:
|
|
579
|
+
try:
|
|
580
|
+
masked = re.sub(pattern, "[SENSITIVE_INFO]", content, flags=re.IGNORECASE)
|
|
581
|
+
if masked != content: # Only return if something was actually masked
|
|
582
|
+
return masked
|
|
583
|
+
except re.error as e:
|
|
584
|
+
raise ValueError(f"Pattern masking failed: {e}")
|
|
585
|
+
|
|
586
|
+
# Try LLM-based masking if available
|
|
587
|
+
if self.mask_agent and disallow_items:
|
|
588
|
+
mask_prompt = f"""
|
|
589
|
+
Mask the sensitive information in this content with [SENSITIVE_INFO]:
|
|
590
|
+
|
|
591
|
+
Content: {content}
|
|
592
|
+
Sensitive categories: {", ".join(disallow_items)}
|
|
593
|
+
Reason: {explanation}
|
|
594
|
+
|
|
595
|
+
Return only the masked content, nothing else.
|
|
596
|
+
"""
|
|
597
|
+
|
|
598
|
+
try:
|
|
599
|
+
response = self.mask_agent.generate_oai_reply(messages=[{"role": "user", "content": mask_prompt}])
|
|
600
|
+
|
|
601
|
+
if response[0] and response[1]:
|
|
602
|
+
masked = response[1].get("content", content) if isinstance(response[1], dict) else str(response[1])
|
|
603
|
+
return masked
|
|
604
|
+
except Exception as e:
|
|
605
|
+
raise ValueError(f"LLM masking failed: {e}")
|
|
606
|
+
|
|
607
|
+
return masked
|
|
608
|
+
|
|
609
|
+
def _handle_blocked_content(
|
|
610
|
+
self, content: str | dict[str, Any] | list[Any], block_message: str
|
|
611
|
+
) -> str | dict[str, Any] | list[Any]:
|
|
612
|
+
"""Handle blocked content based on its structure."""
|
|
613
|
+
block_msg = f"🛡️ BLOCKED: {block_message}"
|
|
614
|
+
|
|
615
|
+
if isinstance(content, dict):
|
|
616
|
+
blocked_content = content.copy()
|
|
617
|
+
|
|
618
|
+
# Handle tool_responses (like in tool outputs)
|
|
619
|
+
if "tool_responses" in blocked_content and blocked_content["tool_responses"]:
|
|
620
|
+
blocked_content["content"] = block_msg
|
|
621
|
+
blocked_content["tool_responses"] = [
|
|
622
|
+
{**response, "content": block_msg} for response in blocked_content["tool_responses"]
|
|
623
|
+
]
|
|
624
|
+
# Handle tool_calls (like in tool inputs)
|
|
625
|
+
elif "tool_calls" in blocked_content and blocked_content["tool_calls"]:
|
|
626
|
+
blocked_content["tool_calls"] = [
|
|
627
|
+
{**tool_call, "function": {**tool_call["function"], "arguments": block_msg}}
|
|
628
|
+
for tool_call in blocked_content["tool_calls"]
|
|
629
|
+
]
|
|
630
|
+
# Handle regular content
|
|
631
|
+
elif "content" in blocked_content:
|
|
632
|
+
blocked_content["content"] = block_msg
|
|
633
|
+
# Handle arguments (for some tool formats)
|
|
634
|
+
elif "arguments" in blocked_content:
|
|
635
|
+
blocked_content["arguments"] = block_msg
|
|
636
|
+
else:
|
|
637
|
+
# Default case - add content field
|
|
638
|
+
blocked_content["content"] = block_msg
|
|
639
|
+
|
|
640
|
+
return blocked_content
|
|
641
|
+
|
|
642
|
+
elif isinstance(content, list):
|
|
643
|
+
# Handle list of messages (like LLM inputs)
|
|
644
|
+
blocked_list = []
|
|
645
|
+
for item in content:
|
|
646
|
+
if isinstance(item, dict):
|
|
647
|
+
blocked_item = item.copy()
|
|
648
|
+
if "content" in blocked_item:
|
|
649
|
+
blocked_item["content"] = block_msg
|
|
650
|
+
if "tool_calls" in blocked_item:
|
|
651
|
+
blocked_item["tool_calls"] = [
|
|
652
|
+
{**tool_call, "function": {**tool_call["function"], "arguments": block_msg}}
|
|
653
|
+
for tool_call in blocked_item["tool_calls"]
|
|
654
|
+
]
|
|
655
|
+
if "tool_responses" in blocked_item:
|
|
656
|
+
blocked_item["tool_responses"] = [
|
|
657
|
+
{**response, "content": block_msg} for response in blocked_item["tool_responses"]
|
|
658
|
+
]
|
|
659
|
+
blocked_list.append(blocked_item)
|
|
660
|
+
else:
|
|
661
|
+
blocked_list.append({"content": block_msg, "role": "function"})
|
|
662
|
+
return blocked_list
|
|
663
|
+
|
|
664
|
+
else:
|
|
665
|
+
# String or other content - return as function message
|
|
666
|
+
return {"content": block_msg, "role": "function"}
|
|
667
|
+
|
|
668
|
+
def _handle_masked_content(
|
|
669
|
+
self, content: str | dict[str, Any] | list[Any], mask_func: Callable[[str], str]
|
|
670
|
+
) -> str | dict[str, Any] | list[Any]:
|
|
671
|
+
"""Handle masked content based on its structure."""
|
|
672
|
+
if isinstance(content, dict):
|
|
673
|
+
masked_content = content.copy()
|
|
674
|
+
|
|
675
|
+
# Handle tool_responses
|
|
676
|
+
if "tool_responses" in masked_content and masked_content["tool_responses"]:
|
|
677
|
+
if "content" in masked_content:
|
|
678
|
+
masked_content["content"] = mask_func(str(masked_content["content"]))
|
|
679
|
+
masked_content["tool_responses"] = [
|
|
680
|
+
{**response, "content": mask_func(str(response.get("content", "")))}
|
|
681
|
+
for response in masked_content["tool_responses"]
|
|
682
|
+
]
|
|
683
|
+
# Handle tool_calls
|
|
684
|
+
elif "tool_calls" in masked_content and masked_content["tool_calls"]:
|
|
685
|
+
masked_content["tool_calls"] = [
|
|
686
|
+
{
|
|
687
|
+
**tool_call,
|
|
688
|
+
"function": {
|
|
689
|
+
**tool_call["function"],
|
|
690
|
+
"arguments": mask_func(str(tool_call["function"].get("arguments", ""))),
|
|
691
|
+
},
|
|
692
|
+
}
|
|
693
|
+
for tool_call in masked_content["tool_calls"]
|
|
694
|
+
]
|
|
695
|
+
# Handle regular content
|
|
696
|
+
elif "content" in masked_content:
|
|
697
|
+
masked_content["content"] = mask_func(str(masked_content["content"]))
|
|
698
|
+
# Handle arguments
|
|
699
|
+
elif "arguments" in masked_content:
|
|
700
|
+
masked_content["arguments"] = mask_func(str(masked_content["arguments"]))
|
|
701
|
+
|
|
702
|
+
return masked_content
|
|
703
|
+
|
|
704
|
+
elif isinstance(content, list):
|
|
705
|
+
# Handle list of messages
|
|
706
|
+
masked_list = []
|
|
707
|
+
for item in content:
|
|
708
|
+
if isinstance(item, dict):
|
|
709
|
+
masked_item = item.copy()
|
|
710
|
+
if "content" in masked_item:
|
|
711
|
+
masked_item["content"] = mask_func(str(masked_item["content"]))
|
|
712
|
+
if "tool_calls" in masked_item:
|
|
713
|
+
masked_item["tool_calls"] = [
|
|
714
|
+
{
|
|
715
|
+
**tool_call,
|
|
716
|
+
"function": {
|
|
717
|
+
**tool_call["function"],
|
|
718
|
+
"arguments": mask_func(str(tool_call["function"].get("arguments", ""))),
|
|
719
|
+
},
|
|
720
|
+
}
|
|
721
|
+
for tool_call in masked_item["tool_calls"]
|
|
722
|
+
]
|
|
723
|
+
if "tool_responses" in masked_item:
|
|
724
|
+
masked_item["tool_responses"] = [
|
|
725
|
+
{**response, "content": mask_func(str(response.get("content", "")))}
|
|
726
|
+
for response in masked_item["tool_responses"]
|
|
727
|
+
]
|
|
728
|
+
masked_list.append(masked_item)
|
|
729
|
+
else:
|
|
730
|
+
# For non-dict items, wrap the masked content in a dict
|
|
731
|
+
masked_item_content: str = mask_func(str(item))
|
|
732
|
+
masked_list.append({"content": masked_item_content, "role": "function"})
|
|
733
|
+
return masked_list
|
|
734
|
+
|
|
735
|
+
else:
|
|
736
|
+
# String content
|
|
737
|
+
return mask_func(str(content))
|
|
738
|
+
|
|
739
|
+
def _check_inter_agent_communication(
|
|
740
|
+
self, sender_name: str, recipient_name: str, message: str | dict[str, Any]
|
|
741
|
+
) -> str | dict[str, Any]:
|
|
742
|
+
"""Check inter-agent communication."""
|
|
743
|
+
content = message.get("content", "") if isinstance(message, dict) else str(message)
|
|
744
|
+
|
|
745
|
+
for rule in self.inter_agent_rules:
|
|
746
|
+
if rule["type"] == "agent_transition":
|
|
747
|
+
# Check if this rule applies
|
|
748
|
+
source_match = rule["source"] == "*" or rule["source"] == sender_name
|
|
749
|
+
target_match = rule["target"] == "*" or rule["target"] == recipient_name
|
|
750
|
+
|
|
751
|
+
if source_match and target_match:
|
|
752
|
+
# Prepare content preview
|
|
753
|
+
content_preview = content[:100] + ("..." if len(content) > 100 else "")
|
|
754
|
+
|
|
755
|
+
# Use guardrail if available
|
|
756
|
+
if "guardrail" in rule and rule["guardrail"]:
|
|
757
|
+
# Send single check event with guardrail info
|
|
758
|
+
self._send_safeguard_event(
|
|
759
|
+
event_type="check",
|
|
760
|
+
message="Checking inter-agent communication",
|
|
761
|
+
source_agent=sender_name,
|
|
762
|
+
target_agent=recipient_name,
|
|
763
|
+
guardrail_type=type(rule["guardrail"]).__name__,
|
|
764
|
+
# action=rule.get('action', 'N/A'),
|
|
765
|
+
content_preview=content_preview,
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
try:
|
|
769
|
+
result = rule["guardrail"].check(content)
|
|
770
|
+
if result.activated:
|
|
771
|
+
self._send_safeguard_event(
|
|
772
|
+
event_type="violation",
|
|
773
|
+
message=f"VIOLATION DETECTED: {result.justification}",
|
|
774
|
+
source_agent=sender_name,
|
|
775
|
+
target_agent=recipient_name,
|
|
776
|
+
guardrail_type=type(rule["guardrail"]).__name__,
|
|
777
|
+
content_preview=content_preview,
|
|
778
|
+
)
|
|
779
|
+
# Pass the pattern if it's a regex guardrail
|
|
780
|
+
pattern = rule.get("pattern") if isinstance(rule["guardrail"], RegexGuardrail) else None
|
|
781
|
+
action_result = self._apply_action(
|
|
782
|
+
action=rule["action"],
|
|
783
|
+
content=message,
|
|
784
|
+
disallow_items=[],
|
|
785
|
+
explanation=result.justification,
|
|
786
|
+
custom_message=rule.get("activation_message", result.justification),
|
|
787
|
+
pattern=pattern,
|
|
788
|
+
guardrail_type=type(rule["guardrail"]).__name__,
|
|
789
|
+
source_agent=sender_name,
|
|
790
|
+
target_agent=recipient_name,
|
|
791
|
+
content_preview=content_preview,
|
|
792
|
+
)
|
|
793
|
+
if isinstance(action_result, (str, dict)):
|
|
794
|
+
return action_result
|
|
795
|
+
else:
|
|
796
|
+
return message
|
|
797
|
+
else:
|
|
798
|
+
# Content passed - no additional event needed, already sent check event above
|
|
799
|
+
pass
|
|
800
|
+
except Exception as e:
|
|
801
|
+
raise ValueError(f"Guardrail check failed: {e}")
|
|
802
|
+
|
|
803
|
+
# Handle legacy pattern-based rules
|
|
804
|
+
elif "pattern" in rule and rule["pattern"]:
|
|
805
|
+
# Send single check event for pattern-based rules
|
|
806
|
+
self._send_safeguard_event(
|
|
807
|
+
event_type="check",
|
|
808
|
+
message="Checking inter-agent communication",
|
|
809
|
+
source_agent=sender_name,
|
|
810
|
+
target_agent=recipient_name,
|
|
811
|
+
guardrail_type="RegexGuardrail",
|
|
812
|
+
# action=rule.get('action', 'N/A'),
|
|
813
|
+
content_preview=content_preview,
|
|
814
|
+
)
|
|
815
|
+
is_violation, explanation = self._check_regex_violation(content, rule["pattern"])
|
|
816
|
+
if is_violation:
|
|
817
|
+
result_value = self._apply_action(
|
|
818
|
+
action=rule["action"],
|
|
819
|
+
content=message,
|
|
820
|
+
disallow_items=[],
|
|
821
|
+
explanation=explanation,
|
|
822
|
+
custom_message=rule.get("activation_message"),
|
|
823
|
+
pattern=rule["pattern"],
|
|
824
|
+
guardrail_type="RegexGuardrail",
|
|
825
|
+
source_agent=sender_name,
|
|
826
|
+
target_agent=recipient_name,
|
|
827
|
+
content_preview=content_preview,
|
|
828
|
+
)
|
|
829
|
+
if isinstance(result_value, (str, dict)):
|
|
830
|
+
return result_value
|
|
831
|
+
else:
|
|
832
|
+
return message
|
|
833
|
+
else:
|
|
834
|
+
pass
|
|
835
|
+
|
|
836
|
+
# Handle legacy disallow-based rules and custom prompts
|
|
837
|
+
elif "disallow" in rule or "custom_prompt" in rule:
|
|
838
|
+
# Send single check event for LLM-based legacy rules
|
|
839
|
+
self._send_safeguard_event(
|
|
840
|
+
event_type="check",
|
|
841
|
+
message="Checking inter-agent communication",
|
|
842
|
+
source_agent=sender_name,
|
|
843
|
+
target_agent=recipient_name,
|
|
844
|
+
guardrail_type="LLMGuardrail",
|
|
845
|
+
# action=rule.get('action', 'N/A'),
|
|
846
|
+
content_preview=content_preview,
|
|
847
|
+
)
|
|
848
|
+
if "custom_prompt" in rule:
|
|
849
|
+
is_violation, explanation = self._check_llm_violation(
|
|
850
|
+
content, custom_prompt=rule["custom_prompt"]
|
|
851
|
+
)
|
|
852
|
+
else:
|
|
853
|
+
is_violation, explanation = self._check_llm_violation(
|
|
854
|
+
content, disallow_items=rule["disallow"]
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
if is_violation:
|
|
858
|
+
result_value = self._apply_action(
|
|
859
|
+
action=rule["action"],
|
|
860
|
+
content=message,
|
|
861
|
+
disallow_items=rule.get("disallow", []),
|
|
862
|
+
explanation=explanation,
|
|
863
|
+
custom_message=None,
|
|
864
|
+
pattern=None,
|
|
865
|
+
guardrail_type="LLMGuardrail",
|
|
866
|
+
source_agent=sender_name,
|
|
867
|
+
target_agent=recipient_name,
|
|
868
|
+
content_preview=content_preview,
|
|
869
|
+
)
|
|
870
|
+
if isinstance(result_value, (str, dict)):
|
|
871
|
+
return result_value
|
|
872
|
+
else:
|
|
873
|
+
return message
|
|
874
|
+
else:
|
|
875
|
+
pass
|
|
876
|
+
|
|
877
|
+
return message
|
|
878
|
+
|
|
879
|
+
def _check_interaction(
|
|
880
|
+
self,
|
|
881
|
+
interaction_type: str,
|
|
882
|
+
source_name: str,
|
|
883
|
+
dest_name: str,
|
|
884
|
+
content: str,
|
|
885
|
+
data: str | dict[str, Any] | list[dict[str, Any]],
|
|
886
|
+
context_info: str,
|
|
887
|
+
) -> str | dict[str, Any] | list[dict[str, Any]] | None:
|
|
888
|
+
"""Unified method to check any type of interaction."""
|
|
889
|
+
for rule in self.environment_rules:
|
|
890
|
+
if (
|
|
891
|
+
rule["type"] == interaction_type
|
|
892
|
+
and "message_source" in rule
|
|
893
|
+
and "message_destination" in rule
|
|
894
|
+
and rule["message_source"] == source_name
|
|
895
|
+
and rule["message_destination"] == dest_name
|
|
896
|
+
):
|
|
897
|
+
content_preview = content[:100] + ("..." if len(content) > 100 else "")
|
|
898
|
+
check_method = rule.get("check_method", "regex")
|
|
899
|
+
guardrail_type = "LLMGuardrail" if check_method == "llm" else "RegexGuardrail"
|
|
900
|
+
|
|
901
|
+
# Send check event
|
|
902
|
+
self._send_safeguard_event(
|
|
903
|
+
event_type="check",
|
|
904
|
+
message=f"Checking {interaction_type.replace('_', ' ')}: {context_info}",
|
|
905
|
+
source_agent=source_name,
|
|
906
|
+
target_agent=dest_name,
|
|
907
|
+
guardrail_type=guardrail_type,
|
|
908
|
+
content_preview=content_preview,
|
|
909
|
+
)
|
|
910
|
+
|
|
911
|
+
# Perform check based on method
|
|
912
|
+
is_violation, explanation = self._perform_check(rule, content, check_method)
|
|
913
|
+
|
|
914
|
+
if is_violation:
|
|
915
|
+
# Send violation event
|
|
916
|
+
self._send_safeguard_event(
|
|
917
|
+
event_type="violation",
|
|
918
|
+
message=f"{guardrail_type.replace('Guardrail', '').upper()} VIOLATION: {explanation}",
|
|
919
|
+
source_agent=source_name,
|
|
920
|
+
target_agent=dest_name,
|
|
921
|
+
guardrail_type=guardrail_type,
|
|
922
|
+
content_preview=content_preview,
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
# Apply action
|
|
926
|
+
result = self._apply_action(
|
|
927
|
+
action=rule["action"],
|
|
928
|
+
content=data,
|
|
929
|
+
disallow_items=rule.get("disallow_item", []),
|
|
930
|
+
explanation=explanation,
|
|
931
|
+
custom_message=rule.get("activation_message"),
|
|
932
|
+
pattern=rule.get("pattern"),
|
|
933
|
+
guardrail_type=guardrail_type,
|
|
934
|
+
source_agent=source_name,
|
|
935
|
+
target_agent=dest_name,
|
|
936
|
+
content_preview=content_preview,
|
|
937
|
+
)
|
|
938
|
+
return result
|
|
939
|
+
|
|
940
|
+
return None
|
|
941
|
+
|
|
942
|
+
def _perform_check(self, rule: dict[str, Any], content: str, check_method: str) -> tuple[bool, str]:
|
|
943
|
+
"""Perform the actual check based on the method."""
|
|
944
|
+
if check_method == "llm":
|
|
945
|
+
if not self.safeguard_llm_config:
|
|
946
|
+
raise ValueError(
|
|
947
|
+
f"safeguard_llm_config is required for LLM-based {rule['type']} rule: {rule['message_source']} -> {rule['message_destination']}"
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
if "custom_prompt" in rule:
|
|
951
|
+
return self._check_llm_violation(content, custom_prompt=rule["custom_prompt"])
|
|
952
|
+
elif "disallow_item" in rule:
|
|
953
|
+
return self._check_llm_violation(content, disallow_items=rule["disallow_item"])
|
|
954
|
+
else:
|
|
955
|
+
raise ValueError(
|
|
956
|
+
f"Either custom_prompt or disallow_item must be provided for LLM-based {rule['type']}: {rule['message_source']} -> {rule['message_destination']}"
|
|
957
|
+
)
|
|
958
|
+
|
|
959
|
+
elif check_method == "regex":
|
|
960
|
+
if "pattern" not in rule:
|
|
961
|
+
raise ValueError(
|
|
962
|
+
f"pattern is required for regex-based {rule['type']}: {rule['message_source']} -> {rule['message_destination']}"
|
|
963
|
+
)
|
|
964
|
+
return self._check_regex_violation(content, rule["pattern"])
|
|
965
|
+
|
|
966
|
+
else:
|
|
967
|
+
raise ValueError(f"Unsupported check_method: {check_method}")
|
|
968
|
+
|
|
969
|
+
def _check_tool_interaction(self, agent_name: str, data: dict[str, Any], direction: str) -> dict[str, Any]:
|
|
970
|
+
"""Check tool interactions."""
|
|
971
|
+
# Extract tool name from data
|
|
972
|
+
tool_name = data.get("name", data.get("tool_name", ""))
|
|
973
|
+
|
|
974
|
+
# Determine source/destination based on direction
|
|
975
|
+
if direction == "output":
|
|
976
|
+
source_name, dest_name = tool_name, agent_name
|
|
977
|
+
content = str(data.get("content", ""))
|
|
978
|
+
else: # input
|
|
979
|
+
source_name, dest_name = agent_name, tool_name
|
|
980
|
+
content = str(data.get("arguments", ""))
|
|
981
|
+
|
|
982
|
+
result = self._check_interaction(
|
|
983
|
+
interaction_type="tool_interaction",
|
|
984
|
+
source_name=source_name,
|
|
985
|
+
dest_name=dest_name,
|
|
986
|
+
content=content,
|
|
987
|
+
data=data,
|
|
988
|
+
context_info=f"{agent_name} <-> {tool_name} ({direction})",
|
|
989
|
+
)
|
|
990
|
+
|
|
991
|
+
if result is not None:
|
|
992
|
+
if isinstance(result, dict):
|
|
993
|
+
return result
|
|
994
|
+
else:
|
|
995
|
+
# Convert string or list result back to dict format
|
|
996
|
+
return {"content": str(result), "name": tool_name}
|
|
997
|
+
return data
|
|
998
|
+
|
|
999
|
+
def _check_llm_interaction(
|
|
1000
|
+
self, agent_name: str, data: str | dict[str, Any] | list[dict[str, Any]], direction: str
|
|
1001
|
+
) -> str | dict[str, Any] | list[dict[str, Any]]:
|
|
1002
|
+
"""Check LLM interactions."""
|
|
1003
|
+
content = str(data)
|
|
1004
|
+
|
|
1005
|
+
# Determine source/destination based on direction
|
|
1006
|
+
if direction == "input":
|
|
1007
|
+
source_name, dest_name = agent_name, "llm"
|
|
1008
|
+
else: # output
|
|
1009
|
+
source_name, dest_name = "llm", agent_name
|
|
1010
|
+
|
|
1011
|
+
result = self._check_interaction(
|
|
1012
|
+
interaction_type="llm_interaction",
|
|
1013
|
+
source_name=source_name,
|
|
1014
|
+
dest_name=dest_name,
|
|
1015
|
+
content=content,
|
|
1016
|
+
data=data,
|
|
1017
|
+
context_info=f"{agent_name} <-> llm ({direction})",
|
|
1018
|
+
)
|
|
1019
|
+
|
|
1020
|
+
return result if result is not None else data
|
|
1021
|
+
|
|
1022
|
+
def _check_user_interaction(self, agent_name: str, user_input: str) -> str | None:
|
|
1023
|
+
"""Check user interactions."""
|
|
1024
|
+
result = self._check_interaction(
|
|
1025
|
+
interaction_type="user_interaction",
|
|
1026
|
+
source_name="user",
|
|
1027
|
+
dest_name=agent_name,
|
|
1028
|
+
content=user_input,
|
|
1029
|
+
data=user_input,
|
|
1030
|
+
context_info=f"user <-> {agent_name}",
|
|
1031
|
+
)
|
|
1032
|
+
|
|
1033
|
+
if result is not None and isinstance(result, str):
|
|
1034
|
+
return result
|
|
1035
|
+
return user_input
|
|
1036
|
+
|
|
1037
|
+
def check_and_act(
|
|
1038
|
+
self, src_agent_name: str, dst_agent_name: str, message_content: str | dict[str, Any]
|
|
1039
|
+
) -> str | dict[str, Any] | None:
|
|
1040
|
+
"""Check and act on inter-agent communication for GroupChat integration.
|
|
1041
|
+
|
|
1042
|
+
This method is called by GroupChat._run_inter_agent_guardrails to check
|
|
1043
|
+
messages between agents and potentially modify or block them.
|
|
1044
|
+
|
|
1045
|
+
Args:
|
|
1046
|
+
src_agent_name: Name of the source agent
|
|
1047
|
+
dst_agent_name: Name of the destination agent
|
|
1048
|
+
message_content: The message content to check
|
|
1049
|
+
|
|
1050
|
+
Returns:
|
|
1051
|
+
Optional replacement message if a safeguard triggers, None otherwise
|
|
1052
|
+
"""
|
|
1053
|
+
# Store original content for comparison
|
|
1054
|
+
original_content = (
|
|
1055
|
+
message_content.get("content", "") if isinstance(message_content, dict) else str(message_content)
|
|
1056
|
+
)
|
|
1057
|
+
|
|
1058
|
+
result = self._check_inter_agent_communication(src_agent_name, dst_agent_name, message_content)
|
|
1059
|
+
|
|
1060
|
+
if result != original_content:
|
|
1061
|
+
# Return the complete modified message structure to preserve tool_calls/tool_responses pairing
|
|
1062
|
+
return result
|
|
1063
|
+
|
|
1064
|
+
return None
|