eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CLI command for running agent evaluations using the ForkableResource framework.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import yaml
|
|
9
|
+
except ImportError:
|
|
10
|
+
import sys
|
|
11
|
+
import types
|
|
12
|
+
|
|
13
|
+
# Create a stub module if yaml is not installed
|
|
14
|
+
yaml = types.ModuleType("yaml")
|
|
15
|
+
|
|
16
|
+
def dummy_safe_load(x):
|
|
17
|
+
return None
|
|
18
|
+
|
|
19
|
+
def dummy_dump(x, **kwargs):
|
|
20
|
+
return None
|
|
21
|
+
|
|
22
|
+
yaml.safe_load = dummy_safe_load # type: ignore[assignment]
|
|
23
|
+
yaml.dump = dummy_dump # type: ignore[assignment]
|
|
24
|
+
|
|
25
|
+
import json # Fallback or for explicit JSON files
|
|
26
|
+
import logging # For logger instance
|
|
27
|
+
import os # For environment variables
|
|
28
|
+
from pathlib import Path
|
|
29
|
+
|
|
30
|
+
from pydantic import ValidationError
|
|
31
|
+
|
|
32
|
+
from eval_protocol.agent import Orchestrator
|
|
33
|
+
from eval_protocol.agent.task_manager import TaskManager
|
|
34
|
+
from eval_protocol.models import TaskDefinitionModel # Import the new Pydantic model
|
|
35
|
+
|
|
36
|
+
# setup_logging is already called in cli.py's main, but good for standalone use if any
|
|
37
|
+
# from .common import setup_logging
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def agent_eval_command(args):
|
|
41
|
+
"""
|
|
42
|
+
Run agent evaluation using the Orchestrator and ForkableResource framework.
|
|
43
|
+
"""
|
|
44
|
+
logger = logging.getLogger("agent_eval")
|
|
45
|
+
logger.info("Starting agent-eval command.")
|
|
46
|
+
|
|
47
|
+
task_manager = TaskManager()
|
|
48
|
+
|
|
49
|
+
if not args.task_def:
|
|
50
|
+
logger.error("Error: --task-def (path to task definition YAML file or directory) is required.")
|
|
51
|
+
return 1
|
|
52
|
+
|
|
53
|
+
task_def_path = Path(args.task_def)
|
|
54
|
+
|
|
55
|
+
registered_task_ids = []
|
|
56
|
+
if task_def_path.is_file():
|
|
57
|
+
task_def = task_manager._load_task_from_file(str(task_def_path))
|
|
58
|
+
if task_def:
|
|
59
|
+
task_id = task_manager.register_task(task_def)
|
|
60
|
+
registered_task_ids.append(task_id)
|
|
61
|
+
else:
|
|
62
|
+
logger.error(f"Failed to load task definition from {task_def_path}")
|
|
63
|
+
return 1
|
|
64
|
+
elif task_def_path.is_dir():
|
|
65
|
+
registered_task_ids = task_manager.register_tasks_from_directory(str(task_def_path))
|
|
66
|
+
if not registered_task_ids:
|
|
67
|
+
logger.error(f"No valid task definitions found in directory: {task_def_path}")
|
|
68
|
+
return 1
|
|
69
|
+
else:
|
|
70
|
+
logger.error(f"Task definition path not found or invalid: {task_def_path}")
|
|
71
|
+
return 1
|
|
72
|
+
|
|
73
|
+
logger.info(f"Registered {len(registered_task_ids)} tasks: {registered_task_ids}")
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
|
|
77
|
+
async def main_flow():
|
|
78
|
+
if getattr(args, "model", None):
|
|
79
|
+
original_model = os.environ.get("MODEL_AGENT")
|
|
80
|
+
os.environ["MODEL_AGENT"] = args.model
|
|
81
|
+
logger.info(f"Model overridden to: {args.model}")
|
|
82
|
+
|
|
83
|
+
parallel = getattr(args, "parallel", False)
|
|
84
|
+
max_concurrency = getattr(args, "max_concurrency", 3)
|
|
85
|
+
filter_tasks = getattr(args, "filter", None)
|
|
86
|
+
|
|
87
|
+
tasks_to_run = registered_task_ids
|
|
88
|
+
if filter_tasks:
|
|
89
|
+
tasks_to_run = [tid for tid in registered_task_ids if tid in filter_tasks]
|
|
90
|
+
if not tasks_to_run:
|
|
91
|
+
logger.warning(f"No tasks match the specified filter: {filter_tasks}")
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
num_rollouts_override = getattr(args, "num_rollouts", None)
|
|
96
|
+
results = await task_manager.execute_tasks(
|
|
97
|
+
task_ids=tasks_to_run,
|
|
98
|
+
parallel=parallel,
|
|
99
|
+
max_concurrency=max_concurrency,
|
|
100
|
+
num_rollouts_override=num_rollouts_override,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
logger.info(f"Execution completed for {len(results)} tasks")
|
|
104
|
+
for task_id, result in results.items():
|
|
105
|
+
if isinstance(result, dict) and "error" in result:
|
|
106
|
+
logger.error(f"Task '{task_id}' failed: {result['error']}")
|
|
107
|
+
elif isinstance(result, dict) and result.get("aggregated", False):
|
|
108
|
+
# Handle aggregated results from multiple rollouts
|
|
109
|
+
logger.info(f"Task '{task_id}' batch results:")
|
|
110
|
+
logger.info(
|
|
111
|
+
f" - Rollouts: {result['successful_rollouts']}/{result['num_rollouts']} successful ({result.get('failed_rollouts', 0)} failed)"
|
|
112
|
+
)
|
|
113
|
+
logger.info(f" - Success rate: {result['success_rate']:.2%}")
|
|
114
|
+
logger.info(f" - Average score: {result['avg_score']:.4f}")
|
|
115
|
+
logger.info(f" - Standard deviation: {result.get('std_dev', 0.0):.4f}")
|
|
116
|
+
logger.info(f" - Score range: {result['min_score']:.4f} - {result['max_score']:.4f}")
|
|
117
|
+
if "aggregated_metrics" in result:
|
|
118
|
+
logger.info(f" - Aggregated metrics:")
|
|
119
|
+
for metric_name, metric_data in result["aggregated_metrics"].items():
|
|
120
|
+
logger.info(
|
|
121
|
+
f" * {metric_name}: avg={metric_data['avg_score']:.4f}, range={metric_data['min_score']:.4f}-{metric_data['max_score']:.4f}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Log path to detailed results file
|
|
125
|
+
if result.get("timestamp"):
|
|
126
|
+
timestamp = (
|
|
127
|
+
result["timestamp"].replace(":", "").replace("-", "").replace("T", "_").split(".")[0]
|
|
128
|
+
)
|
|
129
|
+
# Use the trajectory filename format that matches TaskManager
|
|
130
|
+
trajectory_file = f"trajectory_{task_id}_{timestamp}.jsonl"
|
|
131
|
+
logger.info(f" - Trajectory data saved to: {trajectory_file}")
|
|
132
|
+
elif isinstance(result, dict) and "score" in result:
|
|
133
|
+
logger.info(f"Task '{task_id}' score: {result['score']}")
|
|
134
|
+
else:
|
|
135
|
+
logger.info(f"Task '{task_id}' completed")
|
|
136
|
+
finally:
|
|
137
|
+
await task_manager.cleanup()
|
|
138
|
+
|
|
139
|
+
asyncio.run(main_flow())
|
|
140
|
+
logger.info("agent-eval command finished successfully.")
|
|
141
|
+
return 0
|
|
142
|
+
except Exception as e:
|
|
143
|
+
logger.error(f"Error during agent-eval execution: {e}")
|
|
144
|
+
import traceback
|
|
145
|
+
|
|
146
|
+
logger.debug(traceback.format_exc())
|
|
147
|
+
return 1
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def bfcl_eval_command(args):
|
|
151
|
+
"""
|
|
152
|
+
Run BFCL agent evaluations using the refactored framework.
|
|
153
|
+
This command specifically manages BFCL task evaluation.
|
|
154
|
+
"""
|
|
155
|
+
logger = logging.getLogger("bfcl_eval")
|
|
156
|
+
logger.info("Starting BFCL evaluation command.")
|
|
157
|
+
|
|
158
|
+
task_manager = TaskManager()
|
|
159
|
+
|
|
160
|
+
task_dir = Path(args.task_dir)
|
|
161
|
+
if not task_dir.is_dir():
|
|
162
|
+
logger.error(f"Task directory not found: {task_dir}")
|
|
163
|
+
return 1
|
|
164
|
+
|
|
165
|
+
logger.info(f"Registering BFCL tasks from {task_dir}")
|
|
166
|
+
|
|
167
|
+
try:
|
|
168
|
+
registered_task_ids = []
|
|
169
|
+
|
|
170
|
+
if args.task_id:
|
|
171
|
+
task_path = task_dir / f"{args.task_id}.yaml"
|
|
172
|
+
if not task_path.exists():
|
|
173
|
+
logger.error(f"Task file not found: {task_path}")
|
|
174
|
+
return 1
|
|
175
|
+
|
|
176
|
+
task_def = task_manager._load_task_from_file(str(task_path))
|
|
177
|
+
if task_def:
|
|
178
|
+
task_id = task_manager.register_task(task_def)
|
|
179
|
+
registered_task_ids.append(task_id)
|
|
180
|
+
logger.info(f"Registered task: {task_id}")
|
|
181
|
+
else:
|
|
182
|
+
logger.error(f"Failed to load task from {task_path}")
|
|
183
|
+
return 1
|
|
184
|
+
else:
|
|
185
|
+
registered_task_ids = task_manager.register_tasks_from_directory(str(task_dir))
|
|
186
|
+
if not registered_task_ids:
|
|
187
|
+
logger.error(f"No valid BFCL tasks found in directory: {task_dir}")
|
|
188
|
+
return 1
|
|
189
|
+
logger.info(f"Registered {len(registered_task_ids)} BFCL tasks")
|
|
190
|
+
|
|
191
|
+
async def main_flow():
|
|
192
|
+
if args.model:
|
|
193
|
+
original_model = os.environ.get("MODEL_AGENT")
|
|
194
|
+
os.environ["MODEL_AGENT"] = args.model
|
|
195
|
+
logger.info(f"Model overridden to: {args.model}")
|
|
196
|
+
|
|
197
|
+
if args.output_dir:
|
|
198
|
+
output_path = Path(args.output_dir)
|
|
199
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
200
|
+
logger.info(f"Results will be saved to {output_path}")
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
results = await task_manager.execute_tasks(
|
|
204
|
+
task_ids=registered_task_ids,
|
|
205
|
+
parallel=args.parallel,
|
|
206
|
+
max_concurrency=args.max_concurrency,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
logger.info(f"BFCL evaluation completed for {len(results)} tasks")
|
|
210
|
+
for task_id, result in results.items():
|
|
211
|
+
if isinstance(result, dict) and "error" in result:
|
|
212
|
+
logger.error(f"Task '{task_id}' failed: {result['error']}")
|
|
213
|
+
elif isinstance(result, dict) and "score" in result:
|
|
214
|
+
logger.info(f"Task '{task_id}' score: {result['score']}")
|
|
215
|
+
|
|
216
|
+
# More detailed results for BFCL
|
|
217
|
+
if "format_score" in result:
|
|
218
|
+
logger.info(f"Task '{task_id}' format score: {result['format_score']}")
|
|
219
|
+
if "state_match" in result:
|
|
220
|
+
logger.info(f"Task '{task_id}' state match: {result['state_match']}")
|
|
221
|
+
else:
|
|
222
|
+
logger.info(f"Task '{task_id}' completed with result: {result}")
|
|
223
|
+
|
|
224
|
+
if args.output_dir:
|
|
225
|
+
results_file = Path(args.output_dir) / "bfcl_results.json"
|
|
226
|
+
|
|
227
|
+
# Convert results to JSON-serializable format
|
|
228
|
+
serializable_results = {}
|
|
229
|
+
for task_id, result in results.items():
|
|
230
|
+
if hasattr(result, "dict"):
|
|
231
|
+
# Handle Pydantic models
|
|
232
|
+
serializable_results[task_id] = result.dict()
|
|
233
|
+
elif isinstance(result, dict):
|
|
234
|
+
# Handle dictionaries with potentially non-serializable values
|
|
235
|
+
serializable_dict = {}
|
|
236
|
+
for k, v in result.items():
|
|
237
|
+
if hasattr(v, "dict"):
|
|
238
|
+
serializable_dict[k] = v.dict()
|
|
239
|
+
elif hasattr(v, "__dict__"):
|
|
240
|
+
serializable_dict[k] = str(v)
|
|
241
|
+
else:
|
|
242
|
+
serializable_dict[k] = v
|
|
243
|
+
serializable_results[task_id] = serializable_dict
|
|
244
|
+
else:
|
|
245
|
+
# Handle other objects by converting to string
|
|
246
|
+
serializable_results[task_id] = str(result)
|
|
247
|
+
|
|
248
|
+
with open(results_file, "w") as f:
|
|
249
|
+
json.dump(serializable_results, f, indent=2)
|
|
250
|
+
logger.info(f"Results saved to {results_file}")
|
|
251
|
+
|
|
252
|
+
finally:
|
|
253
|
+
await task_manager.cleanup()
|
|
254
|
+
|
|
255
|
+
asyncio.run(main_flow())
|
|
256
|
+
logger.info("BFCL evaluation completed successfully.")
|
|
257
|
+
return 0
|
|
258
|
+
|
|
259
|
+
except Exception as e:
|
|
260
|
+
logger.error(f"Error during BFCL evaluation: {e}")
|
|
261
|
+
import traceback
|
|
262
|
+
|
|
263
|
+
logger.debug(traceback.format_exc())
|
|
264
|
+
return 1
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Common utility functions for the Reward Kit CLI.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
from typing import Any, Dict, Iterator, List, Optional
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def setup_logging(verbose=False, debug=False):
|
|
14
|
+
"""Setup logging configuration"""
|
|
15
|
+
if debug:
|
|
16
|
+
log_level = logging.DEBUG
|
|
17
|
+
# More detailed format for debug
|
|
18
|
+
format_str = "[%(asctime)s][%(name)s][%(levelname)s] - %(pathname)s:%(lineno)d - %(message)s"
|
|
19
|
+
elif verbose: # --verbose flag
|
|
20
|
+
log_level = logging.INFO
|
|
21
|
+
# Consistent format, similar to user's logs but with name
|
|
22
|
+
format_str = "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s"
|
|
23
|
+
else: # Default (neither --verbose nor --debug)
|
|
24
|
+
log_level = logging.INFO # Changed from WARNING to INFO
|
|
25
|
+
# Use the same format as verbose for default INFO level
|
|
26
|
+
format_str = "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s"
|
|
27
|
+
|
|
28
|
+
logging.basicConfig(level=log_level, format=format_str, datefmt="%Y-%m-%d %H:%M:%S")
|
|
29
|
+
|
|
30
|
+
# Set higher levels for noisy libraries unless in full debug mode
|
|
31
|
+
if not debug:
|
|
32
|
+
noisy_loggers = ["httpx", "mcp", "urllib3", "asyncio", "hpack", "httpcore"]
|
|
33
|
+
for logger_name in noisy_loggers:
|
|
34
|
+
logging.getLogger(logger_name).setLevel(logging.WARNING)
|
|
35
|
+
|
|
36
|
+
# Ensure eval_protocol's own loggers respect the overall log_level,
|
|
37
|
+
# overriding any specific DEBUG settings in submodules unless --debug is used.
|
|
38
|
+
# If log_level is WARNING (default), eval_protocol INFO and DEBUG logs will be suppressed.
|
|
39
|
+
# If log_level is INFO (--verbose), eval_protocol DEBUG logs will be suppressed.
|
|
40
|
+
# If log_level is DEBUG (--debug), all eval_protocol logs show.
|
|
41
|
+
logging.getLogger("eval_protocol").setLevel(log_level)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def check_environment():
|
|
45
|
+
"""Check if required environment variables are set for general commands."""
|
|
46
|
+
if not os.environ.get("FIREWORKS_API_KEY"):
|
|
47
|
+
logger.warning("FIREWORKS_API_KEY environment variable is not set.")
|
|
48
|
+
logger.warning("This is required for API calls. Set this variable before running the command.")
|
|
49
|
+
logger.warning("Example: FIREWORKS_API_KEY=$DEV_FIREWORKS_API_KEY reward-kit [command]")
|
|
50
|
+
return False
|
|
51
|
+
return True
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def check_agent_environment(test_mode=False):
|
|
55
|
+
"""Check if required environment variables are set for agent evaluation commands."""
|
|
56
|
+
missing_vars = []
|
|
57
|
+
if not os.environ.get("MODEL_AGENT"):
|
|
58
|
+
missing_vars.append("MODEL_AGENT")
|
|
59
|
+
|
|
60
|
+
if test_mode:
|
|
61
|
+
if missing_vars:
|
|
62
|
+
logger.info(f"Note: The following environment variables are not set: {', '.join(missing_vars)}")
|
|
63
|
+
logger.info("Since you're running in test mode, these are not strictly required for all operations.")
|
|
64
|
+
return True
|
|
65
|
+
|
|
66
|
+
if missing_vars:
|
|
67
|
+
logger.warning(f"The following environment variables are not set: {', '.join(missing_vars)}")
|
|
68
|
+
logger.warning(
|
|
69
|
+
"These are typically required for full agent evaluation. Set these variables for full functionality."
|
|
70
|
+
)
|
|
71
|
+
logger.warning("Example: MODEL_AGENT=openai/gpt-4o-mini reward-kit agent-eval [args]")
|
|
72
|
+
logger.warning("Alternatively, use --test-mode for certain validation tasks without requiring all API keys.")
|
|
73
|
+
return False
|
|
74
|
+
return True
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# --- Sample Loading Helper Functions ---
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _validate_sample_messages(messages: Any, sample_index: int, line_number: int) -> bool:
|
|
81
|
+
"""Helper to validate the 'messages' field of a sample."""
|
|
82
|
+
if not isinstance(messages, list):
|
|
83
|
+
logger.warning(f"Sample {sample_index} (line {line_number}): 'messages' field is not a list. Skipping sample.")
|
|
84
|
+
return False
|
|
85
|
+
if not messages:
|
|
86
|
+
logger.warning(f"Sample {sample_index} (line {line_number}): 'messages' list is empty. Skipping sample.")
|
|
87
|
+
return False
|
|
88
|
+
for i, msg in enumerate(messages):
|
|
89
|
+
if not isinstance(msg, dict):
|
|
90
|
+
logger.warning(
|
|
91
|
+
f"Sample {sample_index} (line {line_number}): message item {i} is not a dictionary. Skipping sample."
|
|
92
|
+
)
|
|
93
|
+
return False
|
|
94
|
+
role = msg.get("role")
|
|
95
|
+
content = msg.get("content")
|
|
96
|
+
if not isinstance(role, str) or not isinstance(content, str):
|
|
97
|
+
logger.warning(
|
|
98
|
+
f"Sample {sample_index} (line {line_number}): message item {i} missing 'role' or 'content' string fields. Skipping sample."
|
|
99
|
+
)
|
|
100
|
+
return False
|
|
101
|
+
return True
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def load_samples_from_file(filepath: str, max_samples: int) -> Iterator[Dict[str, Any]]:
|
|
105
|
+
"""
|
|
106
|
+
Loads samples from a JSONL file.
|
|
107
|
+
Each line should be a JSON object.
|
|
108
|
+
Each sample must contain a 'messages' key with a list of message dicts (each having 'role' and 'content').
|
|
109
|
+
Yields valid sample dictionaries up to max_samples.
|
|
110
|
+
"""
|
|
111
|
+
count = 0
|
|
112
|
+
line_number = 0
|
|
113
|
+
try:
|
|
114
|
+
with open(filepath, "r", encoding="utf-8") as f:
|
|
115
|
+
for line in f:
|
|
116
|
+
line_number += 1
|
|
117
|
+
if count >= max_samples:
|
|
118
|
+
logger.info(f"Reached max_samples ({max_samples}). Stopping sample loading from {filepath}.")
|
|
119
|
+
break
|
|
120
|
+
line_content = line.strip()
|
|
121
|
+
if not line_content:
|
|
122
|
+
continue
|
|
123
|
+
try:
|
|
124
|
+
sample = json.loads(line_content)
|
|
125
|
+
except json.JSONDecodeError:
|
|
126
|
+
logger.warning(f"Line {line_number}: Invalid JSON. Skipping line: {line_content[:100]}...")
|
|
127
|
+
continue
|
|
128
|
+
if not isinstance(sample, dict):
|
|
129
|
+
logger.warning(f"Line {line_number}: Content is not a JSON object. Skipping line.")
|
|
130
|
+
continue
|
|
131
|
+
messages = sample.get("messages")
|
|
132
|
+
if messages is None:
|
|
133
|
+
logger.warning(f"Sample (line {line_number}): 'messages' field is missing. Skipping sample.")
|
|
134
|
+
continue
|
|
135
|
+
if not _validate_sample_messages(messages, count + 1, line_number):
|
|
136
|
+
continue
|
|
137
|
+
yield sample
|
|
138
|
+
count += 1
|
|
139
|
+
except FileNotFoundError:
|
|
140
|
+
logger.error(f"Sample file not found: {filepath}")
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logger.error(f"Error reading or processing sample file {filepath}: {e}")
|
|
143
|
+
if count == 0:
|
|
144
|
+
logger.info(f"No valid samples loaded from {filepath} after processing {line_number} lines.")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def load_samples_from_huggingface(
|
|
148
|
+
dataset_name: str,
|
|
149
|
+
split: str,
|
|
150
|
+
prompt_key: str,
|
|
151
|
+
response_key: str,
|
|
152
|
+
key_map: Optional[Dict[str, str]],
|
|
153
|
+
max_samples: int,
|
|
154
|
+
) -> Iterator[Dict[str, Any]]:
|
|
155
|
+
"""
|
|
156
|
+
Loads samples from a HuggingFace dataset using the 'datasets' library.
|
|
157
|
+
Constructs 'messages' from prompt_key and response_key.
|
|
158
|
+
Uses key_map to map other dataset fields to custom keys in the output sample.
|
|
159
|
+
Yields valid sample dictionaries up to max_samples.
|
|
160
|
+
"""
|
|
161
|
+
try:
|
|
162
|
+
from datasets import (
|
|
163
|
+
Dataset,
|
|
164
|
+
DatasetDict,
|
|
165
|
+
IterableDataset,
|
|
166
|
+
IterableDatasetDict,
|
|
167
|
+
load_dataset,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Also consider specific exceptions from datasets like DatasetNotFoundError
|
|
171
|
+
except ImportError:
|
|
172
|
+
logger.error(
|
|
173
|
+
"The 'datasets' library is required to load samples from HuggingFace. "
|
|
174
|
+
"Please install it with 'pip install datasets'."
|
|
175
|
+
)
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
count = 0
|
|
179
|
+
processed_records = 0
|
|
180
|
+
try:
|
|
181
|
+
hf_dataset = load_dataset(dataset_name, split=split, streaming=True) # Use streaming
|
|
182
|
+
except Exception as e: # Broad exception for now, can be more specific
|
|
183
|
+
logger.error(f"Error loading HuggingFace dataset '{dataset_name}' (split: {split}): {e}")
|
|
184
|
+
return
|
|
185
|
+
|
|
186
|
+
if not isinstance(
|
|
187
|
+
hf_dataset, (DatasetDict, Dataset, IterableDatasetDict, IterableDataset)
|
|
188
|
+
): # Should be IterableDataset due to streaming=True
|
|
189
|
+
logger.error(f"Loaded HuggingFace dataset '{dataset_name}' is not a recognized Dataset type.")
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
logger.info(f"Streaming samples from HuggingFace dataset '{dataset_name}' (split: {split}).")
|
|
193
|
+
for record in hf_dataset:
|
|
194
|
+
processed_records += 1
|
|
195
|
+
if count >= max_samples:
|
|
196
|
+
logger.info(f"Reached max_samples ({max_samples}). Stopping HuggingFace sample loading.")
|
|
197
|
+
break
|
|
198
|
+
|
|
199
|
+
if not isinstance(record, dict):
|
|
200
|
+
logger.warning(f"HuggingFace dataset record {processed_records} is not a dictionary. Skipping.")
|
|
201
|
+
continue
|
|
202
|
+
|
|
203
|
+
prompt = record.get(prompt_key)
|
|
204
|
+
response_content = record.get(response_key)
|
|
205
|
+
|
|
206
|
+
if not isinstance(prompt, str):
|
|
207
|
+
logger.warning(
|
|
208
|
+
f"HuggingFace record {processed_records}: Prompt key '{prompt_key}' (value: {str(prompt)[:50]}...) did not yield a string. Skipping sample."
|
|
209
|
+
)
|
|
210
|
+
continue
|
|
211
|
+
if not isinstance(response_content, str):
|
|
212
|
+
logger.warning(
|
|
213
|
+
f"HuggingFace record {processed_records}: Response key '{response_key}' (value: {str(response_content)[:50]}...) did not yield a string. Skipping sample."
|
|
214
|
+
)
|
|
215
|
+
continue
|
|
216
|
+
|
|
217
|
+
messages = [
|
|
218
|
+
{"role": "user", "content": prompt},
|
|
219
|
+
{"role": "assistant", "content": response_content},
|
|
220
|
+
]
|
|
221
|
+
|
|
222
|
+
if not _validate_sample_messages(messages, count + 1, processed_records):
|
|
223
|
+
continue
|
|
224
|
+
|
|
225
|
+
sample_output: Dict[str, Any] = {"messages": messages}
|
|
226
|
+
|
|
227
|
+
if key_map:
|
|
228
|
+
for source_key_in_record, target_key_in_sample in key_map.items():
|
|
229
|
+
if source_key_in_record in record:
|
|
230
|
+
sample_output[target_key_in_sample] = record[source_key_in_record]
|
|
231
|
+
else:
|
|
232
|
+
logger.warning(
|
|
233
|
+
f"HuggingFace record {processed_records}: Key '{source_key_in_record}' from key_map not found. It will be omitted."
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
yield sample_output
|
|
237
|
+
count += 1
|
|
238
|
+
|
|
239
|
+
if count == 0:
|
|
240
|
+
logger.info(
|
|
241
|
+
f"No valid samples loaded from HuggingFace dataset '{dataset_name}' (split: {split}) after processing {processed_records} records."
|
|
242
|
+
)
|