eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
eval_protocol/mcp_env.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP Environment API for reward-kit - Backward Compatibility Facade
|
|
3
|
+
|
|
4
|
+
This module has been refactored into modular components for better maintainability.
|
|
5
|
+
This file now serves as a backward compatibility facade.
|
|
6
|
+
|
|
7
|
+
New modular structure:
|
|
8
|
+
- mcp.client.connection: MCP client connection management
|
|
9
|
+
- mcp.execution.policy: LLMBasePolicy and FireworksPolicy for tool calling
|
|
10
|
+
- mcp.execution.rollout: Rollout coordination and lifecycle
|
|
11
|
+
- mcp.session.manager: Session and environment management
|
|
12
|
+
|
|
13
|
+
Usage remains the same:
|
|
14
|
+
import eval_protocol as ep
|
|
15
|
+
|
|
16
|
+
# Load dataset with environment configuration and prompts
|
|
17
|
+
dataset = load_jsonl("dataset.jsonl")
|
|
18
|
+
|
|
19
|
+
# Create general policy (environment-agnostic)
|
|
20
|
+
policy = ep.FireworksPolicy(model_id="accounts/fireworks/models/qwen3-235b-a22b")
|
|
21
|
+
|
|
22
|
+
# Create environments with dataset-driven configuration
|
|
23
|
+
envs = ep.make("http://localhost:8000/mcp", dataset=dataset)
|
|
24
|
+
|
|
25
|
+
# Execute tool-calling rollouts
|
|
26
|
+
trajectories = await ep.rollout(envs, policy=policy, steps=512)
|
|
27
|
+
|
|
28
|
+
Key Features:
|
|
29
|
+
- General tool-calling interface that works with any MCP environment
|
|
30
|
+
- Dataset-driven configuration with system prompts and user prompt templates
|
|
31
|
+
- Automatic MCP tool discovery from servers
|
|
32
|
+
- **PROPER MCP PATTERN**: Initial state obtained from MCP resources during session establishment
|
|
33
|
+
- Tools used only for actions/interactions, not for getting initial state
|
|
34
|
+
- Dynamic user prompt formatting based on current observations
|
|
35
|
+
- Environment-agnostic policy that receives tool schemas and makes structured calls
|
|
36
|
+
- Backward compatibility with servers that don't expose resources
|
|
37
|
+
- **NEW**: LLMBasePolicy abstraction enables easy OpenAI integration
|
|
38
|
+
|
|
39
|
+
MCP Integration:
|
|
40
|
+
- Session establishment creates MCP connection and discovers resources and tools
|
|
41
|
+
- Initial state comes from MCP resources (list_resources + read_resource calls)
|
|
42
|
+
- Tools are used for subsequent actions during rollout steps
|
|
43
|
+
- Resources provide static/configuration data, tools provide dynamic actions
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
# For legacy compatibility - import the facade functions
|
|
47
|
+
import logging
|
|
48
|
+
import os
|
|
49
|
+
import random
|
|
50
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
51
|
+
|
|
52
|
+
# Import all functionality from the new modular components
|
|
53
|
+
from .mcp.execution.manager import ExecutionManager
|
|
54
|
+
from .mcp.execution.policy import AnthropicPolicy, LLMBasePolicy, OpenAIPolicy
|
|
55
|
+
from .mcp.session.manager import GeneralMCPVectorEnv
|
|
56
|
+
from .mcp.types import DatasetRow, MCPSession, MCPToolCall, Trajectory
|
|
57
|
+
|
|
58
|
+
# Try to import FireworksPolicy - it may fail if fireworks-ai is not installed
|
|
59
|
+
# or if a different 'fireworks' package is installed
|
|
60
|
+
try:
|
|
61
|
+
from .mcp.execution.policy import FireworksPolicy
|
|
62
|
+
except:
|
|
63
|
+
# Silently skip if import fails for any reason
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
logger = logging.getLogger(__name__)
|
|
67
|
+
|
|
68
|
+
# Keep the old MCPVectorEnv for backward compatibility
|
|
69
|
+
MCPVectorEnv = GeneralMCPVectorEnv
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def make(
|
|
73
|
+
env_spec: str,
|
|
74
|
+
dataset: Optional[List[Dict]] = None,
|
|
75
|
+
n: Optional[int] = None,
|
|
76
|
+
seeds: Optional[List[int]] = None,
|
|
77
|
+
model_id: str = "unknown",
|
|
78
|
+
user_prompt_formatter: Optional[Callable] = None,
|
|
79
|
+
) -> GeneralMCPVectorEnv:
|
|
80
|
+
"""
|
|
81
|
+
Create general MCP environments driven by dataset configuration.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
env_spec: MCP server URL
|
|
85
|
+
dataset: List of dataset rows with prompts and context (preferred)
|
|
86
|
+
n: Number of environments (for backward compatibility)
|
|
87
|
+
seeds: List of seeds (for backward compatibility)
|
|
88
|
+
model_id: Model identifier
|
|
89
|
+
user_prompt_formatter: Optional callback for formatting user prompts
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
General MCP environment that works with any MCP server
|
|
93
|
+
|
|
94
|
+
Example:
|
|
95
|
+
# New dataset-driven approach (preferred)
|
|
96
|
+
dataset = load_jsonl("dataset.jsonl")
|
|
97
|
+
envs = ep.make("http://localhost:8000/mcp", dataset=dataset)
|
|
98
|
+
|
|
99
|
+
# Legacy approach (backward compatibility)
|
|
100
|
+
envs = ep.make("http://localhost:8000/mcp", n=10, seeds=seeds)
|
|
101
|
+
"""
|
|
102
|
+
# Parse environment specification - make sure URL format is correct
|
|
103
|
+
base_url = env_spec
|
|
104
|
+
if not base_url.startswith("http"):
|
|
105
|
+
raise ValueError("Environment spec must be a valid HTTP URL")
|
|
106
|
+
|
|
107
|
+
# Ensure we HAVE a trailing slash to avoid 307 redirects that break POST requests
|
|
108
|
+
if not base_url.endswith("/"):
|
|
109
|
+
base_url += "/"
|
|
110
|
+
|
|
111
|
+
# Handle dataset-driven vs legacy approaches
|
|
112
|
+
if dataset is not None:
|
|
113
|
+
# New dataset-driven approach
|
|
114
|
+
dataset_rows = []
|
|
115
|
+
sessions = []
|
|
116
|
+
|
|
117
|
+
for row in dataset:
|
|
118
|
+
# Parse dataset row
|
|
119
|
+
if isinstance(row, dict):
|
|
120
|
+
# Handle seed from both old location (backward compatibility) and new location
|
|
121
|
+
environment_context = row.get("environment_context", {})
|
|
122
|
+
seed = environment_context.get("seed")
|
|
123
|
+
|
|
124
|
+
dataset_row = DatasetRow(
|
|
125
|
+
id=row["id"],
|
|
126
|
+
seed=seed,
|
|
127
|
+
system_prompt=row["system_prompt"],
|
|
128
|
+
user_prompt_template=row["user_prompt_template"],
|
|
129
|
+
environment_context=environment_context,
|
|
130
|
+
user_simulation=(row["user_simulation"] if "user_simulation" in row else None),
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
dataset_row = row # Assume it's already a DatasetRow
|
|
134
|
+
|
|
135
|
+
dataset_rows.append(dataset_row)
|
|
136
|
+
|
|
137
|
+
# Create MCP session
|
|
138
|
+
session = MCPSession(
|
|
139
|
+
session_id=dataset_row.id,
|
|
140
|
+
base_url=base_url,
|
|
141
|
+
seed=dataset_row.seed,
|
|
142
|
+
model_id=model_id,
|
|
143
|
+
dataset_row=dataset_row,
|
|
144
|
+
)
|
|
145
|
+
sessions.append(session)
|
|
146
|
+
|
|
147
|
+
return GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)
|
|
148
|
+
|
|
149
|
+
else:
|
|
150
|
+
# Legacy approach for backward compatibility
|
|
151
|
+
if n is None:
|
|
152
|
+
raise ValueError("Either 'dataset' or 'n' must be provided")
|
|
153
|
+
|
|
154
|
+
# Generate seeds if not provided
|
|
155
|
+
if seeds is None:
|
|
156
|
+
seeds = [random.randint(0, 2**31 - 1) for _ in range(n)]
|
|
157
|
+
elif len(seeds) != n:
|
|
158
|
+
raise ValueError(f"Expected {n} seeds, got {len(seeds)}")
|
|
159
|
+
|
|
160
|
+
# Create default dataset rows for legacy mode
|
|
161
|
+
dataset_rows = []
|
|
162
|
+
sessions = []
|
|
163
|
+
|
|
164
|
+
for i in range(n):
|
|
165
|
+
# Create a default dataset row (environment-agnostic)
|
|
166
|
+
dataset_row = DatasetRow(
|
|
167
|
+
id=f"session_{i}",
|
|
168
|
+
seed=seeds[i],
|
|
169
|
+
system_prompt="You are an AI agent interacting with an environment via available tools.",
|
|
170
|
+
user_prompt_template="Current observation: {observation}. Use available tools to interact with the environment.",
|
|
171
|
+
environment_context={},
|
|
172
|
+
)
|
|
173
|
+
dataset_rows.append(dataset_row)
|
|
174
|
+
|
|
175
|
+
# Create MCP session
|
|
176
|
+
session = MCPSession(
|
|
177
|
+
session_id=f"session_{i}",
|
|
178
|
+
base_url=base_url,
|
|
179
|
+
seed=seeds[i],
|
|
180
|
+
model_id=model_id,
|
|
181
|
+
dataset_row=dataset_row,
|
|
182
|
+
)
|
|
183
|
+
sessions.append(session)
|
|
184
|
+
|
|
185
|
+
return GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
async def rollout(
|
|
189
|
+
envs: Union[GeneralMCPVectorEnv, "MCPVectorEnv"],
|
|
190
|
+
policy: Union[FireworksPolicy, LLMBasePolicy, Callable],
|
|
191
|
+
steps: int = 512,
|
|
192
|
+
openai_format_log_file: Optional[str] = None,
|
|
193
|
+
max_concurrent_rollouts: int = 8,
|
|
194
|
+
) -> List[Trajectory]:
|
|
195
|
+
"""
|
|
196
|
+
Execute general rollouts using tool calling interface with automatic record/playback.
|
|
197
|
+
|
|
198
|
+
Uses concurrent execution with semaphore-based concurrency control for efficiency.
|
|
199
|
+
|
|
200
|
+
This works with ANY MCP environment because:
|
|
201
|
+
1. Policy receives tool schemas and makes tool calls
|
|
202
|
+
2. Environment prompts come from dataset
|
|
203
|
+
3. No hardcoded environment logic
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
envs: GeneralMCPVectorEnv instance
|
|
207
|
+
policy: Policy that takes tool schemas, observations, prompts and returns tool calls
|
|
208
|
+
steps: Maximum steps per rollout
|
|
209
|
+
openai_format_log_file: Optional file to log clean OpenAI format for terminated trajectories only
|
|
210
|
+
max_concurrent_rollouts: Maximum number of concurrent rollouts to run
|
|
211
|
+
|
|
212
|
+
Environment Variable Control:
|
|
213
|
+
EP_PLAYBACK_FILE: Controls record/playback mode
|
|
214
|
+
- Not set: Normal live mode
|
|
215
|
+
- Set but file doesn't exist: Record mode (file will be created)
|
|
216
|
+
- Set and file exists: Playback mode (uses recorded data)
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
List of Trajectory objects with complete rollout data
|
|
220
|
+
|
|
221
|
+
Example:
|
|
222
|
+
# Live mode
|
|
223
|
+
trajectories = await ep.rollout(envs, policy)
|
|
224
|
+
|
|
225
|
+
# Recording mode
|
|
226
|
+
os.environ["EP_PLAYBACK_FILE"] = "record.jsonl"
|
|
227
|
+
trajectories = await ep.rollout(envs, policy, openai_format_log_file="sft_data.jsonl")
|
|
228
|
+
|
|
229
|
+
# Playback mode (after recording file exists)
|
|
230
|
+
trajectories = await ep.rollout(envs, policy)
|
|
231
|
+
"""
|
|
232
|
+
# Use the new ExecutionManager for execution
|
|
233
|
+
execution_manager = ExecutionManager()
|
|
234
|
+
|
|
235
|
+
return await execution_manager.execute_rollouts(
|
|
236
|
+
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]:
|
|
241
|
+
"""
|
|
242
|
+
Test function for validating MCP server as mentioned in north star document.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
base_url: Base URL of MCP server (e.g., "http://localhost:8000/mcp")
|
|
246
|
+
seeds: List of seeds to test
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Test results dictionary
|
|
250
|
+
"""
|
|
251
|
+
print(f"🧪 Testing MCP server at {base_url} with {len(seeds)} seeds...")
|
|
252
|
+
|
|
253
|
+
results = {"total_tests": len(seeds), "successful": 0, "failed": 0, "results": []}
|
|
254
|
+
|
|
255
|
+
for seed in seeds:
|
|
256
|
+
try:
|
|
257
|
+
# Create single environment
|
|
258
|
+
envs = make(base_url, n=1, seeds=[seed], model_id="test-model")
|
|
259
|
+
|
|
260
|
+
# Simple policy for testing
|
|
261
|
+
policy = FireworksPolicy("test-model")
|
|
262
|
+
|
|
263
|
+
# Run short rollout
|
|
264
|
+
trajectories = await rollout(envs, policy=policy, steps=10)
|
|
265
|
+
|
|
266
|
+
if trajectories and len(trajectories[0].observations) > 1:
|
|
267
|
+
results["successful"] += 1
|
|
268
|
+
results["results"].append(
|
|
269
|
+
{
|
|
270
|
+
"seed": seed,
|
|
271
|
+
"status": "success",
|
|
272
|
+
"steps": trajectories[0].steps,
|
|
273
|
+
"total_reward": trajectories[0].total_reward,
|
|
274
|
+
}
|
|
275
|
+
)
|
|
276
|
+
else:
|
|
277
|
+
results["failed"] += 1
|
|
278
|
+
results["results"].append({"seed": seed, "status": "failed", "error": "empty_trajectory"})
|
|
279
|
+
|
|
280
|
+
except Exception as e:
|
|
281
|
+
results["failed"] += 1
|
|
282
|
+
results["results"].append({"seed": seed, "status": "failed", "error": str(e)})
|
|
283
|
+
|
|
284
|
+
success_rate = results["successful"] / results["total_tests"] * 100
|
|
285
|
+
print(f"✅ Test complete: {results['successful']}/{results['total_tests']} successful ({success_rate:.1f}%)")
|
|
286
|
+
|
|
287
|
+
return results
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
# Add to eval_protocol.__init__.py exports
|
|
291
|
+
__all__ = [
|
|
292
|
+
"make",
|
|
293
|
+
"rollout",
|
|
294
|
+
"AnthropicPolicy",
|
|
295
|
+
"FireworksPolicy",
|
|
296
|
+
"OpenAIPolicy",
|
|
297
|
+
"LLMBasePolicy", # New base class for OpenAI integration
|
|
298
|
+
"MCPVectorEnv",
|
|
299
|
+
"GeneralMCPVectorEnv",
|
|
300
|
+
"MCPToolCall",
|
|
301
|
+
"DatasetRow",
|
|
302
|
+
"Trajectory",
|
|
303
|
+
"test_mcp",
|
|
304
|
+
]
|
eval_protocol/models.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any, Dict, List, Optional, Union
|
|
3
|
+
|
|
4
|
+
from openai.types.chat.chat_completion_message import (
|
|
5
|
+
ChatCompletionMessageToolCall,
|
|
6
|
+
FunctionCall,
|
|
7
|
+
)
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Message(BaseModel):
|
|
12
|
+
"""Chat message model compatible with OpenAI's interface."""
|
|
13
|
+
|
|
14
|
+
role: str
|
|
15
|
+
content: Optional[str] = "" # Content can be None for tool calls in OpenAI API
|
|
16
|
+
name: Optional[str] = None
|
|
17
|
+
tool_call_id: Optional[str] = None
|
|
18
|
+
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
|
19
|
+
function_call: Optional[FunctionCall] = None
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def model_validate(cls, obj, *args, **kwargs):
|
|
23
|
+
if isinstance(obj, dict) and "role" not in obj:
|
|
24
|
+
raise ValueError("Role is required")
|
|
25
|
+
return super().model_validate(obj, *args, **kwargs)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class MetricResult(BaseModel):
|
|
29
|
+
"""Result of a single metric evaluation.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
is_score_valid (bool): Whether the score is valid for this metric (required).
|
|
33
|
+
score (float): The score for this metric.
|
|
34
|
+
reason (str): Explanation for the score.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
is_score_valid: bool = True
|
|
38
|
+
score: float = Field(..., ge=0.0, le=1.0)
|
|
39
|
+
reason: str
|
|
40
|
+
|
|
41
|
+
def __getitem__(self, key: str) -> Any:
|
|
42
|
+
if key in self.__fields__: # Changed to __fields__ for Pydantic v1 compatibility
|
|
43
|
+
value = getattr(self, key)
|
|
44
|
+
return value
|
|
45
|
+
raise KeyError(f"'{key}'")
|
|
46
|
+
|
|
47
|
+
def __contains__(self, key: str) -> bool:
|
|
48
|
+
return key in self.__fields__ # Changed to __fields__
|
|
49
|
+
|
|
50
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
51
|
+
return getattr(self, key, default)
|
|
52
|
+
|
|
53
|
+
def keys(self):
|
|
54
|
+
return self.__fields__.keys() # Changed to __fields__
|
|
55
|
+
|
|
56
|
+
def values(self):
|
|
57
|
+
# For consistency with __getitem__ returning raw attribute values (including nested models)
|
|
58
|
+
return [getattr(self, key) for key in self.__fields__.keys()] # Changed to __fields__
|
|
59
|
+
|
|
60
|
+
def items(self):
|
|
61
|
+
return [(key, getattr(self, key)) for key in self.__fields__.keys()] # Changed to __fields__
|
|
62
|
+
|
|
63
|
+
def __iter__(self):
|
|
64
|
+
return iter(self.__fields__.keys()) # Changed to __fields__
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class StepOutput(BaseModel):
|
|
68
|
+
"""Defines the base reward and other metrics for a single conceptual step within a rollout,
|
|
69
|
+
as determined by the user's reward function.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
step_index: Union[int, str] = Field(
|
|
73
|
+
description="User-defined index for the step (e.g., assistant message index, turn number). This is used by the system to map this output to the internal StepData."
|
|
74
|
+
)
|
|
75
|
+
base_reward: float = Field(description="Base reward calculated by the user's reward function for this step.")
|
|
76
|
+
terminated: bool = Field(
|
|
77
|
+
default=False,
|
|
78
|
+
description="Whether the environment signaled termination at this step."
|
|
79
|
+
)
|
|
80
|
+
control_plane_info: Optional[Dict[str, Any]] = Field(
|
|
81
|
+
default=None,
|
|
82
|
+
description="Structured info from the environment's control plane."
|
|
83
|
+
)
|
|
84
|
+
metrics: Dict[str, Any] = Field(
|
|
85
|
+
default_factory=dict,
|
|
86
|
+
description="Optional dictionary of custom metrics for this step.",
|
|
87
|
+
)
|
|
88
|
+
reason: Optional[str] = Field(
|
|
89
|
+
default=None,
|
|
90
|
+
description="Optional explanation for the step's base reward or metrics.",
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class EvaluateResult(BaseModel):
|
|
95
|
+
"""The complete result of an evaluator.
|
|
96
|
+
For standard evaluation, it provides an overall score and component metrics.
|
|
97
|
+
For Reinforcement Learning, it can also provide per-step base rewards via 'step_outputs'.
|
|
98
|
+
|
|
99
|
+
This unified model serves both per-turn and per-trajectory evaluation scenarios.
|
|
100
|
+
|
|
101
|
+
Attributes:
|
|
102
|
+
score (float): The overall evaluation score.
|
|
103
|
+
is_score_valid (bool): Whether the overall score is valid. Defaults to True.
|
|
104
|
+
reason (Optional[str]): Optional explanation for the overall score.
|
|
105
|
+
metrics (Dict[str, MetricResult]): Dictionary of component metrics for detailed evaluation.
|
|
106
|
+
step_outputs (Optional[List[StepOutput]]): For RL, a list of outputs for each conceptual step,
|
|
107
|
+
providing base rewards.
|
|
108
|
+
error (Optional[str]): Optional error message if evaluation failed.
|
|
109
|
+
trajectory_info (Optional[Dict[str, Any]]): Additional trajectory-level information.
|
|
110
|
+
final_control_plane_info (Optional[Dict[str, Any]]): The final control plane state that led to termination.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
score: float = Field(..., description="The overall evaluation score, typically between 0.0 and 1.0.")
|
|
114
|
+
is_score_valid: bool = Field(default=True, description="Whether the overall score is valid.")
|
|
115
|
+
reason: Optional[str] = Field(default=None, description="Optional explanation for the overall score.")
|
|
116
|
+
metrics: Dict[str, MetricResult] = Field(
|
|
117
|
+
default_factory=dict,
|
|
118
|
+
description="Dictionary of component metrics for detailed breakdown.",
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# New field for RL per-step base rewards
|
|
122
|
+
step_outputs: Optional[List[StepOutput]] = Field(
|
|
123
|
+
default=None,
|
|
124
|
+
description="For RL, a list of outputs for each conceptual step, providing base rewards.",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
error: Optional[str] = Field(
|
|
128
|
+
default=None,
|
|
129
|
+
description="Optional error message if the evaluation itself encountered an issue.",
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# New fields for unified trajectory and row-wise results
|
|
133
|
+
trajectory_info: Optional[Dict[str, Any]] = Field(
|
|
134
|
+
default=None,
|
|
135
|
+
description="Additional trajectory-level information (duration, steps, termination_reason, etc.)."
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
final_control_plane_info: Optional[Dict[str, Any]] = Field(
|
|
139
|
+
default=None,
|
|
140
|
+
description="The final control plane state that led to termination."
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def __getitem__(self, key: str) -> Any:
|
|
144
|
+
if key in self.__fields__: # Changed to __fields__
|
|
145
|
+
value = getattr(self, key)
|
|
146
|
+
# If the value is a dict of MetricResult, and we want __getitem__ on metrics
|
|
147
|
+
# to return a dict of dicts (rather than dict of MetricResult objects),
|
|
148
|
+
# we'd need special handling here.
|
|
149
|
+
# For now, return the raw attribute value, consistent with MetricResult.__getitem__
|
|
150
|
+
return value
|
|
151
|
+
raise KeyError(f"'{key}'")
|
|
152
|
+
|
|
153
|
+
def __contains__(self, key: str) -> bool:
|
|
154
|
+
return key in self.__fields__ # Changed to __fields__
|
|
155
|
+
|
|
156
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
157
|
+
return getattr(self, key, default)
|
|
158
|
+
|
|
159
|
+
def keys(self):
|
|
160
|
+
return self.__fields__.keys() # Changed to __fields__
|
|
161
|
+
|
|
162
|
+
def values(self):
|
|
163
|
+
# For consistency with __getitem__ returning raw attribute values
|
|
164
|
+
return [getattr(self, key) for key in self.__fields__.keys()] # Changed to __fields__
|
|
165
|
+
|
|
166
|
+
def items(self):
|
|
167
|
+
return [(key, getattr(self, key)) for key in self.__fields__.keys()] # Changed to __fields__
|
|
168
|
+
|
|
169
|
+
def __iter__(self):
|
|
170
|
+
return iter(self.__fields__.keys()) # Changed to __fields__
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class EvaluationRow(BaseModel):
|
|
174
|
+
"""
|
|
175
|
+
Unified data structure for a single evaluation unit that contains messages,
|
|
176
|
+
tools, and evaluation results. This can represent either a single turn evaluation
|
|
177
|
+
or a complete trajectory evaluation.
|
|
178
|
+
|
|
179
|
+
This model serves as the canonical format for evaluation data across the system,
|
|
180
|
+
supporting both row-wise batch evaluation and trajectory-based RL evaluation.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
# Core conversation data
|
|
184
|
+
messages: List[Message] = Field(
|
|
185
|
+
description="List of messages in the conversation/trajectory."
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Tool and function call information
|
|
189
|
+
tools: Optional[List[Dict[str, Any]]] = Field(
|
|
190
|
+
default=None,
|
|
191
|
+
description="Available tools/functions that were provided to the agent."
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Input-related metadata (grouped together for cleaner organization)
|
|
195
|
+
input_metadata: Optional[Dict[str, Any]] = Field(
|
|
196
|
+
default=None,
|
|
197
|
+
description="Metadata related to the input (dataset info, model config, session data, etc.)."
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Unified evaluation result
|
|
201
|
+
evaluation_result: Optional[EvaluateResult] = Field(
|
|
202
|
+
default=None,
|
|
203
|
+
description="The evaluation result for this row/trajectory."
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
def is_trajectory_evaluation(self) -> bool:
|
|
207
|
+
"""
|
|
208
|
+
Returns True if this represents a trajectory evaluation (has step_outputs),
|
|
209
|
+
False if it represents a single turn evaluation.
|
|
210
|
+
"""
|
|
211
|
+
return (
|
|
212
|
+
self.evaluation_result is not None
|
|
213
|
+
and self.evaluation_result.step_outputs is not None
|
|
214
|
+
and len(self.evaluation_result.step_outputs) > 0
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
def get_conversation_length(self) -> int:
|
|
218
|
+
"""Returns the number of messages in the conversation."""
|
|
219
|
+
return len(self.messages)
|
|
220
|
+
|
|
221
|
+
def get_assistant_messages(self) -> List[Message]:
|
|
222
|
+
"""Returns only the assistant messages from the conversation."""
|
|
223
|
+
return [msg for msg in self.messages if msg.role == "assistant"]
|
|
224
|
+
|
|
225
|
+
def get_user_messages(self) -> List[Message]:
|
|
226
|
+
"""Returns only the user messages from the conversation."""
|
|
227
|
+
return [msg for msg in self.messages if msg.role == "user"]
|
|
228
|
+
|
|
229
|
+
def get_input_metadata(self, key: str, default: Any = None) -> Any:
|
|
230
|
+
"""Helper method to get a specific value from input_metadata."""
|
|
231
|
+
if self.input_metadata is None:
|
|
232
|
+
return default
|
|
233
|
+
return self.input_metadata.get(key, default)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
# Original dataclass-based models for backwards compatibility
|
|
237
|
+
# These are deprecated and will be removed in a future version
|
|
238
|
+
# Use EvaluateResult and MetricResult instead
|
|
239
|
+
# MetricRewardOutput and RewardOutput are fully removed.
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# --- Models for New Agent Evaluation Framework (V2) ---
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class ResourceServerConfig(BaseModel):
|
|
246
|
+
"""
|
|
247
|
+
Configuration for a resource server required by a task.
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
start_command: str = Field(
|
|
251
|
+
description="The command to start the server. The string '{port}' will be replaced with a dynamically allocated free port."
|
|
252
|
+
)
|
|
253
|
+
health_check_url: str = Field(
|
|
254
|
+
description="The URL to poll to check if the server is ready. The string '{port}' will be replaced with the allocated port."
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class EvaluationCriteriaModel(BaseModel):
|
|
259
|
+
"""
|
|
260
|
+
Defines criteria for evaluating task success, often by querying the final state of a resource.
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
final_state_query: Optional[str] = Field(
|
|
264
|
+
default=None,
|
|
265
|
+
description="A query (e.g., SQL) to run on the final state of the resource.",
|
|
266
|
+
)
|
|
267
|
+
expected_query_result_transform: Optional[str] = Field(
|
|
268
|
+
default=None,
|
|
269
|
+
description="A Python lambda string (e.g., 'lambda x: x > 0') to transform and evaluate the query result to a boolean.",
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Explicit fields for ground truth data for BFCL evaluation
|
|
273
|
+
ground_truth_function_calls: Optional[List[List[str]]] = Field(
|
|
274
|
+
default=None, description="Ground truth function calls for BFCL evaluation."
|
|
275
|
+
)
|
|
276
|
+
ground_truth_comparable_state: Optional[Dict[str, Any]] = Field(
|
|
277
|
+
default=None, description="Ground truth comparable state for BFCL evaluation."
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Future: Could include other complex evaluation logic or references
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class TaskDefinitionModel(BaseModel):
|
|
284
|
+
"""
|
|
285
|
+
Pydantic model for validating the structure of a V2 agent evaluation task definition file (YAML/JSON).
|
|
286
|
+
"""
|
|
287
|
+
|
|
288
|
+
name: str = Field(description="Unique name for the task.")
|
|
289
|
+
description: Optional[str] = Field(default=None, description="A brief description of the task.")
|
|
290
|
+
|
|
291
|
+
resource_type: str = Field(
|
|
292
|
+
description="The type of ForkableResource to use (e.g., 'SQLResource', 'PythonStateResource', 'FileSystemResource', 'DockerResource')."
|
|
293
|
+
)
|
|
294
|
+
base_resource_config: Dict[str, Any] = Field(
|
|
295
|
+
default_factory=dict,
|
|
296
|
+
description="Configuration dictionary passed to the base resource's setup() method.",
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
tools_module_path: Optional[str] = Field(
|
|
300
|
+
default=None,
|
|
301
|
+
description="Optional Python import path to a module containing custom tool functions for this task.",
|
|
302
|
+
)
|
|
303
|
+
reward_function_path: str = Field(
|
|
304
|
+
description="Python import path to the reward function (e.g., 'my_module.my_reward_func')."
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
goal_description: Optional[str] = Field(
|
|
308
|
+
default=None,
|
|
309
|
+
description="A human-readable description of the agent's goal for this task.",
|
|
310
|
+
)
|
|
311
|
+
evaluation_criteria: Optional[EvaluationCriteriaModel] = Field(
|
|
312
|
+
default=None,
|
|
313
|
+
description="Criteria used by the Orchestrator to determine if the primary goal was achieved.",
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
initial_user_prompt: Optional[str] = Field(
|
|
317
|
+
default=None,
|
|
318
|
+
description="The initial prompt or message to start the agent interaction. Deprecated if 'messages' field is used for multi-turn.",
|
|
319
|
+
)
|
|
320
|
+
messages: Optional[List[Dict[str, Any]]] = Field( # Explicit field for initial/multi-turn messages
|
|
321
|
+
default=None,
|
|
322
|
+
description="A list of messages to start the conversation, can represent multiple user turns for sequential processing.",
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# PoC / Task specific parameters
|
|
326
|
+
poc_max_turns: int = Field(
|
|
327
|
+
default=3,
|
|
328
|
+
ge=1,
|
|
329
|
+
description="For PoC Orchestrator, the maximum number of interaction turns.",
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# Allow other custom fields to be captured if needed by specific tasks or resources
|
|
333
|
+
# These will be accessible via `model_extra` if `model_config` has `extra = 'allow'`
|
|
334
|
+
# Or define a specific field:
|
|
335
|
+
# custom_task_params: Dict[str, Any] = Field(default_factory=dict)
|
|
336
|
+
resource_server: Optional[ResourceServerConfig] = Field(
|
|
337
|
+
default=None,
|
|
338
|
+
description="Configuration for a background server required for the task.",
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
num_rollouts: int = Field(
|
|
342
|
+
default=1,
|
|
343
|
+
ge=1,
|
|
344
|
+
description="Number of parallel rollouts to execute for this task definition.",
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Data-driven evaluation fields
|
|
348
|
+
dataset_path: Optional[str] = Field(
|
|
349
|
+
default=None,
|
|
350
|
+
description="Path to dataset file (JSONL) containing experimental conditions for data-driven evaluation.",
|
|
351
|
+
)
|
|
352
|
+
num_rollouts_per_sample: int = Field(
|
|
353
|
+
default=1,
|
|
354
|
+
ge=1,
|
|
355
|
+
description="Number of rollouts to execute per sample from the dataset.",
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
class Config:
|
|
359
|
+
extra = "allow" # Allow and capture extra fields not explicitly defined
|
|
360
|
+
# For Pydantic v2, it's model_config = {"extra": "allow"}
|
|
361
|
+
# Assuming Pydantic v1 style for now based on existing file, can update if needed.
|
|
362
|
+
# If using Pydantic v1, `Config.extra = "allow"` is correct.
|
|
363
|
+
# For Pydantic v2, this should be:
|
|
364
|
+
# from pydantic import ConfigDict
|
|
365
|
+
# model_config = ConfigDict(extra='allow')
|
|
366
|
+
# For Pydantic v1, `Config.extra = "allow"` is correct.
|