synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +579 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/wordle/__init__.py +29 -0
- synth_ai/environments/examples/wordle/engine.py +391 -0
- synth_ai/environments/examples/wordle/environment.py +154 -0
- synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
- synth_ai/environments/examples/wordle/taskset.py +222 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/environments/service/core_routes.py +38 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
- synth_ai/learning/prompts/mipro.py +273 -1
- synth_ai/learning/prompts/random_search.py +247 -0
- synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
- synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
- synth_ai/lm/injection.py +81 -0
- synth_ai/lm/overrides.py +204 -0
- synth_ai/lm/provider_support/anthropic.py +39 -12
- synth_ai/lm/provider_support/openai.py +31 -4
- synth_ai/lm/vendors/core/anthropic_api.py +16 -0
- synth_ai/lm/vendors/openai_standard.py +35 -5
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +123 -13
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,312 @@
|
|
1
|
+
"""CrafterCustomEnvironment — Custom Crafter with configurable world generation."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import List, Optional, Any, Dict, Union
|
6
|
+
import dataclasses
|
7
|
+
import logging
|
8
|
+
import time
|
9
|
+
|
10
|
+
# Import logging configuration to suppress JAX debug messages
|
11
|
+
from synth_ai.environments.examples.crafter_classic.config_logging import safe_compare
|
12
|
+
|
13
|
+
# Import tracing abstractions
|
14
|
+
from synth_ai.tracing_v3.abstractions import (
|
15
|
+
RuntimeEvent,
|
16
|
+
SessionEventMarkovBlanketMessage,
|
17
|
+
TimeRecord,
|
18
|
+
)
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
# Import the base Crafter components
|
23
|
+
from synth_ai.environments.examples.crafter_classic.engine import (
|
24
|
+
CrafterEngine,
|
25
|
+
CrafterPrivateState,
|
26
|
+
CrafterPublicState,
|
27
|
+
CrafterEngineSnapshot,
|
28
|
+
)
|
29
|
+
from synth_ai.environments.examples.crafter_classic.taskset import CrafterTaskInstance
|
30
|
+
from synth_ai.environments.environment.shared_engine import (
|
31
|
+
GetObservationCallable,
|
32
|
+
InternalObservation,
|
33
|
+
)
|
34
|
+
from synth_ai.environments.reproducibility.core import ReproducibleEnvironment
|
35
|
+
from synth_ai.environments.stateful.core import StatefulEnvironment
|
36
|
+
from synth_ai.environments.environment.tools import (
|
37
|
+
AbstractTool,
|
38
|
+
EnvToolCall,
|
39
|
+
ToolResult,
|
40
|
+
TOOL_REGISTRY,
|
41
|
+
register_tool,
|
42
|
+
)
|
43
|
+
from pydantic import BaseModel, Field
|
44
|
+
|
45
|
+
|
46
|
+
# Use the same tool and observation classes as CrafterClassic
|
47
|
+
from synth_ai.environments.examples.crafter_classic.environment import (
|
48
|
+
CrafterActionInput,
|
49
|
+
CrafterInteractTool,
|
50
|
+
SynthCrafterObservationCallable,
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
class CrafterCustomEnvironment(StatefulEnvironment, ReproducibleEnvironment[CrafterEngine]):
|
55
|
+
"""Custom Crafter environment with configurable world generation."""
|
56
|
+
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
task_instance: "CrafterTaskInstance",
|
60
|
+
custom_step_obs: Optional[GetObservationCallable] = None,
|
61
|
+
custom_ckpt_obs: Optional[GetObservationCallable] = None,
|
62
|
+
session_tracer: Optional[Any] = None, # SessionTracer from higher level
|
63
|
+
) -> None:
|
64
|
+
self.name = "CrafterCustom"
|
65
|
+
self.task_instance = task_instance
|
66
|
+
self.custom_step_observation_callable = custom_step_obs or SynthCrafterObservationCallable()
|
67
|
+
self.custom_checkpoint_observation_callable = (
|
68
|
+
custom_ckpt_obs or SynthCrafterObservationCallable()
|
69
|
+
)
|
70
|
+
|
71
|
+
# Ensure task instance has world configuration
|
72
|
+
if hasattr(task_instance, "metadata"):
|
73
|
+
logger.info(
|
74
|
+
f"Creating CrafterCustom with world_config: {getattr(task_instance.metadata, 'world_config', 'default')}"
|
75
|
+
)
|
76
|
+
|
77
|
+
self.engine = CrafterEngine(task_instance)
|
78
|
+
self.session_tracer = session_tracer # Store tracer for runtime events
|
79
|
+
|
80
|
+
self._interact_tool = CrafterInteractTool(self.engine, session_tracer=session_tracer)
|
81
|
+
|
82
|
+
# Register tool with a unique name for this environment
|
83
|
+
tool_name = f"{self.name.lower()}_interact"
|
84
|
+
if tool_name not in TOOL_REGISTRY:
|
85
|
+
# Create a copy of the tool with the custom name
|
86
|
+
self._interact_tool.name = tool_name
|
87
|
+
register_tool(self._interact_tool)
|
88
|
+
|
89
|
+
# ────────────────────────────────────────────────────────────────────
|
90
|
+
# Lifecycle helpers
|
91
|
+
# ────────────────────────────────────────────────────────────────────
|
92
|
+
|
93
|
+
async def initialize(self) -> InternalObservation: # type: ignore[override]
|
94
|
+
priv, pub = await self.engine._reset_engine()
|
95
|
+
return await self._to_observation(priv, pub, self.custom_step_observation_callable)
|
96
|
+
|
97
|
+
async def terminate(self) -> InternalObservation: # type: ignore[override]
|
98
|
+
pub = self.engine._get_public_state_from_env()
|
99
|
+
priv = self.engine._get_private_state_from_env(0, True, False) # Terminated state
|
100
|
+
priv.terminated = True
|
101
|
+
obs_dict = {"status": "Environment terminated."}
|
102
|
+
return await self._to_observation(
|
103
|
+
priv, pub, self.custom_step_observation_callable, extra_obs=obs_dict
|
104
|
+
)
|
105
|
+
|
106
|
+
# ────────────────────────────────────────────────────────────────────
|
107
|
+
# Step + checkpoint
|
108
|
+
# ────────────────────────────────────────────────────────────────────
|
109
|
+
|
110
|
+
def validate_tool_calls(
|
111
|
+
self, tool_calls: Union[EnvToolCall, List[EnvToolCall], List[List[EnvToolCall]]]
|
112
|
+
) -> EnvToolCall:
|
113
|
+
# Store the original tool calls for tracing
|
114
|
+
state_before = {"tool_calls": tool_calls}
|
115
|
+
|
116
|
+
# Normalize and validate to a single EnvToolCall
|
117
|
+
if isinstance(tool_calls, list):
|
118
|
+
if not tool_calls:
|
119
|
+
raise ValueError("Received empty list of tool calls.")
|
120
|
+
if isinstance(tool_calls[0], list):
|
121
|
+
if not tool_calls[0]:
|
122
|
+
raise ValueError("Received empty inner list of tool calls.")
|
123
|
+
agent_call = tool_calls[0][0]
|
124
|
+
else:
|
125
|
+
agent_call = tool_calls[0]
|
126
|
+
elif isinstance(tool_calls, EnvToolCall):
|
127
|
+
agent_call = tool_calls
|
128
|
+
else:
|
129
|
+
raise TypeError(f"Unexpected type for tool_calls: {type(tool_calls)}")
|
130
|
+
|
131
|
+
if not isinstance(agent_call, EnvToolCall):
|
132
|
+
raise TypeError(f"Processed call is not EnvToolCall: {type(agent_call)}")
|
133
|
+
|
134
|
+
# Accept both "interact" and "craftercustom_interact"
|
135
|
+
if agent_call.tool not in ["interact", f"{self.name.lower()}_interact"]:
|
136
|
+
raise ValueError(
|
137
|
+
f"Unknown tool: {agent_call.tool}. Expected 'interact' or '{self.name.lower()}_interact'."
|
138
|
+
)
|
139
|
+
|
140
|
+
# Record runtime event for tool call validation
|
141
|
+
if (
|
142
|
+
self.session_tracer
|
143
|
+
and hasattr(self.session_tracer, "current_session")
|
144
|
+
and self.session_tracer.current_session
|
145
|
+
):
|
146
|
+
runtime_validation_event = RuntimeEvent()
|
147
|
+
runtime_validation_event.time_record = TimeRecord()
|
148
|
+
runtime_validation_event.time_record.event_time = time.time()
|
149
|
+
runtime_validation_event.time_record.message_time = None
|
150
|
+
runtime_validation_event.system_instance_id = "crafter_custom_environment"
|
151
|
+
runtime_validation_event.system_state_before = state_before
|
152
|
+
runtime_validation_event.system_state_after = {"validated_call": agent_call}
|
153
|
+
runtime_validation_event.metadata = {"validation_step": "tool_call_validation"}
|
154
|
+
# Add directly to event history, bypassing timestep requirement
|
155
|
+
self.session_tracer.current_session.add_event(runtime_validation_event)
|
156
|
+
|
157
|
+
return agent_call
|
158
|
+
|
159
|
+
async def step(
|
160
|
+
self, tool_calls: Union[EnvToolCall, List[EnvToolCall], List[List[EnvToolCall]]]
|
161
|
+
) -> InternalObservation: # type: ignore[override]
|
162
|
+
agent_call = self.validate_tool_calls(tool_calls)
|
163
|
+
tool_result: ToolResult = await self._interact_tool(agent_call)
|
164
|
+
|
165
|
+
payload_dict = tool_result.payload
|
166
|
+
pub_state: CrafterPublicState
|
167
|
+
priv_state: CrafterPrivateState
|
168
|
+
|
169
|
+
if tool_result.ok:
|
170
|
+
# payload contains the actual state objects from the interact tool
|
171
|
+
priv_state = payload_dict.get("private_state")
|
172
|
+
pub_state = payload_dict.get("public_state")
|
173
|
+
|
174
|
+
# Validate we got the expected state objects
|
175
|
+
if not isinstance(priv_state, CrafterPrivateState) or not isinstance(
|
176
|
+
pub_state, CrafterPublicState
|
177
|
+
):
|
178
|
+
logger.error(
|
179
|
+
f"Invalid state types in payload: priv={type(priv_state)}, pub={type(pub_state)}"
|
180
|
+
)
|
181
|
+
# Fall back to getting current state
|
182
|
+
pub_state = self.engine._get_public_state_from_env()
|
183
|
+
health_dead = safe_compare(0, self.engine.env._player.health, ">=")
|
184
|
+
step_exceeded = safe_compare(self.engine.env._length, self.engine.env._step, "<=")
|
185
|
+
priv_state = self.engine._get_private_state_from_env(0, health_dead, step_exceeded)
|
186
|
+
pub_state.error_info = "Invalid state types in tool result"
|
187
|
+
else:
|
188
|
+
# Tool call failed, use states from payload if available, otherwise get current state
|
189
|
+
priv_state = payload_dict.get("private_state")
|
190
|
+
pub_state = payload_dict.get("public_state")
|
191
|
+
|
192
|
+
if not isinstance(priv_state, CrafterPrivateState) or not isinstance(
|
193
|
+
pub_state, CrafterPublicState
|
194
|
+
):
|
195
|
+
# Fall back to getting current state
|
196
|
+
pub_state = self.engine._get_public_state_from_env()
|
197
|
+
health_dead = safe_compare(0, self.engine.env._player.health, ">=")
|
198
|
+
step_exceeded = safe_compare(self.engine.env._length, self.engine.env._step, "<=")
|
199
|
+
priv_state = self.engine._get_private_state_from_env(0, health_dead, step_exceeded)
|
200
|
+
|
201
|
+
if tool_result.error:
|
202
|
+
pub_state.error_info = tool_result.error
|
203
|
+
|
204
|
+
return await self._to_observation(
|
205
|
+
priv_state, pub_state, self.custom_step_observation_callable
|
206
|
+
)
|
207
|
+
|
208
|
+
async def checkpoint(self) -> InternalObservation: # type: ignore[override]
|
209
|
+
engine_snapshot: CrafterEngineSnapshot = await self.engine._serialize_engine()
|
210
|
+
priv = self.engine._get_private_state_from_env(0, False, False) # Get current state for obs
|
211
|
+
pub = self.engine._get_public_state_from_env()
|
212
|
+
obs_data = await self._to_observation(
|
213
|
+
priv, pub, self.custom_checkpoint_observation_callable
|
214
|
+
)
|
215
|
+
if isinstance(obs_data, dict):
|
216
|
+
obs_data["engine_snapshot_data"] = engine_snapshot.model_dump()
|
217
|
+
return obs_data
|
218
|
+
|
219
|
+
async def get_metadata(self) -> Dict[str, Any]:
|
220
|
+
"""Get metadata about the current environment configuration."""
|
221
|
+
metadata = {
|
222
|
+
"environment_type": "CrafterCustom",
|
223
|
+
"engine_seed": getattr(self.engine.env, "_seed", None),
|
224
|
+
"world_area": self.engine.env._area,
|
225
|
+
"max_steps": self.engine.env._length,
|
226
|
+
"current_step": self.engine.env._step,
|
227
|
+
}
|
228
|
+
|
229
|
+
# Add task instance metadata
|
230
|
+
if hasattr(self.task_instance, "metadata"):
|
231
|
+
task_metadata = self.task_instance.metadata
|
232
|
+
metadata.update(
|
233
|
+
{
|
234
|
+
"difficulty": getattr(task_metadata, "difficulty", None),
|
235
|
+
"world_config": getattr(task_metadata, "world_config", None),
|
236
|
+
"world_config_path": getattr(task_metadata, "world_config_path", None),
|
237
|
+
"num_trees_radius": getattr(task_metadata, "num_trees_radius", None),
|
238
|
+
"num_cows_radius": getattr(task_metadata, "num_cows_radius", None),
|
239
|
+
"num_hostiles_radius": getattr(task_metadata, "num_hostiles_radius", None),
|
240
|
+
}
|
241
|
+
)
|
242
|
+
|
243
|
+
# Add current world statistics
|
244
|
+
if hasattr(self.engine, "env") and hasattr(self.engine.env, "_world"):
|
245
|
+
world = self.engine.env._world
|
246
|
+
object_counts = {}
|
247
|
+
|
248
|
+
for obj in world._objects:
|
249
|
+
if obj is None:
|
250
|
+
continue
|
251
|
+
obj_type = type(obj).__name__
|
252
|
+
object_counts[obj_type] = object_counts.get(obj_type, 0) + 1
|
253
|
+
|
254
|
+
metadata["world_object_counts"] = object_counts
|
255
|
+
|
256
|
+
return metadata
|
257
|
+
|
258
|
+
# ────────────────────────────────────────────────────────────────────
|
259
|
+
# Helpers
|
260
|
+
# ────────────────────────────────────────────────────────────────────
|
261
|
+
|
262
|
+
async def _to_observation(
|
263
|
+
self,
|
264
|
+
priv: CrafterPrivateState,
|
265
|
+
pub: CrafterPublicState,
|
266
|
+
obs_cb: Optional[GetObservationCallable],
|
267
|
+
extra_obs: Optional[Dict[str, Any]] = None,
|
268
|
+
) -> InternalObservation:
|
269
|
+
# Store state before observation generation
|
270
|
+
state_before = {"private_state": priv, "public_state": pub}
|
271
|
+
|
272
|
+
active_obs_cb = obs_cb or SynthCrafterObservationCallable()
|
273
|
+
observation = await active_obs_cb.get_observation(pub, priv)
|
274
|
+
if extra_obs and isinstance(observation, dict):
|
275
|
+
observation.update(extra_obs)
|
276
|
+
|
277
|
+
# Record runtime event for observation generation
|
278
|
+
if (
|
279
|
+
self.session_tracer
|
280
|
+
and hasattr(self.session_tracer, "current_session")
|
281
|
+
and self.session_tracer.current_session
|
282
|
+
):
|
283
|
+
runtime_obs_event = RuntimeEvent()
|
284
|
+
runtime_obs_event.time_record = TimeRecord()
|
285
|
+
runtime_obs_event.time_record.event_time = time.time()
|
286
|
+
runtime_obs_event.time_record.message_time = None
|
287
|
+
runtime_obs_event.system_instance_id = "observation_generator"
|
288
|
+
runtime_obs_event.system_state_before = state_before
|
289
|
+
runtime_obs_event.system_state_after = {"observation": observation}
|
290
|
+
runtime_obs_event.metadata = {"observation_step": "state_to_obs_conversion"}
|
291
|
+
# Add directly to event history, bypassing timestep requirement
|
292
|
+
self.session_tracer.current_session.add_event(runtime_obs_event)
|
293
|
+
|
294
|
+
return observation
|
295
|
+
|
296
|
+
# ────────────────────────────────────────────────────────────────────
|
297
|
+
# ReproducibleEnvironment plumbing
|
298
|
+
# ────────────────────────────────────────────────────────────────────
|
299
|
+
|
300
|
+
async def _serialize_engine(self) -> CrafterEngineSnapshot:
|
301
|
+
return await self.engine._serialize_engine()
|
302
|
+
|
303
|
+
@classmethod
|
304
|
+
async def _deserialize_engine(
|
305
|
+
cls, snapshot: CrafterEngineSnapshot, task_instance: "CrafterTaskInstance"
|
306
|
+
) -> "CrafterCustomEnvironment":
|
307
|
+
eng = await CrafterEngine._deserialize_engine(snapshot, task_instance)
|
308
|
+
env = cls(task_instance)
|
309
|
+
env.engine = eng
|
310
|
+
# CRITICAL: Update the interact tool to use the new engine!
|
311
|
+
env._interact_tool.engine = eng
|
312
|
+
return env
|
@@ -0,0 +1,305 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Run script for Crafter dataset instances
|
4
|
+
"""
|
5
|
+
|
6
|
+
import json
|
7
|
+
import argparse
|
8
|
+
import random
|
9
|
+
from pathlib import Path
|
10
|
+
from typing import List, Optional, Dict, Any
|
11
|
+
import uuid
|
12
|
+
import os
|
13
|
+
import sys
|
14
|
+
|
15
|
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
|
16
|
+
|
17
|
+
from crafter import Env
|
18
|
+
|
19
|
+
|
20
|
+
class CrafterDatasetRunner:
|
21
|
+
"""Run Crafter instances from a dataset"""
|
22
|
+
|
23
|
+
def __init__(self, dataset_path: Path = Path("dataset")):
|
24
|
+
self.dataset_path = dataset_path
|
25
|
+
|
26
|
+
def load_dataset(self, dataset_name: str) -> Dict[str, Any]:
|
27
|
+
"""Load a dataset from disk"""
|
28
|
+
dataset_dir = self.dataset_path / dataset_name
|
29
|
+
|
30
|
+
# Load metadata
|
31
|
+
with open(dataset_dir / "metadata.json", "r") as f:
|
32
|
+
metadata = json.load(f)
|
33
|
+
|
34
|
+
# Load instances
|
35
|
+
with open(dataset_dir / "instances.json", "r") as f:
|
36
|
+
instances = json.load(f)
|
37
|
+
|
38
|
+
return {"metadata": metadata, "instances": instances}
|
39
|
+
|
40
|
+
def filter_instances(
|
41
|
+
self,
|
42
|
+
instances: List[Dict[str, Any]],
|
43
|
+
difficulties: Optional[List[str]] = None,
|
44
|
+
impetus_types: Optional[List[str]] = None,
|
45
|
+
split: Optional[str] = None,
|
46
|
+
split_info: Optional[Dict[str, Any]] = None,
|
47
|
+
) -> List[Dict[str, Any]]:
|
48
|
+
"""Filter instances based on criteria"""
|
49
|
+
filtered = instances
|
50
|
+
|
51
|
+
# Filter by difficulty
|
52
|
+
if difficulties:
|
53
|
+
filtered = [inst for inst in filtered if inst["metadata"]["difficulty"] in difficulties]
|
54
|
+
|
55
|
+
# Filter by impetus type
|
56
|
+
if impetus_types:
|
57
|
+
filtered = [inst for inst in filtered if self._get_impetus_type(inst) in impetus_types]
|
58
|
+
|
59
|
+
# Filter by split
|
60
|
+
if split and split_info:
|
61
|
+
if split == "train":
|
62
|
+
val_ids = set(split_info["val_instance_ids"])
|
63
|
+
test_ids = set(split_info["test_instance_ids"])
|
64
|
+
filtered = [
|
65
|
+
inst
|
66
|
+
for inst in filtered
|
67
|
+
if inst["id"] not in val_ids and inst["id"] not in test_ids
|
68
|
+
]
|
69
|
+
elif split == "val":
|
70
|
+
val_ids = set(split_info["val_instance_ids"])
|
71
|
+
filtered = [inst for inst in filtered if inst["id"] in val_ids]
|
72
|
+
elif split == "test":
|
73
|
+
test_ids = set(split_info["test_instance_ids"])
|
74
|
+
filtered = [inst for inst in filtered if inst["id"] in test_ids]
|
75
|
+
|
76
|
+
return filtered
|
77
|
+
|
78
|
+
def _get_impetus_type(self, instance: Dict[str, Any]) -> str:
|
79
|
+
"""Determine impetus type from instructions"""
|
80
|
+
instructions = instance["impetus"]["instructions"].lower()
|
81
|
+
if "speedrun" in instructions:
|
82
|
+
return "speedrun"
|
83
|
+
elif "focus on" in instructions:
|
84
|
+
return "focused"
|
85
|
+
else:
|
86
|
+
return "general"
|
87
|
+
|
88
|
+
def run_instance(
|
89
|
+
self, instance: Dict[str, Any], render: bool = False, max_steps: int = 1000, agent_fn=None
|
90
|
+
):
|
91
|
+
"""Run a single instance"""
|
92
|
+
|
93
|
+
# Extract parameters
|
94
|
+
difficulty = instance["metadata"]["difficulty"]
|
95
|
+
seed = instance["metadata"]["world_seed"]
|
96
|
+
|
97
|
+
print(f"\n{'=' * 60}")
|
98
|
+
print(f"Running instance: {instance['id']}")
|
99
|
+
print(f"Difficulty: {difficulty}")
|
100
|
+
print(f"Seed: {seed}")
|
101
|
+
print(f"Instructions: {instance['impetus']['instructions']}")
|
102
|
+
if instance["impetus"].get("achievement_focus"):
|
103
|
+
print(f"Focus: {', '.join(instance['impetus']['achievement_focus'])}")
|
104
|
+
print(f"{'=' * 60}")
|
105
|
+
|
106
|
+
# Create environment
|
107
|
+
env = Env(seed=seed, world_config=difficulty)
|
108
|
+
|
109
|
+
obs = env.reset()
|
110
|
+
|
111
|
+
# Run agent or random policy
|
112
|
+
total_reward = 0
|
113
|
+
achievements = set()
|
114
|
+
|
115
|
+
for step in range(max_steps):
|
116
|
+
if agent_fn:
|
117
|
+
action = agent_fn(obs, instance)
|
118
|
+
else:
|
119
|
+
action = env.action_space.sample()
|
120
|
+
|
121
|
+
obs, reward, done, info = env.step(action)
|
122
|
+
total_reward += reward
|
123
|
+
|
124
|
+
# Track achievements
|
125
|
+
if "achievements" in info:
|
126
|
+
for ach, unlocked in info["achievements"].items():
|
127
|
+
if unlocked:
|
128
|
+
achievements.add(ach)
|
129
|
+
|
130
|
+
if done:
|
131
|
+
break
|
132
|
+
|
133
|
+
# Evaluate based on intent
|
134
|
+
success = self._evaluate_instance(instance, achievements, total_reward, step)
|
135
|
+
|
136
|
+
print(f"\nResults:")
|
137
|
+
print(f"Steps: {step}")
|
138
|
+
print(f"Total reward: {total_reward}")
|
139
|
+
print(f"Achievements: {len(achievements)} - {list(achievements)}")
|
140
|
+
print(f"Success: {success}")
|
141
|
+
|
142
|
+
return {
|
143
|
+
"instance_id": instance["id"],
|
144
|
+
"difficulty": difficulty,
|
145
|
+
"seed": seed,
|
146
|
+
"steps": step,
|
147
|
+
"total_reward": total_reward,
|
148
|
+
"achievements": list(achievements),
|
149
|
+
"success": success,
|
150
|
+
}
|
151
|
+
|
152
|
+
def _evaluate_instance(
|
153
|
+
self, instance: Dict[str, Any], achievements: set, total_reward: float, steps: int
|
154
|
+
) -> bool:
|
155
|
+
"""Evaluate if instance was successful based on intent"""
|
156
|
+
intent = instance["intent"]
|
157
|
+
|
158
|
+
# Check minimum score
|
159
|
+
if intent.get("minimum_score"):
|
160
|
+
if len(achievements) < intent["minimum_score"]:
|
161
|
+
return False
|
162
|
+
|
163
|
+
# Check target achievements
|
164
|
+
if intent.get("target_achievements"):
|
165
|
+
targets = set(intent["target_achievements"])
|
166
|
+
if not achievements.intersection(targets):
|
167
|
+
return False
|
168
|
+
|
169
|
+
return True
|
170
|
+
|
171
|
+
def run_batch(
|
172
|
+
self,
|
173
|
+
dataset_name: str,
|
174
|
+
num_instances: int = 10,
|
175
|
+
difficulties: Optional[List[str]] = None,
|
176
|
+
impetus_types: Optional[List[str]] = None,
|
177
|
+
split: Optional[str] = None,
|
178
|
+
render: bool = False,
|
179
|
+
max_steps: int = 1000,
|
180
|
+
agent_fn=None,
|
181
|
+
):
|
182
|
+
"""Run a batch of instances"""
|
183
|
+
|
184
|
+
# Load dataset
|
185
|
+
dataset = self.load_dataset(dataset_name)
|
186
|
+
instances = dataset["instances"]
|
187
|
+
|
188
|
+
# Filter instances
|
189
|
+
filtered = self.filter_instances(
|
190
|
+
instances,
|
191
|
+
difficulties=difficulties,
|
192
|
+
impetus_types=impetus_types,
|
193
|
+
split=split,
|
194
|
+
split_info=dataset["metadata"].get("split_info"),
|
195
|
+
)
|
196
|
+
|
197
|
+
if not filtered:
|
198
|
+
print("No instances match the filter criteria!")
|
199
|
+
return []
|
200
|
+
|
201
|
+
# Sample instances
|
202
|
+
if num_instances > len(filtered):
|
203
|
+
print(f"Only {len(filtered)} instances available, running all")
|
204
|
+
selected = filtered
|
205
|
+
else:
|
206
|
+
selected = random.sample(filtered, num_instances)
|
207
|
+
|
208
|
+
print(f"\nRunning {len(selected)} instances from {dataset_name}")
|
209
|
+
print(f"Difficulties: {difficulties or 'all'}")
|
210
|
+
print(f"Impetus types: {impetus_types or 'all'}")
|
211
|
+
print(f"Split: {split or 'all'}")
|
212
|
+
|
213
|
+
# Run instances
|
214
|
+
results = []
|
215
|
+
for instance in selected:
|
216
|
+
result = self.run_instance(
|
217
|
+
instance, render=render, max_steps=max_steps, agent_fn=agent_fn
|
218
|
+
)
|
219
|
+
results.append(result)
|
220
|
+
|
221
|
+
# Summary statistics
|
222
|
+
self._print_summary(results)
|
223
|
+
|
224
|
+
return results
|
225
|
+
|
226
|
+
def _print_summary(self, results: List[Dict[str, Any]]):
|
227
|
+
"""Print summary statistics"""
|
228
|
+
print(f"\n{'=' * 60}")
|
229
|
+
print("SUMMARY")
|
230
|
+
print(f"{'=' * 60}")
|
231
|
+
|
232
|
+
# Group by difficulty
|
233
|
+
by_difficulty = {}
|
234
|
+
for result in results:
|
235
|
+
diff = result["difficulty"]
|
236
|
+
if diff not in by_difficulty:
|
237
|
+
by_difficulty[diff] = []
|
238
|
+
by_difficulty[diff].append(result)
|
239
|
+
|
240
|
+
print(f"\nResults by difficulty:")
|
241
|
+
print(
|
242
|
+
f"{'Difficulty':<15} {'Count':<8} {'Success':<10} {'Avg Steps':<12} {'Avg Achievements'}"
|
243
|
+
)
|
244
|
+
print("-" * 60)
|
245
|
+
|
246
|
+
for diff in sorted(by_difficulty.keys()):
|
247
|
+
diff_results = by_difficulty[diff]
|
248
|
+
count = len(diff_results)
|
249
|
+
success_rate = sum(1 for r in diff_results if r["success"]) / count
|
250
|
+
avg_steps = sum(r["steps"] for r in diff_results) / count
|
251
|
+
avg_achievements = sum(len(r["achievements"]) for r in diff_results) / count
|
252
|
+
|
253
|
+
print(
|
254
|
+
f"{diff:<15} {count:<8} {success_rate:<10.1%} {avg_steps:<12.1f} {avg_achievements:.1f}"
|
255
|
+
)
|
256
|
+
|
257
|
+
# Overall stats
|
258
|
+
total_success = sum(1 for r in results if r["success"])
|
259
|
+
print(
|
260
|
+
f"\nOverall success rate: {total_success}/{len(results)} ({total_success / len(results):.1%})"
|
261
|
+
)
|
262
|
+
|
263
|
+
|
264
|
+
def main():
|
265
|
+
parser = argparse.ArgumentParser(description="Run Crafter dataset instances")
|
266
|
+
parser.add_argument("dataset", help="Dataset name")
|
267
|
+
parser.add_argument(
|
268
|
+
"-n", "--num-instances", type=int, default=10, help="Number of instances to run"
|
269
|
+
)
|
270
|
+
parser.add_argument(
|
271
|
+
"-d",
|
272
|
+
"--difficulties",
|
273
|
+
nargs="+",
|
274
|
+
choices=["easy", "normal", "hard", "peaceful", "resource_rich"],
|
275
|
+
help="Filter by difficulties",
|
276
|
+
)
|
277
|
+
parser.add_argument(
|
278
|
+
"-t",
|
279
|
+
"--impetus-types",
|
280
|
+
nargs="+",
|
281
|
+
choices=["general", "focused", "speedrun"],
|
282
|
+
help="Filter by impetus types",
|
283
|
+
)
|
284
|
+
parser.add_argument(
|
285
|
+
"-s", "--split", choices=["train", "val", "test"], help="Filter by dataset split"
|
286
|
+
)
|
287
|
+
parser.add_argument("--render", action="store_true", help="Render the environment")
|
288
|
+
parser.add_argument("--max-steps", type=int, default=1000, help="Maximum steps per episode")
|
289
|
+
|
290
|
+
args = parser.parse_args()
|
291
|
+
|
292
|
+
runner = CrafterDatasetRunner()
|
293
|
+
runner.run_batch(
|
294
|
+
dataset_name=args.dataset,
|
295
|
+
num_instances=args.num_instances,
|
296
|
+
difficulties=args.difficulties,
|
297
|
+
impetus_types=args.impetus_types,
|
298
|
+
split=args.split,
|
299
|
+
render=args.render,
|
300
|
+
max_steps=args.max_steps,
|
301
|
+
)
|
302
|
+
|
303
|
+
|
304
|
+
if __name__ == "__main__":
|
305
|
+
main()
|