eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Playback policy base class for record-and-replay functionality.
|
|
3
|
+
|
|
4
|
+
This module implements the abstract base class that handles all playback logic,
|
|
5
|
+
allowing concrete policy classes to inherit replay functionality while focusing
|
|
6
|
+
on their specific implementation details.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
14
|
+
|
|
15
|
+
from .mcp.types import LLMUsageStats, MCPToolCall
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PlaybackPolicyBase(ABC):
|
|
21
|
+
"""
|
|
22
|
+
Abstract base class for policies that support record-and-playback functionality.
|
|
23
|
+
|
|
24
|
+
This class handles all playback logic including trajectory loading, parsing,
|
|
25
|
+
and step management. Concrete policy classes inherit from this to get
|
|
26
|
+
replay functionality while implementing their own live mode logic.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
_playback_actions: Optional[Dict[str, List[Dict[str, Any]]]] = None,
|
|
32
|
+
**kwargs,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Initialize policy with optional playback actions.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
_playback_actions: Pre-parsed playback actions organized by environment.
|
|
39
|
+
Format: {env_index: [{"step": int, "messages": [...]}]}
|
|
40
|
+
**kwargs: Additional arguments passed to concrete implementations
|
|
41
|
+
"""
|
|
42
|
+
# Playback state management
|
|
43
|
+
self._playback_actions = _playback_actions
|
|
44
|
+
self._is_playback = _playback_actions is not None
|
|
45
|
+
self._playback_step_counters = {} # {env_index: current_step}
|
|
46
|
+
|
|
47
|
+
# Environment variable override
|
|
48
|
+
playback_file = os.environ.get("EP_PLAYBACK_FILE")
|
|
49
|
+
if playback_file and not self._is_playback:
|
|
50
|
+
logger.info(f"🎬 Auto-enabling playback mode from environment variable: {playback_file}")
|
|
51
|
+
self._playback_actions = self._load_trajectory_file(playback_file)
|
|
52
|
+
self._is_playback = self._playback_actions is not None
|
|
53
|
+
|
|
54
|
+
# Initialize step counters if in playback mode
|
|
55
|
+
if self._is_playback and self._playback_actions:
|
|
56
|
+
for env_index in self._playback_actions.keys():
|
|
57
|
+
self._playback_step_counters[env_index] = 0
|
|
58
|
+
|
|
59
|
+
logger.debug(f"PlaybackPolicyBase initialized: playback_mode={self._is_playback}")
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def _load_trajectory_file(
|
|
63
|
+
filepath: str,
|
|
64
|
+
) -> Optional[Dict[str, List[Dict[str, Any]]]]:
|
|
65
|
+
"""
|
|
66
|
+
Load and parse trajectory file into organized playback actions.
|
|
67
|
+
|
|
68
|
+
Expected JSONL format per design document:
|
|
69
|
+
{"env_index": 0, "step": 0, "messages": [{..}, {..}]}
|
|
70
|
+
{"env_index": 1, "step": 0, "messages": [{..}, {..}]}
|
|
71
|
+
{"env_index": 0, "step": 1, "messages": [{..}, {..}]}
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
filepath: Path to trajectory JSONL file
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Organized playback actions: {env_index: [{"step": int, "messages": [...]}]}
|
|
78
|
+
"""
|
|
79
|
+
if not os.path.exists(filepath):
|
|
80
|
+
logger.error(f"Trajectory file not found: {filepath}")
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
playback_actions = {}
|
|
85
|
+
valid_entries = 0
|
|
86
|
+
|
|
87
|
+
with open(filepath, "r") as f:
|
|
88
|
+
for line_num, line in enumerate(f, 1):
|
|
89
|
+
line = line.strip()
|
|
90
|
+
if not line:
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
entry = json.loads(line)
|
|
95
|
+
|
|
96
|
+
# Validate required fields
|
|
97
|
+
if not isinstance(entry, dict):
|
|
98
|
+
logger.warning(f"Line {line_num}: Entry is not a dictionary, skipping")
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
env_index = entry.get("env_index")
|
|
102
|
+
step = entry.get("step")
|
|
103
|
+
messages = entry.get("messages")
|
|
104
|
+
|
|
105
|
+
if env_index is None or step is None or messages is None:
|
|
106
|
+
logger.warning(
|
|
107
|
+
f"Line {line_num}: Missing required fields (env_index, step, messages), skipping"
|
|
108
|
+
)
|
|
109
|
+
continue
|
|
110
|
+
|
|
111
|
+
# Convert env_index to string for consistent dictionary keys
|
|
112
|
+
env_key = str(env_index)
|
|
113
|
+
|
|
114
|
+
# Initialize environment list if needed
|
|
115
|
+
if env_key not in playback_actions:
|
|
116
|
+
playback_actions[env_key] = []
|
|
117
|
+
|
|
118
|
+
# Add step entry
|
|
119
|
+
playback_actions[env_key].append({"step": step, "messages": messages})
|
|
120
|
+
|
|
121
|
+
valid_entries += 1
|
|
122
|
+
|
|
123
|
+
except json.JSONDecodeError as e:
|
|
124
|
+
logger.warning(f"Line {line_num}: Invalid JSON - {e}")
|
|
125
|
+
continue
|
|
126
|
+
|
|
127
|
+
# Sort each environment's actions by step
|
|
128
|
+
for env_key in playback_actions:
|
|
129
|
+
playback_actions[env_key].sort(key=lambda x: x["step"])
|
|
130
|
+
|
|
131
|
+
if playback_actions:
|
|
132
|
+
logger.info(f"✅ Loaded {valid_entries} trajectory entries for {len(playback_actions)} environments")
|
|
133
|
+
return playback_actions
|
|
134
|
+
else:
|
|
135
|
+
logger.warning(
|
|
136
|
+
f"⚠️ Trajectory file {filepath} exists but contains no valid entries. "
|
|
137
|
+
f"Falling back to recording mode. Please check file format - expected JSONL with "
|
|
138
|
+
f"'env_index', 'step', and 'messages' fields."
|
|
139
|
+
)
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
except Exception as e:
|
|
143
|
+
logger.error(f"Error loading trajectory file {filepath}: {e}")
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
def _get_playback_messages(self, env_index: int) -> Optional[List[Dict[str, Any]]]:
|
|
147
|
+
"""
|
|
148
|
+
Get the next playback messages for the specified environment.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
env_index: Environment index
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Messages list for the current step, or None if no more steps
|
|
155
|
+
"""
|
|
156
|
+
if not self._is_playback or not self._playback_actions:
|
|
157
|
+
return None
|
|
158
|
+
|
|
159
|
+
env_key = str(env_index)
|
|
160
|
+
if env_key not in self._playback_actions:
|
|
161
|
+
logger.warning(f"No playback data for environment {env_index}")
|
|
162
|
+
return None
|
|
163
|
+
|
|
164
|
+
current_step = self._playback_step_counters.get(str(env_index), 0)
|
|
165
|
+
env_actions = self._playback_actions[env_key]
|
|
166
|
+
|
|
167
|
+
# Find action for current step
|
|
168
|
+
for action in env_actions:
|
|
169
|
+
if action["step"] == current_step:
|
|
170
|
+
# Increment step counter for next call
|
|
171
|
+
self._playback_step_counters[str(env_index)] = current_step + 1
|
|
172
|
+
logger.debug(f"🎬 Environment {env_index}: Returning playback messages for step {current_step}")
|
|
173
|
+
return action["messages"]
|
|
174
|
+
|
|
175
|
+
# No more recorded actions available
|
|
176
|
+
logger.debug(f"🎬 Environment {env_index}: No more playback data (step {current_step})")
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
def has_more_playback_data(self, env_index: int) -> bool:
|
|
180
|
+
"""
|
|
181
|
+
Check if there are more playback actions available for an environment.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
env_index: Environment index
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
True if more actions are available, False otherwise
|
|
188
|
+
"""
|
|
189
|
+
if not self._is_playback or not self._playback_actions:
|
|
190
|
+
return False
|
|
191
|
+
|
|
192
|
+
env_key = str(env_index)
|
|
193
|
+
if env_key not in self._playback_actions:
|
|
194
|
+
return False
|
|
195
|
+
|
|
196
|
+
current_step = self._playback_step_counters.get(str(env_index), 0)
|
|
197
|
+
env_actions = self._playback_actions[env_key]
|
|
198
|
+
|
|
199
|
+
# Check if there's an action for the current step
|
|
200
|
+
return any(action["step"] == current_step for action in env_actions)
|
|
201
|
+
|
|
202
|
+
@abstractmethod
|
|
203
|
+
async def _generate_live_tool_calls(
|
|
204
|
+
self,
|
|
205
|
+
tool_schemas: List[Dict],
|
|
206
|
+
env_index: int,
|
|
207
|
+
conversation_history: List[Dict[str, Any]],
|
|
208
|
+
) -> Tuple[List["MCPToolCall"], LLMUsageStats]:
|
|
209
|
+
"""
|
|
210
|
+
Generate tool calls in live mode. Concrete classes must implement this.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
tool_schemas: Available tools for this environment
|
|
214
|
+
env_index: Environment index
|
|
215
|
+
conversation_history: Current conversation history for this environment
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
List of ToolCall objects and LLM interation usage stats
|
|
219
|
+
"""
|
|
220
|
+
pass
|
|
221
|
+
|
|
222
|
+
async def __call__(
|
|
223
|
+
self,
|
|
224
|
+
tool_schemas: List[Dict],
|
|
225
|
+
env_index: int,
|
|
226
|
+
conversation_history: List[Dict[str, Any]],
|
|
227
|
+
):
|
|
228
|
+
"""
|
|
229
|
+
Main policy call method. Delegates to playback or live mode.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
tool_schemas: Available tools for each environment
|
|
233
|
+
observations: Current observations from environments
|
|
234
|
+
system_prompts: System prompts for each environment
|
|
235
|
+
user_prompts: User prompts for each environment
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
List of ToolCall objects and LLM interation usage stats for each environment
|
|
239
|
+
"""
|
|
240
|
+
if self._is_playback:
|
|
241
|
+
# In playback mode, get recorded messages
|
|
242
|
+
messages = self._get_playback_messages(env_index)
|
|
243
|
+
|
|
244
|
+
if messages is None:
|
|
245
|
+
# No more recorded actions - signal early termination
|
|
246
|
+
return [
|
|
247
|
+
MCPToolCall(
|
|
248
|
+
"_playback_terminate",
|
|
249
|
+
{"reason": "no_more_recorded_actions"},
|
|
250
|
+
)
|
|
251
|
+
]
|
|
252
|
+
|
|
253
|
+
# Return the recorded tool call
|
|
254
|
+
return self._extract_tool_call_from_messages(messages, env_index), None
|
|
255
|
+
else:
|
|
256
|
+
# Live mode - generate tool call using provided conversation history
|
|
257
|
+
return await self._generate_live_tool_calls(tool_schemas, env_index, conversation_history)
|
|
258
|
+
|
|
259
|
+
def _extract_tool_call_from_messages(self, messages: List[Dict[str, Any]], env_index: int) -> List[MCPToolCall]:
|
|
260
|
+
"""
|
|
261
|
+
Extract tool calls from recorded conversation messages.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
messages: List of conversation messages
|
|
265
|
+
env_index: Environment index for logging
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
List of MCPToolCall objects
|
|
269
|
+
"""
|
|
270
|
+
# Look for the last assistant message with tool_calls
|
|
271
|
+
for message in reversed(messages):
|
|
272
|
+
if message.get("role") == "assistant" and message.get("tool_calls"):
|
|
273
|
+
tool_calls = message["tool_calls"]
|
|
274
|
+
if tool_calls and len(tool_calls) > 0:
|
|
275
|
+
# Process ALL tool calls, not just the first one
|
|
276
|
+
mcp_tool_calls = []
|
|
277
|
+
for tool_call in tool_calls:
|
|
278
|
+
# Extract function name and arguments
|
|
279
|
+
function = tool_call.get("function", {})
|
|
280
|
+
tool_name = function.get("name", "unknown")
|
|
281
|
+
tool_call_id = tool_call.get("id", "unknown")
|
|
282
|
+
|
|
283
|
+
# Parse arguments if they're a string
|
|
284
|
+
arguments = function.get("arguments", {})
|
|
285
|
+
if isinstance(arguments, str):
|
|
286
|
+
try:
|
|
287
|
+
arguments = json.loads(arguments)
|
|
288
|
+
except json.JSONDecodeError:
|
|
289
|
+
logger.warning(
|
|
290
|
+
f"🎬 Environment {env_index}: Failed to parse tool call arguments: {arguments}"
|
|
291
|
+
)
|
|
292
|
+
arguments = {}
|
|
293
|
+
|
|
294
|
+
mcp_tool_calls.append(MCPToolCall(tool_name, arguments, tool_call_id))
|
|
295
|
+
|
|
296
|
+
logger.debug(f"🎬 Environment {env_index}: Extracted {len(mcp_tool_calls)} tool calls")
|
|
297
|
+
return mcp_tool_calls
|
|
298
|
+
|
|
299
|
+
# Fallback if no tool calls found
|
|
300
|
+
logger.warning(f"🎬 Environment {env_index}: No tool calls found in messages, using unknown tool")
|
|
301
|
+
return [MCPToolCall("unknown", {})]
|
|
302
|
+
|
|
303
|
+
def is_playback_mode(self) -> bool:
|
|
304
|
+
"""
|
|
305
|
+
Check if the policy is in playback mode.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
True if in playback mode, False otherwise
|
|
309
|
+
"""
|
|
310
|
+
return self._is_playback
|
|
311
|
+
|
|
312
|
+
def get_playback_progress(self) -> Dict[str, Any]:
|
|
313
|
+
"""
|
|
314
|
+
Get playback progress information.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
Dictionary with playback progress details
|
|
318
|
+
"""
|
|
319
|
+
if not self._is_playback:
|
|
320
|
+
return {"playback_mode": False}
|
|
321
|
+
|
|
322
|
+
progress = {
|
|
323
|
+
"playback_mode": True,
|
|
324
|
+
"environments": {},
|
|
325
|
+
"total_environments": (len(self._playback_actions) if self._playback_actions else 0),
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
if self._playback_actions:
|
|
329
|
+
for env_key, actions in self._playback_actions.items():
|
|
330
|
+
env_index = int(env_key)
|
|
331
|
+
current_step = self._playback_step_counters.get(str(env_index), 0)
|
|
332
|
+
total_steps = len(actions)
|
|
333
|
+
|
|
334
|
+
progress["environments"][env_index] = {
|
|
335
|
+
"current_step": current_step,
|
|
336
|
+
"total_steps": total_steps,
|
|
337
|
+
"completed": current_step >= total_steps,
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
return progress
|
|
341
|
+
|
|
342
|
+
def log_conversation_state_for_playback(
|
|
343
|
+
self, env_index: int, step: int, conversation_history: List[Dict[str, Any]]
|
|
344
|
+
):
|
|
345
|
+
"""
|
|
346
|
+
Log the current conversation state in the format required for playback.
|
|
347
|
+
|
|
348
|
+
Base implementation that subclasses can override with specific behavior.
|
|
349
|
+
Expected format: {"env_index": 0, "step": 0, "messages": [{..}, {..}]}
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
env_index: Environment index
|
|
353
|
+
step: Current step number
|
|
354
|
+
conversation_history: List of conversation messages
|
|
355
|
+
"""
|
|
356
|
+
# Use EP_PLAYBACK_FILE environment variable for recording
|
|
357
|
+
playback_file = os.environ.get("EP_PLAYBACK_FILE")
|
|
358
|
+
if not playback_file:
|
|
359
|
+
return # No recording file specified
|
|
360
|
+
|
|
361
|
+
playback_entry = {
|
|
362
|
+
"env_index": env_index,
|
|
363
|
+
"step": step,
|
|
364
|
+
"messages": conversation_history.copy(),
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
with open(playback_file, "a") as f:
|
|
368
|
+
f.write(json.dumps(playback_entry) + "\n")
|
|
369
|
+
|
|
370
|
+
def log_conversation_state_for_playback(
|
|
371
|
+
self, env_index: int, step: int, conversation_history: List[Dict[str, Any]]
|
|
372
|
+
):
|
|
373
|
+
"""
|
|
374
|
+
Log the current conversation state in the format required for playback.
|
|
375
|
+
|
|
376
|
+
Base implementation that subclasses can override with specific behavior.
|
|
377
|
+
Expected format: {"env_index": 0, "step": 0, "messages": [{..}, {..}]}
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
env_index: Environment index
|
|
381
|
+
step: Current step number
|
|
382
|
+
conversation_history: List of conversation messages
|
|
383
|
+
"""
|
|
384
|
+
# Use EP_PLAYBACK_FILE environment variable for recording
|
|
385
|
+
playback_file = os.environ.get("EP_PLAYBACK_FILE")
|
|
386
|
+
if not playback_file:
|
|
387
|
+
return # No recording file specified
|
|
388
|
+
|
|
389
|
+
playback_entry = {
|
|
390
|
+
"env_index": env_index,
|
|
391
|
+
"step": step,
|
|
392
|
+
"messages": conversation_history.copy(),
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
with open(playback_file, "a") as f:
|
|
396
|
+
f.write(json.dumps(playback_entry) + "\n")
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Resource management for reward functions.
|
|
3
|
+
|
|
4
|
+
This module provides resource wrappers for external services like LLMs,
|
|
5
|
+
databases, etc. Resources are automatically setup and cleaned up by the
|
|
6
|
+
reward function decorator.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from typing import Any, Dict, List, TypeVar
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
# Type definitions
|
|
16
|
+
T = TypeVar("T")
|
|
17
|
+
ResourceDict = Dict[str, List["ResourceWrapper"]]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ResourceWrapper(ABC):
|
|
21
|
+
"""Abstract base class for all resource wrappers."""
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def setup(self) -> None:
|
|
25
|
+
"""Setup the resource (e.g., start deployment, create connection)."""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def cleanup(self) -> None:
|
|
30
|
+
"""Cleanup the resource (e.g., stop deployment, close connection)."""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def get_client(self) -> Any:
|
|
35
|
+
"""Get the client object for using this resource."""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class LLMResourceWrapper(ResourceWrapper):
|
|
40
|
+
"""Resource wrapper for Fireworks LLM deployments."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, llm_instance: Any):
|
|
43
|
+
"""
|
|
44
|
+
Initialize LLM resource wrapper.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
llm_instance: A Fireworks LLM instance from the Build SDK
|
|
48
|
+
"""
|
|
49
|
+
self.llm_instance = llm_instance
|
|
50
|
+
self._client = None
|
|
51
|
+
self._is_setup = False
|
|
52
|
+
|
|
53
|
+
def setup(self) -> None:
|
|
54
|
+
"""Setup the LLM deployment."""
|
|
55
|
+
if self._is_setup:
|
|
56
|
+
logger.debug(f"LLM resource already setup for model: {self.llm_instance.model}")
|
|
57
|
+
return
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
logger.debug(f"Setting up LLM deployment for model: " f"{self.llm_instance.model}")
|
|
61
|
+
|
|
62
|
+
# For on-demand deployments, call apply()
|
|
63
|
+
if hasattr(self.llm_instance, "deployment_type") and self.llm_instance.deployment_type == "on-demand":
|
|
64
|
+
logger.info("Applying on-demand LLM deployment...")
|
|
65
|
+
self.llm_instance.apply()
|
|
66
|
+
logger.info("On-demand LLM deployment applied successfully")
|
|
67
|
+
|
|
68
|
+
self._client = self.llm_instance
|
|
69
|
+
self._is_setup = True
|
|
70
|
+
|
|
71
|
+
logger.info(f"LLM resource setup completed for model: " f"{self.llm_instance.model}")
|
|
72
|
+
|
|
73
|
+
except Exception as e:
|
|
74
|
+
logger.error(f"Failed to setup LLM resource: {e}")
|
|
75
|
+
raise
|
|
76
|
+
|
|
77
|
+
def cleanup(self) -> None:
|
|
78
|
+
"""Cleanup the LLM deployment."""
|
|
79
|
+
if not self._is_setup:
|
|
80
|
+
logger.debug("LLM resource not setup, nothing to cleanup")
|
|
81
|
+
return
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
logger.debug("Cleaning up LLM resource")
|
|
85
|
+
|
|
86
|
+
# For Fireworks Build SDK, we typically don't need explicit
|
|
87
|
+
# cleanup as deployments are managed by the platform
|
|
88
|
+
self._client = None
|
|
89
|
+
self._is_setup = False
|
|
90
|
+
|
|
91
|
+
logger.debug("LLM resource cleanup completed")
|
|
92
|
+
|
|
93
|
+
except Exception as e:
|
|
94
|
+
logger.error(f"Error during LLM resource cleanup: {e}")
|
|
95
|
+
# Don't re-raise cleanup errors to avoid masking original
|
|
96
|
+
# exceptions
|
|
97
|
+
|
|
98
|
+
def get_client(self) -> Any:
|
|
99
|
+
"""Get the LLM client for making API calls."""
|
|
100
|
+
if not self._is_setup or self._client is None:
|
|
101
|
+
raise RuntimeError("LLM resource not setup. Call setup() first.")
|
|
102
|
+
return self._client
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def create_llm_resource(llm_instance: Any) -> LLMResourceWrapper:
|
|
106
|
+
"""
|
|
107
|
+
Create an LLM resource wrapper from a Fireworks LLM instance.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
llm_instance: A Fireworks LLM instance from the Build SDK
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
LLMResourceWrapper instance
|
|
114
|
+
|
|
115
|
+
Example:
|
|
116
|
+
```python
|
|
117
|
+
from fireworks import LLM
|
|
118
|
+
from eval_protocol import create_llm_resource
|
|
119
|
+
|
|
120
|
+
llm = LLM(
|
|
121
|
+
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
|
122
|
+
deployment_type="on-demand",
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
llm_resource = create_llm_resource(llm)
|
|
126
|
+
```
|
|
127
|
+
"""
|
|
128
|
+
return LLMResourceWrapper(llm_instance)
|