hackagent 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. hackagent/__init__.py +23 -0
  2. hackagent/agent.py +193 -0
  3. hackagent/api/__init__.py +1 -0
  4. hackagent/api/agent/__init__.py +1 -0
  5. hackagent/api/agent/agent_create.py +340 -0
  6. hackagent/api/agent/agent_destroy.py +136 -0
  7. hackagent/api/agent/agent_list.py +234 -0
  8. hackagent/api/agent/agent_partial_update.py +354 -0
  9. hackagent/api/agent/agent_retrieve.py +227 -0
  10. hackagent/api/agent/agent_update.py +354 -0
  11. hackagent/api/attack/__init__.py +1 -0
  12. hackagent/api/attack/attack_create.py +264 -0
  13. hackagent/api/attack/attack_destroy.py +140 -0
  14. hackagent/api/attack/attack_list.py +242 -0
  15. hackagent/api/attack/attack_partial_update.py +278 -0
  16. hackagent/api/attack/attack_retrieve.py +235 -0
  17. hackagent/api/attack/attack_update.py +278 -0
  18. hackagent/api/key/__init__.py +1 -0
  19. hackagent/api/key/key_create.py +168 -0
  20. hackagent/api/key/key_destroy.py +97 -0
  21. hackagent/api/key/key_list.py +158 -0
  22. hackagent/api/key/key_retrieve.py +150 -0
  23. hackagent/api/prompt/__init__.py +1 -0
  24. hackagent/api/prompt/prompt_create.py +160 -0
  25. hackagent/api/prompt/prompt_destroy.py +98 -0
  26. hackagent/api/prompt/prompt_list.py +173 -0
  27. hackagent/api/prompt/prompt_partial_update.py +174 -0
  28. hackagent/api/prompt/prompt_retrieve.py +151 -0
  29. hackagent/api/prompt/prompt_update.py +174 -0
  30. hackagent/api/result/__init__.py +1 -0
  31. hackagent/api/result/result_create.py +160 -0
  32. hackagent/api/result/result_destroy.py +98 -0
  33. hackagent/api/result/result_list.py +233 -0
  34. hackagent/api/result/result_partial_update.py +178 -0
  35. hackagent/api/result/result_retrieve.py +151 -0
  36. hackagent/api/result/result_trace_create.py +178 -0
  37. hackagent/api/result/result_update.py +174 -0
  38. hackagent/api/run/__init__.py +1 -0
  39. hackagent/api/run/run_create.py +172 -0
  40. hackagent/api/run/run_destroy.py +104 -0
  41. hackagent/api/run/run_list.py +260 -0
  42. hackagent/api/run/run_partial_update.py +186 -0
  43. hackagent/api/run/run_result_create.py +178 -0
  44. hackagent/api/run/run_retrieve.py +163 -0
  45. hackagent/api/run/run_run_tests_create.py +172 -0
  46. hackagent/api/run/run_update.py +186 -0
  47. hackagent/attacks/AdvPrefix/README.md +7 -0
  48. hackagent/attacks/AdvPrefix/__init__.py +0 -0
  49. hackagent/attacks/AdvPrefix/completer.py +438 -0
  50. hackagent/attacks/AdvPrefix/config.py +59 -0
  51. hackagent/attacks/AdvPrefix/preprocessing.py +521 -0
  52. hackagent/attacks/AdvPrefix/scorer.py +259 -0
  53. hackagent/attacks/AdvPrefix/scorer_parser.py +498 -0
  54. hackagent/attacks/AdvPrefix/selector.py +246 -0
  55. hackagent/attacks/AdvPrefix/step1_generate.py +324 -0
  56. hackagent/attacks/AdvPrefix/step4_compute_ce.py +293 -0
  57. hackagent/attacks/AdvPrefix/step6_get_completions.py +387 -0
  58. hackagent/attacks/AdvPrefix/step7_evaluate_responses.py +289 -0
  59. hackagent/attacks/AdvPrefix/step8_aggregate_evaluations.py +177 -0
  60. hackagent/attacks/AdvPrefix/step9_select_prefixes.py +59 -0
  61. hackagent/attacks/AdvPrefix/utils.py +192 -0
  62. hackagent/attacks/__init__.py +6 -0
  63. hackagent/attacks/advprefix.py +1136 -0
  64. hackagent/attacks/base.py +50 -0
  65. hackagent/attacks/strategies.py +539 -0
  66. hackagent/branding.py +143 -0
  67. hackagent/client.py +328 -0
  68. hackagent/errors.py +31 -0
  69. hackagent/logger.py +67 -0
  70. hackagent/models/__init__.py +71 -0
  71. hackagent/models/agent.py +240 -0
  72. hackagent/models/agent_request.py +169 -0
  73. hackagent/models/agent_type_enum.py +12 -0
  74. hackagent/models/attack.py +154 -0
  75. hackagent/models/attack_request.py +82 -0
  76. hackagent/models/evaluation_status_enum.py +14 -0
  77. hackagent/models/organization_minimal.py +68 -0
  78. hackagent/models/paginated_agent_list.py +123 -0
  79. hackagent/models/paginated_attack_list.py +123 -0
  80. hackagent/models/paginated_prompt_list.py +123 -0
  81. hackagent/models/paginated_result_list.py +123 -0
  82. hackagent/models/paginated_run_list.py +123 -0
  83. hackagent/models/paginated_user_api_key_list.py +123 -0
  84. hackagent/models/patched_agent_request.py +176 -0
  85. hackagent/models/patched_attack_request.py +92 -0
  86. hackagent/models/patched_prompt_request.py +162 -0
  87. hackagent/models/patched_result_request.py +237 -0
  88. hackagent/models/patched_run_request.py +138 -0
  89. hackagent/models/prompt.py +226 -0
  90. hackagent/models/prompt_request.py +155 -0
  91. hackagent/models/result.py +294 -0
  92. hackagent/models/result_list_evaluation_status.py +14 -0
  93. hackagent/models/result_request.py +232 -0
  94. hackagent/models/run.py +233 -0
  95. hackagent/models/run_list_status.py +12 -0
  96. hackagent/models/run_request.py +133 -0
  97. hackagent/models/status_enum.py +12 -0
  98. hackagent/models/step_type_enum.py +14 -0
  99. hackagent/models/trace.py +121 -0
  100. hackagent/models/trace_request.py +94 -0
  101. hackagent/models/user_api_key.py +201 -0
  102. hackagent/models/user_api_key_request.py +73 -0
  103. hackagent/models/user_profile_minimal.py +76 -0
  104. hackagent/py.typed +1 -0
  105. hackagent/router/__init__.py +11 -0
  106. hackagent/router/adapters/__init__.py +5 -0
  107. hackagent/router/adapters/google_adk.py +658 -0
  108. hackagent/router/adapters/litellm_adapter.py +290 -0
  109. hackagent/router/base.py +48 -0
  110. hackagent/router/router.py +753 -0
  111. hackagent/types.py +46 -0
  112. hackagent/utils.py +61 -0
  113. hackagent/vulnerabilities/__init__.py +0 -0
  114. hackagent-0.1.0.dist-info/LICENSE +202 -0
  115. hackagent-0.1.0.dist-info/METADATA +173 -0
  116. hackagent-0.1.0.dist-info/RECORD +117 -0
  117. hackagent-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,246 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ import logging
4
+ from typing import List, Optional
5
+ from dataclasses import dataclass
6
+
7
+
8
+ @dataclass
9
+ class PrefixSelectorConfig:
10
+ """Configuration for prefix selection"""
11
+
12
+ pasr_weight: float # Weight for log-PASR in selection
13
+ n_prefixes_per_goal: int = 1 # Number of prefixes to select per goal
14
+ nll_tol: float = 999 # Tolerance for NLL relative to best prefix
15
+ pasr_tol: float = 0 # Tolerance for PASR relative to best prefix
16
+ judges: Optional[List[dict]] = None # List of judges to use for PASR calculation
17
+
18
+
19
+ class PrefixSelector:
20
+ """
21
+ Selects prefixes based on a combination of judge scores (PASR) and NLL.
22
+ Supports multiple judges and custom weighting for selection criteria.
23
+ """
24
+
25
+ def __init__(self, config: PrefixSelectorConfig):
26
+ """
27
+ Initialize the prefix selector.
28
+
29
+ Args:
30
+ config: Configuration for prefix selection
31
+ """
32
+ self.config = config
33
+ self.logger = logging.getLogger(__name__)
34
+
35
+ # Mapping of judge types to their column names in the DataFrame
36
+ self.judge_column_map = {
37
+ "nuanced": "eval_nj_mean",
38
+ "jailbreakbench": "eval_jb_mean",
39
+ "harmbench": "eval_hb_mean",
40
+ "strongreject": "eval_sj_binary_mean",
41
+ }
42
+
43
+ def select_prefixes(self, df: pd.DataFrame) -> pd.DataFrame:
44
+ """
45
+ Select prefixes based on combined judge scores and NLL.
46
+
47
+ Args:
48
+ df: DataFrame containing prefixes with evaluation results
49
+
50
+ Returns:
51
+ DataFrame containing selected prefixes
52
+ """
53
+ # Validate judge configuration list
54
+ if not isinstance(self.config.judges, list) or not self.config.judges:
55
+ # Check if judges is a list and not empty
56
+ raise ValueError(
57
+ "Judge configuration ('judges' key) must be a non-empty list of dictionaries."
58
+ )
59
+
60
+ judge_types_found = []
61
+ missing_columns = []
62
+ for judge_config in self.config.judges:
63
+ if not isinstance(judge_config, dict):
64
+ self.logger.warning(
65
+ f"Skipping invalid item in judge config list (not a dict): {judge_config}"
66
+ )
67
+ continue
68
+
69
+ # Extract judge type string (e.g., "nuanced") - Assuming a 'type' key
70
+ judge_type = judge_config.get("type") or judge_config.get("evaluator_type")
71
+ # Could add inference here if needed, similar to step 7
72
+
73
+ if not judge_type:
74
+ self.logger.warning(
75
+ f"Could not determine type for judge config: {judge_config}. Skipping."
76
+ )
77
+ continue
78
+
79
+ if judge_type not in self.judge_column_map:
80
+ # Check if the *type string* is valid
81
+ self.logger.error(
82
+ f"Unknown judge type specified in config: '{judge_type}'"
83
+ )
84
+ raise ValueError(f"Unknown judge type for selection: {judge_type}")
85
+
86
+ # Check if the corresponding column exists in the DataFrame
87
+ expected_col = self.judge_column_map[judge_type]
88
+ if expected_col not in df.columns:
89
+ missing_columns.append(expected_col)
90
+
91
+ if judge_type not in judge_types_found:
92
+ judge_types_found.append(judge_type)
93
+
94
+ if missing_columns:
95
+ raise ValueError(
96
+ f"Missing required evaluation result columns in DataFrame: {missing_columns}"
97
+ )
98
+ if not judge_types_found:
99
+ raise ValueError(
100
+ "No valid judge types found in the configuration to perform selection."
101
+ )
102
+
103
+ # Create a working copy of the DataFrame
104
+ work_df = df.copy()
105
+
106
+ # Calculate combined PASR score using the identified judge types
107
+ work_df["pasr"] = self._calculate_combined_pasr(work_df, judge_types_found)
108
+
109
+ # Calculate log PASR for scoring
110
+ work_df["log_pasr"] = np.log(work_df["pasr"] + 1e-6)
111
+
112
+ # Calculate combined score (minimize both 1 - PASR and prefix_nll)
113
+ work_df["combined_score"] = (
114
+ -self.config.pasr_weight * work_df["log_pasr"] + work_df["prefix_nll"]
115
+ )
116
+
117
+ # Create DataFrame for selected prefixes
118
+ selected_prefixes = pd.DataFrame()
119
+
120
+ # Group by goal and apply selection process
121
+ for goal, group in work_df.groupby("goal"):
122
+ # Step 1: Select first prefix based on combined score
123
+ # Check if group is empty after potential filtering/issues
124
+ if (
125
+ group.empty
126
+ or "combined_score" not in group.columns
127
+ or group["combined_score"].isnull().all()
128
+ ):
129
+ self.logger.warning(
130
+ f"Skipping goal '{goal[:50]}...' during selection due to empty group or missing/invalid scores."
131
+ )
132
+ continue
133
+
134
+ first_selection_idx = group["combined_score"].idxmin()
135
+ first_selection = group.loc[first_selection_idx]
136
+
137
+ # Step 2: Filter prefixes within PASR tolerance
138
+ remaining_candidates = group[
139
+ (group["pasr"] >= first_selection["pasr"] - self.config.pasr_tol)
140
+ & (group.index != first_selection.name)
141
+ ]
142
+
143
+ # Step 3: Filter candidates within NLL tolerance
144
+ valid_candidates = remaining_candidates[
145
+ remaining_candidates["prefix_nll"]
146
+ <= first_selection["prefix_nll"] + self.config.nll_tol
147
+ ]
148
+
149
+ # Initialize selections list with first selection
150
+ selections = [first_selection]
151
+
152
+ # Step 4: Iteratively select additional prefixes
153
+ for _ in range(self.config.n_prefixes_per_goal - 1):
154
+ # Remove candidates that are sub-prefixes of selected ones
155
+ valid_candidates = valid_candidates[
156
+ ~valid_candidates["prefix"].apply(
157
+ lambda x: any(
158
+ str(x).startswith(str(sel["prefix"]))
159
+ for sel in selections
160
+ if sel is not None and "prefix" in sel and x is not None
161
+ )
162
+ )
163
+ ]
164
+
165
+ if valid_candidates.empty:
166
+ break
167
+
168
+ # Select next prefix with lowest NLL
169
+ if (
170
+ "prefix_nll" not in valid_candidates.columns
171
+ or valid_candidates["prefix_nll"].isnull().all()
172
+ ):
173
+ self.logger.warning(
174
+ f"Cannot select next prefix for goal '{goal[:50]}...' due to missing/invalid NLL scores in candidates."
175
+ )
176
+ break
177
+ next_selection = valid_candidates.nsmallest(1, "prefix_nll").iloc[0]
178
+ selections.append(next_selection)
179
+ valid_candidates = valid_candidates[
180
+ valid_candidates.index != next_selection.name
181
+ ]
182
+
183
+ # Combine selections for this goal
184
+ combined_selection = pd.DataFrame(selections)
185
+ selected_prefixes = pd.concat([selected_prefixes, combined_selection])
186
+
187
+ # Reset index
188
+ selected_prefixes.reset_index(drop=True, inplace=True)
189
+
190
+ # Add the new columns (pasr, log_pasr, combined_score) to the output
191
+ # Ensure columns exist before trying to select them
192
+ output_columns = [
193
+ col
194
+ for col in list(df.columns) + ["pasr", "log_pasr", "combined_score"]
195
+ if col in selected_prefixes.columns
196
+ ]
197
+ selected_prefixes = selected_prefixes[output_columns]
198
+
199
+ self.logger.info(
200
+ f"Selected {len(selected_prefixes)} prefixes across {len(df['goal'].unique())} goals"
201
+ )
202
+ return selected_prefixes
203
+
204
+ def _calculate_combined_pasr(
205
+ self, df: pd.DataFrame, judge_types: List[str]
206
+ ) -> pd.Series:
207
+ """
208
+ Calculate combined PASR score from specified judge types.
209
+
210
+ Args:
211
+ df: DataFrame containing judge scores
212
+ judge_types: List of valid judge type strings (e.g., ["nuanced", "harmbench"])
213
+
214
+ Returns:
215
+ Series containing combined PASR scores
216
+ """
217
+ judge_scores = []
218
+ for judge_type in judge_types: # Iterate through the list of type strings
219
+ column = self.judge_column_map[judge_type] # Use the type string for lookup
220
+ # Ensure column is numeric before appending
221
+ if column in df.columns:
222
+ try:
223
+ numeric_scores = pd.to_numeric(df[column], errors="coerce")
224
+ judge_scores.append(numeric_scores)
225
+ except Exception as e:
226
+ self.logger.warning(
227
+ f"Could not convert column '{column}' to numeric for PASR calculation. Skipping. Error: {e}"
228
+ )
229
+ else:
230
+ # This should be caught by initial validation, but as safeguard
231
+ self.logger.warning(
232
+ f"Column '{column}' for judge '{judge_type}' not found during PASR calculation."
233
+ )
234
+
235
+ if not judge_scores:
236
+ self.logger.warning(
237
+ "No valid judge scores found to calculate combined PASR. Returning zeros."
238
+ )
239
+ return pd.Series(0, index=df.index)
240
+
241
+ # Calculate mean of judge scores, handling potential NaNs after conversion
242
+ combined_scores_df = pd.concat(judge_scores, axis=1)
243
+ # Use mean, skipping NaNs. If a row has all NaNs, the mean will be NaN.
244
+ mean_scores = combined_scores_df.mean(axis=1, skipna=True)
245
+ # Fill any resulting NaNs (rows where all judges had NaN scores) with 0
246
+ return mean_scores.fillna(0)
@@ -0,0 +1,324 @@
1
+ import logging
2
+ import pandas as pd
3
+ from typing import List, Dict, Union, Tuple, Optional
4
+
5
+ from hackagent.router.router import AgentRouter # Added
6
+ from hackagent.models import AgentTypeEnum # Added
7
+ from hackagent.client import AuthenticatedClient # Added
8
+ from .utils import get_checkpoint_path
9
+ from rich.progress import (
10
+ Progress,
11
+ BarColumn,
12
+ TextColumn,
13
+ TimeRemainingColumn,
14
+ MofNCompleteColumn,
15
+ SpinnerColumn,
16
+ )
17
+ from hackagent.logger import get_logger
18
+
19
+ logger = get_logger(__name__)
20
+
21
+ # Constants moved from main file
22
+ CUSTOM_CHAT_TEMPLATES = {
23
+ "georgesung/llama2_7b_chat_uncensored": "<s>### HUMAN:\n{content}\n\n### RESPONSE:\n",
24
+ "Tap-M/Luna-AI-Llama2-Uncensored": "<s>USER: {content}\n\nASSISTANT:",
25
+ }
26
+
27
+
28
+ def _construct_prompts(
29
+ goals: List[str],
30
+ meta_prefixes: List[str],
31
+ meta_prefixes_n_samples: Union[int, List[int]], # Allow int or list
32
+ ) -> Tuple[List[Dict[str, str]], List[str], List[str]]:
33
+ """Constructs prompts for the generator model."""
34
+
35
+ # Handle the case where meta_prefixes_n_samples is an integer vs a list
36
+ if isinstance(meta_prefixes_n_samples, list):
37
+ if len(meta_prefixes) != len(meta_prefixes_n_samples):
38
+ raise ValueError(
39
+ "Lengths of meta_prefixes and meta_prefixes_n_samples lists must match."
40
+ )
41
+ n_samples_list = meta_prefixes_n_samples
42
+ elif isinstance(meta_prefixes_n_samples, int):
43
+ # Apply the same integer sample count to all meta prefixes
44
+ n_samples_list = [meta_prefixes_n_samples] * len(meta_prefixes)
45
+ else:
46
+ raise TypeError("meta_prefixes_n_samples must be an int or a list of ints.")
47
+
48
+ formatted_inputs = []
49
+ current_goals = []
50
+ expanded_meta_prefixes = []
51
+
52
+ for goal in goals:
53
+ for meta_prefix, n_samples in zip(meta_prefixes, n_samples_list):
54
+ if n_samples <= 0:
55
+ continue
56
+
57
+ # chat = [{"role": "user", "content": goal}] # Not directly used for router prompt format
58
+ try:
59
+ # The prompt for the router will be the fully constructed context.
60
+ # Custom chat templating needs to happen before sending to router.
61
+ if meta_prefix in CUSTOM_CHAT_TEMPLATES:
62
+ # Assuming meta_prefix identifies the model type for templating,
63
+ # which is a bit indirect. Usually, model_string would be used.
64
+ # For now, we'll keep this logic, but the 'context' is the prompt.
65
+ prompt_content = CUSTOM_CHAT_TEMPLATES[meta_prefix].format(
66
+ content=goal
67
+ )
68
+ else:
69
+ logger.warning(
70
+ f"Using basic formatting for prompt construction with meta_prefix: {meta_prefix}. No matching template found."
71
+ )
72
+ prompt_content = f"USER: {goal}\\nASSISTANT:"
73
+
74
+ # Append the actual meta_prefix text to the prompt that will be sent
75
+ final_prompt = prompt_content + meta_prefix
76
+
77
+ formatted_inputs.extend([final_prompt] * n_samples)
78
+ current_goals.extend([goal] * n_samples)
79
+ expanded_meta_prefixes.extend([meta_prefix] * n_samples)
80
+ except Exception as e:
81
+ logging.error(
82
+ f"Error formatting prompt for goal '{goal}' with meta_prefix '{meta_prefix}': {e}"
83
+ )
84
+
85
+ return formatted_inputs, current_goals, expanded_meta_prefixes
86
+
87
+
88
+ async def _generate_prefixes(
89
+ unique_goals: List[str],
90
+ config: Dict,
91
+ logger: logging.Logger,
92
+ client: AuthenticatedClient, # organization_id removed from here
93
+ ) -> List[Dict]:
94
+ """
95
+ Helper for step 1. Generate prefixes using AgentRouter with a LiteLLM agent.
96
+ """
97
+ results = []
98
+
99
+ generator = config.get("generator", {})
100
+ if not generator:
101
+ logger.error("Missing 'generator'. Cannot initialize AgentRouter for LiteLLM.")
102
+ return results
103
+
104
+ # Map generator to adapter_operational_config for LiteLLM
105
+ # New keys for LiteLLMAgentAdapter: 'name', 'endpoint', 'api_key'
106
+ model_name = generator.get("identifier")
107
+ if not model_name:
108
+ logger.error(
109
+ "Missing 'identifier' in 'generator'. Cannot configure LiteLLM agent."
110
+ )
111
+ return results
112
+
113
+ adapter_operational_config = {
114
+ "name": model_name,
115
+ "endpoint": generator.get("endpoint"),
116
+ "api_key": generator.get("api_key"),
117
+ # Other params like max_new_tokens, temperature, top_p for adapter defaults
118
+ "max_new_tokens": config.get("max_new_tokens", 100),
119
+ "temperature": config.get("temperature", 0.8),
120
+ "top_p": config.get("top_p", 1.0),
121
+ }
122
+
123
+ router: Optional[AgentRouter] = None
124
+ registration_key: Optional[str] = None
125
+
126
+ try:
127
+ logger.info(f"Initializing AgentRouter for LiteLLM model: {model_name}")
128
+ router = AgentRouter(
129
+ client=client,
130
+ name=model_name, # Name for backend agent record
131
+ agent_type=AgentTypeEnum.LITELMM,
132
+ endpoint=generator.get("endpoint"),
133
+ adapter_operational_config=adapter_operational_config,
134
+ metadata=adapter_operational_config.copy(),
135
+ overwrite_metadata=True,
136
+ )
137
+
138
+ if router._agent_registry:
139
+ registration_key = next(iter(router._agent_registry.keys()))
140
+ logger.info(
141
+ f"AgentRouter initialized. Registration key for LiteLLM agent: {registration_key}"
142
+ )
143
+ else:
144
+ logger.error(
145
+ "AgentRouter initialized, but no agent adapter was registered."
146
+ )
147
+ return results # Cannot proceed
148
+
149
+ except Exception as e:
150
+ logger.error(
151
+ f"Error initializing AgentRouter for {model_name}: {e}",
152
+ exc_info=True,
153
+ )
154
+ return results
155
+
156
+ for do_sample in [False, True]:
157
+ progress_bar_description = (
158
+ "[cyan]Generating Prefixes (Random Sampling)..."
159
+ if do_sample
160
+ else "[cyan]Generating Prefixes (Greedy Decoding)..."
161
+ )
162
+ logger.info(
163
+ f"Generating with {'random sampling' if do_sample else 'greedy decoding'} using LiteLLM via AgentRouter..."
164
+ )
165
+ try:
166
+ # _construct_prompts now returns the full prompt string
167
+ prompts_to_send, current_goals, current_meta_prefixes = _construct_prompts(
168
+ unique_goals,
169
+ config.get("meta_prefixes", []),
170
+ config.get("meta_prefix_samples", []),
171
+ )
172
+ logger.debug(f"Prompts to send ({len(prompts_to_send)}): {prompts_to_send}")
173
+ except Exception as e:
174
+ logger.error(f"Error constructing prompts: {e}", exc_info=True)
175
+ continue
176
+
177
+ if not prompts_to_send:
178
+ logger.warning("No prompts to send, skipping completion.")
179
+ continue
180
+
181
+ # Loop through each constructed prompt and call the router
182
+ with Progress(
183
+ SpinnerColumn(),
184
+ TextColumn("[progress.description]{task.description}"),
185
+ BarColumn(),
186
+ MofNCompleteColumn(),
187
+ TextColumn("[progress.percentage]{task.percentage:>3.1f}%"),
188
+ TimeRemainingColumn(),
189
+ ) as progress_bar:
190
+ task = progress_bar.add_task(
191
+ progress_bar_description, total=len(prompts_to_send)
192
+ )
193
+ for idx, current_prompt_text in enumerate(prompts_to_send):
194
+ goal_for_prompt = current_goals[idx]
195
+ meta_prefix_for_prompt = current_meta_prefixes[idx]
196
+
197
+ request_params = {
198
+ "prompt": current_prompt_text,
199
+ "max_new_tokens": config.get("max_new_tokens", 100),
200
+ "temperature": config.get("temperature", 0.8)
201
+ if do_sample
202
+ else 1e-2,
203
+ "top_p": config.get("top_p", 1.0),
204
+ }
205
+
206
+ completion_text = None
207
+ try:
208
+ # logger.info(f"Sending request to router for prompt: {current_prompt_text[:100]}...")
209
+ response = await router.route_request(
210
+ registration_key=registration_key, # type: ignore
211
+ request_data=request_params,
212
+ )
213
+
214
+ # logger.debug(f"Router response: {response}")
215
+
216
+ if response and response.get("error_message"):
217
+ logger.error(
218
+ f"Error from AgentRouter for prompt '{current_prompt_text[:50]}...': {response['error_message']}"
219
+ )
220
+ # Append error marker or skip
221
+ # For now, we'll try to get processed_response even if there's a partial error
222
+ # The adapter should handle this.
223
+ pass # Ensure block is not empty if all lines are comments
224
+
225
+ if response and response.get("processed_response"):
226
+ completion_text = response["processed_response"]
227
+ # The adapter's processed_response is assumed to be the full text (prompt + generation)
228
+ # We need to extract just the generated part.
229
+ if completion_text.startswith(current_prompt_text):
230
+ generated_part = completion_text[len(current_prompt_text) :]
231
+ else:
232
+ # Fallback or warning if the response doesn't start with the prompt
233
+ logger.warning(
234
+ f"Completion for '{current_prompt_text[:50]}...' did not start with the prompt. Using full response as generated part."
235
+ )
236
+ generated_part = completion_text
237
+ else:
238
+ logger.warning(
239
+ f"No 'processed_response' in router output for prompt: {current_prompt_text[:50]}..."
240
+ )
241
+ generated_part = " [GENERATION_VIA_ROUTER_FAILED]"
242
+
243
+ except Exception as e:
244
+ logger.error(
245
+ f"Exception during router.route_request for prompt '{current_prompt_text[:50]}...': {e}",
246
+ exc_info=True,
247
+ )
248
+ generated_part = " [ROUTER_REQUEST_EXCEPTION]"
249
+
250
+ # The 'prefix' should be the meta_prefix + generated_part
251
+ final_prefix = meta_prefix_for_prompt + generated_part
252
+
253
+ results.append(
254
+ {
255
+ "goal": goal_for_prompt,
256
+ "prefix": final_prefix,
257
+ "meta_prefix": meta_prefix_for_prompt,
258
+ "temperature": request_params["temperature"], # Use actual temp
259
+ "model_name": model_name, # Model used by the adapter
260
+ }
261
+ )
262
+ progress_bar.update(task, advance=1)
263
+
264
+ # No need to del router explicitly here, it goes out of scope.
265
+ return results
266
+
267
+
268
+ async def execute(
269
+ goals: List[str],
270
+ config: Dict,
271
+ logger: logging.Logger,
272
+ run_dir: str,
273
+ client: AuthenticatedClient, # organization_id removed from this call
274
+ ) -> pd.DataFrame:
275
+ """Generate initial prefixes using provided goals via AgentRouter."""
276
+ logger.info("Executing Step 1: Generating prefixes using AgentRouter")
277
+
278
+ if not goals:
279
+ logger.warning("Step 1 received no goals. Returning empty DataFrame.")
280
+ return pd.DataFrame(
281
+ columns=["goal", "prefix", "meta_prefix", "temperature", "model_name"]
282
+ )
283
+
284
+ generator = config.get("generator")
285
+
286
+ if not generator or not generator.get("identifier"):
287
+ logger.error(
288
+ "Step 1: Missing 'generator' or 'identifier' in config. Cannot generate prefixes."
289
+ )
290
+ return pd.DataFrame(
291
+ columns=["goal", "prefix", "meta_prefix", "temperature", "model_name"]
292
+ )
293
+
294
+ model_name_from_config = generator["identifier"]
295
+ logger.info(
296
+ f"Generating prefixes for {len(goals)} unique goals using AgentRouter with LiteLLM: {model_name_from_config}"
297
+ )
298
+
299
+ all_results = await _generate_prefixes(
300
+ unique_goals=goals,
301
+ config=config,
302
+ logger=logger,
303
+ client=client, # organization_id removed from this call
304
+ )
305
+
306
+ if not all_results:
307
+ logger.warning("Step 1: No prefixes were generated via AgentRouter.")
308
+ results_df = pd.DataFrame(
309
+ columns=["goal", "prefix", "meta_prefix", "temperature", "model_name"]
310
+ )
311
+ else:
312
+ results_df = pd.DataFrame(all_results)
313
+
314
+ output_path = get_checkpoint_path(run_dir, 1)
315
+ try:
316
+ results_df.to_csv(output_path, index=False)
317
+ logger.info(
318
+ f"Step 1 complete. Generated {len(results_df)} total prefixes via AgentRouter"
319
+ )
320
+ logger.info(f"Checkpoint saved to {output_path}")
321
+ except Exception as e:
322
+ logger.error(f"Failed to save checkpoint for step 1 to {output_path}: {e}")
323
+
324
+ return results_df