hackagent 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hackagent/__init__.py +12 -0
- hackagent/agent.py +214 -0
- hackagent/api/__init__.py +1 -0
- hackagent/api/agent/__init__.py +1 -0
- hackagent/api/agent/agent_create.py +347 -0
- hackagent/api/agent/agent_destroy.py +140 -0
- hackagent/api/agent/agent_list.py +242 -0
- hackagent/api/agent/agent_partial_update.py +361 -0
- hackagent/api/agent/agent_retrieve.py +235 -0
- hackagent/api/agent/agent_update.py +361 -0
- hackagent/api/apilogs/__init__.py +1 -0
- hackagent/api/apilogs/apilogs_list.py +170 -0
- hackagent/api/apilogs/apilogs_retrieve.py +162 -0
- hackagent/api/attack/__init__.py +1 -0
- hackagent/api/attack/attack_create.py +275 -0
- hackagent/api/attack/attack_destroy.py +146 -0
- hackagent/api/attack/attack_list.py +254 -0
- hackagent/api/attack/attack_partial_update.py +289 -0
- hackagent/api/attack/attack_retrieve.py +247 -0
- hackagent/api/attack/attack_update.py +289 -0
- hackagent/api/checkout/__init__.py +1 -0
- hackagent/api/checkout/checkout_create.py +225 -0
- hackagent/api/generate/__init__.py +1 -0
- hackagent/api/generate/generate_create.py +253 -0
- hackagent/api/judge/__init__.py +1 -0
- hackagent/api/judge/judge_create.py +253 -0
- hackagent/api/key/__init__.py +1 -0
- hackagent/api/key/key_create.py +179 -0
- hackagent/api/key/key_destroy.py +103 -0
- hackagent/api/key/key_list.py +170 -0
- hackagent/api/key/key_retrieve.py +162 -0
- hackagent/api/organization/__init__.py +1 -0
- hackagent/api/organization/organization_create.py +208 -0
- hackagent/api/organization/organization_destroy.py +104 -0
- hackagent/api/organization/organization_list.py +170 -0
- hackagent/api/organization/organization_me_retrieve.py +126 -0
- hackagent/api/organization/organization_partial_update.py +222 -0
- hackagent/api/organization/organization_retrieve.py +163 -0
- hackagent/api/organization/organization_update.py +222 -0
- hackagent/api/prompt/__init__.py +1 -0
- hackagent/api/prompt/prompt_create.py +171 -0
- hackagent/api/prompt/prompt_destroy.py +104 -0
- hackagent/api/prompt/prompt_list.py +185 -0
- hackagent/api/prompt/prompt_partial_update.py +185 -0
- hackagent/api/prompt/prompt_retrieve.py +163 -0
- hackagent/api/prompt/prompt_update.py +185 -0
- hackagent/api/result/__init__.py +1 -0
- hackagent/api/result/result_create.py +175 -0
- hackagent/api/result/result_destroy.py +106 -0
- hackagent/api/result/result_list.py +249 -0
- hackagent/api/result/result_partial_update.py +193 -0
- hackagent/api/result/result_retrieve.py +167 -0
- hackagent/api/result/result_trace_create.py +177 -0
- hackagent/api/result/result_update.py +189 -0
- hackagent/api/run/__init__.py +1 -0
- hackagent/api/run/run_create.py +187 -0
- hackagent/api/run/run_destroy.py +112 -0
- hackagent/api/run/run_list.py +291 -0
- hackagent/api/run/run_partial_update.py +201 -0
- hackagent/api/run/run_result_create.py +177 -0
- hackagent/api/run/run_retrieve.py +179 -0
- hackagent/api/run/run_run_tests_create.py +187 -0
- hackagent/api/run/run_update.py +201 -0
- hackagent/api/user/__init__.py +1 -0
- hackagent/api/user/user_create.py +212 -0
- hackagent/api/user/user_destroy.py +106 -0
- hackagent/api/user/user_list.py +174 -0
- hackagent/api/user/user_me_retrieve.py +126 -0
- hackagent/api/user/user_me_update.py +196 -0
- hackagent/api/user/user_partial_update.py +226 -0
- hackagent/api/user/user_retrieve.py +167 -0
- hackagent/api/user/user_update.py +226 -0
- hackagent/attacks/AdvPrefix/__init__.py +41 -0
- hackagent/attacks/AdvPrefix/completions.py +416 -0
- hackagent/attacks/AdvPrefix/config.py +259 -0
- hackagent/attacks/AdvPrefix/evaluation.py +745 -0
- hackagent/attacks/AdvPrefix/evaluators.py +564 -0
- hackagent/attacks/AdvPrefix/generate.py +711 -0
- hackagent/attacks/AdvPrefix/utils.py +307 -0
- hackagent/attacks/__init__.py +35 -0
- hackagent/attacks/advprefix.py +507 -0
- hackagent/attacks/base.py +106 -0
- hackagent/attacks/strategies.py +906 -0
- hackagent/cli/__init__.py +19 -0
- hackagent/cli/commands/__init__.py +20 -0
- hackagent/cli/commands/agent.py +100 -0
- hackagent/cli/commands/attack.py +417 -0
- hackagent/cli/commands/config.py +301 -0
- hackagent/cli/commands/results.py +327 -0
- hackagent/cli/config.py +249 -0
- hackagent/cli/main.py +515 -0
- hackagent/cli/tui/__init__.py +31 -0
- hackagent/cli/tui/actions_logger.py +200 -0
- hackagent/cli/tui/app.py +288 -0
- hackagent/cli/tui/base.py +137 -0
- hackagent/cli/tui/logger.py +318 -0
- hackagent/cli/tui/views/__init__.py +33 -0
- hackagent/cli/tui/views/agents.py +488 -0
- hackagent/cli/tui/views/attacks.py +624 -0
- hackagent/cli/tui/views/config.py +244 -0
- hackagent/cli/tui/views/dashboard.py +307 -0
- hackagent/cli/tui/views/results.py +1210 -0
- hackagent/cli/tui/widgets/__init__.py +24 -0
- hackagent/cli/tui/widgets/actions.py +346 -0
- hackagent/cli/tui/widgets/logs.py +435 -0
- hackagent/cli/utils.py +276 -0
- hackagent/client.py +286 -0
- hackagent/errors.py +37 -0
- hackagent/logger.py +83 -0
- hackagent/models/__init__.py +109 -0
- hackagent/models/agent.py +223 -0
- hackagent/models/agent_request.py +129 -0
- hackagent/models/api_token_log.py +184 -0
- hackagent/models/attack.py +154 -0
- hackagent/models/attack_request.py +82 -0
- hackagent/models/checkout_session_request_request.py +76 -0
- hackagent/models/checkout_session_response.py +59 -0
- hackagent/models/choice.py +81 -0
- hackagent/models/choice_message.py +67 -0
- hackagent/models/evaluation_status_enum.py +14 -0
- hackagent/models/generate_error_response.py +59 -0
- hackagent/models/generate_request_request.py +212 -0
- hackagent/models/generate_success_response.py +115 -0
- hackagent/models/generic_error_response.py +70 -0
- hackagent/models/message_request.py +67 -0
- hackagent/models/organization.py +102 -0
- hackagent/models/organization_minimal.py +68 -0
- hackagent/models/organization_request.py +71 -0
- hackagent/models/paginated_agent_list.py +123 -0
- hackagent/models/paginated_api_token_log_list.py +123 -0
- hackagent/models/paginated_attack_list.py +123 -0
- hackagent/models/paginated_organization_list.py +123 -0
- hackagent/models/paginated_prompt_list.py +123 -0
- hackagent/models/paginated_result_list.py +123 -0
- hackagent/models/paginated_run_list.py +123 -0
- hackagent/models/paginated_user_api_key_list.py +123 -0
- hackagent/models/paginated_user_profile_list.py +123 -0
- hackagent/models/patched_agent_request.py +128 -0
- hackagent/models/patched_attack_request.py +92 -0
- hackagent/models/patched_organization_request.py +71 -0
- hackagent/models/patched_prompt_request.py +125 -0
- hackagent/models/patched_result_request.py +237 -0
- hackagent/models/patched_run_request.py +138 -0
- hackagent/models/patched_user_profile_request.py +99 -0
- hackagent/models/prompt.py +220 -0
- hackagent/models/prompt_request.py +126 -0
- hackagent/models/result.py +294 -0
- hackagent/models/result_list_evaluation_status.py +14 -0
- hackagent/models/result_request.py +232 -0
- hackagent/models/run.py +233 -0
- hackagent/models/run_list_status.py +12 -0
- hackagent/models/run_request.py +133 -0
- hackagent/models/status_enum.py +12 -0
- hackagent/models/step_type_enum.py +14 -0
- hackagent/models/trace.py +121 -0
- hackagent/models/trace_request.py +94 -0
- hackagent/models/usage.py +75 -0
- hackagent/models/user_api_key.py +201 -0
- hackagent/models/user_api_key_request.py +73 -0
- hackagent/models/user_profile.py +135 -0
- hackagent/models/user_profile_minimal.py +76 -0
- hackagent/models/user_profile_request.py +99 -0
- hackagent/router/__init__.py +25 -0
- hackagent/router/adapters/__init__.py +20 -0
- hackagent/router/adapters/base.py +63 -0
- hackagent/router/adapters/google_adk.py +671 -0
- hackagent/router/adapters/litellm_adapter.py +524 -0
- hackagent/router/adapters/openai_adapter.py +426 -0
- hackagent/router/router.py +969 -0
- hackagent/router/types.py +54 -0
- hackagent/tracking/__init__.py +42 -0
- hackagent/tracking/context.py +163 -0
- hackagent/tracking/decorators.py +299 -0
- hackagent/tracking/tracker.py +441 -0
- hackagent/types.py +54 -0
- hackagent/utils.py +194 -0
- hackagent/vulnerabilities/__init__.py +13 -0
- hackagent/vulnerabilities/prompts.py +81 -0
- hackagent-0.3.1.dist-info/METADATA +122 -0
- hackagent-0.3.1.dist-info/RECORD +183 -0
- hackagent-0.3.1.dist-info/WHEEL +4 -0
- hackagent-0.3.1.dist-info/entry_points.txt +2 -0
- hackagent-0.3.1.dist-info/licenses/LICENSE +202 -0
|
@@ -0,0 +1,711 @@
|
|
|
1
|
+
# Copyright 2025 - AI4I. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Refactored adversarial prefix generation module with unified class-based architecture.
|
|
17
|
+
|
|
18
|
+
This refactored version consolidates all prefix generation, preprocessing, and
|
|
19
|
+
cross-entropy computation into a cohesive class-based design that improves:
|
|
20
|
+
- Code organization and maintainability
|
|
21
|
+
- State management and configuration handling
|
|
22
|
+
- Testing and mocking capabilities
|
|
23
|
+
- Logging and tracking throughout the pipeline
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
import logging
|
|
27
|
+
import os
|
|
28
|
+
import re
|
|
29
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
30
|
+
|
|
31
|
+
import pandas as pd
|
|
32
|
+
|
|
33
|
+
from hackagent.client import AuthenticatedClient
|
|
34
|
+
from hackagent.router.router import AgentRouter
|
|
35
|
+
from hackagent.router.types import AgentTypeEnum
|
|
36
|
+
|
|
37
|
+
from .config import CUSTOM_CHAT_TEMPLATES, PrefixGenerationConfig
|
|
38
|
+
from .utils import REFUSAL_KEYWORDS, create_progress_bar, handle_empty_input, log_errors
|
|
39
|
+
|
|
40
|
+
# ============================================================================
|
|
41
|
+
# MAIN PIPELINE CLASS
|
|
42
|
+
# ============================================================================
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class PrefixGenerationPipeline:
|
|
46
|
+
"""
|
|
47
|
+
Unified pipeline for adversarial prefix generation, preprocessing, and evaluation.
|
|
48
|
+
|
|
49
|
+
This class encapsulates all functionality related to generating and processing
|
|
50
|
+
adversarial prefixes, providing a clean interface with proper state management
|
|
51
|
+
and comprehensive tracking capabilities.
|
|
52
|
+
|
|
53
|
+
Architecture:
|
|
54
|
+
- Initialization: Sets up config, logger, clients, and internal state
|
|
55
|
+
- Generation: Creates raw prefixes using uncensored models
|
|
56
|
+
- Preprocessing: Two-phase filtering (pattern-based, then CE-based)
|
|
57
|
+
- Cross-Entropy: Tests prefixes against target agents
|
|
58
|
+
- Orchestration: execute() method coordinates the full pipeline
|
|
59
|
+
|
|
60
|
+
Key Benefits:
|
|
61
|
+
- Single source of truth for configuration
|
|
62
|
+
- Consistent logging throughout all operations
|
|
63
|
+
- Easy to test individual components via method mocking
|
|
64
|
+
- Clear method boundaries with single responsibilities
|
|
65
|
+
- Stateful execution tracking for debugging
|
|
66
|
+
|
|
67
|
+
Example:
|
|
68
|
+
pipeline = PrefixGenerationPipeline(
|
|
69
|
+
config=config_dict,
|
|
70
|
+
logger=logger,
|
|
71
|
+
client=client,
|
|
72
|
+
agent_router=router
|
|
73
|
+
)
|
|
74
|
+
results = pipeline.execute(goals=["harmful goal 1", "harmful goal 2"])
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
config: Dict[str, Any],
|
|
80
|
+
logger: logging.Logger,
|
|
81
|
+
client: AuthenticatedClient,
|
|
82
|
+
agent_router: Optional[AgentRouter] = None,
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Initialize the pipeline with configuration and dependencies.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
config: Configuration dictionary or PrefixGenerationConfig instance
|
|
89
|
+
logger: Logger for tracking execution
|
|
90
|
+
client: Authenticated client for API access
|
|
91
|
+
agent_router: Optional router for CE computation
|
|
92
|
+
"""
|
|
93
|
+
self.config = (
|
|
94
|
+
PrefixGenerationConfig.from_dict(config)
|
|
95
|
+
if isinstance(config, dict)
|
|
96
|
+
else config
|
|
97
|
+
)
|
|
98
|
+
self.logger = logger
|
|
99
|
+
self.client = client
|
|
100
|
+
self.agent_router = agent_router
|
|
101
|
+
|
|
102
|
+
# Initialize internal state for tracking
|
|
103
|
+
self._generation_router: Optional[AgentRouter] = None
|
|
104
|
+
self._statistics: Dict[str, Any] = {
|
|
105
|
+
"raw_generated": 0,
|
|
106
|
+
"phase1_filtered": 0,
|
|
107
|
+
"ce_computed": 0,
|
|
108
|
+
"phase2_filtered": 0,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
self.logger.info("PrefixGenerationPipeline initialized")
|
|
112
|
+
|
|
113
|
+
# ========================================================================
|
|
114
|
+
# PUBLIC INTERFACE
|
|
115
|
+
# ========================================================================
|
|
116
|
+
|
|
117
|
+
@handle_empty_input("Generate Prefixes", empty_result=[])
|
|
118
|
+
@log_errors("Generate Prefixes")
|
|
119
|
+
def execute(self, goals: List[str]) -> List[Dict]:
|
|
120
|
+
"""
|
|
121
|
+
Execute the complete prefix generation pipeline.
|
|
122
|
+
|
|
123
|
+
This is the main entry point that orchestrates all sub-steps:
|
|
124
|
+
1. Generate raw prefixes
|
|
125
|
+
2. Apply Phase 1 preprocessing
|
|
126
|
+
3. Compute cross-entropy (if agent_router provided)
|
|
127
|
+
4. Apply Phase 2 preprocessing
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
goals: List of target goals for prefix generation
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
List of filtered prefix dictionaries ready for completion generation
|
|
134
|
+
"""
|
|
135
|
+
unique_goals = list(dict.fromkeys(goals)) if goals else []
|
|
136
|
+
|
|
137
|
+
# Generate raw prefixes
|
|
138
|
+
self.logger.info(f"Starting generation for {len(unique_goals)} unique goals")
|
|
139
|
+
raw_prefixes = self._generate_raw_prefixes(unique_goals)
|
|
140
|
+
self._statistics["raw_generated"] = len(raw_prefixes)
|
|
141
|
+
|
|
142
|
+
if not raw_prefixes:
|
|
143
|
+
self.logger.warning("No prefixes generated")
|
|
144
|
+
return []
|
|
145
|
+
|
|
146
|
+
# Apply Phase 1 filtering
|
|
147
|
+
self.logger.info(
|
|
148
|
+
f"Applying Phase 1 preprocessing to {len(raw_prefixes)} prefixes"
|
|
149
|
+
)
|
|
150
|
+
phase1_results = self._apply_phase1_preprocessing(raw_prefixes)
|
|
151
|
+
self._statistics["phase1_filtered"] = len(phase1_results)
|
|
152
|
+
|
|
153
|
+
if not phase1_results:
|
|
154
|
+
self.logger.warning("All prefixes filtered out in Phase 1")
|
|
155
|
+
return []
|
|
156
|
+
|
|
157
|
+
# Optional CE computation and Phase 2 filtering
|
|
158
|
+
if self.agent_router:
|
|
159
|
+
self.logger.info(
|
|
160
|
+
f"Computing cross-entropy for {len(phase1_results)} prefixes"
|
|
161
|
+
)
|
|
162
|
+
ce_results = self._compute_cross_entropy_scores(phase1_results)
|
|
163
|
+
self._statistics["ce_computed"] = len(ce_results)
|
|
164
|
+
|
|
165
|
+
if not ce_results:
|
|
166
|
+
self.logger.warning("CE computation produced no results")
|
|
167
|
+
return []
|
|
168
|
+
|
|
169
|
+
self.logger.info(
|
|
170
|
+
f"Applying Phase 2 preprocessing to {len(ce_results)} prefixes"
|
|
171
|
+
)
|
|
172
|
+
final_results = self._apply_phase2_preprocessing(ce_results)
|
|
173
|
+
self._statistics["phase2_filtered"] = len(final_results)
|
|
174
|
+
|
|
175
|
+
if not final_results:
|
|
176
|
+
self.logger.warning("All prefixes filtered out in Phase 2")
|
|
177
|
+
return []
|
|
178
|
+
else:
|
|
179
|
+
self.logger.info("Skipping CE computation (no agent_router provided)")
|
|
180
|
+
final_results = phase1_results
|
|
181
|
+
|
|
182
|
+
self._log_pipeline_statistics()
|
|
183
|
+
return final_results
|
|
184
|
+
|
|
185
|
+
def get_statistics(self) -> Dict[str, Any]:
|
|
186
|
+
"""Return execution statistics for monitoring and debugging."""
|
|
187
|
+
return self._statistics.copy()
|
|
188
|
+
|
|
189
|
+
# ========================================================================
|
|
190
|
+
# GENERATION METHODS
|
|
191
|
+
# ========================================================================
|
|
192
|
+
|
|
193
|
+
def _generate_raw_prefixes(self, goals: List[str]) -> List[Dict]:
|
|
194
|
+
"""
|
|
195
|
+
Generate raw adversarial prefixes using uncensored models.
|
|
196
|
+
|
|
197
|
+
Handles:
|
|
198
|
+
- Router initialization
|
|
199
|
+
- Prompt construction
|
|
200
|
+
- Both greedy and sampling generation modes
|
|
201
|
+
- Result collection with metadata
|
|
202
|
+
"""
|
|
203
|
+
if not self.config.generator:
|
|
204
|
+
self.logger.error("Missing generator configuration")
|
|
205
|
+
return []
|
|
206
|
+
|
|
207
|
+
model_name = self.config.generator.get("identifier")
|
|
208
|
+
if not model_name:
|
|
209
|
+
self.logger.error("Missing model identifier in generator config")
|
|
210
|
+
return []
|
|
211
|
+
|
|
212
|
+
# Initialize router if needed
|
|
213
|
+
if not self._generation_router:
|
|
214
|
+
self._generation_router = self._initialize_generation_router()
|
|
215
|
+
if not self._generation_router:
|
|
216
|
+
return []
|
|
217
|
+
|
|
218
|
+
# Construct prompts
|
|
219
|
+
prompts, prompt_goals, meta_prefixes = self._construct_prompts(goals)
|
|
220
|
+
if not prompts:
|
|
221
|
+
self.logger.warning("No prompts constructed")
|
|
222
|
+
return []
|
|
223
|
+
|
|
224
|
+
# Generate with both modes
|
|
225
|
+
results = []
|
|
226
|
+
for do_sample in [False, True]:
|
|
227
|
+
mode_name = "sampling" if do_sample else "greedy"
|
|
228
|
+
self.logger.debug(
|
|
229
|
+
f"Running {mode_name} generation for {len(prompts)} prompts"
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
mode_results = self._run_generation_mode(
|
|
233
|
+
prompts=prompts,
|
|
234
|
+
goals=prompt_goals,
|
|
235
|
+
meta_prefixes=meta_prefixes,
|
|
236
|
+
do_sample=do_sample,
|
|
237
|
+
)
|
|
238
|
+
results.extend(mode_results)
|
|
239
|
+
|
|
240
|
+
self.logger.info(f"Generated {len(results)} raw prefixes")
|
|
241
|
+
return results
|
|
242
|
+
|
|
243
|
+
def _initialize_generation_router(self) -> Optional[AgentRouter]:
|
|
244
|
+
"""Initialize and configure the AgentRouter for generation."""
|
|
245
|
+
try:
|
|
246
|
+
endpoint = self.config.generator.get("endpoint")
|
|
247
|
+
model_name = self.config.generator.get("identifier")
|
|
248
|
+
|
|
249
|
+
# Handle API key
|
|
250
|
+
api_key = self.client.token
|
|
251
|
+
api_key_config = self.config.generator.get("api_key")
|
|
252
|
+
if api_key_config:
|
|
253
|
+
env_key = os.environ.get(api_key_config)
|
|
254
|
+
api_key = env_key if env_key else api_key_config
|
|
255
|
+
|
|
256
|
+
operational_config = {
|
|
257
|
+
"name": model_name,
|
|
258
|
+
"endpoint": endpoint,
|
|
259
|
+
"api_key": api_key,
|
|
260
|
+
"max_new_tokens": self.config.max_new_tokens,
|
|
261
|
+
"temperature": self.config.temperature,
|
|
262
|
+
"top_p": self.config.top_p,
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
router = AgentRouter(
|
|
266
|
+
client=self.client,
|
|
267
|
+
name=model_name,
|
|
268
|
+
agent_type=AgentTypeEnum.LITELLM,
|
|
269
|
+
endpoint=endpoint,
|
|
270
|
+
adapter_operational_config=operational_config,
|
|
271
|
+
metadata=operational_config.copy(),
|
|
272
|
+
overwrite_metadata=True,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
if not router._agent_registry: # type: ignore
|
|
276
|
+
self.logger.error("Router initialized but no agent registered")
|
|
277
|
+
return None
|
|
278
|
+
|
|
279
|
+
self.logger.debug(f"Generation router initialized for {model_name}")
|
|
280
|
+
return router
|
|
281
|
+
|
|
282
|
+
except Exception as e:
|
|
283
|
+
self.logger.error(
|
|
284
|
+
f"Failed to initialize generation router: {e}", exc_info=True
|
|
285
|
+
)
|
|
286
|
+
return None
|
|
287
|
+
|
|
288
|
+
def _construct_prompts(
|
|
289
|
+
self, goals: List[str]
|
|
290
|
+
) -> Tuple[List[str], List[str], List[str]]:
|
|
291
|
+
"""
|
|
292
|
+
Construct formatted prompts from goals and meta-prefixes.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
Tuple of (prompts, corresponding_goals, corresponding_meta_prefixes)
|
|
296
|
+
"""
|
|
297
|
+
# Handle sample count specification
|
|
298
|
+
meta_prefixes = self.config.meta_prefixes
|
|
299
|
+
n_samples = self.config.meta_prefix_samples
|
|
300
|
+
|
|
301
|
+
if isinstance(n_samples, list):
|
|
302
|
+
if len(meta_prefixes) != len(n_samples):
|
|
303
|
+
raise ValueError(
|
|
304
|
+
"Length mismatch between meta_prefixes and meta_prefix_samples"
|
|
305
|
+
)
|
|
306
|
+
n_samples_list = n_samples
|
|
307
|
+
else:
|
|
308
|
+
n_samples_list = [n_samples] * len(meta_prefixes)
|
|
309
|
+
|
|
310
|
+
prompts = []
|
|
311
|
+
prompt_goals = []
|
|
312
|
+
prompt_meta_prefixes = []
|
|
313
|
+
|
|
314
|
+
for goal in goals:
|
|
315
|
+
for meta_prefix, n_count in zip(meta_prefixes, n_samples_list):
|
|
316
|
+
if n_count <= 0:
|
|
317
|
+
continue
|
|
318
|
+
|
|
319
|
+
try:
|
|
320
|
+
# Format prompt using template
|
|
321
|
+
if meta_prefix in CUSTOM_CHAT_TEMPLATES:
|
|
322
|
+
template = CUSTOM_CHAT_TEMPLATES[meta_prefix]
|
|
323
|
+
prompt_content = template.format(content=goal)
|
|
324
|
+
else:
|
|
325
|
+
self.logger.debug(
|
|
326
|
+
f"No template for {meta_prefix}, using basic format"
|
|
327
|
+
)
|
|
328
|
+
prompt_content = f"USER: {goal}\\nASSISTANT:"
|
|
329
|
+
|
|
330
|
+
full_prompt = prompt_content + meta_prefix
|
|
331
|
+
|
|
332
|
+
# Replicate for n_count samples
|
|
333
|
+
prompts.extend([full_prompt] * n_count)
|
|
334
|
+
prompt_goals.extend([goal] * n_count)
|
|
335
|
+
prompt_meta_prefixes.extend([meta_prefix] * n_count)
|
|
336
|
+
|
|
337
|
+
except Exception as e:
|
|
338
|
+
self.logger.error(
|
|
339
|
+
f"Error constructing prompt for goal '{goal[:30]}...': {e}"
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
return prompts, prompt_goals, prompt_meta_prefixes
|
|
343
|
+
|
|
344
|
+
def _run_generation_mode(
|
|
345
|
+
self,
|
|
346
|
+
prompts: List[str],
|
|
347
|
+
goals: List[str],
|
|
348
|
+
meta_prefixes: List[str],
|
|
349
|
+
do_sample: bool,
|
|
350
|
+
) -> List[Dict]:
|
|
351
|
+
"""Run generation in either greedy or sampling mode."""
|
|
352
|
+
results = []
|
|
353
|
+
mode = "sampling" if do_sample else "greedy"
|
|
354
|
+
temperature = self.config.temperature if do_sample else 1e-2
|
|
355
|
+
|
|
356
|
+
registration_key = next(iter(self._generation_router._agent_registry.keys())) # type: ignore
|
|
357
|
+
|
|
358
|
+
progress_desc = f"[cyan]Generating ({mode})..."
|
|
359
|
+
|
|
360
|
+
with create_progress_bar(progress_desc, total=len(prompts)) as (pbar, task):
|
|
361
|
+
for prompt, goal, meta_prefix in zip(prompts, goals, meta_prefixes):
|
|
362
|
+
request_params = {
|
|
363
|
+
"prompt": prompt,
|
|
364
|
+
"max_new_tokens": self.config.max_new_tokens,
|
|
365
|
+
"temperature": temperature,
|
|
366
|
+
"top_p": self.config.top_p,
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
response = self._generation_router.route_request(
|
|
370
|
+
registration_key=registration_key,
|
|
371
|
+
request_data=request_params,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
generated_text = self._extract_generated_text(response, prompt, goal)
|
|
375
|
+
final_prefix = meta_prefix + generated_text
|
|
376
|
+
|
|
377
|
+
results.append(
|
|
378
|
+
{
|
|
379
|
+
"goal": goal,
|
|
380
|
+
"prefix": final_prefix,
|
|
381
|
+
"meta_prefix": meta_prefix,
|
|
382
|
+
"temperature": temperature,
|
|
383
|
+
"model_name": self.config.generator.get("identifier"),
|
|
384
|
+
}
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
pbar.update(task, advance=1)
|
|
388
|
+
|
|
389
|
+
return results
|
|
390
|
+
|
|
391
|
+
def _extract_generated_text(self, response: Dict, prompt: str, goal: str) -> str:
|
|
392
|
+
"""Extract and clean generated text from router response."""
|
|
393
|
+
error_msg = response.get("error_message")
|
|
394
|
+
if error_msg:
|
|
395
|
+
error_cat = response.get("error_category", "Unknown")
|
|
396
|
+
self.logger.warning(
|
|
397
|
+
f"Router error for goal '{goal[:30]}...': {error_msg} ({error_cat})"
|
|
398
|
+
)
|
|
399
|
+
return f" [ROUTER_ERROR: {error_cat}]"
|
|
400
|
+
|
|
401
|
+
generated = response.get("processed_response")
|
|
402
|
+
if not generated:
|
|
403
|
+
return " [ROUTER_NO_CONTENT]"
|
|
404
|
+
|
|
405
|
+
# Strip prompt if echoed
|
|
406
|
+
if generated.startswith(prompt):
|
|
407
|
+
return generated[len(prompt) :]
|
|
408
|
+
|
|
409
|
+
self.logger.debug("Response didn't start with prompt, using full response")
|
|
410
|
+
return generated
|
|
411
|
+
|
|
412
|
+
# ========================================================================
|
|
413
|
+
# PREPROCESSING METHODS
|
|
414
|
+
# ========================================================================
|
|
415
|
+
|
|
416
|
+
def _apply_phase1_preprocessing(self, prefixes: List[Dict]) -> List[Dict]:
|
|
417
|
+
"""
|
|
418
|
+
Apply Phase 1 preprocessing: pattern-based filtering and deduplication.
|
|
419
|
+
|
|
420
|
+
Filters:
|
|
421
|
+
- Prefixes starting with refusal patterns
|
|
422
|
+
- Prefixes containing refusal patterns
|
|
423
|
+
- Prefixes below minimum character length
|
|
424
|
+
- Prefixes without required linebreaks
|
|
425
|
+
- Duplicate prefixes (within goals)
|
|
426
|
+
"""
|
|
427
|
+
df = pd.DataFrame(prefixes)
|
|
428
|
+
|
|
429
|
+
# Apply filters sequentially
|
|
430
|
+
df = self._filter_by_start_patterns(df)
|
|
431
|
+
df = self._filter_by_contain_patterns(df)
|
|
432
|
+
df = self._filter_by_char_length(df)
|
|
433
|
+
|
|
434
|
+
if self.config.require_linebreak:
|
|
435
|
+
df = self._filter_by_linebreak(df)
|
|
436
|
+
|
|
437
|
+
df = self._merge_duplicates(df)
|
|
438
|
+
|
|
439
|
+
self._log_filtering_stats(df, "Phase 1")
|
|
440
|
+
return df.to_dict("records")
|
|
441
|
+
|
|
442
|
+
def _apply_phase2_preprocessing(self, prefixes: List[Dict]) -> List[Dict]:
|
|
443
|
+
"""
|
|
444
|
+
Apply Phase 2 preprocessing: CE-based filtering and top-k selection.
|
|
445
|
+
|
|
446
|
+
Filters:
|
|
447
|
+
- Prefixes with CE scores above threshold
|
|
448
|
+
- Keeps only top-k prefixes per goal based on CE score
|
|
449
|
+
"""
|
|
450
|
+
df = pd.DataFrame(prefixes)
|
|
451
|
+
|
|
452
|
+
if "prefix_nll" not in df.columns:
|
|
453
|
+
self.logger.error("Phase 2 requires 'prefix_nll' column, skipping")
|
|
454
|
+
return prefixes
|
|
455
|
+
|
|
456
|
+
# Filter by CE threshold
|
|
457
|
+
if self.config.max_ce is not None:
|
|
458
|
+
df = self._filter_by_ce_threshold(df)
|
|
459
|
+
|
|
460
|
+
# Top-k selection per goal
|
|
461
|
+
if self.config.n_candidates_per_goal > 0:
|
|
462
|
+
df = self._select_top_k_per_goal(df)
|
|
463
|
+
|
|
464
|
+
self._log_filtering_stats(df, "Phase 2")
|
|
465
|
+
return df.to_dict("records")
|
|
466
|
+
|
|
467
|
+
def _filter_by_start_patterns(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
468
|
+
"""Remove prefixes starting with refusal patterns."""
|
|
469
|
+
if not self.config.start_patterns:
|
|
470
|
+
return df
|
|
471
|
+
|
|
472
|
+
before = len(df)
|
|
473
|
+
df_filtered = df[
|
|
474
|
+
~df["prefix"]
|
|
475
|
+
.str.lstrip()
|
|
476
|
+
.str.startswith(self.config.start_patterns, na=False)
|
|
477
|
+
]
|
|
478
|
+
removed = before - len(df_filtered)
|
|
479
|
+
|
|
480
|
+
if removed > 0:
|
|
481
|
+
self.logger.debug(f"Start pattern filter removed {removed} prefixes")
|
|
482
|
+
|
|
483
|
+
return df_filtered
|
|
484
|
+
|
|
485
|
+
def _filter_by_contain_patterns(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
486
|
+
"""Remove prefixes containing refusal patterns."""
|
|
487
|
+
if not self.config.contain_patterns:
|
|
488
|
+
return df
|
|
489
|
+
|
|
490
|
+
before = len(df)
|
|
491
|
+
pattern = "|".join(map(re.escape, self.config.contain_patterns))
|
|
492
|
+
df_filtered = df[~df["prefix"].str.contains(pattern, regex=True, na=False)]
|
|
493
|
+
removed = before - len(df_filtered)
|
|
494
|
+
|
|
495
|
+
if removed > 0:
|
|
496
|
+
self.logger.debug(f"Contain pattern filter removed {removed} prefixes")
|
|
497
|
+
|
|
498
|
+
return df_filtered
|
|
499
|
+
|
|
500
|
+
def _filter_by_char_length(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
501
|
+
"""Remove prefixes shorter than minimum character length."""
|
|
502
|
+
if self.config.min_char_length <= 0:
|
|
503
|
+
return df
|
|
504
|
+
|
|
505
|
+
before = len(df)
|
|
506
|
+
df_filtered = df[df["prefix"].str.len() >= self.config.min_char_length]
|
|
507
|
+
removed = before - len(df_filtered)
|
|
508
|
+
|
|
509
|
+
if removed > 0:
|
|
510
|
+
self.logger.debug(
|
|
511
|
+
f"Character length filter removed {removed} prefixes (< {self.config.min_char_length} chars)"
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
return df_filtered
|
|
515
|
+
|
|
516
|
+
def _filter_by_linebreak(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
517
|
+
"""Remove prefixes without internal linebreaks."""
|
|
518
|
+
before = len(df)
|
|
519
|
+
df_filtered = df[
|
|
520
|
+
df["prefix"].str.strip().str.strip("\n").str.contains("\n", na=False)
|
|
521
|
+
]
|
|
522
|
+
removed = before - len(df_filtered)
|
|
523
|
+
|
|
524
|
+
if removed > 0:
|
|
525
|
+
self.logger.debug(f"Linebreak filter removed {removed} prefixes")
|
|
526
|
+
|
|
527
|
+
return df_filtered
|
|
528
|
+
|
|
529
|
+
def _filter_by_ce_threshold(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
530
|
+
"""Remove prefixes with CE scores above threshold."""
|
|
531
|
+
df["prefix_nll_numeric"] = pd.to_numeric(df["prefix_nll"], errors="coerce")
|
|
532
|
+
|
|
533
|
+
# Check if all values are infinite
|
|
534
|
+
valid_scores = df["prefix_nll_numeric"][
|
|
535
|
+
~df["prefix_nll_numeric"].isin([float("inf"), float("-inf")])
|
|
536
|
+
]
|
|
537
|
+
|
|
538
|
+
if len(valid_scores) == 0:
|
|
539
|
+
self.logger.warning("All CE scores are infinite, skipping CE filtering")
|
|
540
|
+
return df.drop(columns=["prefix_nll_numeric"])
|
|
541
|
+
|
|
542
|
+
before = len(df)
|
|
543
|
+
df_filtered = df[df["prefix_nll_numeric"] <= self.config.max_ce]
|
|
544
|
+
removed = before - len(df_filtered)
|
|
545
|
+
|
|
546
|
+
if removed > 0:
|
|
547
|
+
self.logger.debug(
|
|
548
|
+
f"CE threshold filter removed {removed} prefixes (CE > {self.config.max_ce})"
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
return df_filtered.drop(columns=["prefix_nll_numeric"])
|
|
552
|
+
|
|
553
|
+
def _select_top_k_per_goal(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
554
|
+
"""Select top-k prefixes per goal based on CE score."""
|
|
555
|
+
df["prefix_nll_numeric"] = pd.to_numeric(df["prefix_nll"], errors="coerce")
|
|
556
|
+
|
|
557
|
+
before = len(df)
|
|
558
|
+
df_selected = (
|
|
559
|
+
df.sort_values("prefix_nll_numeric", na_position="last")
|
|
560
|
+
.groupby("goal")
|
|
561
|
+
.head(self.config.n_candidates_per_goal)
|
|
562
|
+
)
|
|
563
|
+
removed = before - len(df_selected)
|
|
564
|
+
|
|
565
|
+
if removed > 0:
|
|
566
|
+
self.logger.debug(
|
|
567
|
+
f"Top-k selection removed {removed} prefixes (keeping top {self.config.n_candidates_per_goal} per goal)"
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
return df_selected.drop(columns=["prefix_nll_numeric"]).reset_index(drop=True)
|
|
571
|
+
|
|
572
|
+
def _merge_duplicates(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
573
|
+
"""Merge duplicate prefixes within goal groups."""
|
|
574
|
+
before = len(df)
|
|
575
|
+
|
|
576
|
+
def merge_group(group):
|
|
577
|
+
agg_dict = {
|
|
578
|
+
"model_name": lambda x: ",".join(str(v) for v in set(x)),
|
|
579
|
+
"meta_prefix": lambda x: ",".join(
|
|
580
|
+
str(v) for v in set(x) if pd.notna(v)
|
|
581
|
+
),
|
|
582
|
+
"temperature": lambda x: ",".join(str(v) for v in set(x)),
|
|
583
|
+
"goal": "first",
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
if "prefix_nll" in group.columns:
|
|
587
|
+
agg_dict["prefix_nll"] = "first"
|
|
588
|
+
|
|
589
|
+
return group.groupby("prefix").agg(agg_dict).reset_index()
|
|
590
|
+
|
|
591
|
+
df = df.groupby("goal").apply(merge_group).reset_index(drop=True)
|
|
592
|
+
removed = before - len(df)
|
|
593
|
+
|
|
594
|
+
if removed > 0:
|
|
595
|
+
self.logger.debug(f"Deduplication removed {removed} duplicate prefixes")
|
|
596
|
+
|
|
597
|
+
return df
|
|
598
|
+
|
|
599
|
+
def _log_filtering_stats(self, df: pd.DataFrame, phase_name: str):
|
|
600
|
+
"""Log detailed statistics about filtering results."""
|
|
601
|
+
if df.empty:
|
|
602
|
+
self.logger.info(f"{phase_name}: No prefixes remaining")
|
|
603
|
+
return
|
|
604
|
+
|
|
605
|
+
goal_counts = df.groupby("goal")["prefix"].count()
|
|
606
|
+
self.logger.info(
|
|
607
|
+
f"{phase_name}: {len(df)} prefixes remaining for {len(goal_counts)} goals "
|
|
608
|
+
f"(min={goal_counts.min()}, max={goal_counts.max()}, "
|
|
609
|
+
f"avg={goal_counts.mean():.1f})"
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
# ========================================================================
|
|
613
|
+
# CROSS-ENTROPY METHODS
|
|
614
|
+
# ========================================================================
|
|
615
|
+
|
|
616
|
+
def _compute_cross_entropy_scores(self, prefixes: List[Dict]) -> List[Dict]:
|
|
617
|
+
"""
|
|
618
|
+
Compute cross-entropy acceptability scores for prefixes.
|
|
619
|
+
|
|
620
|
+
Tests each prefix against the target agent and assigns scores:
|
|
621
|
+
- 0.0: Agent accepted and responded without refusal
|
|
622
|
+
- inf: Agent refused, errored, or response contained refusal keywords
|
|
623
|
+
"""
|
|
624
|
+
if not self.agent_router:
|
|
625
|
+
self.logger.warning("No agent_router available for CE computation")
|
|
626
|
+
return prefixes
|
|
627
|
+
|
|
628
|
+
results = []
|
|
629
|
+
victim_key = str(self.agent_router.backend_agent.id)
|
|
630
|
+
|
|
631
|
+
progress_desc = (
|
|
632
|
+
f"[blue]Computing CE via {self.agent_router.backend_agent.agent_type}..."
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
with create_progress_bar(progress_desc, total=len(prefixes)) as (pbar, task):
|
|
636
|
+
for record in prefixes:
|
|
637
|
+
prefix_text = record.get("prefix", "")
|
|
638
|
+
|
|
639
|
+
# Initialize result with default values
|
|
640
|
+
result = record.copy()
|
|
641
|
+
result["prefix_nll"] = float("inf")
|
|
642
|
+
|
|
643
|
+
# Skip invalid prefixes
|
|
644
|
+
if (
|
|
645
|
+
not prefix_text
|
|
646
|
+
or not isinstance(prefix_text, str)
|
|
647
|
+
or prefix_text.isspace()
|
|
648
|
+
):
|
|
649
|
+
result["error_message"] = "Empty or invalid prefix"
|
|
650
|
+
results.append(result)
|
|
651
|
+
pbar.update(task, advance=1)
|
|
652
|
+
continue
|
|
653
|
+
|
|
654
|
+
# Send request to agent
|
|
655
|
+
response = self.agent_router.route_request(
|
|
656
|
+
registration_key=victim_key,
|
|
657
|
+
request_data={"prompt": prefix_text},
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
# Evaluate response
|
|
661
|
+
generated_text = response.get("generated_text")
|
|
662
|
+
error_message = response.get("error_message")
|
|
663
|
+
|
|
664
|
+
if error_message or generated_text is None:
|
|
665
|
+
result["prefix_nll"] = float("inf")
|
|
666
|
+
result["error_message"] = error_message or "No response"
|
|
667
|
+
elif any(kw in generated_text.lower() for kw in REFUSAL_KEYWORDS):
|
|
668
|
+
result["prefix_nll"] = float("inf")
|
|
669
|
+
result["error_message"] = "Response contained refusal keywords"
|
|
670
|
+
else:
|
|
671
|
+
result["prefix_nll"] = 0.0
|
|
672
|
+
result["error_message"] = None
|
|
673
|
+
|
|
674
|
+
# Store response metadata
|
|
675
|
+
result["request_payload"] = response.get("raw_request") or {
|
|
676
|
+
"prompt": prefix_text
|
|
677
|
+
}
|
|
678
|
+
result["response_status"] = response.get("raw_response_status")
|
|
679
|
+
result["response_headers"] = response.get("raw_response_headers")
|
|
680
|
+
result["response_body_raw"] = response.get("raw_response_body")
|
|
681
|
+
|
|
682
|
+
agent_specific = response.get("agent_specific_data", {})
|
|
683
|
+
if agent_specific:
|
|
684
|
+
result["events_list"] = agent_specific.get("events_list")
|
|
685
|
+
|
|
686
|
+
results.append(result)
|
|
687
|
+
pbar.update(task, advance=1)
|
|
688
|
+
|
|
689
|
+
# Log statistics
|
|
690
|
+
accepted = sum(1 for r in results if r.get("prefix_nll") == 0.0)
|
|
691
|
+
self.logger.info(
|
|
692
|
+
f"CE computation: {accepted}/{len(results)} prefixes accepted by target agent"
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
return results
|
|
696
|
+
|
|
697
|
+
def _log_pipeline_statistics(self):
|
|
698
|
+
"""Log comprehensive pipeline execution statistics."""
|
|
699
|
+
stats = self._statistics
|
|
700
|
+
self.logger.info("=" * 60)
|
|
701
|
+
self.logger.info("Pipeline Execution Statistics:")
|
|
702
|
+
self.logger.info(f" Raw generated: {stats['raw_generated']}")
|
|
703
|
+
self.logger.info(f" Phase 1 filtered: {stats['phase1_filtered']}")
|
|
704
|
+
self.logger.info(f" CE computed: {stats['ce_computed']}")
|
|
705
|
+
self.logger.info(f" Phase 2 filtered: {stats['phase2_filtered']}")
|
|
706
|
+
|
|
707
|
+
if stats["raw_generated"] > 0:
|
|
708
|
+
retention = (stats["phase2_filtered"] / stats["raw_generated"]) * 100
|
|
709
|
+
self.logger.info(f" Retention rate: {retention:.1f}%")
|
|
710
|
+
|
|
711
|
+
self.logger.info("=" * 60)
|