hackagent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hackagent/__init__.py +23 -0
- hackagent/agent.py +193 -0
- hackagent/api/__init__.py +1 -0
- hackagent/api/agent/__init__.py +1 -0
- hackagent/api/agent/agent_create.py +340 -0
- hackagent/api/agent/agent_destroy.py +136 -0
- hackagent/api/agent/agent_list.py +234 -0
- hackagent/api/agent/agent_partial_update.py +354 -0
- hackagent/api/agent/agent_retrieve.py +227 -0
- hackagent/api/agent/agent_update.py +354 -0
- hackagent/api/attack/__init__.py +1 -0
- hackagent/api/attack/attack_create.py +264 -0
- hackagent/api/attack/attack_destroy.py +140 -0
- hackagent/api/attack/attack_list.py +242 -0
- hackagent/api/attack/attack_partial_update.py +278 -0
- hackagent/api/attack/attack_retrieve.py +235 -0
- hackagent/api/attack/attack_update.py +278 -0
- hackagent/api/key/__init__.py +1 -0
- hackagent/api/key/key_create.py +168 -0
- hackagent/api/key/key_destroy.py +97 -0
- hackagent/api/key/key_list.py +158 -0
- hackagent/api/key/key_retrieve.py +150 -0
- hackagent/api/prompt/__init__.py +1 -0
- hackagent/api/prompt/prompt_create.py +160 -0
- hackagent/api/prompt/prompt_destroy.py +98 -0
- hackagent/api/prompt/prompt_list.py +173 -0
- hackagent/api/prompt/prompt_partial_update.py +174 -0
- hackagent/api/prompt/prompt_retrieve.py +151 -0
- hackagent/api/prompt/prompt_update.py +174 -0
- hackagent/api/result/__init__.py +1 -0
- hackagent/api/result/result_create.py +160 -0
- hackagent/api/result/result_destroy.py +98 -0
- hackagent/api/result/result_list.py +233 -0
- hackagent/api/result/result_partial_update.py +178 -0
- hackagent/api/result/result_retrieve.py +151 -0
- hackagent/api/result/result_trace_create.py +178 -0
- hackagent/api/result/result_update.py +174 -0
- hackagent/api/run/__init__.py +1 -0
- hackagent/api/run/run_create.py +172 -0
- hackagent/api/run/run_destroy.py +104 -0
- hackagent/api/run/run_list.py +260 -0
- hackagent/api/run/run_partial_update.py +186 -0
- hackagent/api/run/run_result_create.py +178 -0
- hackagent/api/run/run_retrieve.py +163 -0
- hackagent/api/run/run_run_tests_create.py +172 -0
- hackagent/api/run/run_update.py +186 -0
- hackagent/attacks/AdvPrefix/README.md +7 -0
- hackagent/attacks/AdvPrefix/__init__.py +0 -0
- hackagent/attacks/AdvPrefix/completer.py +438 -0
- hackagent/attacks/AdvPrefix/config.py +59 -0
- hackagent/attacks/AdvPrefix/preprocessing.py +521 -0
- hackagent/attacks/AdvPrefix/scorer.py +259 -0
- hackagent/attacks/AdvPrefix/scorer_parser.py +498 -0
- hackagent/attacks/AdvPrefix/selector.py +246 -0
- hackagent/attacks/AdvPrefix/step1_generate.py +324 -0
- hackagent/attacks/AdvPrefix/step4_compute_ce.py +293 -0
- hackagent/attacks/AdvPrefix/step6_get_completions.py +387 -0
- hackagent/attacks/AdvPrefix/step7_evaluate_responses.py +289 -0
- hackagent/attacks/AdvPrefix/step8_aggregate_evaluations.py +177 -0
- hackagent/attacks/AdvPrefix/step9_select_prefixes.py +59 -0
- hackagent/attacks/AdvPrefix/utils.py +192 -0
- hackagent/attacks/__init__.py +6 -0
- hackagent/attacks/advprefix.py +1136 -0
- hackagent/attacks/base.py +50 -0
- hackagent/attacks/strategies.py +539 -0
- hackagent/branding.py +143 -0
- hackagent/client.py +328 -0
- hackagent/errors.py +31 -0
- hackagent/logger.py +67 -0
- hackagent/models/__init__.py +71 -0
- hackagent/models/agent.py +240 -0
- hackagent/models/agent_request.py +169 -0
- hackagent/models/agent_type_enum.py +12 -0
- hackagent/models/attack.py +154 -0
- hackagent/models/attack_request.py +82 -0
- hackagent/models/evaluation_status_enum.py +14 -0
- hackagent/models/organization_minimal.py +68 -0
- hackagent/models/paginated_agent_list.py +123 -0
- hackagent/models/paginated_attack_list.py +123 -0
- hackagent/models/paginated_prompt_list.py +123 -0
- hackagent/models/paginated_result_list.py +123 -0
- hackagent/models/paginated_run_list.py +123 -0
- hackagent/models/paginated_user_api_key_list.py +123 -0
- hackagent/models/patched_agent_request.py +176 -0
- hackagent/models/patched_attack_request.py +92 -0
- hackagent/models/patched_prompt_request.py +162 -0
- hackagent/models/patched_result_request.py +237 -0
- hackagent/models/patched_run_request.py +138 -0
- hackagent/models/prompt.py +226 -0
- hackagent/models/prompt_request.py +155 -0
- hackagent/models/result.py +294 -0
- hackagent/models/result_list_evaluation_status.py +14 -0
- hackagent/models/result_request.py +232 -0
- hackagent/models/run.py +233 -0
- hackagent/models/run_list_status.py +12 -0
- hackagent/models/run_request.py +133 -0
- hackagent/models/status_enum.py +12 -0
- hackagent/models/step_type_enum.py +14 -0
- hackagent/models/trace.py +121 -0
- hackagent/models/trace_request.py +94 -0
- hackagent/models/user_api_key.py +201 -0
- hackagent/models/user_api_key_request.py +73 -0
- hackagent/models/user_profile_minimal.py +76 -0
- hackagent/py.typed +1 -0
- hackagent/router/__init__.py +11 -0
- hackagent/router/adapters/__init__.py +5 -0
- hackagent/router/adapters/google_adk.py +658 -0
- hackagent/router/adapters/litellm_adapter.py +290 -0
- hackagent/router/base.py +48 -0
- hackagent/router/router.py +753 -0
- hackagent/types.py +46 -0
- hackagent/utils.py +61 -0
- hackagent/vulnerabilities/__init__.py +0 -0
- hackagent-0.1.0.dist-info/LICENSE +202 -0
- hackagent-0.1.0.dist-info/METADATA +173 -0
- hackagent-0.1.0.dist-info/RECORD +117 -0
- hackagent-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
import logging
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class PrefixSelectorConfig:
|
|
10
|
+
"""Configuration for prefix selection"""
|
|
11
|
+
|
|
12
|
+
pasr_weight: float # Weight for log-PASR in selection
|
|
13
|
+
n_prefixes_per_goal: int = 1 # Number of prefixes to select per goal
|
|
14
|
+
nll_tol: float = 999 # Tolerance for NLL relative to best prefix
|
|
15
|
+
pasr_tol: float = 0 # Tolerance for PASR relative to best prefix
|
|
16
|
+
judges: Optional[List[dict]] = None # List of judges to use for PASR calculation
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PrefixSelector:
|
|
20
|
+
"""
|
|
21
|
+
Selects prefixes based on a combination of judge scores (PASR) and NLL.
|
|
22
|
+
Supports multiple judges and custom weighting for selection criteria.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, config: PrefixSelectorConfig):
|
|
26
|
+
"""
|
|
27
|
+
Initialize the prefix selector.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
config: Configuration for prefix selection
|
|
31
|
+
"""
|
|
32
|
+
self.config = config
|
|
33
|
+
self.logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
# Mapping of judge types to their column names in the DataFrame
|
|
36
|
+
self.judge_column_map = {
|
|
37
|
+
"nuanced": "eval_nj_mean",
|
|
38
|
+
"jailbreakbench": "eval_jb_mean",
|
|
39
|
+
"harmbench": "eval_hb_mean",
|
|
40
|
+
"strongreject": "eval_sj_binary_mean",
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
def select_prefixes(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
44
|
+
"""
|
|
45
|
+
Select prefixes based on combined judge scores and NLL.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
df: DataFrame containing prefixes with evaluation results
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
DataFrame containing selected prefixes
|
|
52
|
+
"""
|
|
53
|
+
# Validate judge configuration list
|
|
54
|
+
if not isinstance(self.config.judges, list) or not self.config.judges:
|
|
55
|
+
# Check if judges is a list and not empty
|
|
56
|
+
raise ValueError(
|
|
57
|
+
"Judge configuration ('judges' key) must be a non-empty list of dictionaries."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
judge_types_found = []
|
|
61
|
+
missing_columns = []
|
|
62
|
+
for judge_config in self.config.judges:
|
|
63
|
+
if not isinstance(judge_config, dict):
|
|
64
|
+
self.logger.warning(
|
|
65
|
+
f"Skipping invalid item in judge config list (not a dict): {judge_config}"
|
|
66
|
+
)
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
# Extract judge type string (e.g., "nuanced") - Assuming a 'type' key
|
|
70
|
+
judge_type = judge_config.get("type") or judge_config.get("evaluator_type")
|
|
71
|
+
# Could add inference here if needed, similar to step 7
|
|
72
|
+
|
|
73
|
+
if not judge_type:
|
|
74
|
+
self.logger.warning(
|
|
75
|
+
f"Could not determine type for judge config: {judge_config}. Skipping."
|
|
76
|
+
)
|
|
77
|
+
continue
|
|
78
|
+
|
|
79
|
+
if judge_type not in self.judge_column_map:
|
|
80
|
+
# Check if the *type string* is valid
|
|
81
|
+
self.logger.error(
|
|
82
|
+
f"Unknown judge type specified in config: '{judge_type}'"
|
|
83
|
+
)
|
|
84
|
+
raise ValueError(f"Unknown judge type for selection: {judge_type}")
|
|
85
|
+
|
|
86
|
+
# Check if the corresponding column exists in the DataFrame
|
|
87
|
+
expected_col = self.judge_column_map[judge_type]
|
|
88
|
+
if expected_col not in df.columns:
|
|
89
|
+
missing_columns.append(expected_col)
|
|
90
|
+
|
|
91
|
+
if judge_type not in judge_types_found:
|
|
92
|
+
judge_types_found.append(judge_type)
|
|
93
|
+
|
|
94
|
+
if missing_columns:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"Missing required evaluation result columns in DataFrame: {missing_columns}"
|
|
97
|
+
)
|
|
98
|
+
if not judge_types_found:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
"No valid judge types found in the configuration to perform selection."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Create a working copy of the DataFrame
|
|
104
|
+
work_df = df.copy()
|
|
105
|
+
|
|
106
|
+
# Calculate combined PASR score using the identified judge types
|
|
107
|
+
work_df["pasr"] = self._calculate_combined_pasr(work_df, judge_types_found)
|
|
108
|
+
|
|
109
|
+
# Calculate log PASR for scoring
|
|
110
|
+
work_df["log_pasr"] = np.log(work_df["pasr"] + 1e-6)
|
|
111
|
+
|
|
112
|
+
# Calculate combined score (minimize both 1 - PASR and prefix_nll)
|
|
113
|
+
work_df["combined_score"] = (
|
|
114
|
+
-self.config.pasr_weight * work_df["log_pasr"] + work_df["prefix_nll"]
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Create DataFrame for selected prefixes
|
|
118
|
+
selected_prefixes = pd.DataFrame()
|
|
119
|
+
|
|
120
|
+
# Group by goal and apply selection process
|
|
121
|
+
for goal, group in work_df.groupby("goal"):
|
|
122
|
+
# Step 1: Select first prefix based on combined score
|
|
123
|
+
# Check if group is empty after potential filtering/issues
|
|
124
|
+
if (
|
|
125
|
+
group.empty
|
|
126
|
+
or "combined_score" not in group.columns
|
|
127
|
+
or group["combined_score"].isnull().all()
|
|
128
|
+
):
|
|
129
|
+
self.logger.warning(
|
|
130
|
+
f"Skipping goal '{goal[:50]}...' during selection due to empty group or missing/invalid scores."
|
|
131
|
+
)
|
|
132
|
+
continue
|
|
133
|
+
|
|
134
|
+
first_selection_idx = group["combined_score"].idxmin()
|
|
135
|
+
first_selection = group.loc[first_selection_idx]
|
|
136
|
+
|
|
137
|
+
# Step 2: Filter prefixes within PASR tolerance
|
|
138
|
+
remaining_candidates = group[
|
|
139
|
+
(group["pasr"] >= first_selection["pasr"] - self.config.pasr_tol)
|
|
140
|
+
& (group.index != first_selection.name)
|
|
141
|
+
]
|
|
142
|
+
|
|
143
|
+
# Step 3: Filter candidates within NLL tolerance
|
|
144
|
+
valid_candidates = remaining_candidates[
|
|
145
|
+
remaining_candidates["prefix_nll"]
|
|
146
|
+
<= first_selection["prefix_nll"] + self.config.nll_tol
|
|
147
|
+
]
|
|
148
|
+
|
|
149
|
+
# Initialize selections list with first selection
|
|
150
|
+
selections = [first_selection]
|
|
151
|
+
|
|
152
|
+
# Step 4: Iteratively select additional prefixes
|
|
153
|
+
for _ in range(self.config.n_prefixes_per_goal - 1):
|
|
154
|
+
# Remove candidates that are sub-prefixes of selected ones
|
|
155
|
+
valid_candidates = valid_candidates[
|
|
156
|
+
~valid_candidates["prefix"].apply(
|
|
157
|
+
lambda x: any(
|
|
158
|
+
str(x).startswith(str(sel["prefix"]))
|
|
159
|
+
for sel in selections
|
|
160
|
+
if sel is not None and "prefix" in sel and x is not None
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
if valid_candidates.empty:
|
|
166
|
+
break
|
|
167
|
+
|
|
168
|
+
# Select next prefix with lowest NLL
|
|
169
|
+
if (
|
|
170
|
+
"prefix_nll" not in valid_candidates.columns
|
|
171
|
+
or valid_candidates["prefix_nll"].isnull().all()
|
|
172
|
+
):
|
|
173
|
+
self.logger.warning(
|
|
174
|
+
f"Cannot select next prefix for goal '{goal[:50]}...' due to missing/invalid NLL scores in candidates."
|
|
175
|
+
)
|
|
176
|
+
break
|
|
177
|
+
next_selection = valid_candidates.nsmallest(1, "prefix_nll").iloc[0]
|
|
178
|
+
selections.append(next_selection)
|
|
179
|
+
valid_candidates = valid_candidates[
|
|
180
|
+
valid_candidates.index != next_selection.name
|
|
181
|
+
]
|
|
182
|
+
|
|
183
|
+
# Combine selections for this goal
|
|
184
|
+
combined_selection = pd.DataFrame(selections)
|
|
185
|
+
selected_prefixes = pd.concat([selected_prefixes, combined_selection])
|
|
186
|
+
|
|
187
|
+
# Reset index
|
|
188
|
+
selected_prefixes.reset_index(drop=True, inplace=True)
|
|
189
|
+
|
|
190
|
+
# Add the new columns (pasr, log_pasr, combined_score) to the output
|
|
191
|
+
# Ensure columns exist before trying to select them
|
|
192
|
+
output_columns = [
|
|
193
|
+
col
|
|
194
|
+
for col in list(df.columns) + ["pasr", "log_pasr", "combined_score"]
|
|
195
|
+
if col in selected_prefixes.columns
|
|
196
|
+
]
|
|
197
|
+
selected_prefixes = selected_prefixes[output_columns]
|
|
198
|
+
|
|
199
|
+
self.logger.info(
|
|
200
|
+
f"Selected {len(selected_prefixes)} prefixes across {len(df['goal'].unique())} goals"
|
|
201
|
+
)
|
|
202
|
+
return selected_prefixes
|
|
203
|
+
|
|
204
|
+
def _calculate_combined_pasr(
|
|
205
|
+
self, df: pd.DataFrame, judge_types: List[str]
|
|
206
|
+
) -> pd.Series:
|
|
207
|
+
"""
|
|
208
|
+
Calculate combined PASR score from specified judge types.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
df: DataFrame containing judge scores
|
|
212
|
+
judge_types: List of valid judge type strings (e.g., ["nuanced", "harmbench"])
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Series containing combined PASR scores
|
|
216
|
+
"""
|
|
217
|
+
judge_scores = []
|
|
218
|
+
for judge_type in judge_types: # Iterate through the list of type strings
|
|
219
|
+
column = self.judge_column_map[judge_type] # Use the type string for lookup
|
|
220
|
+
# Ensure column is numeric before appending
|
|
221
|
+
if column in df.columns:
|
|
222
|
+
try:
|
|
223
|
+
numeric_scores = pd.to_numeric(df[column], errors="coerce")
|
|
224
|
+
judge_scores.append(numeric_scores)
|
|
225
|
+
except Exception as e:
|
|
226
|
+
self.logger.warning(
|
|
227
|
+
f"Could not convert column '{column}' to numeric for PASR calculation. Skipping. Error: {e}"
|
|
228
|
+
)
|
|
229
|
+
else:
|
|
230
|
+
# This should be caught by initial validation, but as safeguard
|
|
231
|
+
self.logger.warning(
|
|
232
|
+
f"Column '{column}' for judge '{judge_type}' not found during PASR calculation."
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
if not judge_scores:
|
|
236
|
+
self.logger.warning(
|
|
237
|
+
"No valid judge scores found to calculate combined PASR. Returning zeros."
|
|
238
|
+
)
|
|
239
|
+
return pd.Series(0, index=df.index)
|
|
240
|
+
|
|
241
|
+
# Calculate mean of judge scores, handling potential NaNs after conversion
|
|
242
|
+
combined_scores_df = pd.concat(judge_scores, axis=1)
|
|
243
|
+
# Use mean, skipping NaNs. If a row has all NaNs, the mean will be NaN.
|
|
244
|
+
mean_scores = combined_scores_df.mean(axis=1, skipna=True)
|
|
245
|
+
# Fill any resulting NaNs (rows where all judges had NaN scores) with 0
|
|
246
|
+
return mean_scores.fillna(0)
|
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from typing import List, Dict, Union, Tuple, Optional
|
|
4
|
+
|
|
5
|
+
from hackagent.router.router import AgentRouter # Added
|
|
6
|
+
from hackagent.models import AgentTypeEnum # Added
|
|
7
|
+
from hackagent.client import AuthenticatedClient # Added
|
|
8
|
+
from .utils import get_checkpoint_path
|
|
9
|
+
from rich.progress import (
|
|
10
|
+
Progress,
|
|
11
|
+
BarColumn,
|
|
12
|
+
TextColumn,
|
|
13
|
+
TimeRemainingColumn,
|
|
14
|
+
MofNCompleteColumn,
|
|
15
|
+
SpinnerColumn,
|
|
16
|
+
)
|
|
17
|
+
from hackagent.logger import get_logger
|
|
18
|
+
|
|
19
|
+
logger = get_logger(__name__)
|
|
20
|
+
|
|
21
|
+
# Constants moved from main file
|
|
22
|
+
CUSTOM_CHAT_TEMPLATES = {
|
|
23
|
+
"georgesung/llama2_7b_chat_uncensored": "<s>### HUMAN:\n{content}\n\n### RESPONSE:\n",
|
|
24
|
+
"Tap-M/Luna-AI-Llama2-Uncensored": "<s>USER: {content}\n\nASSISTANT:",
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _construct_prompts(
|
|
29
|
+
goals: List[str],
|
|
30
|
+
meta_prefixes: List[str],
|
|
31
|
+
meta_prefixes_n_samples: Union[int, List[int]], # Allow int or list
|
|
32
|
+
) -> Tuple[List[Dict[str, str]], List[str], List[str]]:
|
|
33
|
+
"""Constructs prompts for the generator model."""
|
|
34
|
+
|
|
35
|
+
# Handle the case where meta_prefixes_n_samples is an integer vs a list
|
|
36
|
+
if isinstance(meta_prefixes_n_samples, list):
|
|
37
|
+
if len(meta_prefixes) != len(meta_prefixes_n_samples):
|
|
38
|
+
raise ValueError(
|
|
39
|
+
"Lengths of meta_prefixes and meta_prefixes_n_samples lists must match."
|
|
40
|
+
)
|
|
41
|
+
n_samples_list = meta_prefixes_n_samples
|
|
42
|
+
elif isinstance(meta_prefixes_n_samples, int):
|
|
43
|
+
# Apply the same integer sample count to all meta prefixes
|
|
44
|
+
n_samples_list = [meta_prefixes_n_samples] * len(meta_prefixes)
|
|
45
|
+
else:
|
|
46
|
+
raise TypeError("meta_prefixes_n_samples must be an int or a list of ints.")
|
|
47
|
+
|
|
48
|
+
formatted_inputs = []
|
|
49
|
+
current_goals = []
|
|
50
|
+
expanded_meta_prefixes = []
|
|
51
|
+
|
|
52
|
+
for goal in goals:
|
|
53
|
+
for meta_prefix, n_samples in zip(meta_prefixes, n_samples_list):
|
|
54
|
+
if n_samples <= 0:
|
|
55
|
+
continue
|
|
56
|
+
|
|
57
|
+
# chat = [{"role": "user", "content": goal}] # Not directly used for router prompt format
|
|
58
|
+
try:
|
|
59
|
+
# The prompt for the router will be the fully constructed context.
|
|
60
|
+
# Custom chat templating needs to happen before sending to router.
|
|
61
|
+
if meta_prefix in CUSTOM_CHAT_TEMPLATES:
|
|
62
|
+
# Assuming meta_prefix identifies the model type for templating,
|
|
63
|
+
# which is a bit indirect. Usually, model_string would be used.
|
|
64
|
+
# For now, we'll keep this logic, but the 'context' is the prompt.
|
|
65
|
+
prompt_content = CUSTOM_CHAT_TEMPLATES[meta_prefix].format(
|
|
66
|
+
content=goal
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
logger.warning(
|
|
70
|
+
f"Using basic formatting for prompt construction with meta_prefix: {meta_prefix}. No matching template found."
|
|
71
|
+
)
|
|
72
|
+
prompt_content = f"USER: {goal}\\nASSISTANT:"
|
|
73
|
+
|
|
74
|
+
# Append the actual meta_prefix text to the prompt that will be sent
|
|
75
|
+
final_prompt = prompt_content + meta_prefix
|
|
76
|
+
|
|
77
|
+
formatted_inputs.extend([final_prompt] * n_samples)
|
|
78
|
+
current_goals.extend([goal] * n_samples)
|
|
79
|
+
expanded_meta_prefixes.extend([meta_prefix] * n_samples)
|
|
80
|
+
except Exception as e:
|
|
81
|
+
logging.error(
|
|
82
|
+
f"Error formatting prompt for goal '{goal}' with meta_prefix '{meta_prefix}': {e}"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return formatted_inputs, current_goals, expanded_meta_prefixes
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
async def _generate_prefixes(
|
|
89
|
+
unique_goals: List[str],
|
|
90
|
+
config: Dict,
|
|
91
|
+
logger: logging.Logger,
|
|
92
|
+
client: AuthenticatedClient, # organization_id removed from here
|
|
93
|
+
) -> List[Dict]:
|
|
94
|
+
"""
|
|
95
|
+
Helper for step 1. Generate prefixes using AgentRouter with a LiteLLM agent.
|
|
96
|
+
"""
|
|
97
|
+
results = []
|
|
98
|
+
|
|
99
|
+
generator = config.get("generator", {})
|
|
100
|
+
if not generator:
|
|
101
|
+
logger.error("Missing 'generator'. Cannot initialize AgentRouter for LiteLLM.")
|
|
102
|
+
return results
|
|
103
|
+
|
|
104
|
+
# Map generator to adapter_operational_config for LiteLLM
|
|
105
|
+
# New keys for LiteLLMAgentAdapter: 'name', 'endpoint', 'api_key'
|
|
106
|
+
model_name = generator.get("identifier")
|
|
107
|
+
if not model_name:
|
|
108
|
+
logger.error(
|
|
109
|
+
"Missing 'identifier' in 'generator'. Cannot configure LiteLLM agent."
|
|
110
|
+
)
|
|
111
|
+
return results
|
|
112
|
+
|
|
113
|
+
adapter_operational_config = {
|
|
114
|
+
"name": model_name,
|
|
115
|
+
"endpoint": generator.get("endpoint"),
|
|
116
|
+
"api_key": generator.get("api_key"),
|
|
117
|
+
# Other params like max_new_tokens, temperature, top_p for adapter defaults
|
|
118
|
+
"max_new_tokens": config.get("max_new_tokens", 100),
|
|
119
|
+
"temperature": config.get("temperature", 0.8),
|
|
120
|
+
"top_p": config.get("top_p", 1.0),
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
router: Optional[AgentRouter] = None
|
|
124
|
+
registration_key: Optional[str] = None
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
logger.info(f"Initializing AgentRouter for LiteLLM model: {model_name}")
|
|
128
|
+
router = AgentRouter(
|
|
129
|
+
client=client,
|
|
130
|
+
name=model_name, # Name for backend agent record
|
|
131
|
+
agent_type=AgentTypeEnum.LITELMM,
|
|
132
|
+
endpoint=generator.get("endpoint"),
|
|
133
|
+
adapter_operational_config=adapter_operational_config,
|
|
134
|
+
metadata=adapter_operational_config.copy(),
|
|
135
|
+
overwrite_metadata=True,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if router._agent_registry:
|
|
139
|
+
registration_key = next(iter(router._agent_registry.keys()))
|
|
140
|
+
logger.info(
|
|
141
|
+
f"AgentRouter initialized. Registration key for LiteLLM agent: {registration_key}"
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
logger.error(
|
|
145
|
+
"AgentRouter initialized, but no agent adapter was registered."
|
|
146
|
+
)
|
|
147
|
+
return results # Cannot proceed
|
|
148
|
+
|
|
149
|
+
except Exception as e:
|
|
150
|
+
logger.error(
|
|
151
|
+
f"Error initializing AgentRouter for {model_name}: {e}",
|
|
152
|
+
exc_info=True,
|
|
153
|
+
)
|
|
154
|
+
return results
|
|
155
|
+
|
|
156
|
+
for do_sample in [False, True]:
|
|
157
|
+
progress_bar_description = (
|
|
158
|
+
"[cyan]Generating Prefixes (Random Sampling)..."
|
|
159
|
+
if do_sample
|
|
160
|
+
else "[cyan]Generating Prefixes (Greedy Decoding)..."
|
|
161
|
+
)
|
|
162
|
+
logger.info(
|
|
163
|
+
f"Generating with {'random sampling' if do_sample else 'greedy decoding'} using LiteLLM via AgentRouter..."
|
|
164
|
+
)
|
|
165
|
+
try:
|
|
166
|
+
# _construct_prompts now returns the full prompt string
|
|
167
|
+
prompts_to_send, current_goals, current_meta_prefixes = _construct_prompts(
|
|
168
|
+
unique_goals,
|
|
169
|
+
config.get("meta_prefixes", []),
|
|
170
|
+
config.get("meta_prefix_samples", []),
|
|
171
|
+
)
|
|
172
|
+
logger.debug(f"Prompts to send ({len(prompts_to_send)}): {prompts_to_send}")
|
|
173
|
+
except Exception as e:
|
|
174
|
+
logger.error(f"Error constructing prompts: {e}", exc_info=True)
|
|
175
|
+
continue
|
|
176
|
+
|
|
177
|
+
if not prompts_to_send:
|
|
178
|
+
logger.warning("No prompts to send, skipping completion.")
|
|
179
|
+
continue
|
|
180
|
+
|
|
181
|
+
# Loop through each constructed prompt and call the router
|
|
182
|
+
with Progress(
|
|
183
|
+
SpinnerColumn(),
|
|
184
|
+
TextColumn("[progress.description]{task.description}"),
|
|
185
|
+
BarColumn(),
|
|
186
|
+
MofNCompleteColumn(),
|
|
187
|
+
TextColumn("[progress.percentage]{task.percentage:>3.1f}%"),
|
|
188
|
+
TimeRemainingColumn(),
|
|
189
|
+
) as progress_bar:
|
|
190
|
+
task = progress_bar.add_task(
|
|
191
|
+
progress_bar_description, total=len(prompts_to_send)
|
|
192
|
+
)
|
|
193
|
+
for idx, current_prompt_text in enumerate(prompts_to_send):
|
|
194
|
+
goal_for_prompt = current_goals[idx]
|
|
195
|
+
meta_prefix_for_prompt = current_meta_prefixes[idx]
|
|
196
|
+
|
|
197
|
+
request_params = {
|
|
198
|
+
"prompt": current_prompt_text,
|
|
199
|
+
"max_new_tokens": config.get("max_new_tokens", 100),
|
|
200
|
+
"temperature": config.get("temperature", 0.8)
|
|
201
|
+
if do_sample
|
|
202
|
+
else 1e-2,
|
|
203
|
+
"top_p": config.get("top_p", 1.0),
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
completion_text = None
|
|
207
|
+
try:
|
|
208
|
+
# logger.info(f"Sending request to router for prompt: {current_prompt_text[:100]}...")
|
|
209
|
+
response = await router.route_request(
|
|
210
|
+
registration_key=registration_key, # type: ignore
|
|
211
|
+
request_data=request_params,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# logger.debug(f"Router response: {response}")
|
|
215
|
+
|
|
216
|
+
if response and response.get("error_message"):
|
|
217
|
+
logger.error(
|
|
218
|
+
f"Error from AgentRouter for prompt '{current_prompt_text[:50]}...': {response['error_message']}"
|
|
219
|
+
)
|
|
220
|
+
# Append error marker or skip
|
|
221
|
+
# For now, we'll try to get processed_response even if there's a partial error
|
|
222
|
+
# The adapter should handle this.
|
|
223
|
+
pass # Ensure block is not empty if all lines are comments
|
|
224
|
+
|
|
225
|
+
if response and response.get("processed_response"):
|
|
226
|
+
completion_text = response["processed_response"]
|
|
227
|
+
# The adapter's processed_response is assumed to be the full text (prompt + generation)
|
|
228
|
+
# We need to extract just the generated part.
|
|
229
|
+
if completion_text.startswith(current_prompt_text):
|
|
230
|
+
generated_part = completion_text[len(current_prompt_text) :]
|
|
231
|
+
else:
|
|
232
|
+
# Fallback or warning if the response doesn't start with the prompt
|
|
233
|
+
logger.warning(
|
|
234
|
+
f"Completion for '{current_prompt_text[:50]}...' did not start with the prompt. Using full response as generated part."
|
|
235
|
+
)
|
|
236
|
+
generated_part = completion_text
|
|
237
|
+
else:
|
|
238
|
+
logger.warning(
|
|
239
|
+
f"No 'processed_response' in router output for prompt: {current_prompt_text[:50]}..."
|
|
240
|
+
)
|
|
241
|
+
generated_part = " [GENERATION_VIA_ROUTER_FAILED]"
|
|
242
|
+
|
|
243
|
+
except Exception as e:
|
|
244
|
+
logger.error(
|
|
245
|
+
f"Exception during router.route_request for prompt '{current_prompt_text[:50]}...': {e}",
|
|
246
|
+
exc_info=True,
|
|
247
|
+
)
|
|
248
|
+
generated_part = " [ROUTER_REQUEST_EXCEPTION]"
|
|
249
|
+
|
|
250
|
+
# The 'prefix' should be the meta_prefix + generated_part
|
|
251
|
+
final_prefix = meta_prefix_for_prompt + generated_part
|
|
252
|
+
|
|
253
|
+
results.append(
|
|
254
|
+
{
|
|
255
|
+
"goal": goal_for_prompt,
|
|
256
|
+
"prefix": final_prefix,
|
|
257
|
+
"meta_prefix": meta_prefix_for_prompt,
|
|
258
|
+
"temperature": request_params["temperature"], # Use actual temp
|
|
259
|
+
"model_name": model_name, # Model used by the adapter
|
|
260
|
+
}
|
|
261
|
+
)
|
|
262
|
+
progress_bar.update(task, advance=1)
|
|
263
|
+
|
|
264
|
+
# No need to del router explicitly here, it goes out of scope.
|
|
265
|
+
return results
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
async def execute(
|
|
269
|
+
goals: List[str],
|
|
270
|
+
config: Dict,
|
|
271
|
+
logger: logging.Logger,
|
|
272
|
+
run_dir: str,
|
|
273
|
+
client: AuthenticatedClient, # organization_id removed from this call
|
|
274
|
+
) -> pd.DataFrame:
|
|
275
|
+
"""Generate initial prefixes using provided goals via AgentRouter."""
|
|
276
|
+
logger.info("Executing Step 1: Generating prefixes using AgentRouter")
|
|
277
|
+
|
|
278
|
+
if not goals:
|
|
279
|
+
logger.warning("Step 1 received no goals. Returning empty DataFrame.")
|
|
280
|
+
return pd.DataFrame(
|
|
281
|
+
columns=["goal", "prefix", "meta_prefix", "temperature", "model_name"]
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
generator = config.get("generator")
|
|
285
|
+
|
|
286
|
+
if not generator or not generator.get("identifier"):
|
|
287
|
+
logger.error(
|
|
288
|
+
"Step 1: Missing 'generator' or 'identifier' in config. Cannot generate prefixes."
|
|
289
|
+
)
|
|
290
|
+
return pd.DataFrame(
|
|
291
|
+
columns=["goal", "prefix", "meta_prefix", "temperature", "model_name"]
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
model_name_from_config = generator["identifier"]
|
|
295
|
+
logger.info(
|
|
296
|
+
f"Generating prefixes for {len(goals)} unique goals using AgentRouter with LiteLLM: {model_name_from_config}"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
all_results = await _generate_prefixes(
|
|
300
|
+
unique_goals=goals,
|
|
301
|
+
config=config,
|
|
302
|
+
logger=logger,
|
|
303
|
+
client=client, # organization_id removed from this call
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
if not all_results:
|
|
307
|
+
logger.warning("Step 1: No prefixes were generated via AgentRouter.")
|
|
308
|
+
results_df = pd.DataFrame(
|
|
309
|
+
columns=["goal", "prefix", "meta_prefix", "temperature", "model_name"]
|
|
310
|
+
)
|
|
311
|
+
else:
|
|
312
|
+
results_df = pd.DataFrame(all_results)
|
|
313
|
+
|
|
314
|
+
output_path = get_checkpoint_path(run_dir, 1)
|
|
315
|
+
try:
|
|
316
|
+
results_df.to_csv(output_path, index=False)
|
|
317
|
+
logger.info(
|
|
318
|
+
f"Step 1 complete. Generated {len(results_df)} total prefixes via AgentRouter"
|
|
319
|
+
)
|
|
320
|
+
logger.info(f"Checkpoint saved to {output_path}")
|
|
321
|
+
except Exception as e:
|
|
322
|
+
logger.error(f"Failed to save checkpoint for step 1 to {output_path}: {e}")
|
|
323
|
+
|
|
324
|
+
return results_df
|