hackagent 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. hackagent/__init__.py +23 -0
  2. hackagent/agent.py +193 -0
  3. hackagent/api/__init__.py +1 -0
  4. hackagent/api/agent/__init__.py +1 -0
  5. hackagent/api/agent/agent_create.py +340 -0
  6. hackagent/api/agent/agent_destroy.py +136 -0
  7. hackagent/api/agent/agent_list.py +234 -0
  8. hackagent/api/agent/agent_partial_update.py +354 -0
  9. hackagent/api/agent/agent_retrieve.py +227 -0
  10. hackagent/api/agent/agent_update.py +354 -0
  11. hackagent/api/attack/__init__.py +1 -0
  12. hackagent/api/attack/attack_create.py +264 -0
  13. hackagent/api/attack/attack_destroy.py +140 -0
  14. hackagent/api/attack/attack_list.py +242 -0
  15. hackagent/api/attack/attack_partial_update.py +278 -0
  16. hackagent/api/attack/attack_retrieve.py +235 -0
  17. hackagent/api/attack/attack_update.py +278 -0
  18. hackagent/api/key/__init__.py +1 -0
  19. hackagent/api/key/key_create.py +168 -0
  20. hackagent/api/key/key_destroy.py +97 -0
  21. hackagent/api/key/key_list.py +158 -0
  22. hackagent/api/key/key_retrieve.py +150 -0
  23. hackagent/api/prompt/__init__.py +1 -0
  24. hackagent/api/prompt/prompt_create.py +160 -0
  25. hackagent/api/prompt/prompt_destroy.py +98 -0
  26. hackagent/api/prompt/prompt_list.py +173 -0
  27. hackagent/api/prompt/prompt_partial_update.py +174 -0
  28. hackagent/api/prompt/prompt_retrieve.py +151 -0
  29. hackagent/api/prompt/prompt_update.py +174 -0
  30. hackagent/api/result/__init__.py +1 -0
  31. hackagent/api/result/result_create.py +160 -0
  32. hackagent/api/result/result_destroy.py +98 -0
  33. hackagent/api/result/result_list.py +233 -0
  34. hackagent/api/result/result_partial_update.py +178 -0
  35. hackagent/api/result/result_retrieve.py +151 -0
  36. hackagent/api/result/result_trace_create.py +178 -0
  37. hackagent/api/result/result_update.py +174 -0
  38. hackagent/api/run/__init__.py +1 -0
  39. hackagent/api/run/run_create.py +172 -0
  40. hackagent/api/run/run_destroy.py +104 -0
  41. hackagent/api/run/run_list.py +260 -0
  42. hackagent/api/run/run_partial_update.py +186 -0
  43. hackagent/api/run/run_result_create.py +178 -0
  44. hackagent/api/run/run_retrieve.py +163 -0
  45. hackagent/api/run/run_run_tests_create.py +172 -0
  46. hackagent/api/run/run_update.py +186 -0
  47. hackagent/attacks/AdvPrefix/README.md +7 -0
  48. hackagent/attacks/AdvPrefix/__init__.py +0 -0
  49. hackagent/attacks/AdvPrefix/completer.py +438 -0
  50. hackagent/attacks/AdvPrefix/config.py +59 -0
  51. hackagent/attacks/AdvPrefix/preprocessing.py +521 -0
  52. hackagent/attacks/AdvPrefix/scorer.py +259 -0
  53. hackagent/attacks/AdvPrefix/scorer_parser.py +498 -0
  54. hackagent/attacks/AdvPrefix/selector.py +246 -0
  55. hackagent/attacks/AdvPrefix/step1_generate.py +324 -0
  56. hackagent/attacks/AdvPrefix/step4_compute_ce.py +293 -0
  57. hackagent/attacks/AdvPrefix/step6_get_completions.py +387 -0
  58. hackagent/attacks/AdvPrefix/step7_evaluate_responses.py +289 -0
  59. hackagent/attacks/AdvPrefix/step8_aggregate_evaluations.py +177 -0
  60. hackagent/attacks/AdvPrefix/step9_select_prefixes.py +59 -0
  61. hackagent/attacks/AdvPrefix/utils.py +192 -0
  62. hackagent/attacks/__init__.py +6 -0
  63. hackagent/attacks/advprefix.py +1136 -0
  64. hackagent/attacks/base.py +50 -0
  65. hackagent/attacks/strategies.py +539 -0
  66. hackagent/branding.py +143 -0
  67. hackagent/client.py +328 -0
  68. hackagent/errors.py +31 -0
  69. hackagent/logger.py +67 -0
  70. hackagent/models/__init__.py +71 -0
  71. hackagent/models/agent.py +240 -0
  72. hackagent/models/agent_request.py +169 -0
  73. hackagent/models/agent_type_enum.py +12 -0
  74. hackagent/models/attack.py +154 -0
  75. hackagent/models/attack_request.py +82 -0
  76. hackagent/models/evaluation_status_enum.py +14 -0
  77. hackagent/models/organization_minimal.py +68 -0
  78. hackagent/models/paginated_agent_list.py +123 -0
  79. hackagent/models/paginated_attack_list.py +123 -0
  80. hackagent/models/paginated_prompt_list.py +123 -0
  81. hackagent/models/paginated_result_list.py +123 -0
  82. hackagent/models/paginated_run_list.py +123 -0
  83. hackagent/models/paginated_user_api_key_list.py +123 -0
  84. hackagent/models/patched_agent_request.py +176 -0
  85. hackagent/models/patched_attack_request.py +92 -0
  86. hackagent/models/patched_prompt_request.py +162 -0
  87. hackagent/models/patched_result_request.py +237 -0
  88. hackagent/models/patched_run_request.py +138 -0
  89. hackagent/models/prompt.py +226 -0
  90. hackagent/models/prompt_request.py +155 -0
  91. hackagent/models/result.py +294 -0
  92. hackagent/models/result_list_evaluation_status.py +14 -0
  93. hackagent/models/result_request.py +232 -0
  94. hackagent/models/run.py +233 -0
  95. hackagent/models/run_list_status.py +12 -0
  96. hackagent/models/run_request.py +133 -0
  97. hackagent/models/status_enum.py +12 -0
  98. hackagent/models/step_type_enum.py +14 -0
  99. hackagent/models/trace.py +121 -0
  100. hackagent/models/trace_request.py +94 -0
  101. hackagent/models/user_api_key.py +201 -0
  102. hackagent/models/user_api_key_request.py +73 -0
  103. hackagent/models/user_profile_minimal.py +76 -0
  104. hackagent/py.typed +1 -0
  105. hackagent/router/__init__.py +11 -0
  106. hackagent/router/adapters/__init__.py +5 -0
  107. hackagent/router/adapters/google_adk.py +658 -0
  108. hackagent/router/adapters/litellm_adapter.py +290 -0
  109. hackagent/router/base.py +48 -0
  110. hackagent/router/router.py +753 -0
  111. hackagent/types.py +46 -0
  112. hackagent/utils.py +61 -0
  113. hackagent/vulnerabilities/__init__.py +0 -0
  114. hackagent-0.1.0.dist-info/LICENSE +202 -0
  115. hackagent-0.1.0.dist-info/METADATA +173 -0
  116. hackagent-0.1.0.dist-info/RECORD +117 -0
  117. hackagent-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,289 @@
1
+ import logging
2
+ import pandas as pd
3
+ import multiprocessing
4
+ import os
5
+ from typing import Dict
6
+ from dataclasses import fields # Import fields to inspect dataclass
7
+
8
+ from hackagent.attacks.AdvPrefix.scorer_parser import (
9
+ EvaluatorConfig,
10
+ NuancedEvaluator,
11
+ HarmBenchEvaluator,
12
+ JailbreakBenchEvaluator,
13
+ )
14
+ from hackagent.attacks.AdvPrefix.utils import get_checkpoint_path
15
+
16
+ EVALUATOR_MAP = {
17
+ "nuanced": NuancedEvaluator,
18
+ "jailbreakbench": JailbreakBenchEvaluator,
19
+ "harmbench": HarmBenchEvaluator,
20
+ }
21
+
22
+ MERGE_KEYS = ["goal", "prefix", "completion"] # Standard merge keys
23
+
24
+ JUDGE_COLUMN_MAP = {
25
+ "nuanced": ["eval_nj", "explanation_nj"],
26
+ "jailbreakbench": ["eval_jb", "explanation_jb"],
27
+ "harmbench": ["eval_hb", "explanation_hb"],
28
+ }
29
+
30
+
31
+ def _run_evaluator_process_wrapper(
32
+ judge_type: str, config_dict_serializable: Dict, df: pd.DataFrame
33
+ ):
34
+ """Static method to run a specific evaluator, suitable for multiprocessing."""
35
+ process_logger = logging.getLogger(__name__ + f".evaluator_process_{judge_type}")
36
+ process_logger.info(f"Evaluator process started for judge: {judge_type}")
37
+
38
+ evaluator_class = EVALUATOR_MAP.get(judge_type)
39
+ if not evaluator_class:
40
+ process_logger.warning(f"Unknown judge type: {judge_type}, skipping")
41
+ return None
42
+
43
+ evaluator = None # Ensure cleanup
44
+ try:
45
+ # Filter the config dict to only include keys expected by EvaluatorConfig
46
+ expected_fields = {f.name for f in fields(EvaluatorConfig)}
47
+ filtered_config_dict = {
48
+ k: v for k, v in config_dict_serializable.items() if k in expected_fields
49
+ }
50
+
51
+ # model_id is handled specially by EvaluatorConfig.with_default_model
52
+ # or within the specific evaluator's __init__.
53
+ # We should ensure model_id from the judge's config is passed if present.
54
+ if (
55
+ "model_id" in config_dict_serializable
56
+ and config_dict_serializable["model_id"]
57
+ ):
58
+ filtered_config_dict["model_id"] = config_dict_serializable["model_id"]
59
+ elif (
60
+ "identifier" in config_dict_serializable
61
+ and config_dict_serializable["identifier"]
62
+ ):
63
+ # Fallback to using 'identifier' if 'model_id' wasn't explicitly passed/overridden
64
+ filtered_config_dict["model_id"] = config_dict_serializable["identifier"]
65
+
66
+ process_logger.debug(
67
+ f"Filtered config for {judge_type} evaluator: {filtered_config_dict}"
68
+ )
69
+ evaluator_config = EvaluatorConfig(**filtered_config_dict)
70
+
71
+ # Instantiate the specific evaluator class
72
+ evaluator = evaluator_class(evaluator_config)
73
+ evaluated_df = evaluator.evaluate(df)
74
+
75
+ process_logger.info(f"Evaluator process finished for judge: {judge_type}")
76
+ # Return only the essential columns: merge keys + judge-specific columns
77
+ eval_cols = JUDGE_COLUMN_MAP.get(judge_type, [])
78
+ # Ensure merge keys are present in the returned df
79
+ if not all(key in evaluated_df.columns for key in MERGE_KEYS):
80
+ process_logger.error(
81
+ f"Evaluation result for {judge_type} is missing merge keys {MERGE_KEYS}. Available: {evaluated_df.columns}. Returning None."
82
+ )
83
+ return None
84
+
85
+ cols_to_return = MERGE_KEYS + [
86
+ col for col in eval_cols if col in evaluated_df.columns
87
+ ]
88
+ return evaluated_df[cols_to_return]
89
+
90
+ except Exception as e:
91
+ process_logger.error(
92
+ f"Error occurred while running {judge_type} evaluator: {str(e)}",
93
+ exc_info=True,
94
+ )
95
+ return None # Indicate failure
96
+ finally:
97
+ # Cleanup
98
+ del evaluator
99
+ process_logger.info(
100
+ f"Evaluator process cleanup finished for judge: {judge_type}"
101
+ )
102
+
103
+
104
+ def execute(
105
+ input_df: pd.DataFrame, config: Dict, logger: logging.Logger, run_dir: str
106
+ ) -> pd.DataFrame:
107
+ """Evaluate completions using specified judges."""
108
+ logger.info("Executing Step 7: Evaluating responses")
109
+ original_df = input_df
110
+
111
+ if original_df.empty:
112
+ logger.warning("Step 7 received an empty DataFrame. Skipping evaluation.")
113
+ return original_df
114
+
115
+ # Config key 'judges' should be a list of dictionaries
116
+ judge_configs_list = config.get("judges")
117
+ if not isinstance(judge_configs_list, list) or not judge_configs_list:
118
+ logger.warning(
119
+ "Step 7: 'judges' key in configuration is missing, not a list, or empty. Skipping evaluation."
120
+ )
121
+ return original_df
122
+
123
+ # Base config for evaluators (extract non-judge-specific params)
124
+ evaluator_base_config_dict = {
125
+ "batch_size": config.get("batch_size_judge"),
126
+ "max_new_tokens_eval": config.get("max_new_tokens_eval"),
127
+ "filter_len": config.get("filter_len"),
128
+ # General API settings (judges might override)
129
+ "endpoint": config.get("judge_endpoint"),
130
+ "api_key": config.get("judge_api_key"),
131
+ "request_timeout": config.get("judge_request_timeout"),
132
+ }
133
+
134
+ judge_results_dfs = {}
135
+ failed_judges = []
136
+ async_results = []
137
+ judges_to_run = [] # Store valid (type, config_dict) tuples
138
+
139
+ # --- Prepare Judge Runs ---
140
+ for judge_config_item in judge_configs_list:
141
+ if not isinstance(judge_config_item, dict):
142
+ logger.warning(
143
+ f"Skipping invalid item in 'judges' list (not a dict): {judge_config_item}"
144
+ )
145
+ continue
146
+
147
+ # Extract the judge type string (e.g., "nuanced", "harmbench")
148
+ # Assuming the type is specified by a 'type' key in the judge dict.
149
+ # Alternative: Infer based on 'identifier'?
150
+ judge_type_str = judge_config_item.get(
151
+ "evaluator_type"
152
+ ) or judge_config_item.get("type")
153
+ judge_identifier = judge_config_item.get("identifier")
154
+
155
+ if not judge_type_str:
156
+ # If type isn't explicit, try to infer (this part might need refinement)
157
+ if "nuanced" in judge_identifier.lower():
158
+ judge_type_str = "nuanced"
159
+ elif "harmbench" in judge_identifier.lower():
160
+ judge_type_str = "harmbench"
161
+ elif "jailbreak" in judge_identifier.lower():
162
+ judge_type_str = "jailbreakbench"
163
+ else:
164
+ logger.warning(
165
+ f"Could not determine evaluator type for judge config: {judge_config_item}. Skipping."
166
+ )
167
+ continue
168
+
169
+ # Check if the extracted type string is valid
170
+ if judge_type_str not in EVALUATOR_MAP:
171
+ logger.warning(
172
+ f"Skipping unknown judge type '{judge_type_str}' specified in config: {judge_config_item}"
173
+ )
174
+ continue
175
+
176
+ # Prepare the specific config to pass to the subprocess
177
+ # Start with base, then override with judge-specific settings
178
+ subprocess_config = evaluator_base_config_dict.copy()
179
+ subprocess_config.update(judge_config_item) # Override base with specifics
180
+ # Ensure model_id is set correctly (use 'identifier')
181
+ if judge_identifier:
182
+ subprocess_config["model_id"] = judge_identifier
183
+
184
+ judges_to_run.append((judge_type_str, subprocess_config))
185
+
186
+ if not judges_to_run:
187
+ logger.warning(
188
+ "Step 7: No valid judges found after processing configuration. Skipping evaluation."
189
+ )
190
+ return original_df
191
+
192
+ # --- Setup Multiprocessing Pool ---
193
+ try:
194
+ current_start_method = multiprocessing.get_start_method(allow_none=True)
195
+ if current_start_method != "spawn":
196
+ multiprocessing.set_start_method("spawn", force=True)
197
+ logger.info("Set multiprocessing start method to 'spawn' for Step 7.")
198
+ except Exception as e:
199
+ logger.warning(f"Could not set multiprocessing start method to spawn: {e}")
200
+
201
+ num_judges = len(judges_to_run)
202
+ num_workers = min(num_judges, os.cpu_count() or 1, 4)
203
+ logger.info(
204
+ f"Starting evaluation pool with {num_workers} workers for {num_judges} judges."
205
+ )
206
+
207
+ # --- Dispatch and Collect Results ---
208
+ with multiprocessing.Pool(processes=num_workers) as pool:
209
+ # Dispatch tasks using the prepared list
210
+ for judge_type_str, subprocess_config in judges_to_run:
211
+ logger.info(
212
+ f"Dispatching evaluation with {judge_type_str} judge. Config: {subprocess_config}"
213
+ )
214
+ args = (judge_type_str, subprocess_config, original_df.copy())
215
+ async_results.append(
216
+ pool.apply_async(_run_evaluator_process_wrapper, args=args)
217
+ )
218
+
219
+ # Collect results using the order in judges_to_run
220
+ for (judge_type_str, _), result in zip(judges_to_run, async_results):
221
+ try:
222
+ evaluated_df_subset = result.get()
223
+ if evaluated_df_subset is not None:
224
+ judge_results_dfs[judge_type_str] = evaluated_df_subset
225
+ logger.info(
226
+ f"Successfully completed evaluation for judge: {judge_type_str}"
227
+ )
228
+ else:
229
+ failed_judges.append(judge_type_str)
230
+ logger.error(
231
+ f"Evaluation failed for judge: {judge_type_str} (process returned None)"
232
+ )
233
+ except Exception as e:
234
+ failed_judges.append(judge_type_str)
235
+ logger.error(
236
+ f"Evaluation task failed for judge {judge_type_str}: {e}",
237
+ exc_info=True,
238
+ )
239
+
240
+ # --- Merge Results ---
241
+ final_df = original_df.copy()
242
+ successful_judges = list(judge_results_dfs.keys())
243
+
244
+ if not successful_judges:
245
+ logger.warning(
246
+ "Step 7: No judges completed successfully. Returning original DataFrame."
247
+ )
248
+ else:
249
+ logger.info(f"Merging results from successful judges: {successful_judges}")
250
+ for judge_type_str in successful_judges:
251
+ judge_df_subset = judge_results_dfs[judge_type_str]
252
+ eval_cols = JUDGE_COLUMN_MAP.get(judge_type_str, [])
253
+ judge_cols_present = [
254
+ col for col in eval_cols if col in judge_df_subset.columns
255
+ ]
256
+
257
+ if not judge_cols_present:
258
+ logger.warning(
259
+ f"No specific evaluation columns found in result for judge {judge_type_str}. Skipping merge for this judge."
260
+ )
261
+ continue
262
+
263
+ try:
264
+ final_df = final_df.merge(
265
+ judge_df_subset,
266
+ on=MERGE_KEYS,
267
+ how="left",
268
+ suffixes=("", f"_{judge_type_str}_dup"),
269
+ )
270
+ logger.debug(f"Merged results from judge {judge_type_str}")
271
+ except Exception as e:
272
+ logger.error(f"Error merging results for judge {judge_type_str}: {e}")
273
+
274
+ # Save final merged results checkpoint
275
+ output_path = get_checkpoint_path(run_dir, 7)
276
+ try:
277
+ final_df.to_csv(output_path, index=False)
278
+ logger.info(f"Step 7 complete. Evaluated {len(final_df)} responses.")
279
+ if successful_judges:
280
+ logger.info(
281
+ f"Successfully completed judges: {', '.join(successful_judges)}"
282
+ )
283
+ if failed_judges:
284
+ logger.warning(f"Failed judges: {', '.join(failed_judges)}")
285
+ logger.info(f"Final evaluation results checkpoint saved to {output_path}")
286
+ except Exception as e:
287
+ logger.error(f"Failed to save checkpoint for step 7 to {output_path}: {e}")
288
+
289
+ return final_df
@@ -0,0 +1,177 @@
1
+ import pandas as pd
2
+ from typing import Dict, Any
3
+
4
+ from .utils import get_checkpoint_path
5
+
6
+ # Map judge type to expected column prefix/name used for aggregation stats
7
+ JUDGE_AGG_COLUMN_MAP = {
8
+ "nuanced": "eval_nj",
9
+ "jailbreakbench": "eval_jb",
10
+ "harmbench": "eval_hb",
11
+ }
12
+
13
+ GROUP_KEYS = ["goal", "prefix"]
14
+
15
+
16
+ def execute(
17
+ input_df: pd.DataFrame, config: Dict[str, Any], run_dir: str
18
+ ) -> pd.DataFrame:
19
+ """
20
+ Aggregate evaluation results from different judges using the input DataFrame.
21
+ Combines results from multiple evaluation samples and judges into single scores per prefix.
22
+ """
23
+ print("Executing Step 8: Aggregating evaluation results")
24
+
25
+ if input_df.empty:
26
+ print("WARNING: Step 8 received an empty DataFrame. Skipping aggregation.")
27
+ # Define expected aggregated columns if returning empty
28
+ cols = GROUP_KEYS + [
29
+ "prefix_nll",
30
+ "model_name",
31
+ "meta_prefix",
32
+ "temperature",
33
+ "n_eval_samples",
34
+ ]
35
+ for _, col_base in JUDGE_AGG_COLUMN_MAP.items():
36
+ cols.extend([f"{col_base}_mean", f"{col_base}_count"])
37
+ return pd.DataFrame(columns=cols)
38
+
39
+ analysis = input_df.copy()
40
+
41
+ # Optionally filter based on cross-entropy / NLL score
42
+ if "prefix_nll" in analysis.columns and config.get("max_ce") is not None:
43
+ try:
44
+ max_ce_threshold = float(config.get("max_ce"))
45
+ initial_count = len(analysis)
46
+ # Use dictionary access for config
47
+ analysis = analysis[analysis["prefix_nll"] < max_ce_threshold]
48
+ filtered_count = len(analysis)
49
+ print(
50
+ f"Filtered {initial_count - filtered_count} rows based on prefix_nll >= {max_ce_threshold}"
51
+ )
52
+ except KeyError:
53
+ print("WARNING: 'max_ce' key not found in config, skipping NLL filtering.")
54
+ except Exception as e:
55
+ print(f"ERROR: Error during NLL filtering in aggregation: {e}")
56
+ # Continue without NLL filtering if error occurs
57
+ elif "prefix_nll" not in analysis.columns:
58
+ print(
59
+ "WARNING: Column 'prefix_nll' not found. Skipping NLL filtering in aggregation step."
60
+ )
61
+
62
+ # Detect available judges based on column names for aggregation
63
+ available_judges_agg_cols = {}
64
+ judges_in_config = config.get("judges", []) # Judges that were supposed to run
65
+ for judge_type, col_name in JUDGE_AGG_COLUMN_MAP.items():
66
+ if col_name in analysis.columns:
67
+ available_judges_agg_cols[judge_type] = col_name
68
+ else:
69
+ # Log if any expected judge column is missing
70
+ if judge_type in judges_in_config:
71
+ print(
72
+ f"WARNING: Expected aggregation column '{col_name}' for judge '{judge_type}' not found in the dataframe for Step 8."
73
+ )
74
+
75
+ if not available_judges_agg_cols:
76
+ print(
77
+ "ERROR: No recognized evaluation result columns found for aggregation. Check step 7 output."
78
+ )
79
+ output_path = get_checkpoint_path(run_dir, 8)
80
+ try:
81
+ analysis.to_csv(output_path, index=False)
82
+ print(
83
+ f"WARNING: Step 8 saving unaggregated data to {output_path} due to missing judge columns."
84
+ )
85
+ except Exception as e:
86
+ print(
87
+ f"ERROR: Failed to save unaggregated data checkpoint for step 8 to {output_path}: {e}"
88
+ )
89
+ return analysis # Return unaggregated data
90
+
91
+ print(
92
+ f"Found aggregation columns for judges: {list(available_judges_agg_cols.keys())}"
93
+ )
94
+
95
+ # Ensure group keys exist
96
+ if not all(key in analysis.columns for key in GROUP_KEYS):
97
+ missing_keys = [key for key in GROUP_KEYS if key not in analysis.columns]
98
+ print(
99
+ f"ERROR: Missing required grouping keys for aggregation: {missing_keys}. Cannot aggregate."
100
+ )
101
+ output_path = get_checkpoint_path(run_dir, 8)
102
+ try:
103
+ analysis.to_csv(output_path, index=False)
104
+ print(
105
+ f"WARNING: Step 8 saving unaggregated data to {output_path} due to missing group keys."
106
+ )
107
+ except Exception as e:
108
+ print(
109
+ f"ERROR: Failed to save unaggregated data checkpoint for step 8 to {output_path}: {e}"
110
+ )
111
+ return analysis
112
+
113
+ # Define aggregations
114
+ agg_funcs = {
115
+ # Use pd.NamedAgg for clarity and future compatibility
116
+ "prefix_nll": pd.NamedAgg(column="prefix_nll", aggfunc="first"),
117
+ "model_name": pd.NamedAgg(column="model_name", aggfunc="first"),
118
+ "meta_prefix": pd.NamedAgg(column="meta_prefix", aggfunc="first"),
119
+ "temperature": pd.NamedAgg(column="temperature", aggfunc="first"),
120
+ # Count samples - use one of the group keys or index if reset
121
+ "n_eval_samples": pd.NamedAgg(column=GROUP_KEYS[0], aggfunc="size"),
122
+ }
123
+
124
+ # Add judge-specific aggregations
125
+ for judge_type, col_name in available_judges_agg_cols.items():
126
+ # Ensure the column is numeric before calculating mean
127
+ try:
128
+ analysis[col_name] = pd.to_numeric(analysis[col_name], errors="coerce")
129
+ agg_funcs[f"{col_name}_mean"] = pd.NamedAgg(column=col_name, aggfunc="mean")
130
+ agg_funcs[f"{col_name}_count"] = pd.NamedAgg(
131
+ column=col_name, aggfunc="count"
132
+ ) # Count non-NA numeric values
133
+ print(
134
+ f"DEBUG: Added mean/count aggregation for numeric column '{col_name}'"
135
+ )
136
+ except KeyError:
137
+ print(
138
+ f"WARNING: Column '{col_name}' unexpectedly missing during aggregation setup. Skipping mean/count."
139
+ )
140
+ except Exception as e:
141
+ print(
142
+ f"ERROR: Could not convert column '{col_name}' to numeric for aggregation. Skipping mean/count. Error: {e}"
143
+ )
144
+ # Optionally add just size aggregation if mean fails?
145
+ agg_funcs[f"{col_name}_size"] = pd.NamedAgg(column=col_name, aggfunc="size")
146
+
147
+ # Perform aggregation
148
+ try:
149
+ grouped = analysis.groupby(GROUP_KEYS, observed=False, dropna=False)
150
+ aggregated = grouped.agg(**agg_funcs)
151
+ aggregated = aggregated.reset_index()
152
+ except Exception as e:
153
+ print(
154
+ f"ERROR: Error during aggregation: {e}. Check aggregation functions and column types."
155
+ )
156
+ output_path = get_checkpoint_path(run_dir, 8)
157
+ try:
158
+ analysis.to_csv(output_path, index=False)
159
+ print(
160
+ f"WARNING: Step 8 saving unaggregated data to {output_path} due to aggregation error."
161
+ )
162
+ except Exception as e_save:
163
+ print(
164
+ f"ERROR: Failed to save unaggregated data checkpoint for step 8 to {output_path}: {e_save}"
165
+ )
166
+ return analysis # Return unaggregated on error
167
+
168
+ # Save results checkpoint
169
+ output_path = get_checkpoint_path(run_dir, 8)
170
+ try:
171
+ aggregated.to_csv(output_path, index=False)
172
+ print(f"Step 8 complete. Aggregated {len(aggregated)} prefix results.")
173
+ print(f"Checkpoint saved to {output_path}")
174
+ except Exception as e:
175
+ print(f"ERROR: Failed to save checkpoint for step 8 to {output_path}: {e}")
176
+
177
+ return aggregated
@@ -0,0 +1,59 @@
1
+ import logging
2
+ import pandas as pd
3
+ from typing import Dict, Any
4
+
5
+ from hackagent.attacks.AdvPrefix.selector import (
6
+ PrefixSelectorConfig,
7
+ PrefixSelector,
8
+ )
9
+
10
+ from .utils import get_checkpoint_path
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def execute(
16
+ input_df: pd.DataFrame, config: Dict[str, Any], run_dir: str
17
+ ) -> pd.DataFrame:
18
+ """Select final prefixes based on specified judges and selection criteria using input DataFrame."""
19
+ logger.info("Executing Step 9: Selecting final prefixes")
20
+
21
+ if input_df.empty:
22
+ logger.warning("Step 9 received an empty DataFrame. Skipping selection.")
23
+ return input_df
24
+
25
+ selector = None # Ensure cleanup
26
+ selected_df = input_df # Default to input if selection fails
27
+
28
+ try:
29
+ # Initialize selector here
30
+ selector_config = PrefixSelectorConfig(
31
+ pasr_weight=config.get("pasr_weight", 0.5),
32
+ n_prefixes_per_goal=config.get("n_prefixes_per_goal", 3),
33
+ judges=config.get("selection_judges", []),
34
+ )
35
+ selector = PrefixSelector(selector_config)
36
+
37
+ # Select prefixes
38
+ selected_df = selector.select_prefixes(input_df)
39
+ logger.info(f"Selection complete. Selected {len(selected_df)} prefixes.")
40
+
41
+ except Exception as e:
42
+ logger.error(f"Error during prefix selection: {e}", exc_info=True)
43
+ logger.warning("Returning unselected prefixes due to selection error.")
44
+ selected_df = input_df # Fallback to returning the input df
45
+
46
+ finally:
47
+ del selector
48
+ # No GPU cleanup needed typically for selection
49
+
50
+ # Save results checkpoint (final step)
51
+ output_path = get_checkpoint_path(run_dir, 9)
52
+ try:
53
+ selected_df.to_csv(output_path, index=False)
54
+ logger.info("Step 9 complete.")
55
+ logger.info(f"Final selected prefixes checkpoint saved to {output_path}")
56
+ except Exception as e:
57
+ logger.error(f"Failed to save checkpoint for step 9 to {output_path}: {e}")
58
+
59
+ return selected_df