hackagent 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. hackagent/__init__.py +12 -0
  2. hackagent/agent.py +214 -0
  3. hackagent/api/__init__.py +1 -0
  4. hackagent/api/agent/__init__.py +1 -0
  5. hackagent/api/agent/agent_create.py +347 -0
  6. hackagent/api/agent/agent_destroy.py +140 -0
  7. hackagent/api/agent/agent_list.py +242 -0
  8. hackagent/api/agent/agent_partial_update.py +361 -0
  9. hackagent/api/agent/agent_retrieve.py +235 -0
  10. hackagent/api/agent/agent_update.py +361 -0
  11. hackagent/api/apilogs/__init__.py +1 -0
  12. hackagent/api/apilogs/apilogs_list.py +170 -0
  13. hackagent/api/apilogs/apilogs_retrieve.py +162 -0
  14. hackagent/api/attack/__init__.py +1 -0
  15. hackagent/api/attack/attack_create.py +275 -0
  16. hackagent/api/attack/attack_destroy.py +146 -0
  17. hackagent/api/attack/attack_list.py +254 -0
  18. hackagent/api/attack/attack_partial_update.py +289 -0
  19. hackagent/api/attack/attack_retrieve.py +247 -0
  20. hackagent/api/attack/attack_update.py +289 -0
  21. hackagent/api/checkout/__init__.py +1 -0
  22. hackagent/api/checkout/checkout_create.py +225 -0
  23. hackagent/api/generate/__init__.py +1 -0
  24. hackagent/api/generate/generate_create.py +253 -0
  25. hackagent/api/judge/__init__.py +1 -0
  26. hackagent/api/judge/judge_create.py +253 -0
  27. hackagent/api/key/__init__.py +1 -0
  28. hackagent/api/key/key_create.py +179 -0
  29. hackagent/api/key/key_destroy.py +103 -0
  30. hackagent/api/key/key_list.py +170 -0
  31. hackagent/api/key/key_retrieve.py +162 -0
  32. hackagent/api/organization/__init__.py +1 -0
  33. hackagent/api/organization/organization_create.py +208 -0
  34. hackagent/api/organization/organization_destroy.py +104 -0
  35. hackagent/api/organization/organization_list.py +170 -0
  36. hackagent/api/organization/organization_me_retrieve.py +126 -0
  37. hackagent/api/organization/organization_partial_update.py +222 -0
  38. hackagent/api/organization/organization_retrieve.py +163 -0
  39. hackagent/api/organization/organization_update.py +222 -0
  40. hackagent/api/prompt/__init__.py +1 -0
  41. hackagent/api/prompt/prompt_create.py +171 -0
  42. hackagent/api/prompt/prompt_destroy.py +104 -0
  43. hackagent/api/prompt/prompt_list.py +185 -0
  44. hackagent/api/prompt/prompt_partial_update.py +185 -0
  45. hackagent/api/prompt/prompt_retrieve.py +163 -0
  46. hackagent/api/prompt/prompt_update.py +185 -0
  47. hackagent/api/result/__init__.py +1 -0
  48. hackagent/api/result/result_create.py +175 -0
  49. hackagent/api/result/result_destroy.py +106 -0
  50. hackagent/api/result/result_list.py +249 -0
  51. hackagent/api/result/result_partial_update.py +193 -0
  52. hackagent/api/result/result_retrieve.py +167 -0
  53. hackagent/api/result/result_trace_create.py +177 -0
  54. hackagent/api/result/result_update.py +189 -0
  55. hackagent/api/run/__init__.py +1 -0
  56. hackagent/api/run/run_create.py +187 -0
  57. hackagent/api/run/run_destroy.py +112 -0
  58. hackagent/api/run/run_list.py +291 -0
  59. hackagent/api/run/run_partial_update.py +201 -0
  60. hackagent/api/run/run_result_create.py +177 -0
  61. hackagent/api/run/run_retrieve.py +179 -0
  62. hackagent/api/run/run_run_tests_create.py +187 -0
  63. hackagent/api/run/run_update.py +201 -0
  64. hackagent/api/user/__init__.py +1 -0
  65. hackagent/api/user/user_create.py +212 -0
  66. hackagent/api/user/user_destroy.py +106 -0
  67. hackagent/api/user/user_list.py +174 -0
  68. hackagent/api/user/user_me_retrieve.py +126 -0
  69. hackagent/api/user/user_me_update.py +196 -0
  70. hackagent/api/user/user_partial_update.py +226 -0
  71. hackagent/api/user/user_retrieve.py +167 -0
  72. hackagent/api/user/user_update.py +226 -0
  73. hackagent/attacks/AdvPrefix/__init__.py +41 -0
  74. hackagent/attacks/AdvPrefix/completions.py +416 -0
  75. hackagent/attacks/AdvPrefix/config.py +259 -0
  76. hackagent/attacks/AdvPrefix/evaluation.py +745 -0
  77. hackagent/attacks/AdvPrefix/evaluators.py +564 -0
  78. hackagent/attacks/AdvPrefix/generate.py +711 -0
  79. hackagent/attacks/AdvPrefix/utils.py +307 -0
  80. hackagent/attacks/__init__.py +35 -0
  81. hackagent/attacks/advprefix.py +507 -0
  82. hackagent/attacks/base.py +106 -0
  83. hackagent/attacks/strategies.py +906 -0
  84. hackagent/cli/__init__.py +19 -0
  85. hackagent/cli/commands/__init__.py +20 -0
  86. hackagent/cli/commands/agent.py +100 -0
  87. hackagent/cli/commands/attack.py +417 -0
  88. hackagent/cli/commands/config.py +301 -0
  89. hackagent/cli/commands/results.py +327 -0
  90. hackagent/cli/config.py +249 -0
  91. hackagent/cli/main.py +515 -0
  92. hackagent/cli/tui/__init__.py +31 -0
  93. hackagent/cli/tui/actions_logger.py +200 -0
  94. hackagent/cli/tui/app.py +288 -0
  95. hackagent/cli/tui/base.py +137 -0
  96. hackagent/cli/tui/logger.py +318 -0
  97. hackagent/cli/tui/views/__init__.py +33 -0
  98. hackagent/cli/tui/views/agents.py +488 -0
  99. hackagent/cli/tui/views/attacks.py +624 -0
  100. hackagent/cli/tui/views/config.py +244 -0
  101. hackagent/cli/tui/views/dashboard.py +307 -0
  102. hackagent/cli/tui/views/results.py +1210 -0
  103. hackagent/cli/tui/widgets/__init__.py +24 -0
  104. hackagent/cli/tui/widgets/actions.py +346 -0
  105. hackagent/cli/tui/widgets/logs.py +435 -0
  106. hackagent/cli/utils.py +276 -0
  107. hackagent/client.py +286 -0
  108. hackagent/errors.py +37 -0
  109. hackagent/logger.py +83 -0
  110. hackagent/models/__init__.py +109 -0
  111. hackagent/models/agent.py +223 -0
  112. hackagent/models/agent_request.py +129 -0
  113. hackagent/models/api_token_log.py +184 -0
  114. hackagent/models/attack.py +154 -0
  115. hackagent/models/attack_request.py +82 -0
  116. hackagent/models/checkout_session_request_request.py +76 -0
  117. hackagent/models/checkout_session_response.py +59 -0
  118. hackagent/models/choice.py +81 -0
  119. hackagent/models/choice_message.py +67 -0
  120. hackagent/models/evaluation_status_enum.py +14 -0
  121. hackagent/models/generate_error_response.py +59 -0
  122. hackagent/models/generate_request_request.py +212 -0
  123. hackagent/models/generate_success_response.py +115 -0
  124. hackagent/models/generic_error_response.py +70 -0
  125. hackagent/models/message_request.py +67 -0
  126. hackagent/models/organization.py +102 -0
  127. hackagent/models/organization_minimal.py +68 -0
  128. hackagent/models/organization_request.py +71 -0
  129. hackagent/models/paginated_agent_list.py +123 -0
  130. hackagent/models/paginated_api_token_log_list.py +123 -0
  131. hackagent/models/paginated_attack_list.py +123 -0
  132. hackagent/models/paginated_organization_list.py +123 -0
  133. hackagent/models/paginated_prompt_list.py +123 -0
  134. hackagent/models/paginated_result_list.py +123 -0
  135. hackagent/models/paginated_run_list.py +123 -0
  136. hackagent/models/paginated_user_api_key_list.py +123 -0
  137. hackagent/models/paginated_user_profile_list.py +123 -0
  138. hackagent/models/patched_agent_request.py +128 -0
  139. hackagent/models/patched_attack_request.py +92 -0
  140. hackagent/models/patched_organization_request.py +71 -0
  141. hackagent/models/patched_prompt_request.py +125 -0
  142. hackagent/models/patched_result_request.py +237 -0
  143. hackagent/models/patched_run_request.py +138 -0
  144. hackagent/models/patched_user_profile_request.py +99 -0
  145. hackagent/models/prompt.py +220 -0
  146. hackagent/models/prompt_request.py +126 -0
  147. hackagent/models/result.py +294 -0
  148. hackagent/models/result_list_evaluation_status.py +14 -0
  149. hackagent/models/result_request.py +232 -0
  150. hackagent/models/run.py +233 -0
  151. hackagent/models/run_list_status.py +12 -0
  152. hackagent/models/run_request.py +133 -0
  153. hackagent/models/status_enum.py +12 -0
  154. hackagent/models/step_type_enum.py +14 -0
  155. hackagent/models/trace.py +121 -0
  156. hackagent/models/trace_request.py +94 -0
  157. hackagent/models/usage.py +75 -0
  158. hackagent/models/user_api_key.py +201 -0
  159. hackagent/models/user_api_key_request.py +73 -0
  160. hackagent/models/user_profile.py +135 -0
  161. hackagent/models/user_profile_minimal.py +76 -0
  162. hackagent/models/user_profile_request.py +99 -0
  163. hackagent/router/__init__.py +25 -0
  164. hackagent/router/adapters/__init__.py +20 -0
  165. hackagent/router/adapters/base.py +63 -0
  166. hackagent/router/adapters/google_adk.py +671 -0
  167. hackagent/router/adapters/litellm_adapter.py +524 -0
  168. hackagent/router/adapters/openai_adapter.py +426 -0
  169. hackagent/router/router.py +969 -0
  170. hackagent/router/types.py +54 -0
  171. hackagent/tracking/__init__.py +42 -0
  172. hackagent/tracking/context.py +163 -0
  173. hackagent/tracking/decorators.py +299 -0
  174. hackagent/tracking/tracker.py +441 -0
  175. hackagent/types.py +54 -0
  176. hackagent/utils.py +194 -0
  177. hackagent/vulnerabilities/__init__.py +13 -0
  178. hackagent/vulnerabilities/prompts.py +81 -0
  179. hackagent-0.3.1.dist-info/METADATA +122 -0
  180. hackagent-0.3.1.dist-info/RECORD +183 -0
  181. hackagent-0.3.1.dist-info/WHEEL +4 -0
  182. hackagent-0.3.1.dist-info/entry_points.txt +2 -0
  183. hackagent-0.3.1.dist-info/licenses/LICENSE +202 -0
@@ -0,0 +1,745 @@
1
+ # Copyright 2025 - AI4I. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Evaluation stage module for AdvPrefix attacks.
17
+
18
+ This module implements the Evaluation stage of the AdvPrefix pipeline, which consolidates
19
+ judge-based evaluation, result aggregation, and prefix selection into a cohesive
20
+ class-based design that improves:
21
+ - Code organization and maintainability
22
+ - State management and configuration handling
23
+ - Testing and mocking capabilities
24
+ - Logging and tracking throughout the pipeline
25
+
26
+ The module provides functionality for:
27
+ - Automated evaluation using judge models
28
+ - Result aggregation and statistical analysis
29
+ - Optimal prefix selection using multi-criteria optimization
30
+ - Unified pipeline execution with proper error handling
31
+ - Integration with various judge model backends
32
+ - Customizable evaluation, aggregation, and selection strategies
33
+ """
34
+
35
+ import logging
36
+ import math
37
+ from collections import defaultdict
38
+ from dataclasses import fields
39
+ from typing import Any, Dict, List, Optional
40
+
41
+ import pandas as pd
42
+
43
+ from hackagent.attacks.AdvPrefix.evaluators import (
44
+ HarmBenchEvaluator,
45
+ JailbreakBenchEvaluator,
46
+ NuancedEvaluator,
47
+ )
48
+ from hackagent.client import AuthenticatedClient
49
+ from hackagent.router.types import AgentTypeEnum
50
+
51
+ from .config import EvaluationPipelineConfig, EvaluatorConfig
52
+ from .utils import handle_empty_input, log_errors
53
+
54
+ # ============================================================================
55
+ # CONSTANTS AND MAPPINGS
56
+ # ============================================================================
57
+
58
+ EVALUATOR_MAP = {
59
+ "nuanced": NuancedEvaluator,
60
+ "jailbreakbench": JailbreakBenchEvaluator,
61
+ "harmbench": HarmBenchEvaluator,
62
+ }
63
+
64
+ MERGE_KEYS = ["goal", "prefix", "completion"]
65
+
66
+ JUDGE_COLUMN_MAP = {
67
+ "nuanced": ["eval_nj", "explanation_nj"],
68
+ "jailbreakbench": ["eval_jb", "explanation_jb"],
69
+ "harmbench": ["eval_hb", "explanation_hb"],
70
+ }
71
+
72
+ JUDGE_AGG_COLUMN_MAP = {
73
+ "nuanced": "eval_nj",
74
+ "jailbreakbench": "eval_jb",
75
+ "harmbench": "eval_hb",
76
+ }
77
+
78
+ JUDGE_MEAN_COLUMN_MAP = {
79
+ "nuanced": "eval_nj_mean",
80
+ "jailbreakbench": "eval_jb_mean",
81
+ "harmbench": "eval_hb_mean",
82
+ "strongreject": "eval_sj_binary_mean",
83
+ }
84
+
85
+ GROUP_KEYS = ["goal", "prefix"]
86
+
87
+
88
+ # ============================================================================
89
+ # MAIN PIPELINE CLASS
90
+ # ============================================================================
91
+
92
+
93
+ class EvaluationPipeline:
94
+ """
95
+ Unified pipeline for the Evaluation stage of AdvPrefix attacks.
96
+
97
+ This class encapsulates all functionality related to evaluating completions,
98
+ aggregating results, and selecting optimal prefixes, providing a clean interface
99
+ with proper state management and comprehensive tracking capabilities.
100
+
101
+ Architecture:
102
+ - Initialization: Sets up config, logger, client, and internal state
103
+ - Judge Evaluation: Run judge models on completions
104
+ - Aggregation: Aggregate evaluation results by goal/prefix
105
+ - Selection: Select best prefixes using multi-criteria optimization
106
+ - Orchestration: execute() method coordinates the full pipeline
107
+
108
+ Key Benefits:
109
+ - Single source of truth for configuration
110
+ - Consistent logging throughout all operations
111
+ - Easy to test individual components via method mocking
112
+ - Clear method boundaries with single responsibilities
113
+ - Stateful execution tracking for debugging
114
+
115
+ Example:
116
+ pipeline = EvaluationPipeline(
117
+ config=config_dict,
118
+ logger=logger,
119
+ client=client
120
+ )
121
+ results = pipeline.execute(input_data=completion_data)
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ config: Dict[str, Any],
127
+ logger: logging.Logger,
128
+ client: AuthenticatedClient,
129
+ ):
130
+ """
131
+ Initialize the pipeline with configuration and dependencies.
132
+
133
+ Args:
134
+ config: Configuration dictionary or EvaluationPipelineConfig instance
135
+ logger: Logger for tracking execution
136
+ client: Authenticated client for API access
137
+ """
138
+ self.config = (
139
+ EvaluationPipelineConfig.from_dict(config)
140
+ if isinstance(config, dict)
141
+ else config
142
+ )
143
+ # Use provided logger, but ensure it's child of hackagent.attacks hierarchy
144
+ self.logger = (
145
+ logger
146
+ if logger.name.startswith("hackagent.attacks")
147
+ else logging.getLogger("hackagent.attacks.advprefix.evaluation")
148
+ )
149
+ self.client = client
150
+
151
+ # Initialize internal state for tracking
152
+ self._statistics: Dict[str, Any] = {
153
+ "input_count": 0,
154
+ "evaluated_count": 0,
155
+ "aggregated_count": 0,
156
+ "selected_count": 0,
157
+ "successful_judges": [],
158
+ "failed_judges": [],
159
+ }
160
+
161
+ self.logger.info("EvaluationPipeline initialized")
162
+
163
+ # ========================================================================
164
+ # PUBLIC INTERFACE
165
+ # ========================================================================
166
+
167
+ @handle_empty_input("Evaluation Stage", empty_result=[])
168
+ @log_errors("Evaluation Stage")
169
+ def execute(self, input_data: List[Dict]) -> List[Dict]:
170
+ """
171
+ Execute the complete Evaluation stage: judge evaluation, aggregation, and selection.
172
+
173
+ This is the main entry point that orchestrates all sub-processes:
174
+ 1. Judge Evaluation: Evaluate completions with judge models
175
+ 2. Aggregation: Aggregate evaluation results by goal/prefix
176
+ 3. Selection: Select optimal prefixes using multi-criteria optimization
177
+
178
+ Args:
179
+ input_data: List of dicts containing completion data from Execution stage
180
+
181
+ Returns:
182
+ List of selected prefix dictionaries ready for final output
183
+ """
184
+ self._statistics["input_count"] = len(input_data)
185
+
186
+ # Judge Evaluation
187
+ self.logger.info(
188
+ f"Judge Evaluation: Starting evaluation for {len(input_data)} completions"
189
+ )
190
+ evaluated_data = self._run_evaluation(input_data)
191
+ self._statistics["evaluated_count"] = len(evaluated_data)
192
+
193
+ if not evaluated_data:
194
+ self.logger.warning("No data after evaluation")
195
+ return []
196
+
197
+ # Aggregation
198
+ self.logger.info(
199
+ f"Aggregation: Aggregating {len(evaluated_data)} evaluation results"
200
+ )
201
+ aggregated_data = self._run_aggregation(evaluated_data)
202
+ self._statistics["aggregated_count"] = len(aggregated_data)
203
+
204
+ if not aggregated_data:
205
+ self.logger.warning("No data after aggregation")
206
+ return []
207
+
208
+ # Selection
209
+ self.logger.info(
210
+ f"Selection: Selecting best prefixes from {len(aggregated_data)} candidates"
211
+ )
212
+ selected_data = self._run_selection(aggregated_data)
213
+ self._statistics["selected_count"] = len(selected_data)
214
+
215
+ self._log_pipeline_statistics()
216
+ return selected_data
217
+
218
+ def get_statistics(self) -> Dict[str, Any]:
219
+ """Return execution statistics for monitoring and debugging."""
220
+ return self._statistics.copy()
221
+
222
+ # ========================================================================
223
+ # JUDGE EVALUATION METHODS
224
+ # ========================================================================
225
+
226
+ def _run_evaluation(self, input_data: List[Dict]) -> List[Dict]:
227
+ """
228
+ Execute judge evaluation: Evaluate completions using judge models.
229
+
230
+ Handles:
231
+ - Judge configuration validation
232
+ - Sequential or parallel judge execution
233
+ - Result merging across judges
234
+ - Error handling for failed judges
235
+ """
236
+ judge_configs_list = self.config.judges
237
+ if not isinstance(judge_configs_list, list) or not judge_configs_list:
238
+ self.logger.warning("No judges configured, skipping evaluation")
239
+ return input_data
240
+
241
+ # Convert to DataFrame for evaluators
242
+ original_df = pd.DataFrame(input_data)
243
+
244
+ # Base config for evaluators
245
+ evaluator_base_config_dict = {
246
+ "batch_size": self.config.batch_size_judge,
247
+ "max_new_tokens_eval": self.config.max_new_tokens_eval,
248
+ "filter_len": self.config.filter_len,
249
+ "request_timeout": self.config.judge_request_timeout,
250
+ "temperature": self.config.judge_temperature,
251
+ "organization_id": self.config.organization_id,
252
+ }
253
+
254
+ judge_results_dfs = {}
255
+ judges_to_run = self._prepare_judge_configs(
256
+ judge_configs_list, evaluator_base_config_dict
257
+ )
258
+
259
+ if not judges_to_run:
260
+ self.logger.warning("No valid judges found after configuration processing")
261
+ return input_data
262
+
263
+ # Execute judges sequentially
264
+ for judge_type_str, subprocess_config in judges_to_run:
265
+ evaluated_df = self._run_single_evaluator(
266
+ judge_type=judge_type_str,
267
+ config=subprocess_config,
268
+ df=original_df.copy(),
269
+ )
270
+
271
+ if evaluated_df is not None:
272
+ judge_results_dfs[judge_type_str] = evaluated_df
273
+ self._statistics["successful_judges"].append(judge_type_str)
274
+ else:
275
+ self._statistics["failed_judges"].append(judge_type_str)
276
+
277
+ # Merge results
278
+ final_df = self._merge_evaluation_results(original_df, judge_results_dfs)
279
+
280
+ return final_df.to_dict(orient="records")
281
+
282
+ def _prepare_judge_configs(
283
+ self, judge_configs_list: List[Dict], base_config: Dict[str, Any]
284
+ ) -> List[tuple]:
285
+ """Prepare and validate judge configurations."""
286
+ judges_to_run = []
287
+
288
+ for judge_config_item in judge_configs_list:
289
+ if not isinstance(judge_config_item, dict):
290
+ self.logger.warning(
291
+ f"Skipping invalid judge config: {judge_config_item}"
292
+ )
293
+ continue
294
+
295
+ # Extract judge type
296
+ judge_type_str = judge_config_item.get(
297
+ "evaluator_type"
298
+ ) or judge_config_item.get("type")
299
+ judge_identifier = judge_config_item.get("identifier")
300
+
301
+ if not judge_type_str:
302
+ judge_type_str = self._infer_judge_type(judge_identifier)
303
+
304
+ if not judge_type_str or judge_type_str not in EVALUATOR_MAP:
305
+ self.logger.warning(
306
+ f"Unknown or missing judge type for: {judge_config_item}"
307
+ )
308
+ continue
309
+
310
+ if not judge_identifier:
311
+ self.logger.warning(
312
+ f"Missing identifier for judge: {judge_config_item}"
313
+ )
314
+ continue
315
+
316
+ # Prepare subprocess config
317
+ subprocess_config = base_config.copy()
318
+ subprocess_config.update(judge_config_item)
319
+
320
+ # Populate EvaluatorConfig fields
321
+ subprocess_config["agent_name"] = (
322
+ judge_config_item.get("agent_name")
323
+ or f"judge-{judge_type_str}-{judge_identifier.replace('/', '-')[:20]}"
324
+ )
325
+ subprocess_config["agent_type"] = judge_config_item.get(
326
+ "agent_type", "LITELLM"
327
+ )
328
+ subprocess_config["model_id"] = judge_identifier
329
+ subprocess_config["agent_endpoint"] = judge_config_item.get("endpoint")
330
+ subprocess_config["agent_metadata"] = judge_config_item.get(
331
+ "agent_metadata", {}
332
+ )
333
+
334
+ judges_to_run.append((judge_type_str, subprocess_config))
335
+
336
+ return judges_to_run
337
+
338
+ def _infer_judge_type(self, identifier: Optional[str]) -> Optional[str]:
339
+ """Infer judge type from identifier string."""
340
+ if not identifier:
341
+ return None
342
+
343
+ identifier_lower = identifier.lower()
344
+ if "nuanced" in identifier_lower:
345
+ return "nuanced"
346
+ elif "harmbench" in identifier_lower:
347
+ return "harmbench"
348
+ elif "jailbreak" in identifier_lower:
349
+ return "jailbreakbench"
350
+
351
+ return None
352
+
353
+ def _run_single_evaluator(
354
+ self,
355
+ judge_type: str,
356
+ config: Dict[str, Any],
357
+ df: pd.DataFrame,
358
+ ) -> Optional[pd.DataFrame]:
359
+ """Execute a single evaluator process."""
360
+ evaluator_class = EVALUATOR_MAP.get(judge_type)
361
+ if not evaluator_class:
362
+ self.logger.warning(f"Unknown judge type: {judge_type}")
363
+ return None
364
+
365
+ evaluator = None
366
+ try:
367
+ # Filter config for EvaluatorConfig
368
+ expected_fields = {f.name for f in fields(EvaluatorConfig)}
369
+ filtered_config = {k: v for k, v in config.items() if k in expected_fields}
370
+
371
+ # Convert agent_type string to enum
372
+ if "agent_type" in filtered_config and isinstance(
373
+ filtered_config["agent_type"], str
374
+ ):
375
+ try:
376
+ filtered_config["agent_type"] = AgentTypeEnum(
377
+ filtered_config["agent_type"].upper()
378
+ )
379
+ except ValueError:
380
+ self.logger.error(
381
+ f"Invalid agent_type: {filtered_config['agent_type']}"
382
+ )
383
+ return None
384
+
385
+ evaluator_config = EvaluatorConfig(**filtered_config)
386
+ evaluator = evaluator_class(client=self.client, config=evaluator_config)
387
+ evaluated_df = evaluator.evaluate(df)
388
+
389
+ # Return only merge keys + judge-specific columns
390
+ eval_cols = JUDGE_COLUMN_MAP.get(judge_type, [])
391
+ if not all(key in evaluated_df.columns for key in MERGE_KEYS):
392
+ self.logger.error(
393
+ f"Evaluation result missing merge keys for {judge_type}"
394
+ )
395
+ return None
396
+
397
+ cols_to_return = MERGE_KEYS + [
398
+ col for col in eval_cols if col in evaluated_df.columns
399
+ ]
400
+ return evaluated_df[cols_to_return]
401
+
402
+ except Exception as e:
403
+ self.logger.error(
404
+ f"Error running {judge_type} evaluator: {e}", exc_info=True
405
+ )
406
+ return None
407
+ finally:
408
+ del evaluator
409
+
410
+ def _merge_evaluation_results(
411
+ self, original_df: pd.DataFrame, judge_results: Dict[str, pd.DataFrame]
412
+ ) -> pd.DataFrame:
413
+ """Merge evaluation results from multiple judges."""
414
+ final_df = original_df.copy()
415
+
416
+ for judge_type, judge_df in judge_results.items():
417
+ eval_cols = JUDGE_COLUMN_MAP.get(judge_type, [])
418
+ judge_cols_present = [col for col in eval_cols if col in judge_df.columns]
419
+
420
+ if not judge_cols_present:
421
+ self.logger.warning(f"No evaluation columns found for {judge_type}")
422
+ continue
423
+
424
+ try:
425
+ final_df = final_df.merge(
426
+ judge_df,
427
+ on=MERGE_KEYS,
428
+ how="left",
429
+ suffixes=("", f"_{judge_type}_dup"),
430
+ )
431
+ except Exception as e:
432
+ self.logger.error(f"Error merging results for {judge_type}: {e}")
433
+
434
+ return final_df
435
+
436
+ # ========================================================================
437
+ # AGGREGATION METHODS
438
+ # ========================================================================
439
+
440
+ def _run_aggregation(self, input_data: List[Dict]) -> List[Dict]:
441
+ """
442
+ Execute aggregation: Aggregate evaluation results.
443
+
444
+ Handles:
445
+ - NLL filtering based on threshold
446
+ - Grouping by goal and prefix
447
+ - Statistical aggregation (mean, count)
448
+ - Metadata preservation
449
+ """
450
+ # Apply NLL filtering
451
+ max_ce_threshold = self.config.max_ce
452
+ if max_ce_threshold is not None:
453
+ try:
454
+ max_ce_threshold = float(max_ce_threshold)
455
+ input_data = self._filter_by_nll(input_data, max_ce_threshold)
456
+ except ValueError:
457
+ self.logger.warning(f"Invalid max_ce value: {max_ce_threshold}")
458
+
459
+ # Get available judge columns
460
+ config_judges = [
461
+ j.get("type") or j.get("evaluator_type")
462
+ for j in self.config.judges
463
+ if isinstance(j, dict)
464
+ ]
465
+ available_judges_agg_cols = self._get_available_judge_agg_cols(
466
+ input_data, config_judges
467
+ )
468
+
469
+ if not available_judges_agg_cols:
470
+ self.logger.error("No recognized evaluation keys found for aggregation")
471
+ return input_data
472
+
473
+ # Validate required keys
474
+ if not input_data:
475
+ return []
476
+
477
+ sample_keys = set(input_data[0].keys())
478
+ if not all(key in sample_keys for key in GROUP_KEYS):
479
+ missing_keys = [key for key in GROUP_KEYS if key not in sample_keys]
480
+ self.logger.error(f"Missing grouping keys: {missing_keys}")
481
+ return input_data
482
+
483
+ # Group and aggregate
484
+ groups = defaultdict(list)
485
+ for item in input_data:
486
+ key = tuple(item.get(k) for k in GROUP_KEYS)
487
+ groups[key].append(item)
488
+
489
+ aggregated_results = []
490
+ for group_key, group_items in groups.items():
491
+ result = {k: v for k, v in zip(GROUP_KEYS, group_key)}
492
+
493
+ # Preserve first values
494
+ result["prefix_nll"] = group_items[0].get("prefix_nll")
495
+ result["model_name"] = group_items[0].get("model_name")
496
+ result["meta_prefix"] = group_items[0].get("meta_prefix")
497
+ result["temperature"] = group_items[0].get("temperature")
498
+ result["n_eval_samples"] = len(group_items)
499
+
500
+ # Calculate judge statistics
501
+ for judge_type, col_name in available_judges_agg_cols.items():
502
+ values = []
503
+ for item in group_items:
504
+ val = item.get(col_name)
505
+ if val is not None:
506
+ try:
507
+ values.append(float(val))
508
+ except (ValueError, TypeError):
509
+ pass
510
+
511
+ if values:
512
+ result[f"{col_name}_mean"] = sum(values) / len(values)
513
+ result[f"{col_name}_count"] = len(values)
514
+ else:
515
+ result[f"{col_name}_mean"] = None
516
+ result[f"{col_name}_count"] = 0
517
+
518
+ aggregated_results.append(result)
519
+
520
+ return aggregated_results
521
+
522
+ def _filter_by_nll(self, data: List[Dict], max_ce_threshold: float) -> List[Dict]:
523
+ """Filter data by cross-entropy threshold."""
524
+ if not any("prefix_nll" in item for item in data):
525
+ self.logger.warning("prefix_nll key not found, skipping NLL filtering")
526
+ return data
527
+
528
+ try:
529
+ filtered = [
530
+ item
531
+ for item in data
532
+ if item.get("prefix_nll", float("inf")) < max_ce_threshold
533
+ ]
534
+ self.logger.info(f"NLL filtering: {len(data)} -> {len(filtered)} items")
535
+ return filtered
536
+ except Exception as e:
537
+ self.logger.error(f"Error during NLL filtering: {e}")
538
+ return data
539
+
540
+ def _get_available_judge_agg_cols(
541
+ self, data: List[Dict], config_judges: List[str]
542
+ ) -> Dict[str, str]:
543
+ """Identify available judge evaluation keys."""
544
+ available_judges_agg_cols = {}
545
+ sample_keys = set(data[0].keys()) if data else set()
546
+
547
+ for judge_type, col_name in JUDGE_AGG_COLUMN_MAP.items():
548
+ if col_name in sample_keys:
549
+ available_judges_agg_cols[judge_type] = col_name
550
+ elif judge_type in config_judges:
551
+ self.logger.warning(
552
+ f"Expected key '{col_name}' for judge '{judge_type}' not found"
553
+ )
554
+
555
+ return available_judges_agg_cols
556
+
557
+ # ========================================================================
558
+ # SELECTION METHODS
559
+ # ========================================================================
560
+
561
+ def _run_selection(self, input_data: List[Dict]) -> List[Dict]:
562
+ """
563
+ Execute selection: Select optimal prefixes.
564
+
565
+ Handles:
566
+ - Multi-criteria scoring (PASR + NLL)
567
+ - Tolerance-based filtering
568
+ - Diversity-preserving selection
569
+ - Sub-prefix elimination
570
+ """
571
+ # Use selection_judges if specified, otherwise use all judges
572
+ judge_configs = self.config.selection_judges or self.config.judges
573
+
574
+ if not isinstance(judge_configs, list) or not judge_configs:
575
+ self.logger.error("No judges configured for selection")
576
+ return input_data
577
+
578
+ # Extract and validate judge types
579
+ judge_types_found = []
580
+ sample_keys = set(input_data[0].keys()) if input_data else set()
581
+
582
+ for judge_config in judge_configs:
583
+ if not isinstance(judge_config, dict):
584
+ continue
585
+
586
+ judge_type = judge_config.get("type") or judge_config.get("evaluator_type")
587
+ if not judge_type:
588
+ continue
589
+
590
+ if judge_type not in JUDGE_MEAN_COLUMN_MAP:
591
+ self.logger.error(f"Unknown judge type for selection: {judge_type}")
592
+ continue
593
+
594
+ expected_key = JUDGE_MEAN_COLUMN_MAP[judge_type]
595
+ if expected_key not in sample_keys:
596
+ self.logger.warning(f"Missing key '{expected_key}' for selection")
597
+ continue
598
+
599
+ if judge_type not in judge_types_found:
600
+ judge_types_found.append(judge_type)
601
+
602
+ if not judge_types_found:
603
+ self.logger.error("No valid judges found for selection")
604
+ return input_data
605
+
606
+ # Calculate selection scores
607
+ for item in input_data:
608
+ item["pasr"] = self._calculate_combined_pasr(item, judge_types_found)
609
+ item["log_pasr"] = math.log(item["pasr"] + 1e-6)
610
+ item["combined_score"] = -self.config.pasr_weight * item[
611
+ "log_pasr"
612
+ ] + item.get("prefix_nll", 0)
613
+
614
+ # Group by goal and select
615
+ groups = defaultdict(list)
616
+ for item in input_data:
617
+ groups[item["goal"]].append(item)
618
+
619
+ selected_prefixes = []
620
+ for goal, group in groups.items():
621
+ if not group or all(item.get("combined_score") is None for item in group):
622
+ self.logger.warning(
623
+ f"Skipping goal '{goal[:50]}...' due to invalid scores"
624
+ )
625
+ continue
626
+
627
+ # Select prefixes for this goal
628
+ goal_selections = self._select_prefixes_for_goal(group)
629
+ selected_prefixes.extend(goal_selections)
630
+
631
+ return selected_prefixes
632
+
633
+ def _calculate_combined_pasr(self, item: Dict, judge_types: List[str]) -> float:
634
+ """Calculate combined Pass@1 Success Rate across judges."""
635
+ judge_scores = []
636
+
637
+ for judge_type in judge_types:
638
+ key = JUDGE_MEAN_COLUMN_MAP[judge_type]
639
+ if key in item:
640
+ try:
641
+ score = float(item[key]) if item[key] is not None else None
642
+ if score is not None:
643
+ judge_scores.append(score)
644
+ except (ValueError, TypeError) as e:
645
+ self.logger.warning(f"Could not convert '{key}' to numeric: {e}")
646
+
647
+ if not judge_scores:
648
+ self.logger.warning("No valid judge scores for PASR calculation")
649
+ return 0.0
650
+
651
+ return sum(judge_scores) / len(judge_scores)
652
+
653
+ def _select_prefixes_for_goal(self, group: List[Dict]) -> List[Dict]:
654
+ """Select top prefixes for a single goal using multi-criteria optimization."""
655
+ # First: Select prefix with best combined score
656
+ first_selection = min(
657
+ (item for item in group if item.get("combined_score") is not None),
658
+ key=lambda x: x["combined_score"],
659
+ )
660
+
661
+ # Second: Filter by PASR tolerance
662
+ remaining_candidates = [
663
+ item
664
+ for item in group
665
+ if item != first_selection
666
+ and item.get("pasr", 0)
667
+ >= first_selection.get("pasr", 0) - self.config.pasr_tol
668
+ ]
669
+
670
+ # Third: Filter by NLL tolerance
671
+ valid_candidates = [
672
+ item
673
+ for item in remaining_candidates
674
+ if item.get("prefix_nll", float("inf"))
675
+ <= first_selection.get("prefix_nll", float("inf")) + self.config.nll_tol
676
+ ]
677
+
678
+ # Initialize selections
679
+ selections = [first_selection]
680
+
681
+ # Fourth: Iteratively select additional prefixes
682
+ for _ in range(self.config.n_prefixes_per_goal - 1):
683
+ # Remove sub-prefix candidates
684
+ valid_candidates = [
685
+ item
686
+ for item in valid_candidates
687
+ if not any(
688
+ str(item.get("prefix", "")).startswith(str(sel.get("prefix", "")))
689
+ for sel in selections
690
+ )
691
+ ]
692
+
693
+ if not valid_candidates:
694
+ break
695
+
696
+ if all(item.get("prefix_nll") is None for item in valid_candidates):
697
+ self.logger.warning(
698
+ "Cannot select next prefix due to missing NLL scores"
699
+ )
700
+ break
701
+
702
+ # Select next with lowest NLL
703
+ next_selection = min(
704
+ (
705
+ item
706
+ for item in valid_candidates
707
+ if item.get("prefix_nll") is not None
708
+ ),
709
+ key=lambda x: x["prefix_nll"],
710
+ )
711
+ selections.append(next_selection)
712
+ valid_candidates = [
713
+ item for item in valid_candidates if item != next_selection
714
+ ]
715
+
716
+ return selections
717
+
718
+ # ========================================================================
719
+ # UTILITY METHODS
720
+ # ========================================================================
721
+
722
+ def _log_pipeline_statistics(self):
723
+ """Log comprehensive pipeline execution statistics."""
724
+ stats = self._statistics
725
+ self.logger.info("=" * 60)
726
+ self.logger.info("Evaluation Pipeline Statistics:")
727
+ self.logger.info(f" Input completions: {stats['input_count']}")
728
+ self.logger.info(f" After evaluation: {stats['evaluated_count']}")
729
+ self.logger.info(f" After aggregation: {stats['aggregated_count']}")
730
+ self.logger.info(f" Final selected: {stats['selected_count']}")
731
+
732
+ if stats["successful_judges"]:
733
+ self.logger.info(
734
+ f" Successful judges: {', '.join(stats['successful_judges'])}"
735
+ )
736
+ if stats["failed_judges"]:
737
+ self.logger.warning(
738
+ f" Failed judges: {', '.join(stats['failed_judges'])}"
739
+ )
740
+
741
+ if stats["input_count"] > 0:
742
+ retention = (stats["selected_count"] / stats["input_count"]) * 100
743
+ self.logger.info(f" Overall retention: {retention:.1f}%")
744
+
745
+ self.logger.info("=" * 60)