eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for running batch evaluation on transformed N-variant data.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Any, Dict, List, Optional, Union
|
|
8
|
+
|
|
9
|
+
from ..models import EvaluateResult
|
|
10
|
+
from ..utils.module_loader import load_function as load_reward_function
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def run_batch_evaluation(
|
|
16
|
+
batch_jsonl_path: str,
|
|
17
|
+
reward_function_path: str,
|
|
18
|
+
output_path: str,
|
|
19
|
+
reward_function_kwargs: Optional[Dict[str, Any]] = None,
|
|
20
|
+
) -> List[Dict[str, Any]]:
|
|
21
|
+
"""
|
|
22
|
+
Run batch evaluation on transformed N-variant data.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
batch_jsonl_path: Path to the batch evaluation JSONL file
|
|
26
|
+
reward_function_path: Path to the batch reward function (e.g., "module.function")
|
|
27
|
+
output_path: Path to write the batch evaluation results
|
|
28
|
+
reward_function_kwargs: Additional kwargs for the reward function
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
List of batch evaluation results
|
|
32
|
+
"""
|
|
33
|
+
if reward_function_kwargs is None:
|
|
34
|
+
reward_function_kwargs = {}
|
|
35
|
+
|
|
36
|
+
# Load the batch reward function
|
|
37
|
+
reward_function = load_reward_function(reward_function_path)
|
|
38
|
+
|
|
39
|
+
# Verify it's a batch mode function
|
|
40
|
+
if not hasattr(reward_function, "_reward_function_mode"):
|
|
41
|
+
logger.warning(f"Reward function {reward_function_path} doesn't have mode metadata. Assuming batch mode.")
|
|
42
|
+
elif getattr(reward_function, "_reward_function_mode") != "batch":
|
|
43
|
+
raise ValueError(
|
|
44
|
+
f"Reward function {reward_function_path} is not configured for batch mode. Expected mode='batch'."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
results = []
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
with open(batch_jsonl_path, "r", encoding="utf-8") as f:
|
|
51
|
+
for line_num, line in enumerate(f, 1):
|
|
52
|
+
try:
|
|
53
|
+
data = json.loads(line.strip())
|
|
54
|
+
|
|
55
|
+
# Extract required fields
|
|
56
|
+
request_id = data.get("request_id")
|
|
57
|
+
rollouts_messages = data.get("rollouts_messages")
|
|
58
|
+
|
|
59
|
+
if not request_id:
|
|
60
|
+
logger.error(f"Line {line_num}: Missing request_id")
|
|
61
|
+
continue
|
|
62
|
+
|
|
63
|
+
if not rollouts_messages or not isinstance(rollouts_messages, list):
|
|
64
|
+
logger.error(f"Line {line_num}: Missing or invalid rollouts_messages")
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
# Prepare kwargs for the batch reward function
|
|
68
|
+
batch_kwargs = dict(reward_function_kwargs)
|
|
69
|
+
|
|
70
|
+
# Add other fields from the data as kwargs (excluding the main inputs)
|
|
71
|
+
excluded_fields = {
|
|
72
|
+
"request_id",
|
|
73
|
+
"rollouts_messages",
|
|
74
|
+
"num_variants",
|
|
75
|
+
"response_ids",
|
|
76
|
+
}
|
|
77
|
+
for key, value in data.items():
|
|
78
|
+
if key not in excluded_fields:
|
|
79
|
+
batch_kwargs[key] = value
|
|
80
|
+
|
|
81
|
+
# Call the batch reward function
|
|
82
|
+
try:
|
|
83
|
+
batch_results = reward_function(rollouts_messages=rollouts_messages, **batch_kwargs)
|
|
84
|
+
|
|
85
|
+
# Validate results
|
|
86
|
+
if not isinstance(batch_results, list):
|
|
87
|
+
raise ValueError(f"Batch reward function must return a list, got {type(batch_results)}")
|
|
88
|
+
|
|
89
|
+
if len(batch_results) != len(rollouts_messages):
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"Batch reward function returned {len(batch_results)} results "
|
|
92
|
+
f"but expected {len(rollouts_messages)} (one per rollout)"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Create result entries
|
|
96
|
+
response_ids = data.get("response_ids", list(range(len(rollouts_messages))))
|
|
97
|
+
|
|
98
|
+
for i, (response_id, eval_result) in enumerate(zip(response_ids, batch_results)):
|
|
99
|
+
if not isinstance(eval_result, EvaluateResult):
|
|
100
|
+
logger.error(f"Result {i} is not an EvaluateResult: {type(eval_result)}")
|
|
101
|
+
continue
|
|
102
|
+
|
|
103
|
+
result_entry = {
|
|
104
|
+
"request_id": request_id,
|
|
105
|
+
"response_id": response_id,
|
|
106
|
+
"rollout_index": i,
|
|
107
|
+
"evaluation_score": eval_result.score,
|
|
108
|
+
"evaluation_reason": eval_result.reason,
|
|
109
|
+
"is_score_valid": eval_result.is_score_valid,
|
|
110
|
+
"evaluation_metrics": (
|
|
111
|
+
{k: v.model_dump() for k, v in eval_result.metrics.items()}
|
|
112
|
+
if eval_result.metrics
|
|
113
|
+
else {}
|
|
114
|
+
),
|
|
115
|
+
# Include original metadata
|
|
116
|
+
**{k: v for k, v in data.items() if k not in excluded_fields},
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
results.append(result_entry)
|
|
120
|
+
|
|
121
|
+
except Exception as e:
|
|
122
|
+
logger.error(f"Error calling batch reward function for request {request_id}: {e}")
|
|
123
|
+
# Create error entries for each expected result
|
|
124
|
+
response_ids = data.get("response_ids", list(range(len(rollouts_messages))))
|
|
125
|
+
for i, response_id in enumerate(response_ids):
|
|
126
|
+
error_entry = {
|
|
127
|
+
"request_id": request_id,
|
|
128
|
+
"response_id": response_id,
|
|
129
|
+
"rollout_index": i,
|
|
130
|
+
"error": f"Batch evaluation failed: {str(e)}",
|
|
131
|
+
"evaluation_score": 0.0,
|
|
132
|
+
"is_score_valid": False,
|
|
133
|
+
}
|
|
134
|
+
results.append(error_entry)
|
|
135
|
+
|
|
136
|
+
except json.JSONDecodeError as e:
|
|
137
|
+
logger.error(f"Invalid JSON on line {line_num}: {e}")
|
|
138
|
+
continue
|
|
139
|
+
except Exception as e:
|
|
140
|
+
logger.error(f"Error processing line {line_num}: {e}")
|
|
141
|
+
continue
|
|
142
|
+
|
|
143
|
+
except FileNotFoundError:
|
|
144
|
+
raise FileNotFoundError(f"Batch JSONL file not found: {batch_jsonl_path}")
|
|
145
|
+
|
|
146
|
+
# Write results
|
|
147
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
148
|
+
for result in results:
|
|
149
|
+
f.write(json.dumps(result) + "\n")
|
|
150
|
+
|
|
151
|
+
logger.info(f"Batch evaluation completed. {len(results)} results written to {output_path}")
|
|
152
|
+
return results
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def create_sample_batch_reward_function():
|
|
156
|
+
"""
|
|
157
|
+
Create a sample batch reward function for testing.
|
|
158
|
+
This is a simple function that compares all variants and scores them relative to each other.
|
|
159
|
+
"""
|
|
160
|
+
from ..models import EvaluateResult, Message
|
|
161
|
+
from ..typed_interface import reward_function
|
|
162
|
+
|
|
163
|
+
@reward_function(mode="batch")
|
|
164
|
+
def sample_batch_reward(
|
|
165
|
+
rollouts_messages: List[List[Message]],
|
|
166
|
+
ground_truth_for_eval: Optional[str] = None,
|
|
167
|
+
**kwargs: Any,
|
|
168
|
+
) -> List[EvaluateResult]:
|
|
169
|
+
"""
|
|
170
|
+
Sample batch reward function that scores variants relative to each other.
|
|
171
|
+
|
|
172
|
+
This function demonstrates how to process multiple rollouts (variants) together
|
|
173
|
+
and return comparative scores.
|
|
174
|
+
"""
|
|
175
|
+
from ..models import MetricResult
|
|
176
|
+
|
|
177
|
+
results = []
|
|
178
|
+
|
|
179
|
+
# Extract the assistant responses from each rollout
|
|
180
|
+
assistant_responses = []
|
|
181
|
+
for rollout in rollouts_messages:
|
|
182
|
+
assistant_msg = None
|
|
183
|
+
for msg in reversed(rollout): # Find the last assistant message
|
|
184
|
+
if msg.role == "assistant":
|
|
185
|
+
assistant_msg = msg.content
|
|
186
|
+
break
|
|
187
|
+
assistant_responses.append(assistant_msg or "")
|
|
188
|
+
|
|
189
|
+
# Simple scoring: longer responses get higher scores (just for demonstration)
|
|
190
|
+
response_lengths = [len(response) for response in assistant_responses]
|
|
191
|
+
max_length = max(response_lengths) if response_lengths else 1
|
|
192
|
+
|
|
193
|
+
for i, (response, length) in enumerate(zip(assistant_responses, response_lengths)):
|
|
194
|
+
# Normalize score based on length (0.1 to 1.0)
|
|
195
|
+
base_score = 0.1 + 0.9 * (length / max_length)
|
|
196
|
+
|
|
197
|
+
# Add some variation based on position (earlier variants get slight bonus)
|
|
198
|
+
position_bonus = 0.1 * (1 - i / len(assistant_responses))
|
|
199
|
+
final_score = min(1.0, base_score + position_bonus)
|
|
200
|
+
|
|
201
|
+
result = EvaluateResult(
|
|
202
|
+
score=final_score,
|
|
203
|
+
reason=f"Variant {i}: Length={length}, Base={base_score:.2f}, Position bonus={position_bonus:.2f}",
|
|
204
|
+
is_score_valid=True,
|
|
205
|
+
metrics={
|
|
206
|
+
"response_length": MetricResult(
|
|
207
|
+
score=min(1.0, length / 100.0), # Normalize length score
|
|
208
|
+
reason=f"Response contains {length} characters",
|
|
209
|
+
is_score_valid=True,
|
|
210
|
+
)
|
|
211
|
+
},
|
|
212
|
+
)
|
|
213
|
+
results.append(result)
|
|
214
|
+
|
|
215
|
+
return results
|
|
216
|
+
|
|
217
|
+
return sample_batch_reward
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for transforming N-variant generation results into batch evaluation format.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
from typing import Any, Dict, List, Optional, Union
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def transform_n_variant_jsonl_to_batch_format(
|
|
14
|
+
input_file_path: str,
|
|
15
|
+
output_file_path: Optional[str] = None,
|
|
16
|
+
request_id_field: str = "request_id",
|
|
17
|
+
response_id_field: str = "response_id",
|
|
18
|
+
messages_field: str = "full_conversation_history",
|
|
19
|
+
fallback_messages_fields: List[str] = None,
|
|
20
|
+
) -> List[Dict[str, Any]]:
|
|
21
|
+
"""
|
|
22
|
+
Transform N-variant generation JSONL output into batch evaluation format.
|
|
23
|
+
|
|
24
|
+
This function groups variants by request_id and creates rollouts_messages
|
|
25
|
+
containing all variant conversations for each original request.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
input_file_path: Path to the N-variant generation JSONL file
|
|
29
|
+
output_file_path: Optional path to write the transformed data (if None, returns data only)
|
|
30
|
+
request_id_field: Field name containing the original request ID (default: "request_id")
|
|
31
|
+
response_id_field: Field name containing the variant response ID (default: "response_id")
|
|
32
|
+
messages_field: Primary field containing conversation messages (default: "full_conversation_history")
|
|
33
|
+
fallback_messages_fields: Fallback fields to construct messages if primary field is missing
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
List of batch evaluation entries, each containing rollouts_messages and other metadata
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
FileNotFoundError: If input file doesn't exist
|
|
40
|
+
ValueError: If required fields are missing or data format is invalid
|
|
41
|
+
"""
|
|
42
|
+
if fallback_messages_fields is None:
|
|
43
|
+
fallback_messages_fields = ["user_query", "system_prompt", "assistant_response"]
|
|
44
|
+
|
|
45
|
+
# Group variants by request_id
|
|
46
|
+
grouped_variants = defaultdict(list)
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
with open(input_file_path, "r", encoding="utf-8") as f:
|
|
50
|
+
for line_num, line in enumerate(f, 1):
|
|
51
|
+
try:
|
|
52
|
+
data = json.loads(line.strip())
|
|
53
|
+
|
|
54
|
+
# Skip lines with errors
|
|
55
|
+
if "error" in data:
|
|
56
|
+
logger.warning(f"Skipping line {line_num} due to error: {data.get('error')}")
|
|
57
|
+
continue
|
|
58
|
+
|
|
59
|
+
# Validate required fields
|
|
60
|
+
if request_id_field not in data:
|
|
61
|
+
raise ValueError(f"Line {line_num}: Missing required field '{request_id_field}'")
|
|
62
|
+
|
|
63
|
+
if response_id_field not in data:
|
|
64
|
+
raise ValueError(f"Line {line_num}: Missing required field '{response_id_field}'")
|
|
65
|
+
|
|
66
|
+
request_id = data[request_id_field]
|
|
67
|
+
response_id = data[response_id_field]
|
|
68
|
+
|
|
69
|
+
# Extract messages
|
|
70
|
+
messages = _extract_messages_from_data(data, messages_field, fallback_messages_fields, line_num)
|
|
71
|
+
|
|
72
|
+
# Store variant data
|
|
73
|
+
variant_data = {
|
|
74
|
+
"response_id": response_id,
|
|
75
|
+
"messages": messages,
|
|
76
|
+
"original_data": data, # Keep original data for metadata extraction
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
grouped_variants[request_id].append(variant_data)
|
|
80
|
+
|
|
81
|
+
except json.JSONDecodeError as e:
|
|
82
|
+
logger.error(f"Invalid JSON on line {line_num}: {e}")
|
|
83
|
+
continue
|
|
84
|
+
except Exception as e:
|
|
85
|
+
logger.error(f"Error processing line {line_num}: {e}")
|
|
86
|
+
continue
|
|
87
|
+
|
|
88
|
+
except FileNotFoundError:
|
|
89
|
+
raise FileNotFoundError(f"Input file not found: {input_file_path}")
|
|
90
|
+
|
|
91
|
+
# Transform grouped variants into batch format
|
|
92
|
+
batch_entries = []
|
|
93
|
+
|
|
94
|
+
for request_id, variants in grouped_variants.items():
|
|
95
|
+
# Sort variants by response_id to ensure consistent ordering
|
|
96
|
+
variants.sort(key=lambda x: x["response_id"])
|
|
97
|
+
|
|
98
|
+
# Extract rollouts_messages (list of message lists)
|
|
99
|
+
rollouts_messages = [variant["messages"] for variant in variants]
|
|
100
|
+
|
|
101
|
+
# Extract common metadata from the first variant (assuming it's consistent across variants)
|
|
102
|
+
first_variant = variants[0]["original_data"]
|
|
103
|
+
|
|
104
|
+
# Create batch entry
|
|
105
|
+
batch_entry = {
|
|
106
|
+
"request_id": request_id,
|
|
107
|
+
"rollouts_messages": rollouts_messages,
|
|
108
|
+
"num_variants": len(variants),
|
|
109
|
+
"response_ids": [variant["response_id"] for variant in variants],
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
# Add common fields as kwargs (excluding variant-specific fields)
|
|
113
|
+
excluded_fields = {
|
|
114
|
+
"id",
|
|
115
|
+
request_id_field,
|
|
116
|
+
response_id_field,
|
|
117
|
+
messages_field,
|
|
118
|
+
"full_conversation_history",
|
|
119
|
+
"assistant_response",
|
|
120
|
+
"evaluation_score",
|
|
121
|
+
"evaluation_reason",
|
|
122
|
+
"evaluation_metrics",
|
|
123
|
+
"executed_tool_calls",
|
|
124
|
+
"discovered_tools",
|
|
125
|
+
"final_mcp_state_captured",
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
for key, value in first_variant.items():
|
|
129
|
+
if key not in excluded_fields:
|
|
130
|
+
batch_entry[key] = value
|
|
131
|
+
|
|
132
|
+
batch_entries.append(batch_entry)
|
|
133
|
+
|
|
134
|
+
# Write output file if specified
|
|
135
|
+
if output_file_path:
|
|
136
|
+
with open(output_file_path, "w", encoding="utf-8") as f:
|
|
137
|
+
for entry in batch_entries:
|
|
138
|
+
f.write(json.dumps(entry) + "\n")
|
|
139
|
+
logger.info(f"Transformed {len(batch_entries)} batch entries written to {output_file_path}")
|
|
140
|
+
|
|
141
|
+
return batch_entries
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _extract_messages_from_data(
|
|
145
|
+
data: Dict[str, Any], primary_field: str, fallback_fields: List[str], line_num: int
|
|
146
|
+
) -> List[Dict[str, Any]]:
|
|
147
|
+
"""
|
|
148
|
+
Extract conversation messages from variant data.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
data: Variant data dictionary
|
|
152
|
+
primary_field: Primary field containing messages
|
|
153
|
+
fallback_fields: Fallback fields to construct messages
|
|
154
|
+
line_num: Line number for error reporting
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
List of message dictionaries
|
|
158
|
+
"""
|
|
159
|
+
# Try primary field first
|
|
160
|
+
if primary_field in data and data[primary_field]:
|
|
161
|
+
messages = data[primary_field]
|
|
162
|
+
if isinstance(messages, list):
|
|
163
|
+
return messages
|
|
164
|
+
else:
|
|
165
|
+
logger.warning(f"Line {line_num}: {primary_field} is not a list, trying fallback")
|
|
166
|
+
|
|
167
|
+
# Try to construct messages from fallback fields
|
|
168
|
+
messages = []
|
|
169
|
+
|
|
170
|
+
# Add system message if available
|
|
171
|
+
if "system_prompt" in data and data["system_prompt"]:
|
|
172
|
+
messages.append({"role": "system", "content": data["system_prompt"]})
|
|
173
|
+
|
|
174
|
+
# Add user message if available
|
|
175
|
+
if "user_query" in data and data["user_query"]:
|
|
176
|
+
messages.append({"role": "user", "content": data["user_query"]})
|
|
177
|
+
|
|
178
|
+
# Add assistant message if available
|
|
179
|
+
if "assistant_response" in data and data["assistant_response"]:
|
|
180
|
+
messages.append({"role": "assistant", "content": data["assistant_response"]})
|
|
181
|
+
|
|
182
|
+
if not messages:
|
|
183
|
+
raise ValueError(f"Line {line_num}: Could not extract messages from any available fields")
|
|
184
|
+
|
|
185
|
+
return messages
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def create_batch_evaluation_dataset(n_variant_jsonl_path: str, output_jsonl_path: str, **transform_kwargs) -> str:
|
|
189
|
+
"""
|
|
190
|
+
Convenience function to create a batch evaluation dataset from N-variant generation output.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
n_variant_jsonl_path: Path to N-variant generation JSONL file
|
|
194
|
+
output_jsonl_path: Path for the batch evaluation JSONL file
|
|
195
|
+
**transform_kwargs: Additional arguments for transform_n_variant_jsonl_to_batch_format
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Path to the created batch evaluation dataset
|
|
199
|
+
"""
|
|
200
|
+
transform_n_variant_jsonl_to_batch_format(
|
|
201
|
+
input_file_path=n_variant_jsonl_path,
|
|
202
|
+
output_file_path=output_jsonl_path,
|
|
203
|
+
**transform_kwargs,
|
|
204
|
+
)
|
|
205
|
+
return output_jsonl_path
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from datasets import Dataset
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from datasets import Dataset
|
|
9
|
+
|
|
10
|
+
HAS_DATASETS_LIB = True
|
|
11
|
+
except ImportError:
|
|
12
|
+
HAS_DATASETS_LIB = False
|
|
13
|
+
if not TYPE_CHECKING:
|
|
14
|
+
|
|
15
|
+
class Dataset:
|
|
16
|
+
"""Placeholder for HuggingFace Dataset for when the library is not installed."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def load_jsonl_to_hf_dataset(
|
|
23
|
+
dataset_path: str,
|
|
24
|
+
transform_fn: Optional[Callable[[Dict[str, Any]], Optional[Dict[str, Any]]]] = None,
|
|
25
|
+
prompt_column: str = "prompt",
|
|
26
|
+
required_columns: Optional[List[str]] = None,
|
|
27
|
+
dataset_filter_fn: Optional[Callable[[Dict[str, Any]], bool]] = None,
|
|
28
|
+
) -> Optional["Dataset"]:
|
|
29
|
+
"""
|
|
30
|
+
Loads a JSONL file into a HuggingFace Dataset, optionally applying a
|
|
31
|
+
transformation to each sample and ensuring required columns are present.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
dataset_path: Path to the JSONL file.
|
|
35
|
+
transform_fn: An optional function to apply to each raw dictionary (sample)
|
|
36
|
+
from the JSONL file. It should take a dict and return a dict
|
|
37
|
+
(or None to skip the sample).
|
|
38
|
+
prompt_column: The name of the column expected to contain the prompt for TRL.
|
|
39
|
+
Defaults to "prompt".
|
|
40
|
+
required_columns: A list of column names that must be present in the
|
|
41
|
+
final dataset (after transformation).
|
|
42
|
+
dataset_filter_fn: An optional function to filter samples from the dataset
|
|
43
|
+
after transformation. It should take a dict and return
|
|
44
|
+
True to keep the sample, False to discard.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
A HuggingFace Dataset object, or None if the datasets library is not installed
|
|
48
|
+
or if an error occurs.
|
|
49
|
+
"""
|
|
50
|
+
if not HAS_DATASETS_LIB:
|
|
51
|
+
print("The 'datasets' library is not installed. Please install it with 'pip install datasets'.")
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
processed_samples: List[Dict[str, Any]] = []
|
|
55
|
+
try:
|
|
56
|
+
with open(dataset_path, "r", encoding="utf-8") as f:
|
|
57
|
+
for line_number, line in enumerate(f, 1):
|
|
58
|
+
try:
|
|
59
|
+
raw_sample = json.loads(line.strip())
|
|
60
|
+
|
|
61
|
+
transformed_sample: Optional[Dict[str, Any]]
|
|
62
|
+
if transform_fn:
|
|
63
|
+
transformed_sample = transform_fn(raw_sample)
|
|
64
|
+
else:
|
|
65
|
+
transformed_sample = raw_sample
|
|
66
|
+
|
|
67
|
+
if transformed_sample is None:
|
|
68
|
+
continue
|
|
69
|
+
|
|
70
|
+
if dataset_filter_fn and not dataset_filter_fn(transformed_sample):
|
|
71
|
+
continue
|
|
72
|
+
|
|
73
|
+
processed_samples.append(transformed_sample)
|
|
74
|
+
|
|
75
|
+
except json.JSONDecodeError as e:
|
|
76
|
+
print(f"Warning: Skipping line {line_number} in {dataset_path} due to JSON decode error: {e}")
|
|
77
|
+
except Exception as e:
|
|
78
|
+
print(
|
|
79
|
+
f"Warning: Skipping line {line_number} in {dataset_path} due to error in transform_fn or filter_fn: {e}"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if not processed_samples:
|
|
83
|
+
print(f"Warning: No samples were processed from {dataset_path}.")
|
|
84
|
+
return Dataset.from_list([])
|
|
85
|
+
|
|
86
|
+
hf_dataset = Dataset.from_list(processed_samples)
|
|
87
|
+
|
|
88
|
+
if prompt_column not in hf_dataset.column_names:
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"Dataset from {dataset_path} must contain a '{prompt_column}' column after transformation."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
final_required_columns = set(required_columns or [])
|
|
94
|
+
final_required_columns.add(prompt_column)
|
|
95
|
+
|
|
96
|
+
for col in final_required_columns:
|
|
97
|
+
if col not in hf_dataset.column_names:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"Dataset from {dataset_path} must contain a '{col}' column after transformation for the reward function/TRL."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return hf_dataset
|
|
103
|
+
|
|
104
|
+
except FileNotFoundError:
|
|
105
|
+
print(f"Error: Dataset file not found at {dataset_path}")
|
|
106
|
+
return None
|
|
107
|
+
except ValueError as ve:
|
|
108
|
+
print(f"Error: {ve}")
|
|
109
|
+
return None
|
|
110
|
+
except Exception as e:
|
|
111
|
+
print(f"An unexpected error occurred while loading dataset {dataset_path}: {e}")
|
|
112
|
+
return None
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility for dynamically loading modules and functions.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import importlib
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Any, Callable
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def load_function(import_path: str) -> Callable[..., Any]:
|
|
13
|
+
"""
|
|
14
|
+
Dynamically loads a function given its full import path.
|
|
15
|
+
Example: "my_package.my_module.my_function"
|
|
16
|
+
"""
|
|
17
|
+
try:
|
|
18
|
+
module_path, function_name = import_path.rsplit(".", 1)
|
|
19
|
+
module = importlib.import_module(module_path)
|
|
20
|
+
func = getattr(module, function_name)
|
|
21
|
+
if not callable(func):
|
|
22
|
+
raise AttributeError(f"'{function_name}' in module '{module_path}' is not callable.")
|
|
23
|
+
logger.info(f"Successfully loaded function '{function_name}' from '{module_path}'.")
|
|
24
|
+
return func
|
|
25
|
+
except ImportError as e:
|
|
26
|
+
logger.error(f"Failed to import module from path '{import_path}': {e}")
|
|
27
|
+
raise
|
|
28
|
+
except AttributeError as e:
|
|
29
|
+
logger.error(f"Failed to find or access function in path '{import_path}': {e}")
|
|
30
|
+
raise
|
|
31
|
+
except Exception as e:
|
|
32
|
+
logger.error(f"An unexpected error occurred while loading function from '{import_path}': {e}")
|
|
33
|
+
raise
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# Example usage:
|
|
37
|
+
# if __name__ == '__main__':
|
|
38
|
+
# try:
|
|
39
|
+
# # Assuming you have a eval_protocol.rewards.math module with math_reward function
|
|
40
|
+
# math_reward_func = load_function("eval_protocol.rewards.math.math_reward")
|
|
41
|
+
# print(f"Loaded: {math_reward_func}")
|
|
42
|
+
# # You could then call it, e.g., if it took simple args: math_reward_func(arg1="test")
|
|
43
|
+
# except Exception as e:
|
|
44
|
+
# print(f"Test loading failed: {e}")
|
|
45
|
+
|
|
46
|
+
# try:
|
|
47
|
+
# # Test with a non-existent function
|
|
48
|
+
# load_function("eval_protocol.rewards.math.non_existent_function")
|
|
49
|
+
# except Exception as e:
|
|
50
|
+
# print(f"Test loading non-existent function failed as expected: {e}")
|
|
51
|
+
|
|
52
|
+
# try:
|
|
53
|
+
# # Test with a non-existent module
|
|
54
|
+
# load_function("non_existent_package.non_existent_module.some_function")
|
|
55
|
+
# except Exception as e:
|
|
56
|
+
# print(f"Test loading non-existent module failed as expected: {e}")
|