hackagent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hackagent/__init__.py +23 -0
- hackagent/agent.py +193 -0
- hackagent/api/__init__.py +1 -0
- hackagent/api/agent/__init__.py +1 -0
- hackagent/api/agent/agent_create.py +340 -0
- hackagent/api/agent/agent_destroy.py +136 -0
- hackagent/api/agent/agent_list.py +234 -0
- hackagent/api/agent/agent_partial_update.py +354 -0
- hackagent/api/agent/agent_retrieve.py +227 -0
- hackagent/api/agent/agent_update.py +354 -0
- hackagent/api/attack/__init__.py +1 -0
- hackagent/api/attack/attack_create.py +264 -0
- hackagent/api/attack/attack_destroy.py +140 -0
- hackagent/api/attack/attack_list.py +242 -0
- hackagent/api/attack/attack_partial_update.py +278 -0
- hackagent/api/attack/attack_retrieve.py +235 -0
- hackagent/api/attack/attack_update.py +278 -0
- hackagent/api/key/__init__.py +1 -0
- hackagent/api/key/key_create.py +168 -0
- hackagent/api/key/key_destroy.py +97 -0
- hackagent/api/key/key_list.py +158 -0
- hackagent/api/key/key_retrieve.py +150 -0
- hackagent/api/prompt/__init__.py +1 -0
- hackagent/api/prompt/prompt_create.py +160 -0
- hackagent/api/prompt/prompt_destroy.py +98 -0
- hackagent/api/prompt/prompt_list.py +173 -0
- hackagent/api/prompt/prompt_partial_update.py +174 -0
- hackagent/api/prompt/prompt_retrieve.py +151 -0
- hackagent/api/prompt/prompt_update.py +174 -0
- hackagent/api/result/__init__.py +1 -0
- hackagent/api/result/result_create.py +160 -0
- hackagent/api/result/result_destroy.py +98 -0
- hackagent/api/result/result_list.py +233 -0
- hackagent/api/result/result_partial_update.py +178 -0
- hackagent/api/result/result_retrieve.py +151 -0
- hackagent/api/result/result_trace_create.py +178 -0
- hackagent/api/result/result_update.py +174 -0
- hackagent/api/run/__init__.py +1 -0
- hackagent/api/run/run_create.py +172 -0
- hackagent/api/run/run_destroy.py +104 -0
- hackagent/api/run/run_list.py +260 -0
- hackagent/api/run/run_partial_update.py +186 -0
- hackagent/api/run/run_result_create.py +178 -0
- hackagent/api/run/run_retrieve.py +163 -0
- hackagent/api/run/run_run_tests_create.py +172 -0
- hackagent/api/run/run_update.py +186 -0
- hackagent/attacks/AdvPrefix/README.md +7 -0
- hackagent/attacks/AdvPrefix/__init__.py +0 -0
- hackagent/attacks/AdvPrefix/completer.py +438 -0
- hackagent/attacks/AdvPrefix/config.py +59 -0
- hackagent/attacks/AdvPrefix/preprocessing.py +521 -0
- hackagent/attacks/AdvPrefix/scorer.py +259 -0
- hackagent/attacks/AdvPrefix/scorer_parser.py +498 -0
- hackagent/attacks/AdvPrefix/selector.py +246 -0
- hackagent/attacks/AdvPrefix/step1_generate.py +324 -0
- hackagent/attacks/AdvPrefix/step4_compute_ce.py +293 -0
- hackagent/attacks/AdvPrefix/step6_get_completions.py +387 -0
- hackagent/attacks/AdvPrefix/step7_evaluate_responses.py +289 -0
- hackagent/attacks/AdvPrefix/step8_aggregate_evaluations.py +177 -0
- hackagent/attacks/AdvPrefix/step9_select_prefixes.py +59 -0
- hackagent/attacks/AdvPrefix/utils.py +192 -0
- hackagent/attacks/__init__.py +6 -0
- hackagent/attacks/advprefix.py +1136 -0
- hackagent/attacks/base.py +50 -0
- hackagent/attacks/strategies.py +539 -0
- hackagent/branding.py +143 -0
- hackagent/client.py +328 -0
- hackagent/errors.py +31 -0
- hackagent/logger.py +67 -0
- hackagent/models/__init__.py +71 -0
- hackagent/models/agent.py +240 -0
- hackagent/models/agent_request.py +169 -0
- hackagent/models/agent_type_enum.py +12 -0
- hackagent/models/attack.py +154 -0
- hackagent/models/attack_request.py +82 -0
- hackagent/models/evaluation_status_enum.py +14 -0
- hackagent/models/organization_minimal.py +68 -0
- hackagent/models/paginated_agent_list.py +123 -0
- hackagent/models/paginated_attack_list.py +123 -0
- hackagent/models/paginated_prompt_list.py +123 -0
- hackagent/models/paginated_result_list.py +123 -0
- hackagent/models/paginated_run_list.py +123 -0
- hackagent/models/paginated_user_api_key_list.py +123 -0
- hackagent/models/patched_agent_request.py +176 -0
- hackagent/models/patched_attack_request.py +92 -0
- hackagent/models/patched_prompt_request.py +162 -0
- hackagent/models/patched_result_request.py +237 -0
- hackagent/models/patched_run_request.py +138 -0
- hackagent/models/prompt.py +226 -0
- hackagent/models/prompt_request.py +155 -0
- hackagent/models/result.py +294 -0
- hackagent/models/result_list_evaluation_status.py +14 -0
- hackagent/models/result_request.py +232 -0
- hackagent/models/run.py +233 -0
- hackagent/models/run_list_status.py +12 -0
- hackagent/models/run_request.py +133 -0
- hackagent/models/status_enum.py +12 -0
- hackagent/models/step_type_enum.py +14 -0
- hackagent/models/trace.py +121 -0
- hackagent/models/trace_request.py +94 -0
- hackagent/models/user_api_key.py +201 -0
- hackagent/models/user_api_key_request.py +73 -0
- hackagent/models/user_profile_minimal.py +76 -0
- hackagent/py.typed +1 -0
- hackagent/router/__init__.py +11 -0
- hackagent/router/adapters/__init__.py +5 -0
- hackagent/router/adapters/google_adk.py +658 -0
- hackagent/router/adapters/litellm_adapter.py +290 -0
- hackagent/router/base.py +48 -0
- hackagent/router/router.py +753 -0
- hackagent/types.py +46 -0
- hackagent/utils.py +61 -0
- hackagent/vulnerabilities/__init__.py +0 -0
- hackagent-0.1.0.dist-info/LICENSE +202 -0
- hackagent-0.1.0.dist-info/METADATA +173 -0
- hackagent-0.1.0.dist-info/RECORD +117 -0
- hackagent-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import multiprocessing
|
|
4
|
+
import os
|
|
5
|
+
from typing import Dict
|
|
6
|
+
from dataclasses import fields # Import fields to inspect dataclass
|
|
7
|
+
|
|
8
|
+
from hackagent.attacks.AdvPrefix.scorer_parser import (
|
|
9
|
+
EvaluatorConfig,
|
|
10
|
+
NuancedEvaluator,
|
|
11
|
+
HarmBenchEvaluator,
|
|
12
|
+
JailbreakBenchEvaluator,
|
|
13
|
+
)
|
|
14
|
+
from hackagent.attacks.AdvPrefix.utils import get_checkpoint_path
|
|
15
|
+
|
|
16
|
+
EVALUATOR_MAP = {
|
|
17
|
+
"nuanced": NuancedEvaluator,
|
|
18
|
+
"jailbreakbench": JailbreakBenchEvaluator,
|
|
19
|
+
"harmbench": HarmBenchEvaluator,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
MERGE_KEYS = ["goal", "prefix", "completion"] # Standard merge keys
|
|
23
|
+
|
|
24
|
+
JUDGE_COLUMN_MAP = {
|
|
25
|
+
"nuanced": ["eval_nj", "explanation_nj"],
|
|
26
|
+
"jailbreakbench": ["eval_jb", "explanation_jb"],
|
|
27
|
+
"harmbench": ["eval_hb", "explanation_hb"],
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _run_evaluator_process_wrapper(
|
|
32
|
+
judge_type: str, config_dict_serializable: Dict, df: pd.DataFrame
|
|
33
|
+
):
|
|
34
|
+
"""Static method to run a specific evaluator, suitable for multiprocessing."""
|
|
35
|
+
process_logger = logging.getLogger(__name__ + f".evaluator_process_{judge_type}")
|
|
36
|
+
process_logger.info(f"Evaluator process started for judge: {judge_type}")
|
|
37
|
+
|
|
38
|
+
evaluator_class = EVALUATOR_MAP.get(judge_type)
|
|
39
|
+
if not evaluator_class:
|
|
40
|
+
process_logger.warning(f"Unknown judge type: {judge_type}, skipping")
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
evaluator = None # Ensure cleanup
|
|
44
|
+
try:
|
|
45
|
+
# Filter the config dict to only include keys expected by EvaluatorConfig
|
|
46
|
+
expected_fields = {f.name for f in fields(EvaluatorConfig)}
|
|
47
|
+
filtered_config_dict = {
|
|
48
|
+
k: v for k, v in config_dict_serializable.items() if k in expected_fields
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
# model_id is handled specially by EvaluatorConfig.with_default_model
|
|
52
|
+
# or within the specific evaluator's __init__.
|
|
53
|
+
# We should ensure model_id from the judge's config is passed if present.
|
|
54
|
+
if (
|
|
55
|
+
"model_id" in config_dict_serializable
|
|
56
|
+
and config_dict_serializable["model_id"]
|
|
57
|
+
):
|
|
58
|
+
filtered_config_dict["model_id"] = config_dict_serializable["model_id"]
|
|
59
|
+
elif (
|
|
60
|
+
"identifier" in config_dict_serializable
|
|
61
|
+
and config_dict_serializable["identifier"]
|
|
62
|
+
):
|
|
63
|
+
# Fallback to using 'identifier' if 'model_id' wasn't explicitly passed/overridden
|
|
64
|
+
filtered_config_dict["model_id"] = config_dict_serializable["identifier"]
|
|
65
|
+
|
|
66
|
+
process_logger.debug(
|
|
67
|
+
f"Filtered config for {judge_type} evaluator: {filtered_config_dict}"
|
|
68
|
+
)
|
|
69
|
+
evaluator_config = EvaluatorConfig(**filtered_config_dict)
|
|
70
|
+
|
|
71
|
+
# Instantiate the specific evaluator class
|
|
72
|
+
evaluator = evaluator_class(evaluator_config)
|
|
73
|
+
evaluated_df = evaluator.evaluate(df)
|
|
74
|
+
|
|
75
|
+
process_logger.info(f"Evaluator process finished for judge: {judge_type}")
|
|
76
|
+
# Return only the essential columns: merge keys + judge-specific columns
|
|
77
|
+
eval_cols = JUDGE_COLUMN_MAP.get(judge_type, [])
|
|
78
|
+
# Ensure merge keys are present in the returned df
|
|
79
|
+
if not all(key in evaluated_df.columns for key in MERGE_KEYS):
|
|
80
|
+
process_logger.error(
|
|
81
|
+
f"Evaluation result for {judge_type} is missing merge keys {MERGE_KEYS}. Available: {evaluated_df.columns}. Returning None."
|
|
82
|
+
)
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
cols_to_return = MERGE_KEYS + [
|
|
86
|
+
col for col in eval_cols if col in evaluated_df.columns
|
|
87
|
+
]
|
|
88
|
+
return evaluated_df[cols_to_return]
|
|
89
|
+
|
|
90
|
+
except Exception as e:
|
|
91
|
+
process_logger.error(
|
|
92
|
+
f"Error occurred while running {judge_type} evaluator: {str(e)}",
|
|
93
|
+
exc_info=True,
|
|
94
|
+
)
|
|
95
|
+
return None # Indicate failure
|
|
96
|
+
finally:
|
|
97
|
+
# Cleanup
|
|
98
|
+
del evaluator
|
|
99
|
+
process_logger.info(
|
|
100
|
+
f"Evaluator process cleanup finished for judge: {judge_type}"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def execute(
|
|
105
|
+
input_df: pd.DataFrame, config: Dict, logger: logging.Logger, run_dir: str
|
|
106
|
+
) -> pd.DataFrame:
|
|
107
|
+
"""Evaluate completions using specified judges."""
|
|
108
|
+
logger.info("Executing Step 7: Evaluating responses")
|
|
109
|
+
original_df = input_df
|
|
110
|
+
|
|
111
|
+
if original_df.empty:
|
|
112
|
+
logger.warning("Step 7 received an empty DataFrame. Skipping evaluation.")
|
|
113
|
+
return original_df
|
|
114
|
+
|
|
115
|
+
# Config key 'judges' should be a list of dictionaries
|
|
116
|
+
judge_configs_list = config.get("judges")
|
|
117
|
+
if not isinstance(judge_configs_list, list) or not judge_configs_list:
|
|
118
|
+
logger.warning(
|
|
119
|
+
"Step 7: 'judges' key in configuration is missing, not a list, or empty. Skipping evaluation."
|
|
120
|
+
)
|
|
121
|
+
return original_df
|
|
122
|
+
|
|
123
|
+
# Base config for evaluators (extract non-judge-specific params)
|
|
124
|
+
evaluator_base_config_dict = {
|
|
125
|
+
"batch_size": config.get("batch_size_judge"),
|
|
126
|
+
"max_new_tokens_eval": config.get("max_new_tokens_eval"),
|
|
127
|
+
"filter_len": config.get("filter_len"),
|
|
128
|
+
# General API settings (judges might override)
|
|
129
|
+
"endpoint": config.get("judge_endpoint"),
|
|
130
|
+
"api_key": config.get("judge_api_key"),
|
|
131
|
+
"request_timeout": config.get("judge_request_timeout"),
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
judge_results_dfs = {}
|
|
135
|
+
failed_judges = []
|
|
136
|
+
async_results = []
|
|
137
|
+
judges_to_run = [] # Store valid (type, config_dict) tuples
|
|
138
|
+
|
|
139
|
+
# --- Prepare Judge Runs ---
|
|
140
|
+
for judge_config_item in judge_configs_list:
|
|
141
|
+
if not isinstance(judge_config_item, dict):
|
|
142
|
+
logger.warning(
|
|
143
|
+
f"Skipping invalid item in 'judges' list (not a dict): {judge_config_item}"
|
|
144
|
+
)
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
# Extract the judge type string (e.g., "nuanced", "harmbench")
|
|
148
|
+
# Assuming the type is specified by a 'type' key in the judge dict.
|
|
149
|
+
# Alternative: Infer based on 'identifier'?
|
|
150
|
+
judge_type_str = judge_config_item.get(
|
|
151
|
+
"evaluator_type"
|
|
152
|
+
) or judge_config_item.get("type")
|
|
153
|
+
judge_identifier = judge_config_item.get("identifier")
|
|
154
|
+
|
|
155
|
+
if not judge_type_str:
|
|
156
|
+
# If type isn't explicit, try to infer (this part might need refinement)
|
|
157
|
+
if "nuanced" in judge_identifier.lower():
|
|
158
|
+
judge_type_str = "nuanced"
|
|
159
|
+
elif "harmbench" in judge_identifier.lower():
|
|
160
|
+
judge_type_str = "harmbench"
|
|
161
|
+
elif "jailbreak" in judge_identifier.lower():
|
|
162
|
+
judge_type_str = "jailbreakbench"
|
|
163
|
+
else:
|
|
164
|
+
logger.warning(
|
|
165
|
+
f"Could not determine evaluator type for judge config: {judge_config_item}. Skipping."
|
|
166
|
+
)
|
|
167
|
+
continue
|
|
168
|
+
|
|
169
|
+
# Check if the extracted type string is valid
|
|
170
|
+
if judge_type_str not in EVALUATOR_MAP:
|
|
171
|
+
logger.warning(
|
|
172
|
+
f"Skipping unknown judge type '{judge_type_str}' specified in config: {judge_config_item}"
|
|
173
|
+
)
|
|
174
|
+
continue
|
|
175
|
+
|
|
176
|
+
# Prepare the specific config to pass to the subprocess
|
|
177
|
+
# Start with base, then override with judge-specific settings
|
|
178
|
+
subprocess_config = evaluator_base_config_dict.copy()
|
|
179
|
+
subprocess_config.update(judge_config_item) # Override base with specifics
|
|
180
|
+
# Ensure model_id is set correctly (use 'identifier')
|
|
181
|
+
if judge_identifier:
|
|
182
|
+
subprocess_config["model_id"] = judge_identifier
|
|
183
|
+
|
|
184
|
+
judges_to_run.append((judge_type_str, subprocess_config))
|
|
185
|
+
|
|
186
|
+
if not judges_to_run:
|
|
187
|
+
logger.warning(
|
|
188
|
+
"Step 7: No valid judges found after processing configuration. Skipping evaluation."
|
|
189
|
+
)
|
|
190
|
+
return original_df
|
|
191
|
+
|
|
192
|
+
# --- Setup Multiprocessing Pool ---
|
|
193
|
+
try:
|
|
194
|
+
current_start_method = multiprocessing.get_start_method(allow_none=True)
|
|
195
|
+
if current_start_method != "spawn":
|
|
196
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
197
|
+
logger.info("Set multiprocessing start method to 'spawn' for Step 7.")
|
|
198
|
+
except Exception as e:
|
|
199
|
+
logger.warning(f"Could not set multiprocessing start method to spawn: {e}")
|
|
200
|
+
|
|
201
|
+
num_judges = len(judges_to_run)
|
|
202
|
+
num_workers = min(num_judges, os.cpu_count() or 1, 4)
|
|
203
|
+
logger.info(
|
|
204
|
+
f"Starting evaluation pool with {num_workers} workers for {num_judges} judges."
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# --- Dispatch and Collect Results ---
|
|
208
|
+
with multiprocessing.Pool(processes=num_workers) as pool:
|
|
209
|
+
# Dispatch tasks using the prepared list
|
|
210
|
+
for judge_type_str, subprocess_config in judges_to_run:
|
|
211
|
+
logger.info(
|
|
212
|
+
f"Dispatching evaluation with {judge_type_str} judge. Config: {subprocess_config}"
|
|
213
|
+
)
|
|
214
|
+
args = (judge_type_str, subprocess_config, original_df.copy())
|
|
215
|
+
async_results.append(
|
|
216
|
+
pool.apply_async(_run_evaluator_process_wrapper, args=args)
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Collect results using the order in judges_to_run
|
|
220
|
+
for (judge_type_str, _), result in zip(judges_to_run, async_results):
|
|
221
|
+
try:
|
|
222
|
+
evaluated_df_subset = result.get()
|
|
223
|
+
if evaluated_df_subset is not None:
|
|
224
|
+
judge_results_dfs[judge_type_str] = evaluated_df_subset
|
|
225
|
+
logger.info(
|
|
226
|
+
f"Successfully completed evaluation for judge: {judge_type_str}"
|
|
227
|
+
)
|
|
228
|
+
else:
|
|
229
|
+
failed_judges.append(judge_type_str)
|
|
230
|
+
logger.error(
|
|
231
|
+
f"Evaluation failed for judge: {judge_type_str} (process returned None)"
|
|
232
|
+
)
|
|
233
|
+
except Exception as e:
|
|
234
|
+
failed_judges.append(judge_type_str)
|
|
235
|
+
logger.error(
|
|
236
|
+
f"Evaluation task failed for judge {judge_type_str}: {e}",
|
|
237
|
+
exc_info=True,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# --- Merge Results ---
|
|
241
|
+
final_df = original_df.copy()
|
|
242
|
+
successful_judges = list(judge_results_dfs.keys())
|
|
243
|
+
|
|
244
|
+
if not successful_judges:
|
|
245
|
+
logger.warning(
|
|
246
|
+
"Step 7: No judges completed successfully. Returning original DataFrame."
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
logger.info(f"Merging results from successful judges: {successful_judges}")
|
|
250
|
+
for judge_type_str in successful_judges:
|
|
251
|
+
judge_df_subset = judge_results_dfs[judge_type_str]
|
|
252
|
+
eval_cols = JUDGE_COLUMN_MAP.get(judge_type_str, [])
|
|
253
|
+
judge_cols_present = [
|
|
254
|
+
col for col in eval_cols if col in judge_df_subset.columns
|
|
255
|
+
]
|
|
256
|
+
|
|
257
|
+
if not judge_cols_present:
|
|
258
|
+
logger.warning(
|
|
259
|
+
f"No specific evaluation columns found in result for judge {judge_type_str}. Skipping merge for this judge."
|
|
260
|
+
)
|
|
261
|
+
continue
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
final_df = final_df.merge(
|
|
265
|
+
judge_df_subset,
|
|
266
|
+
on=MERGE_KEYS,
|
|
267
|
+
how="left",
|
|
268
|
+
suffixes=("", f"_{judge_type_str}_dup"),
|
|
269
|
+
)
|
|
270
|
+
logger.debug(f"Merged results from judge {judge_type_str}")
|
|
271
|
+
except Exception as e:
|
|
272
|
+
logger.error(f"Error merging results for judge {judge_type_str}: {e}")
|
|
273
|
+
|
|
274
|
+
# Save final merged results checkpoint
|
|
275
|
+
output_path = get_checkpoint_path(run_dir, 7)
|
|
276
|
+
try:
|
|
277
|
+
final_df.to_csv(output_path, index=False)
|
|
278
|
+
logger.info(f"Step 7 complete. Evaluated {len(final_df)} responses.")
|
|
279
|
+
if successful_judges:
|
|
280
|
+
logger.info(
|
|
281
|
+
f"Successfully completed judges: {', '.join(successful_judges)}"
|
|
282
|
+
)
|
|
283
|
+
if failed_judges:
|
|
284
|
+
logger.warning(f"Failed judges: {', '.join(failed_judges)}")
|
|
285
|
+
logger.info(f"Final evaluation results checkpoint saved to {output_path}")
|
|
286
|
+
except Exception as e:
|
|
287
|
+
logger.error(f"Failed to save checkpoint for step 7 to {output_path}: {e}")
|
|
288
|
+
|
|
289
|
+
return final_df
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from typing import Dict, Any
|
|
3
|
+
|
|
4
|
+
from .utils import get_checkpoint_path
|
|
5
|
+
|
|
6
|
+
# Map judge type to expected column prefix/name used for aggregation stats
|
|
7
|
+
JUDGE_AGG_COLUMN_MAP = {
|
|
8
|
+
"nuanced": "eval_nj",
|
|
9
|
+
"jailbreakbench": "eval_jb",
|
|
10
|
+
"harmbench": "eval_hb",
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
GROUP_KEYS = ["goal", "prefix"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def execute(
|
|
17
|
+
input_df: pd.DataFrame, config: Dict[str, Any], run_dir: str
|
|
18
|
+
) -> pd.DataFrame:
|
|
19
|
+
"""
|
|
20
|
+
Aggregate evaluation results from different judges using the input DataFrame.
|
|
21
|
+
Combines results from multiple evaluation samples and judges into single scores per prefix.
|
|
22
|
+
"""
|
|
23
|
+
print("Executing Step 8: Aggregating evaluation results")
|
|
24
|
+
|
|
25
|
+
if input_df.empty:
|
|
26
|
+
print("WARNING: Step 8 received an empty DataFrame. Skipping aggregation.")
|
|
27
|
+
# Define expected aggregated columns if returning empty
|
|
28
|
+
cols = GROUP_KEYS + [
|
|
29
|
+
"prefix_nll",
|
|
30
|
+
"model_name",
|
|
31
|
+
"meta_prefix",
|
|
32
|
+
"temperature",
|
|
33
|
+
"n_eval_samples",
|
|
34
|
+
]
|
|
35
|
+
for _, col_base in JUDGE_AGG_COLUMN_MAP.items():
|
|
36
|
+
cols.extend([f"{col_base}_mean", f"{col_base}_count"])
|
|
37
|
+
return pd.DataFrame(columns=cols)
|
|
38
|
+
|
|
39
|
+
analysis = input_df.copy()
|
|
40
|
+
|
|
41
|
+
# Optionally filter based on cross-entropy / NLL score
|
|
42
|
+
if "prefix_nll" in analysis.columns and config.get("max_ce") is not None:
|
|
43
|
+
try:
|
|
44
|
+
max_ce_threshold = float(config.get("max_ce"))
|
|
45
|
+
initial_count = len(analysis)
|
|
46
|
+
# Use dictionary access for config
|
|
47
|
+
analysis = analysis[analysis["prefix_nll"] < max_ce_threshold]
|
|
48
|
+
filtered_count = len(analysis)
|
|
49
|
+
print(
|
|
50
|
+
f"Filtered {initial_count - filtered_count} rows based on prefix_nll >= {max_ce_threshold}"
|
|
51
|
+
)
|
|
52
|
+
except KeyError:
|
|
53
|
+
print("WARNING: 'max_ce' key not found in config, skipping NLL filtering.")
|
|
54
|
+
except Exception as e:
|
|
55
|
+
print(f"ERROR: Error during NLL filtering in aggregation: {e}")
|
|
56
|
+
# Continue without NLL filtering if error occurs
|
|
57
|
+
elif "prefix_nll" not in analysis.columns:
|
|
58
|
+
print(
|
|
59
|
+
"WARNING: Column 'prefix_nll' not found. Skipping NLL filtering in aggregation step."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Detect available judges based on column names for aggregation
|
|
63
|
+
available_judges_agg_cols = {}
|
|
64
|
+
judges_in_config = config.get("judges", []) # Judges that were supposed to run
|
|
65
|
+
for judge_type, col_name in JUDGE_AGG_COLUMN_MAP.items():
|
|
66
|
+
if col_name in analysis.columns:
|
|
67
|
+
available_judges_agg_cols[judge_type] = col_name
|
|
68
|
+
else:
|
|
69
|
+
# Log if any expected judge column is missing
|
|
70
|
+
if judge_type in judges_in_config:
|
|
71
|
+
print(
|
|
72
|
+
f"WARNING: Expected aggregation column '{col_name}' for judge '{judge_type}' not found in the dataframe for Step 8."
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if not available_judges_agg_cols:
|
|
76
|
+
print(
|
|
77
|
+
"ERROR: No recognized evaluation result columns found for aggregation. Check step 7 output."
|
|
78
|
+
)
|
|
79
|
+
output_path = get_checkpoint_path(run_dir, 8)
|
|
80
|
+
try:
|
|
81
|
+
analysis.to_csv(output_path, index=False)
|
|
82
|
+
print(
|
|
83
|
+
f"WARNING: Step 8 saving unaggregated data to {output_path} due to missing judge columns."
|
|
84
|
+
)
|
|
85
|
+
except Exception as e:
|
|
86
|
+
print(
|
|
87
|
+
f"ERROR: Failed to save unaggregated data checkpoint for step 8 to {output_path}: {e}"
|
|
88
|
+
)
|
|
89
|
+
return analysis # Return unaggregated data
|
|
90
|
+
|
|
91
|
+
print(
|
|
92
|
+
f"Found aggregation columns for judges: {list(available_judges_agg_cols.keys())}"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Ensure group keys exist
|
|
96
|
+
if not all(key in analysis.columns for key in GROUP_KEYS):
|
|
97
|
+
missing_keys = [key for key in GROUP_KEYS if key not in analysis.columns]
|
|
98
|
+
print(
|
|
99
|
+
f"ERROR: Missing required grouping keys for aggregation: {missing_keys}. Cannot aggregate."
|
|
100
|
+
)
|
|
101
|
+
output_path = get_checkpoint_path(run_dir, 8)
|
|
102
|
+
try:
|
|
103
|
+
analysis.to_csv(output_path, index=False)
|
|
104
|
+
print(
|
|
105
|
+
f"WARNING: Step 8 saving unaggregated data to {output_path} due to missing group keys."
|
|
106
|
+
)
|
|
107
|
+
except Exception as e:
|
|
108
|
+
print(
|
|
109
|
+
f"ERROR: Failed to save unaggregated data checkpoint for step 8 to {output_path}: {e}"
|
|
110
|
+
)
|
|
111
|
+
return analysis
|
|
112
|
+
|
|
113
|
+
# Define aggregations
|
|
114
|
+
agg_funcs = {
|
|
115
|
+
# Use pd.NamedAgg for clarity and future compatibility
|
|
116
|
+
"prefix_nll": pd.NamedAgg(column="prefix_nll", aggfunc="first"),
|
|
117
|
+
"model_name": pd.NamedAgg(column="model_name", aggfunc="first"),
|
|
118
|
+
"meta_prefix": pd.NamedAgg(column="meta_prefix", aggfunc="first"),
|
|
119
|
+
"temperature": pd.NamedAgg(column="temperature", aggfunc="first"),
|
|
120
|
+
# Count samples - use one of the group keys or index if reset
|
|
121
|
+
"n_eval_samples": pd.NamedAgg(column=GROUP_KEYS[0], aggfunc="size"),
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
# Add judge-specific aggregations
|
|
125
|
+
for judge_type, col_name in available_judges_agg_cols.items():
|
|
126
|
+
# Ensure the column is numeric before calculating mean
|
|
127
|
+
try:
|
|
128
|
+
analysis[col_name] = pd.to_numeric(analysis[col_name], errors="coerce")
|
|
129
|
+
agg_funcs[f"{col_name}_mean"] = pd.NamedAgg(column=col_name, aggfunc="mean")
|
|
130
|
+
agg_funcs[f"{col_name}_count"] = pd.NamedAgg(
|
|
131
|
+
column=col_name, aggfunc="count"
|
|
132
|
+
) # Count non-NA numeric values
|
|
133
|
+
print(
|
|
134
|
+
f"DEBUG: Added mean/count aggregation for numeric column '{col_name}'"
|
|
135
|
+
)
|
|
136
|
+
except KeyError:
|
|
137
|
+
print(
|
|
138
|
+
f"WARNING: Column '{col_name}' unexpectedly missing during aggregation setup. Skipping mean/count."
|
|
139
|
+
)
|
|
140
|
+
except Exception as e:
|
|
141
|
+
print(
|
|
142
|
+
f"ERROR: Could not convert column '{col_name}' to numeric for aggregation. Skipping mean/count. Error: {e}"
|
|
143
|
+
)
|
|
144
|
+
# Optionally add just size aggregation if mean fails?
|
|
145
|
+
agg_funcs[f"{col_name}_size"] = pd.NamedAgg(column=col_name, aggfunc="size")
|
|
146
|
+
|
|
147
|
+
# Perform aggregation
|
|
148
|
+
try:
|
|
149
|
+
grouped = analysis.groupby(GROUP_KEYS, observed=False, dropna=False)
|
|
150
|
+
aggregated = grouped.agg(**agg_funcs)
|
|
151
|
+
aggregated = aggregated.reset_index()
|
|
152
|
+
except Exception as e:
|
|
153
|
+
print(
|
|
154
|
+
f"ERROR: Error during aggregation: {e}. Check aggregation functions and column types."
|
|
155
|
+
)
|
|
156
|
+
output_path = get_checkpoint_path(run_dir, 8)
|
|
157
|
+
try:
|
|
158
|
+
analysis.to_csv(output_path, index=False)
|
|
159
|
+
print(
|
|
160
|
+
f"WARNING: Step 8 saving unaggregated data to {output_path} due to aggregation error."
|
|
161
|
+
)
|
|
162
|
+
except Exception as e_save:
|
|
163
|
+
print(
|
|
164
|
+
f"ERROR: Failed to save unaggregated data checkpoint for step 8 to {output_path}: {e_save}"
|
|
165
|
+
)
|
|
166
|
+
return analysis # Return unaggregated on error
|
|
167
|
+
|
|
168
|
+
# Save results checkpoint
|
|
169
|
+
output_path = get_checkpoint_path(run_dir, 8)
|
|
170
|
+
try:
|
|
171
|
+
aggregated.to_csv(output_path, index=False)
|
|
172
|
+
print(f"Step 8 complete. Aggregated {len(aggregated)} prefix results.")
|
|
173
|
+
print(f"Checkpoint saved to {output_path}")
|
|
174
|
+
except Exception as e:
|
|
175
|
+
print(f"ERROR: Failed to save checkpoint for step 8 to {output_path}: {e}")
|
|
176
|
+
|
|
177
|
+
return aggregated
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from typing import Dict, Any
|
|
4
|
+
|
|
5
|
+
from hackagent.attacks.AdvPrefix.selector import (
|
|
6
|
+
PrefixSelectorConfig,
|
|
7
|
+
PrefixSelector,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
from .utils import get_checkpoint_path
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def execute(
|
|
16
|
+
input_df: pd.DataFrame, config: Dict[str, Any], run_dir: str
|
|
17
|
+
) -> pd.DataFrame:
|
|
18
|
+
"""Select final prefixes based on specified judges and selection criteria using input DataFrame."""
|
|
19
|
+
logger.info("Executing Step 9: Selecting final prefixes")
|
|
20
|
+
|
|
21
|
+
if input_df.empty:
|
|
22
|
+
logger.warning("Step 9 received an empty DataFrame. Skipping selection.")
|
|
23
|
+
return input_df
|
|
24
|
+
|
|
25
|
+
selector = None # Ensure cleanup
|
|
26
|
+
selected_df = input_df # Default to input if selection fails
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
# Initialize selector here
|
|
30
|
+
selector_config = PrefixSelectorConfig(
|
|
31
|
+
pasr_weight=config.get("pasr_weight", 0.5),
|
|
32
|
+
n_prefixes_per_goal=config.get("n_prefixes_per_goal", 3),
|
|
33
|
+
judges=config.get("selection_judges", []),
|
|
34
|
+
)
|
|
35
|
+
selector = PrefixSelector(selector_config)
|
|
36
|
+
|
|
37
|
+
# Select prefixes
|
|
38
|
+
selected_df = selector.select_prefixes(input_df)
|
|
39
|
+
logger.info(f"Selection complete. Selected {len(selected_df)} prefixes.")
|
|
40
|
+
|
|
41
|
+
except Exception as e:
|
|
42
|
+
logger.error(f"Error during prefix selection: {e}", exc_info=True)
|
|
43
|
+
logger.warning("Returning unselected prefixes due to selection error.")
|
|
44
|
+
selected_df = input_df # Fallback to returning the input df
|
|
45
|
+
|
|
46
|
+
finally:
|
|
47
|
+
del selector
|
|
48
|
+
# No GPU cleanup needed typically for selection
|
|
49
|
+
|
|
50
|
+
# Save results checkpoint (final step)
|
|
51
|
+
output_path = get_checkpoint_path(run_dir, 9)
|
|
52
|
+
try:
|
|
53
|
+
selected_df.to_csv(output_path, index=False)
|
|
54
|
+
logger.info("Step 9 complete.")
|
|
55
|
+
logger.info(f"Final selected prefixes checkpoint saved to {output_path}")
|
|
56
|
+
except Exception as e:
|
|
57
|
+
logger.error(f"Failed to save checkpoint for step 9 to {output_path}: {e}")
|
|
58
|
+
|
|
59
|
+
return selected_df
|