eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HTTP Rollout Resource implementation for the agent evaluation framework.
|
|
3
|
+
|
|
4
|
+
This resource bridges the HTTP rollout protocol with the ForkableResource interface,
|
|
5
|
+
allowing HTTP-based environments to be used in agent evaluations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import uuid
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
from ..resource_abc import ForkableResource
|
|
15
|
+
from .http_rollout_protocol import (
|
|
16
|
+
EndEpisodeRequest,
|
|
17
|
+
GameObservation,
|
|
18
|
+
HttpRolloutConfig,
|
|
19
|
+
StartEpisodeRequest,
|
|
20
|
+
StartEpisodeResponse,
|
|
21
|
+
StepRequest,
|
|
22
|
+
StepResponse,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class HttpRolloutResource(ForkableResource):
|
|
27
|
+
"""
|
|
28
|
+
A ForkableResource implementation that communicates with HTTP rollout servers.
|
|
29
|
+
|
|
30
|
+
This resource allows the agent evaluation framework to interact with
|
|
31
|
+
HTTP-based environments through a standardized rollout protocol.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
"""Initialize the HTTP rollout resource."""
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.config: Optional[HttpRolloutConfig] = None
|
|
38
|
+
self.episode_id: Optional[str] = None
|
|
39
|
+
self.current_observation: Optional[Dict[str, Any]] = None
|
|
40
|
+
self.is_episode_active = False
|
|
41
|
+
self.client: Optional[httpx.Client] = None
|
|
42
|
+
|
|
43
|
+
# Set up logging
|
|
44
|
+
import logging
|
|
45
|
+
|
|
46
|
+
self.logger = logging.getLogger(f"{self.__class__.__name__}")
|
|
47
|
+
|
|
48
|
+
async def setup(self, config: Dict[str, Any]) -> None:
|
|
49
|
+
"""
|
|
50
|
+
Set up the resource with the provided configuration.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
config: Configuration dictionary from the task definition
|
|
54
|
+
"""
|
|
55
|
+
self.config = HttpRolloutConfig(**config)
|
|
56
|
+
self.client = httpx.Client(timeout=self.config.timeout)
|
|
57
|
+
|
|
58
|
+
async def fork(self) -> "HttpRolloutResource":
|
|
59
|
+
"""
|
|
60
|
+
Create a new independent instance of this resource.
|
|
61
|
+
|
|
62
|
+
For HTTP rollout, forking means creating a new resource instance
|
|
63
|
+
that will start its own episode when initialized.
|
|
64
|
+
"""
|
|
65
|
+
if not self.config:
|
|
66
|
+
raise RuntimeError("Resource not set up. Call setup() first.")
|
|
67
|
+
|
|
68
|
+
# Create a new instance with the same config
|
|
69
|
+
new_resource = HttpRolloutResource()
|
|
70
|
+
await new_resource.setup(self.config.model_dump())
|
|
71
|
+
return new_resource
|
|
72
|
+
|
|
73
|
+
async def get_state(self) -> Dict[str, Any]:
|
|
74
|
+
"""
|
|
75
|
+
Get the current state of the resource.
|
|
76
|
+
|
|
77
|
+
Returns the current observation and episode metadata.
|
|
78
|
+
"""
|
|
79
|
+
return {
|
|
80
|
+
"episode_id": self.episode_id,
|
|
81
|
+
"observation": self.current_observation,
|
|
82
|
+
"is_episode_active": self.is_episode_active,
|
|
83
|
+
"type": "http_rollout",
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
async def initialize(self, **kwargs) -> None:
|
|
87
|
+
"""
|
|
88
|
+
Initialize the resource by starting a new episode.
|
|
89
|
+
Passes any provided kwargs (like seed) to the server in the request body.
|
|
90
|
+
"""
|
|
91
|
+
try:
|
|
92
|
+
url = f"{self.config.base_url}{self.config.start_episode_endpoint}"
|
|
93
|
+
|
|
94
|
+
# Include any sample data (like seed) in the request body
|
|
95
|
+
if kwargs:
|
|
96
|
+
self.logger.info(f"Sending initialization data to server: {kwargs}")
|
|
97
|
+
response = self.client.post(url, json=kwargs)
|
|
98
|
+
else:
|
|
99
|
+
response = self.client.post(url)
|
|
100
|
+
response.raise_for_status()
|
|
101
|
+
|
|
102
|
+
episode_data = response.json()
|
|
103
|
+
self.episode_id = episode_data["episode_id"]
|
|
104
|
+
self.current_observation = episode_data["observation"]
|
|
105
|
+
self.is_episode_active = True
|
|
106
|
+
|
|
107
|
+
except Exception as e:
|
|
108
|
+
raise RuntimeError(f"Failed to start HTTP rollout episode: {e}")
|
|
109
|
+
|
|
110
|
+
async def get_initial_state_description(self) -> str:
|
|
111
|
+
"""
|
|
112
|
+
Get a formatted description of the initial game state for the agent.
|
|
113
|
+
Uses the observation from start_episode to build the prompt.
|
|
114
|
+
"""
|
|
115
|
+
# Start episode to get current game state
|
|
116
|
+
if not self.is_episode_active:
|
|
117
|
+
await self.initialize()
|
|
118
|
+
|
|
119
|
+
if not self.current_observation:
|
|
120
|
+
return "No initial state available."
|
|
121
|
+
|
|
122
|
+
obs = self.current_observation
|
|
123
|
+
|
|
124
|
+
# Build comprehensive game prompt
|
|
125
|
+
content = """🎮 FROZEN LAKE GAME - AUTONOMOUS PLAY MODE
|
|
126
|
+
|
|
127
|
+
🎯 OBJECTIVE: Navigate from S to G without hitting H
|
|
128
|
+
|
|
129
|
+
📋 GAME RULES: S=start, F=safe, H=hole(death), G=goal(win)
|
|
130
|
+
|
|
131
|
+
🤖 AUTONOMOUS MODE INSTRUCTIONS:
|
|
132
|
+
- You are playing this game AUTONOMOUSLY until completion
|
|
133
|
+
- KEEP MAKING MOVES using the step tool until you reach G or hit H
|
|
134
|
+
- DO NOT ask for user input or wait for confirmation
|
|
135
|
+
- DO NOT stop after one move - continue until the game ends
|
|
136
|
+
- Each move should be followed immediately by another move
|
|
137
|
+
- Game only ends when you reach G (win) or hit H (lose)
|
|
138
|
+
|
|
139
|
+
🎮 ACTION: Use step tool with: "left", "right", "up", or "down"
|
|
140
|
+
|
|
141
|
+
⚡ START NOW - Make your first move and continue until the game is complete!"""
|
|
142
|
+
|
|
143
|
+
description_parts = [content]
|
|
144
|
+
|
|
145
|
+
if obs.get("message"):
|
|
146
|
+
description_parts.append(f"\nEnvironment: {obs['message']}")
|
|
147
|
+
|
|
148
|
+
if obs.get("visual"):
|
|
149
|
+
description_parts.append(f"\nGame Board:\n{obs['visual']}")
|
|
150
|
+
|
|
151
|
+
if obs.get("position"):
|
|
152
|
+
description_parts.append(f"\nStarting Position: {obs['position']}")
|
|
153
|
+
|
|
154
|
+
description_parts.append("\nGame Rules:")
|
|
155
|
+
description_parts.append("- S = Start position")
|
|
156
|
+
description_parts.append("- F = Frozen (safe to step on)")
|
|
157
|
+
description_parts.append("- H = Hole (game over if you step here)")
|
|
158
|
+
description_parts.append("- G = Goal (reach this to win)")
|
|
159
|
+
description_parts.append("- [X] = Your current position")
|
|
160
|
+
|
|
161
|
+
return "\n".join(description_parts)
|
|
162
|
+
|
|
163
|
+
async def cleanup(self) -> None:
|
|
164
|
+
"""
|
|
165
|
+
Clean up the resource by ending the current episode.
|
|
166
|
+
"""
|
|
167
|
+
if self.is_episode_active and self.episode_id:
|
|
168
|
+
try:
|
|
169
|
+
url = f"{self.config.base_url}{self.config.end_episode_endpoint}"
|
|
170
|
+
response = self.client.post(url, json={"episode_id": self.episode_id})
|
|
171
|
+
response.raise_for_status()
|
|
172
|
+
|
|
173
|
+
except Exception as e:
|
|
174
|
+
# Log but don't raise - cleanup should be best effort
|
|
175
|
+
print(f"Warning: Failed to properly end episode {self.episode_id}: {e}")
|
|
176
|
+
|
|
177
|
+
finally:
|
|
178
|
+
self.episode_id = None
|
|
179
|
+
self.current_observation = None
|
|
180
|
+
self.is_episode_active = False
|
|
181
|
+
|
|
182
|
+
# Close the HTTP client
|
|
183
|
+
self.client.close()
|
|
184
|
+
|
|
185
|
+
async def get_tools_spec(self) -> List[Dict[str, Any]]:
|
|
186
|
+
"""
|
|
187
|
+
Get the list of available tools for this resource.
|
|
188
|
+
|
|
189
|
+
For HTTP rollout, this returns the 'step' tool that allows
|
|
190
|
+
the agent to take actions in the environment.
|
|
191
|
+
"""
|
|
192
|
+
return [
|
|
193
|
+
{
|
|
194
|
+
"name": "step",
|
|
195
|
+
"description": "Take a step in the Frozen Lake game by choosing a direction to move",
|
|
196
|
+
"parameters": {
|
|
197
|
+
"type": "object",
|
|
198
|
+
"properties": {
|
|
199
|
+
"action": {
|
|
200
|
+
"type": "string",
|
|
201
|
+
"enum": ["left", "down", "right", "up"],
|
|
202
|
+
"description": "The direction to move in the game: 'left', 'down', 'right', or 'up'",
|
|
203
|
+
}
|
|
204
|
+
},
|
|
205
|
+
"required": ["action"],
|
|
206
|
+
},
|
|
207
|
+
}
|
|
208
|
+
]
|
|
209
|
+
|
|
210
|
+
async def step(self, action_name: str, action_params: Dict[str, Any]) -> Any:
|
|
211
|
+
"""
|
|
212
|
+
Execute a tool call on this resource.
|
|
213
|
+
|
|
214
|
+
For HTTP rollout, this handles the 'step' tool by sending
|
|
215
|
+
the action to the HTTP rollout server.
|
|
216
|
+
"""
|
|
217
|
+
if not self.is_episode_active or not self.episode_id:
|
|
218
|
+
# If no active episode, start one first
|
|
219
|
+
await self.initialize()
|
|
220
|
+
|
|
221
|
+
if action_name == "step":
|
|
222
|
+
action = action_params.get("action")
|
|
223
|
+
return await self._handle_step_tool(action)
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(f"Unknown action: {action_name}")
|
|
226
|
+
|
|
227
|
+
async def get_observation(self) -> Any:
|
|
228
|
+
"""
|
|
229
|
+
Get the current observation from the environment.
|
|
230
|
+
"""
|
|
231
|
+
if self.current_observation:
|
|
232
|
+
return self.current_observation
|
|
233
|
+
else:
|
|
234
|
+
return {"message": "No observation available. Start an episode first."}
|
|
235
|
+
|
|
236
|
+
async def checkpoint(self) -> Dict[str, Any]:
|
|
237
|
+
"""
|
|
238
|
+
Create a checkpoint of the current resource state.
|
|
239
|
+
|
|
240
|
+
For HTTP rollout, this saves the episode ID and current observation.
|
|
241
|
+
"""
|
|
242
|
+
return {
|
|
243
|
+
"episode_id": self.episode_id,
|
|
244
|
+
"current_observation": self.current_observation,
|
|
245
|
+
"is_episode_active": self.is_episode_active,
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
async def restore(self, state_data: Dict[str, Any]) -> None:
|
|
249
|
+
"""
|
|
250
|
+
Restore the resource state from a checkpoint.
|
|
251
|
+
|
|
252
|
+
Note: This is limited for HTTP rollout since we can't restore
|
|
253
|
+
arbitrary server-side state.
|
|
254
|
+
"""
|
|
255
|
+
self.episode_id = state_data.get("episode_id")
|
|
256
|
+
self.current_observation = state_data.get("current_observation")
|
|
257
|
+
self.is_episode_active = state_data.get("is_episode_active", False)
|
|
258
|
+
|
|
259
|
+
async def close(self) -> None:
|
|
260
|
+
"""
|
|
261
|
+
Clean up and close the resource.
|
|
262
|
+
"""
|
|
263
|
+
await self.cleanup()
|
|
264
|
+
|
|
265
|
+
async def _handle_step_tool(self, action: Any) -> Dict[str, Any]:
|
|
266
|
+
"""
|
|
267
|
+
Handle the 'step' tool by sending an action to the HTTP rollout server.
|
|
268
|
+
"""
|
|
269
|
+
try:
|
|
270
|
+
# Convert string action to integer for the server
|
|
271
|
+
action_map = {"left": 0, "down": 1, "right": 2, "up": 3}
|
|
272
|
+
|
|
273
|
+
if isinstance(action, str):
|
|
274
|
+
if action.lower() not in action_map:
|
|
275
|
+
raise ValueError(f"Invalid action '{action}'. Must be one of: left, down, right, up")
|
|
276
|
+
numeric_action = action_map[action.lower()]
|
|
277
|
+
else:
|
|
278
|
+
# Backward compatibility with numeric actions
|
|
279
|
+
numeric_action = action
|
|
280
|
+
|
|
281
|
+
url = f"{self.config.base_url}{self.config.step_endpoint}"
|
|
282
|
+
step_data = {"episode_id": self.episode_id, "action": numeric_action}
|
|
283
|
+
|
|
284
|
+
response = self.client.post(url, json=step_data)
|
|
285
|
+
response.raise_for_status()
|
|
286
|
+
|
|
287
|
+
step_result = response.json()
|
|
288
|
+
self.current_observation = step_result["observation"]
|
|
289
|
+
|
|
290
|
+
# If the episode is done, mark it as inactive
|
|
291
|
+
if step_result.get("is_done", False):
|
|
292
|
+
self.is_episode_active = False
|
|
293
|
+
|
|
294
|
+
# Format the response for the agent
|
|
295
|
+
observation = step_result["observation"]
|
|
296
|
+
message = observation.get("message", "")
|
|
297
|
+
visual = observation.get("visual", "")
|
|
298
|
+
|
|
299
|
+
# Create a comprehensive response
|
|
300
|
+
response_content = []
|
|
301
|
+
if message:
|
|
302
|
+
response_content.append(f"Environment: {message}")
|
|
303
|
+
if visual:
|
|
304
|
+
response_content.append(f"Visual State:\n{visual}")
|
|
305
|
+
|
|
306
|
+
# Add structured data
|
|
307
|
+
response_content.append(f"Position: {observation.get('position', 'unknown')}")
|
|
308
|
+
response_content.append(f"Done: {step_result.get('is_done', False)}")
|
|
309
|
+
|
|
310
|
+
if step_result.get("is_done", False):
|
|
311
|
+
won = observation.get("won", False)
|
|
312
|
+
response_content.append(f"Result: {'Victory!' if won else 'Game Over'}")
|
|
313
|
+
|
|
314
|
+
return {"content": [{"type": "text", "text": "\n".join(response_content)}]}
|
|
315
|
+
|
|
316
|
+
except Exception as e:
|
|
317
|
+
raise RuntimeError(f"Failed to execute step: {e}")
|
|
318
|
+
|
|
319
|
+
def __del__(self):
|
|
320
|
+
"""Ensure cleanup on deletion."""
|
|
321
|
+
if hasattr(self, "client") and self.client:
|
|
322
|
+
try:
|
|
323
|
+
self.client.close()
|
|
324
|
+
except Exception:
|
|
325
|
+
pass # Ignore cleanup errors during deletion
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PythonStateResource: A ForkableResource that manages state as a Python dictionary.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
import pickle
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
from ..resource_abc import ForkableResource
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PythonStateResource(ForkableResource):
|
|
13
|
+
"""
|
|
14
|
+
A ForkableResource that manages its state as an in-memory Python dictionary.
|
|
15
|
+
|
|
16
|
+
This resource is useful for tasks where the environment's state can be
|
|
17
|
+
represented and manipulated directly as Python objects.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
_state (Dict[str, Any]): The internal dictionary holding the resource's state.
|
|
21
|
+
_config (Dict[str, Any]): The configuration passed during setup.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self) -> None:
|
|
25
|
+
self._state: Dict[str, Any] = {}
|
|
26
|
+
self._config: Dict[str, Any] = {}
|
|
27
|
+
|
|
28
|
+
async def setup(self, config: Dict[str, Any]) -> None:
|
|
29
|
+
"""
|
|
30
|
+
Initializes the resource with a given configuration.
|
|
31
|
+
|
|
32
|
+
The configuration can specify an 'initial_state' dictionary.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config: Configuration dictionary.
|
|
36
|
+
Expected keys:
|
|
37
|
+
- 'initial_state' (Optional[Dict[str, Any]]):
|
|
38
|
+
A dictionary to set as the initial state.
|
|
39
|
+
"""
|
|
40
|
+
self._config = copy.deepcopy(config)
|
|
41
|
+
self._state = copy.deepcopy(self._config.get("initial_state", {}))
|
|
42
|
+
|
|
43
|
+
async def fork(self) -> "PythonStateResource":
|
|
44
|
+
"""
|
|
45
|
+
Creates and returns a new, independent instance of this resource
|
|
46
|
+
with an identical copy of the current state.
|
|
47
|
+
"""
|
|
48
|
+
forked_resource = PythonStateResource()
|
|
49
|
+
forked_resource._config = copy.deepcopy(self._config)
|
|
50
|
+
forked_resource._state = copy.deepcopy(self._state)
|
|
51
|
+
return forked_resource
|
|
52
|
+
|
|
53
|
+
async def checkpoint(self) -> bytes:
|
|
54
|
+
"""
|
|
55
|
+
Returns a serializable representation of the resource's current state
|
|
56
|
+
using pickle.
|
|
57
|
+
"""
|
|
58
|
+
return pickle.dumps(self._state)
|
|
59
|
+
|
|
60
|
+
async def restore(self, state_data: bytes) -> None:
|
|
61
|
+
"""
|
|
62
|
+
Restores the resource's state from previously checkpointed state_data
|
|
63
|
+
(pickle format).
|
|
64
|
+
"""
|
|
65
|
+
self._state = pickle.loads(state_data)
|
|
66
|
+
|
|
67
|
+
async def step(self, action_name: str, action_params: Dict[str, Any]) -> Any:
|
|
68
|
+
"""
|
|
69
|
+
Executes a named action with given parameters on the resource.
|
|
70
|
+
|
|
71
|
+
This implementation provides a generic 'update_state' action
|
|
72
|
+
that merges action_params into the current state.
|
|
73
|
+
Subclasses could override this for more specific actions.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
action_name: The name of the action to perform.
|
|
77
|
+
Currently supports 'update_state'.
|
|
78
|
+
action_params: A dictionary of parameters for the action.
|
|
79
|
+
For 'update_state', these are key-value pairs
|
|
80
|
+
to update in the state.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
A copy of the updated state.
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
NotImplementedError: If action_name is not 'update_state'.
|
|
87
|
+
"""
|
|
88
|
+
if action_name == "update_state":
|
|
89
|
+
self._state.update(action_params)
|
|
90
|
+
return copy.deepcopy(self._state)
|
|
91
|
+
elif action_name == "get_value":
|
|
92
|
+
key = action_params.get("key")
|
|
93
|
+
if key is None:
|
|
94
|
+
raise ValueError("Missing 'key' in action_params for 'get_value'")
|
|
95
|
+
return self._state.get(key)
|
|
96
|
+
else:
|
|
97
|
+
raise NotImplementedError(f"Action '{action_name}' is not implemented for PythonStateResource.")
|
|
98
|
+
|
|
99
|
+
async def get_observation(self) -> Dict[str, Any]:
|
|
100
|
+
"""
|
|
101
|
+
Returns a deep copy of the current observable state of the resource.
|
|
102
|
+
"""
|
|
103
|
+
return copy.deepcopy(self._state)
|
|
104
|
+
|
|
105
|
+
def get_state(self) -> Dict[str, Any]:
|
|
106
|
+
"""
|
|
107
|
+
Returns a deep copy of the current state dictionary.
|
|
108
|
+
This is a synchronous version of get_observation for compatibility with test tasks.
|
|
109
|
+
"""
|
|
110
|
+
return copy.deepcopy(self._state)
|
|
111
|
+
|
|
112
|
+
def set_state(self, state: Dict[str, Any]) -> None:
|
|
113
|
+
"""
|
|
114
|
+
Sets the resource's state to the provided dictionary.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
state: A dictionary containing the new state.
|
|
118
|
+
"""
|
|
119
|
+
self._state = copy.deepcopy(state)
|
|
120
|
+
|
|
121
|
+
async def get_tools_spec(self) -> List[Dict[str, Any]]:
|
|
122
|
+
"""
|
|
123
|
+
Returns a list of tool specifications available for this resource.
|
|
124
|
+
|
|
125
|
+
Provides generic 'update_state' and 'get_value' tools.
|
|
126
|
+
"""
|
|
127
|
+
return [
|
|
128
|
+
{
|
|
129
|
+
"type": "function",
|
|
130
|
+
"function": {
|
|
131
|
+
"name": "update_state",
|
|
132
|
+
"description": "Updates the current state dictionary with the provided key-value pairs.",
|
|
133
|
+
"parameters": {
|
|
134
|
+
"type": "object",
|
|
135
|
+
"properties": {
|
|
136
|
+
"updates": {
|
|
137
|
+
"type": "object",
|
|
138
|
+
"description": "Key-value pairs to update in the state.",
|
|
139
|
+
}
|
|
140
|
+
},
|
|
141
|
+
"required": ["updates"],
|
|
142
|
+
},
|
|
143
|
+
},
|
|
144
|
+
},
|
|
145
|
+
{
|
|
146
|
+
"type": "function",
|
|
147
|
+
"function": {
|
|
148
|
+
"name": "get_value",
|
|
149
|
+
"description": "Retrieves a value from the state dictionary for a given key.",
|
|
150
|
+
"parameters": {
|
|
151
|
+
"type": "object",
|
|
152
|
+
"properties": {
|
|
153
|
+
"key": {
|
|
154
|
+
"type": "string",
|
|
155
|
+
"description": "The key of the value to retrieve.",
|
|
156
|
+
}
|
|
157
|
+
},
|
|
158
|
+
"required": ["key"],
|
|
159
|
+
},
|
|
160
|
+
},
|
|
161
|
+
},
|
|
162
|
+
]
|
|
163
|
+
|
|
164
|
+
async def close(self) -> None:
|
|
165
|
+
"""
|
|
166
|
+
Performs any necessary cleanup for the resource.
|
|
167
|
+
For PythonStateResource, this is a no-op as state is in-memory.
|
|
168
|
+
"""
|
|
169
|
+
self._state = {}
|
|
170
|
+
self._config = {}
|